merge from master
This commit is contained in:
		
						commit
						2381ee56d9
					
				@ -42,6 +42,10 @@
 | 
			
		||||
        *   Removed deprecated `Interpreter::UseNNAPI(bool)` C++ API.
 | 
			
		||||
            *   Use `NnApiDelegate()` and related delegate configuration methods
 | 
			
		||||
                directly.
 | 
			
		||||
*   TF Core:
 | 
			
		||||
    *   Corrected higher-order gradients of control flow constructs (`tf.cond`,
 | 
			
		||||
        `tf.while_loop`, and compositions like `tf.foldl`) computed with
 | 
			
		||||
        `tf.GradientTape` inside a `tf.function`.
 | 
			
		||||
 | 
			
		||||
## Thanks to our Contributors
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -769,7 +769,7 @@ void Execute_MatMul_CPU_Runtime_Error(bool async) {
 | 
			
		||||
    TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
 | 
			
		||||
    EXPECT_NE(TF_OK, TF_GetCode(status));
 | 
			
		||||
    EXPECT_EQ(nullptr, t);
 | 
			
		||||
    const char* msg = "Matrix size-incompatible: In[0]: [2,2], In[1]: [3,2]";
 | 
			
		||||
    const char* msg = "In[0] mismatch In[1] shape: 2 vs. 3: [2,2] [3,2]";
 | 
			
		||||
    EXPECT_TRUE(strstr(TF_Message(status), msg) != nullptr)
 | 
			
		||||
        << TF_Message(status);
 | 
			
		||||
    // Since error is not cleared, the following copy with correct device will
 | 
			
		||||
 | 
			
		||||
@ -583,7 +583,11 @@ XlaComputationLaunchContext::BuildXlaCompilerArguments(
 | 
			
		||||
    XlaCompiler::Argument& arg = out[input_num];
 | 
			
		||||
    if (absl::c_binary_search(must_be_constant_idxs, input_num)) {
 | 
			
		||||
      // Handles compile-time constants.
 | 
			
		||||
      TF_RET_CHECK(input->dtype() != DT_RESOURCE);
 | 
			
		||||
 | 
			
		||||
      // TODO(b/157241314): Support constants located in resource variables.
 | 
			
		||||
      TF_RET_CHECK(input->dtype() != DT_RESOURCE)
 | 
			
		||||
          << "tf2xla bridge does not support must-be-constants located in "
 | 
			
		||||
             "resource variables; try moving them to a tensor";
 | 
			
		||||
      arg.kind = XlaCompiler::Argument::kConstant;
 | 
			
		||||
      arg.type = input->dtype();
 | 
			
		||||
      arg.shape = input->shape();
 | 
			
		||||
 | 
			
		||||
@ -517,6 +517,15 @@ cc_library(
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
cc_library(
 | 
			
		||||
    name = "map_chlo_to_hlo_op",
 | 
			
		||||
    hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":hlo",
 | 
			
		||||
        "@llvm-project//mlir:IR",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
cc_library(
 | 
			
		||||
    name = "map_hlo_to_lhlo_op",
 | 
			
		||||
    hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h"],
 | 
			
		||||
@ -606,9 +615,11 @@ cc_library(
 | 
			
		||||
    ],
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":hlo",
 | 
			
		||||
        ":map_chlo_to_hlo_op",
 | 
			
		||||
        "@llvm-project//llvm:Support",
 | 
			
		||||
        "@llvm-project//mlir:IR",
 | 
			
		||||
        "@llvm-project//mlir:Pass",
 | 
			
		||||
        "@llvm-project//mlir:SCFDialect",
 | 
			
		||||
        "@llvm-project//mlir:Shape",
 | 
			
		||||
        "@llvm-project//mlir:StandardOps",
 | 
			
		||||
        "@llvm-project//mlir:Transforms",
 | 
			
		||||
@ -893,6 +904,7 @@ cc_library(
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":chlo_legalize_to_hlo_inc_gen",
 | 
			
		||||
        ":hlo",
 | 
			
		||||
        ":map_chlo_to_hlo_op",
 | 
			
		||||
        "@llvm-project//mlir:IR",
 | 
			
		||||
        "@llvm-project//mlir:SCFDialect",
 | 
			
		||||
        "@llvm-project//mlir:Shape",
 | 
			
		||||
 | 
			
		||||
@ -0,0 +1,97 @@
 | 
			
		||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_CHLO_TO_MHLO_OP_H_
 | 
			
		||||
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_CHLO_TO_MHLO_OP_H_
 | 
			
		||||
 | 
			
		||||
#include <type_traits>
 | 
			
		||||
 | 
			
		||||
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
 | 
			
		||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
 | 
			
		||||
#include "mlir/IR/PatternMatch.h"
 | 
			
		||||
 | 
			
		||||
namespace mlir {
 | 
			
		||||
namespace chlo {
 | 
			
		||||
 | 
			
		||||
struct HloComplexAdaptor {
 | 
			
		||||
  static mhlo::ComplexOp CreateOp(BroadcastComplexOp from_op, Type result_type,
 | 
			
		||||
                                  Value broadcasted_lhs, Value broadcasted_rhs,
 | 
			
		||||
                                  OpBuilder &builder) {
 | 
			
		||||
    return builder.create<mhlo::ComplexOp>(from_op.getLoc(), result_type,
 | 
			
		||||
                                           broadcasted_lhs, broadcasted_rhs);
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
template <typename FromOpTy, typename ToOpTy>
 | 
			
		||||
struct HloBinaryElementwiseAdaptor {
 | 
			
		||||
  static ToOpTy CreateOp(FromOpTy from_op, Type result_type,
 | 
			
		||||
                         Value broadcasted_lhs, Value broadcasted_rhs,
 | 
			
		||||
                         OpBuilder &builder) {
 | 
			
		||||
    return builder.create<ToOpTy>(from_op.getLoc(), result_type,
 | 
			
		||||
                                  broadcasted_lhs, broadcasted_rhs);
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
struct HloCompareAdaptor {
 | 
			
		||||
  static mhlo::CompareOp CreateOp(BroadcastCompareOp from_op, Type result_type,
 | 
			
		||||
                                  Value broadcasted_lhs, Value broadcasted_rhs,
 | 
			
		||||
                                  OpBuilder &builder) {
 | 
			
		||||
    return builder.create<mhlo::CompareOp>(
 | 
			
		||||
        from_op.getLoc(), result_type, broadcasted_lhs, broadcasted_rhs,
 | 
			
		||||
        from_op.comparison_direction(), from_op.compare_typeAttr());
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// Populate a pattern for each Broadcasting CHlo op. This requires the pattern
 | 
			
		||||
// to take a ChloOpTy, MhloOpTy, and an Adaptor as templated values.
 | 
			
		||||
template <template <typename, typename, typename> class Pattern,
 | 
			
		||||
          typename... ConstructorArgs>
 | 
			
		||||
void PopulateForBroadcastingBinaryOp(MLIRContext *context,
 | 
			
		||||
                                     OwningRewritePatternList *patterns,
 | 
			
		||||
                                     ConstructorArgs &&...args) {
 | 
			
		||||
#define POPULATE_BCAST(ChloOp, HloOp)                                      \
 | 
			
		||||
  patterns->insert<                                                        \
 | 
			
		||||
      Pattern<ChloOp, HloOp, HloBinaryElementwiseAdaptor<ChloOp, HloOp>>>( \
 | 
			
		||||
      context, args...);
 | 
			
		||||
 | 
			
		||||
  POPULATE_BCAST(BroadcastAddOp, mhlo::AddOp);
 | 
			
		||||
  POPULATE_BCAST(BroadcastAndOp, mhlo::AndOp);
 | 
			
		||||
  POPULATE_BCAST(BroadcastAtan2Op, mhlo::Atan2Op);
 | 
			
		||||
  POPULATE_BCAST(BroadcastDivOp, mhlo::DivOp);
 | 
			
		||||
  POPULATE_BCAST(BroadcastMaxOp, mhlo::MaxOp);
 | 
			
		||||
  POPULATE_BCAST(BroadcastMinOp, mhlo::MinOp);
 | 
			
		||||
  POPULATE_BCAST(BroadcastMulOp, mhlo::MulOp);
 | 
			
		||||
  POPULATE_BCAST(BroadcastOrOp, mhlo::OrOp);
 | 
			
		||||
  POPULATE_BCAST(BroadcastPowOp, mhlo::PowOp);
 | 
			
		||||
  POPULATE_BCAST(BroadcastRemOp, mhlo::RemOp);
 | 
			
		||||
  POPULATE_BCAST(BroadcastShiftLeftOp, mhlo::ShiftLeftOp);
 | 
			
		||||
  POPULATE_BCAST(BroadcastShiftRightArithmeticOp, mhlo::ShiftRightArithmeticOp);
 | 
			
		||||
  POPULATE_BCAST(BroadcastShiftRightLogicalOp, mhlo::ShiftRightLogicalOp);
 | 
			
		||||
  POPULATE_BCAST(BroadcastSubOp, mhlo::SubOp);
 | 
			
		||||
  POPULATE_BCAST(BroadcastXorOp, mhlo::XorOp);
 | 
			
		||||
 | 
			
		||||
  // Broadcasting ops requiring special construction.
 | 
			
		||||
  patterns
 | 
			
		||||
      ->insert<Pattern<BroadcastComplexOp, mhlo::ComplexOp, HloComplexAdaptor>>(
 | 
			
		||||
          context, args...);
 | 
			
		||||
  patterns
 | 
			
		||||
      ->insert<Pattern<BroadcastCompareOp, mhlo::CompareOp, HloCompareAdaptor>>(
 | 
			
		||||
          context, args...);
 | 
			
		||||
 | 
			
		||||
#undef POPULATE_BCAST
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace chlo
 | 
			
		||||
}  // namespace mlir
 | 
			
		||||
 | 
			
		||||
#endif  // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_CHLO_TO_HLO_OP_H_
 | 
			
		||||
@ -17,6 +17,7 @@ limitations under the License.
 | 
			
		||||
 | 
			
		||||
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
 | 
			
		||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
 | 
			
		||||
#include "mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h"
 | 
			
		||||
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
 | 
			
		||||
#include "mlir-hlo/utils/broadcast_utils.h"
 | 
			
		||||
#include "mlir/Dialect/SCF/SCF.h"
 | 
			
		||||
@ -69,13 +70,18 @@ struct ConvertConstantLikeOp : public OpConversionPattern<ConstantLikeOp> {
 | 
			
		||||
// Converts binary ops that statically are determined to not broadcast directly
 | 
			
		||||
// to the corresponding mhlo non-broadcasting op.
 | 
			
		||||
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
 | 
			
		||||
struct ConvertTrivialNonBroadcastBinaryOp : public OpRewritePattern<ChloOpTy> {
 | 
			
		||||
  using OpRewritePattern<ChloOpTy>::OpRewritePattern;
 | 
			
		||||
  LogicalResult matchAndRewrite(ChloOpTy op,
 | 
			
		||||
                                PatternRewriter &rewriter) const override {
 | 
			
		||||
struct ConvertTrivialNonBroadcastBinaryOp
 | 
			
		||||
    : public OpConversionPattern<ChloOpTy> {
 | 
			
		||||
  using OpConversionPattern<ChloOpTy>::OpConversionPattern;
 | 
			
		||||
  LogicalResult matchAndRewrite(
 | 
			
		||||
      ChloOpTy op, ArrayRef<Value> operands,
 | 
			
		||||
      ConversionPatternRewriter &rewriter) const override {
 | 
			
		||||
    // Only rewrite for statically determinable non-broadcasting cases.
 | 
			
		||||
    auto lhs_type = op.lhs().getType().template dyn_cast<RankedTensorType>();
 | 
			
		||||
    auto rhs_type = op.rhs().getType().template dyn_cast<RankedTensorType>();
 | 
			
		||||
    typename ChloOpTy::Adaptor transformed(operands);
 | 
			
		||||
    auto lhs_type =
 | 
			
		||||
        transformed.lhs().getType().template dyn_cast<RankedTensorType>();
 | 
			
		||||
    auto rhs_type =
 | 
			
		||||
        transformed.rhs().getType().template dyn_cast<RankedTensorType>();
 | 
			
		||||
    if (!lhs_type || !rhs_type) return failure();
 | 
			
		||||
 | 
			
		||||
    // Requires rank broadcast.
 | 
			
		||||
@ -93,8 +99,9 @@ struct ConvertTrivialNonBroadcastBinaryOp : public OpRewritePattern<ChloOpTy> {
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    rewriter.replaceOp(op, {Adaptor::CreateOp(op, op.getResult().getType(),
 | 
			
		||||
                                              op.lhs(), op.rhs(), rewriter)});
 | 
			
		||||
    rewriter.replaceOp(
 | 
			
		||||
        op, {Adaptor::CreateOp(op, op.getResult().getType(), operands[0],
 | 
			
		||||
                               operands[1], rewriter)});
 | 
			
		||||
    return success();
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
@ -113,13 +120,15 @@ struct ConvertTrivialNonBroadcastBinaryOp : public OpRewritePattern<ChloOpTy> {
 | 
			
		||||
// `shape.broadcast` op, which only supports prefix-padding.
 | 
			
		||||
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
 | 
			
		||||
struct ConvertRankedDynamicBroadcastBinaryOp
 | 
			
		||||
    : public OpRewritePattern<ChloOpTy> {
 | 
			
		||||
  using OpRewritePattern<ChloOpTy>::OpRewritePattern;
 | 
			
		||||
  LogicalResult matchAndRewrite(ChloOpTy op,
 | 
			
		||||
                                PatternRewriter &rewriter) const override {
 | 
			
		||||
    : public OpConversionPattern<ChloOpTy> {
 | 
			
		||||
  using OpConversionPattern<ChloOpTy>::OpConversionPattern;
 | 
			
		||||
  LogicalResult matchAndRewrite(
 | 
			
		||||
      ChloOpTy op, ArrayRef<Value> operands,
 | 
			
		||||
      ConversionPatternRewriter &rewriter) const override {
 | 
			
		||||
    // Only support ranked operands.
 | 
			
		||||
    Value lhs = op.lhs();
 | 
			
		||||
    Value rhs = op.rhs();
 | 
			
		||||
    typename ChloOpTy::Adaptor transformed(operands);
 | 
			
		||||
    Value lhs = transformed.lhs();
 | 
			
		||||
    Value rhs = transformed.rhs();
 | 
			
		||||
    auto lhs_type = lhs.getType().dyn_cast<RankedTensorType>();
 | 
			
		||||
    auto rhs_type = rhs.getType().dyn_cast<RankedTensorType>();
 | 
			
		||||
    auto result_type =
 | 
			
		||||
@ -193,324 +202,6 @@ struct ConvertRankedDynamicBroadcastBinaryOp
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// Converts a broadcasting binary operation with a scalar operand and an
 | 
			
		||||
// unranked operand to a ranked broadcasting operation by dynamically reshaping
 | 
			
		||||
// the unranked operand to a 1D tensor. This will always be safe because
 | 
			
		||||
// broadcasting from a scalar to another shape always works.
 | 
			
		||||
template <typename ChloOpTy, typename HloOpTy>
 | 
			
		||||
struct ConvertUnrankedScalarDynamicBroadcastBinaryOp
 | 
			
		||||
    : public OpRewritePattern<ChloOpTy> {
 | 
			
		||||
  using OpRewritePattern<ChloOpTy>::OpRewritePattern;
 | 
			
		||||
  LogicalResult matchAndRewrite(ChloOpTy op,
 | 
			
		||||
                                PatternRewriter &rewriter) const override {
 | 
			
		||||
    auto loc = op.getLoc();
 | 
			
		||||
    Value lhs = op.lhs();
 | 
			
		||||
    Value rhs = op.rhs();
 | 
			
		||||
 | 
			
		||||
    auto lhs_ranked_type = lhs.getType().dyn_cast<RankedTensorType>();
 | 
			
		||||
    auto lhs_unranked_type = lhs.getType().dyn_cast<UnrankedTensorType>();
 | 
			
		||||
 | 
			
		||||
    auto rhs_ranked_type = rhs.getType().dyn_cast<RankedTensorType>();
 | 
			
		||||
    auto rhs_unranked_type = rhs.getType().dyn_cast<UnrankedTensorType>();
 | 
			
		||||
 | 
			
		||||
    bool lhs_is_scalar = lhs_ranked_type &&
 | 
			
		||||
                         lhs_ranked_type.getShape().empty() &&
 | 
			
		||||
                         rhs_unranked_type;
 | 
			
		||||
    bool rhs_is_scalar = rhs_ranked_type &&
 | 
			
		||||
                         rhs_ranked_type.getShape().empty() &&
 | 
			
		||||
                         lhs_unranked_type;
 | 
			
		||||
 | 
			
		||||
    // Only support the case where exactly one operand is scalar and the other
 | 
			
		||||
    // is unranked. Other patterns in this file will create more efficient
 | 
			
		||||
    // lowerings for cases where both ranks are known or will handle the more
 | 
			
		||||
    // generic case of both inputs being unranked.
 | 
			
		||||
    if (!(lhs_is_scalar ^ rhs_is_scalar)) return failure();
 | 
			
		||||
 | 
			
		||||
    auto result_type = op.getResult().getType().template dyn_cast<TensorType>();
 | 
			
		||||
 | 
			
		||||
    // Reshape the non-scalar value into a dynamically sized, rank-1 tensor
 | 
			
		||||
    Value shape =
 | 
			
		||||
        rewriter.create<shape::ShapeOfOp>(loc, lhs_is_scalar ? rhs : lhs);
 | 
			
		||||
    Value num_elements = rewriter.create<shape::NumElementsOp>(loc, shape);
 | 
			
		||||
    Value size_tensor =
 | 
			
		||||
        rewriter.create<TensorFromElementsOp>(loc, num_elements);
 | 
			
		||||
    Value reshaped = rewriter.create<mhlo::DynamicReshapeOp>(
 | 
			
		||||
        loc, RankedTensorType::get({-1}, result_type.getElementType()),
 | 
			
		||||
        lhs_is_scalar ? rhs : lhs, size_tensor);
 | 
			
		||||
 | 
			
		||||
    // Create a new ranked Chlo op that will be further lowered by other
 | 
			
		||||
    // patterns into Mhlo.
 | 
			
		||||
    SmallVector<Value, 2> operands{lhs_is_scalar ? lhs : reshaped,
 | 
			
		||||
                                   rhs_is_scalar ? rhs : reshaped};
 | 
			
		||||
    Value computed = rewriter.create<ChloOpTy>(
 | 
			
		||||
        loc, SmallVector<Type, 1>{reshaped.getType()}, operands, op.getAttrs());
 | 
			
		||||
 | 
			
		||||
    // Reshape the result back into an unranked tensor.
 | 
			
		||||
    rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(op, result_type,
 | 
			
		||||
                                                        computed, shape);
 | 
			
		||||
 | 
			
		||||
    return success();
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// Handles lowering of the following pattern to patterns that will be further
 | 
			
		||||
// matched by other patterns until they result in LHLO:
 | 
			
		||||
//   %result = "chlo.op"(%lhs, %rhs) : (<*xTy>, <*xTy>) -> <*xTy>
 | 
			
		||||
//
 | 
			
		||||
// The sequence of specializations this handles is:
 | 
			
		||||
//   - Either operand being scalar
 | 
			
		||||
//   - Operands having equal shapes
 | 
			
		||||
//   - The resulting value being any of ranks [2,6]
 | 
			
		||||
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
 | 
			
		||||
struct ConvertUnrankedDynamicBroadcastBinaryOp
 | 
			
		||||
    : public OpRewritePattern<ChloOpTy> {
 | 
			
		||||
  using OpRewritePattern<ChloOpTy>::OpRewritePattern;
 | 
			
		||||
 | 
			
		||||
  LogicalResult matchAndRewrite(ChloOpTy op,
 | 
			
		||||
                                PatternRewriter &rewriter) const override {
 | 
			
		||||
    auto loc = op.getLoc();
 | 
			
		||||
    Value lhs = op.lhs();
 | 
			
		||||
    Value rhs = op.rhs();
 | 
			
		||||
    auto lhs_type = lhs.getType().dyn_cast<UnrankedTensorType>();
 | 
			
		||||
    auto rhs_type = rhs.getType().dyn_cast<UnrankedTensorType>();
 | 
			
		||||
    auto result_type = op.getResult().getType().template dyn_cast<TensorType>();
 | 
			
		||||
 | 
			
		||||
    // Only support unranked operands. If either operand is ranked, another
 | 
			
		||||
    // pattern will handle the lowering.
 | 
			
		||||
    if (!lhs_type || !rhs_type) return failure();
 | 
			
		||||
 | 
			
		||||
    // If lhs is scalar
 | 
			
		||||
    auto if_op = rewriter.create<scf::IfOp>(
 | 
			
		||||
        loc, result_type, IsScalarTensor(rewriter, op, lhs), true);
 | 
			
		||||
    OpBuilder if_lhs_scalar_builder = if_op.getThenBodyBuilder();
 | 
			
		||||
    Value reshaped_lhs = if_lhs_scalar_builder.create<TensorCastOp>(
 | 
			
		||||
        loc, RankedTensorType::get({}, lhs_type.getElementType()), lhs);
 | 
			
		||||
    Value if_lhs_scalar_result = if_lhs_scalar_builder.create<ChloOpTy>(
 | 
			
		||||
        loc, ArrayRef<Type>{result_type}, ArrayRef<Value>{reshaped_lhs, rhs},
 | 
			
		||||
        op.getAttrs());
 | 
			
		||||
    if_lhs_scalar_builder.create<scf::YieldOp>(loc, if_lhs_scalar_result);
 | 
			
		||||
 | 
			
		||||
    // If lhs is NOT scalar
 | 
			
		||||
    //
 | 
			
		||||
    // See if rhs is scalar
 | 
			
		||||
    OpBuilder else_lhs_scalar_builder = if_op.getElseBodyBuilder();
 | 
			
		||||
    auto if_rhs_scalar_op = else_lhs_scalar_builder.create<scf::IfOp>(
 | 
			
		||||
        loc, result_type, IsScalarTensor(else_lhs_scalar_builder, op, rhs),
 | 
			
		||||
        true);
 | 
			
		||||
    else_lhs_scalar_builder.create<scf::YieldOp>(loc,
 | 
			
		||||
                                                 if_rhs_scalar_op.getResult(0));
 | 
			
		||||
    OpBuilder if_rhs_scalar_builder = if_rhs_scalar_op.getThenBodyBuilder();
 | 
			
		||||
    Value reshaped_rhs = if_rhs_scalar_builder.create<TensorCastOp>(
 | 
			
		||||
        loc, RankedTensorType::get({}, lhs_type.getElementType()), rhs);
 | 
			
		||||
    Value if_rhs_scalar_result = if_rhs_scalar_builder.create<ChloOpTy>(
 | 
			
		||||
        loc, ArrayRef<Type>{result_type}, ArrayRef<Value>{lhs, reshaped_rhs},
 | 
			
		||||
        op.getAttrs());
 | 
			
		||||
    if_rhs_scalar_builder.create<scf::YieldOp>(loc, if_rhs_scalar_result);
 | 
			
		||||
 | 
			
		||||
    // If NEITHER shape is scalar
 | 
			
		||||
    //
 | 
			
		||||
    // See if shapes are equal.
 | 
			
		||||
    OpBuilder else_no_scalars_builder = if_rhs_scalar_op.getElseBodyBuilder();
 | 
			
		||||
    Value shape_of_lhs =
 | 
			
		||||
        else_no_scalars_builder.create<shape::ShapeOfOp>(loc, lhs);
 | 
			
		||||
    Value shape_of_rhs =
 | 
			
		||||
        else_no_scalars_builder.create<shape::ShapeOfOp>(loc, rhs);
 | 
			
		||||
    Value equal_shapes = else_no_scalars_builder.create<shape::ShapeEqOp>(
 | 
			
		||||
        loc, shape_of_lhs, shape_of_rhs);
 | 
			
		||||
 | 
			
		||||
    auto if_eq_shapes_op = else_no_scalars_builder.create<scf::IfOp>(
 | 
			
		||||
        loc, result_type, equal_shapes, true);
 | 
			
		||||
    else_no_scalars_builder.create<scf::YieldOp>(loc,
 | 
			
		||||
                                                 if_eq_shapes_op.getResult(0));
 | 
			
		||||
 | 
			
		||||
    OpBuilder if_eq_shapes_builder = if_eq_shapes_op.getThenBodyBuilder();
 | 
			
		||||
    Value non_broadcast_op =
 | 
			
		||||
        Adaptor::CreateOp(op, result_type, lhs, rhs, if_eq_shapes_builder);
 | 
			
		||||
    if_eq_shapes_builder.create<scf::YieldOp>(loc, non_broadcast_op);
 | 
			
		||||
 | 
			
		||||
    // If shapes are not scalar, nor equal
 | 
			
		||||
    //
 | 
			
		||||
    // See if values are of a rank that we support.
 | 
			
		||||
    OpBuilder if_neq_shapes_builder = if_eq_shapes_op.getElseBodyBuilder();
 | 
			
		||||
    if_neq_shapes_builder.create<scf::YieldOp>(
 | 
			
		||||
        loc, HandleBroadcastAndOp(if_neq_shapes_builder, op, lhs, rhs));
 | 
			
		||||
 | 
			
		||||
    rewriter.replaceOp(op, {if_op.getResult(0)});
 | 
			
		||||
    return success();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
 private:
 | 
			
		||||
  // Returns the dyanamic result of checking the given value is a scalar
 | 
			
		||||
  // tensor.
 | 
			
		||||
  Value IsScalarTensor(OpBuilder &rewriter, ChloOpTy op, Value tensor) const {
 | 
			
		||||
    auto loc = op.getLoc();
 | 
			
		||||
 | 
			
		||||
    Value shape_of_tensor = rewriter.create<shape::ShapeOfOp>(loc, tensor);
 | 
			
		||||
    Value rank_tensor = rewriter.create<shape::RankOp>(
 | 
			
		||||
        loc, rewriter.getIndexType(), shape_of_tensor);
 | 
			
		||||
    return rewriter.create<CmpIOp>(loc, rewriter.getI1Type(), CmpIPredicate::eq,
 | 
			
		||||
                                   rank_tensor,
 | 
			
		||||
                                   rewriter.create<ConstantIndexOp>(loc, 0));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Create the if statement and code for a broadcasting op with a result of a
 | 
			
		||||
  // given rank.
 | 
			
		||||
  scf::IfOp createRankSpecializedBroadcastAndOp(OpBuilder &builder, ChloOpTy op,
 | 
			
		||||
                                                Value lhs, Value rhs,
 | 
			
		||||
                                                Value actual_rank,
 | 
			
		||||
                                                int targeted_rank) const {
 | 
			
		||||
    auto loc = op.getLoc();
 | 
			
		||||
 | 
			
		||||
    // Create the if block to place the current specialized logic in.
 | 
			
		||||
    Value greater_rank_is_n = builder.create<CmpIOp>(
 | 
			
		||||
        loc, CmpIPredicate::eq, actual_rank,
 | 
			
		||||
        builder.create<ConstantIndexOp>(loc, targeted_rank));
 | 
			
		||||
    auto if_op =
 | 
			
		||||
        builder.create<scf::IfOp>(loc, lhs.getType(), greater_rank_is_n, true);
 | 
			
		||||
    OpBuilder if_builder = if_op.getThenBodyBuilder();
 | 
			
		||||
 | 
			
		||||
    // Handle shape broadcasting and inferrence.
 | 
			
		||||
    Value lhs_shape = if_builder.create<shape::ShapeOfOp>(loc, lhs);
 | 
			
		||||
    Value rhs_shape = if_builder.create<shape::ShapeOfOp>(loc, rhs);
 | 
			
		||||
    SmallVector<int64_t, 6> ranked_shape(targeted_rank, 1);
 | 
			
		||||
    auto unknown_rank_extent_tensor_type = RankedTensorType::get(
 | 
			
		||||
        {RankedTensorType::kDynamicSize}, builder.getIndexType());
 | 
			
		||||
    auto known_rank_extent_tensor_type =
 | 
			
		||||
        RankedTensorType::get({targeted_rank}, builder.getIndexType());
 | 
			
		||||
    auto reshaped_type = RankedTensorType::get(
 | 
			
		||||
        llvm::SmallVector<int64_t, 6>(targeted_rank,
 | 
			
		||||
                                      RankedTensorType::kDynamicSize),
 | 
			
		||||
        lhs.getType().template dyn_cast<TensorType>().getElementType());
 | 
			
		||||
    Value ranked_shape_val = if_builder.create<shape::ConstShapeOp>(
 | 
			
		||||
        loc, known_rank_extent_tensor_type,
 | 
			
		||||
        mlir::DenseIntElementsAttr::get(known_rank_extent_tensor_type,
 | 
			
		||||
                                        ranked_shape));
 | 
			
		||||
    Value extended_lhs = if_builder.create<shape::BroadcastOp>(
 | 
			
		||||
        loc, unknown_rank_extent_tensor_type, lhs_shape, ranked_shape_val,
 | 
			
		||||
        nullptr);
 | 
			
		||||
    Value extended_lhs_casted = if_builder.create<TensorCastOp>(
 | 
			
		||||
        loc, known_rank_extent_tensor_type, extended_lhs);
 | 
			
		||||
    Value extended_rhs = if_builder.create<shape::BroadcastOp>(
 | 
			
		||||
        loc, unknown_rank_extent_tensor_type, rhs_shape, ranked_shape_val,
 | 
			
		||||
        nullptr);
 | 
			
		||||
    Value extended_rhs_casted = if_builder.create<TensorCastOp>(
 | 
			
		||||
        loc, known_rank_extent_tensor_type, extended_rhs);
 | 
			
		||||
 | 
			
		||||
    // 1. Reshape operands to the given rank (with the same number of elements)
 | 
			
		||||
    // 2. Compute the ranked-broadcasted ChloOp (which will assert that the ops
 | 
			
		||||
    //    can be broadcasted and do the actual broadcasting)
 | 
			
		||||
    // 3. Type erase the output back to unranked
 | 
			
		||||
    Value reshaped_lhs = if_builder.create<mhlo::DynamicReshapeOp>(
 | 
			
		||||
        loc, reshaped_type, lhs, extended_lhs_casted);
 | 
			
		||||
    Value reshaped_rhs = if_builder.create<mhlo::DynamicReshapeOp>(
 | 
			
		||||
        loc, reshaped_type, rhs, extended_rhs_casted);
 | 
			
		||||
    Value result = if_builder.create<ChloOpTy>(
 | 
			
		||||
        loc, ArrayRef<Type>{reshaped_type},
 | 
			
		||||
        ArrayRef<Value>{reshaped_lhs, reshaped_rhs}, op.getAttrs());
 | 
			
		||||
    Value reshaped_result = if_builder.create<TensorCastOp>(
 | 
			
		||||
        loc, UnrankedTensorType::get(reshaped_type.getElementType()), result);
 | 
			
		||||
    if_builder.create<scf::YieldOp>(loc, reshaped_result);
 | 
			
		||||
 | 
			
		||||
    // Return the if_op, so the result can be used and the else block can be
 | 
			
		||||
    // used for the next rank specialized step.
 | 
			
		||||
    return if_op;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Iterates over the desired ranks to be specialized and generates the code
 | 
			
		||||
  // snippet for each case.
 | 
			
		||||
  Value HandleBroadcastAndOp(OpBuilder &rewriter, ChloOpTy op, Value lhs,
 | 
			
		||||
                             Value rhs) const {
 | 
			
		||||
    constexpr int max_rank_specialization = 7;
 | 
			
		||||
    auto loc = op.getLoc();
 | 
			
		||||
 | 
			
		||||
    // Find the larger rank of the 2 operands.
 | 
			
		||||
    auto extent_tensor_type = RankedTensorType::get({ShapedType::kDynamicSize},
 | 
			
		||||
                                                    rewriter.getIndexType());
 | 
			
		||||
    Value lhs_shape =
 | 
			
		||||
        rewriter.create<shape::ShapeOfOp>(loc, extent_tensor_type, lhs);
 | 
			
		||||
    Value rhs_shape =
 | 
			
		||||
        rewriter.create<shape::ShapeOfOp>(loc, extent_tensor_type, rhs);
 | 
			
		||||
    Value lhs_rank =
 | 
			
		||||
        rewriter.create<RankOp>(loc, rewriter.getIndexType(), lhs_shape);
 | 
			
		||||
    Value rhs_rank =
 | 
			
		||||
        rewriter.create<RankOp>(loc, rewriter.getIndexType(), rhs_shape);
 | 
			
		||||
    Value greater_rank_lhs =
 | 
			
		||||
        rewriter.create<CmpIOp>(loc, CmpIPredicate::sgt, lhs_rank, rhs_rank);
 | 
			
		||||
    Value greater_rank =
 | 
			
		||||
        rewriter.create<SelectOp>(loc, greater_rank_lhs, lhs_rank, rhs_rank);
 | 
			
		||||
 | 
			
		||||
    // Generate a list of nested if/else statements to handle rank
 | 
			
		||||
    // specializations from 2-6.
 | 
			
		||||
    scf::IfOp if_op = createRankSpecializedBroadcastAndOp(rewriter, op, lhs,
 | 
			
		||||
                                                          rhs, greater_rank, 2);
 | 
			
		||||
 | 
			
		||||
    // Put each subsequent rank specialization inside the else statement of the
 | 
			
		||||
    // previous one.
 | 
			
		||||
    OpBuilder else_builder = if_op.getElseBodyBuilder();
 | 
			
		||||
    for (int i = 3; i < max_rank_specialization; i++) {
 | 
			
		||||
      auto inner_if = createRankSpecializedBroadcastAndOp(else_builder, op, lhs,
 | 
			
		||||
                                                          rhs, greater_rank, i);
 | 
			
		||||
 | 
			
		||||
      else_builder.create<scf::YieldOp>(loc, inner_if.getResult(0));
 | 
			
		||||
      else_builder = inner_if.getElseBodyBuilder();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // Fire an assertion if none of the rank specializations applied (one of the
 | 
			
		||||
    // ranks was greater than 6).
 | 
			
		||||
    else_builder.create<AssertOp>(
 | 
			
		||||
        loc, else_builder.create<ConstantIntOp>(loc, 0, 1),
 | 
			
		||||
        "Input for dynamic binary op lowering was of a rank greater than 6");
 | 
			
		||||
    else_builder.create<scf::YieldOp>(loc, lhs);
 | 
			
		||||
 | 
			
		||||
    // Return the result of the outermost if statement.
 | 
			
		||||
    return if_op.getResult(0);
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
 | 
			
		||||
void PopulateForBinaryOp(MLIRContext *context,
 | 
			
		||||
                         OwningRewritePatternList *patterns) {
 | 
			
		||||
  patterns
 | 
			
		||||
      ->insert<ConvertTrivialNonBroadcastBinaryOp<ChloOpTy, HloOpTy, Adaptor>>(
 | 
			
		||||
          context, 10);
 | 
			
		||||
  patterns->insert<
 | 
			
		||||
      ConvertRankedDynamicBroadcastBinaryOp<ChloOpTy, HloOpTy, Adaptor>>(
 | 
			
		||||
      context, 5);
 | 
			
		||||
  patterns->insert<
 | 
			
		||||
      ConvertUnrankedScalarDynamicBroadcastBinaryOp<ChloOpTy, HloOpTy>,
 | 
			
		||||
      ConvertUnrankedDynamicBroadcastBinaryOp<ChloOpTy, HloOpTy, Adaptor>>(
 | 
			
		||||
      context);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename FromOpTy, typename ToOpTy>
 | 
			
		||||
struct HloBinaryElementwiseAdaptor {
 | 
			
		||||
  static ToOpTy CreateOp(FromOpTy from_op, Type result_type,
 | 
			
		||||
                         Value broadcasted_lhs, Value broadcasted_rhs,
 | 
			
		||||
                         OpBuilder &builder) {
 | 
			
		||||
    return builder.create<ToOpTy>(from_op.getLoc(), result_type,
 | 
			
		||||
                                  broadcasted_lhs, broadcasted_rhs);
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
struct HloComplexAdaptor {
 | 
			
		||||
  static mhlo::ComplexOp CreateOp(BroadcastComplexOp from_op, Type result_type,
 | 
			
		||||
                                  Value broadcasted_lhs, Value broadcasted_rhs,
 | 
			
		||||
                                  OpBuilder &builder) {
 | 
			
		||||
    return builder.create<mhlo::ComplexOp>(from_op.getLoc(), result_type,
 | 
			
		||||
                                           broadcasted_lhs, broadcasted_rhs);
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
struct HloCompareAdaptor {
 | 
			
		||||
  static mhlo::CompareOp CreateOp(BroadcastCompareOp from_op, Type result_type,
 | 
			
		||||
                                  Value broadcasted_lhs, Value broadcasted_rhs,
 | 
			
		||||
                                  OpBuilder &builder) {
 | 
			
		||||
    return builder.create<mhlo::CompareOp>(
 | 
			
		||||
        from_op.getLoc(), result_type, broadcasted_lhs, broadcasted_rhs,
 | 
			
		||||
        from_op.comparison_direction(), from_op.compare_typeAttr());
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
#include "generated_chlo_legalize_to_hlo.inc"
 | 
			
		||||
}  // namespace
 | 
			
		||||
 | 
			
		||||
@ -521,32 +212,10 @@ void PopulateLegalizeChloToHloPatterns(MLIRContext *context,
 | 
			
		||||
  // Instantiate conversion templates for conforming binary elementwise ops
 | 
			
		||||
  // that do not have different dtypes between operands and results and do
 | 
			
		||||
  // not have special attributes that need to be preserved.
 | 
			
		||||
#define POPULATE_BCAST(ChloOp, HloOp)                                      \
 | 
			
		||||
  PopulateForBinaryOp<ChloOp, HloOp,                                       \
 | 
			
		||||
                      HloBinaryElementwiseAdaptor<ChloOp, HloOp>>(context, \
 | 
			
		||||
                                                                  patterns);
 | 
			
		||||
 | 
			
		||||
  POPULATE_BCAST(BroadcastAddOp, mhlo::AddOp);
 | 
			
		||||
  POPULATE_BCAST(BroadcastAndOp, mhlo::AndOp);
 | 
			
		||||
  POPULATE_BCAST(BroadcastAtan2Op, mhlo::Atan2Op);
 | 
			
		||||
  POPULATE_BCAST(BroadcastDivOp, mhlo::DivOp);
 | 
			
		||||
  POPULATE_BCAST(BroadcastMaxOp, mhlo::MaxOp);
 | 
			
		||||
  POPULATE_BCAST(BroadcastMinOp, mhlo::MinOp);
 | 
			
		||||
  POPULATE_BCAST(BroadcastMulOp, mhlo::MulOp);
 | 
			
		||||
  POPULATE_BCAST(BroadcastOrOp, mhlo::OrOp);
 | 
			
		||||
  POPULATE_BCAST(BroadcastPowOp, mhlo::PowOp);
 | 
			
		||||
  POPULATE_BCAST(BroadcastRemOp, mhlo::RemOp);
 | 
			
		||||
  POPULATE_BCAST(BroadcastShiftLeftOp, mhlo::ShiftLeftOp);
 | 
			
		||||
  POPULATE_BCAST(BroadcastShiftRightArithmeticOp, mhlo::ShiftRightArithmeticOp);
 | 
			
		||||
  POPULATE_BCAST(BroadcastShiftRightLogicalOp, mhlo::ShiftRightLogicalOp);
 | 
			
		||||
  POPULATE_BCAST(BroadcastSubOp, mhlo::SubOp);
 | 
			
		||||
  POPULATE_BCAST(BroadcastXorOp, mhlo::XorOp);
 | 
			
		||||
 | 
			
		||||
  // Broadcasting ops requiring special construction.
 | 
			
		||||
  PopulateForBinaryOp<BroadcastComplexOp, mhlo::ComplexOp, HloComplexAdaptor>(
 | 
			
		||||
      context, patterns);
 | 
			
		||||
  PopulateForBinaryOp<BroadcastCompareOp, mhlo::CompareOp, HloCompareAdaptor>(
 | 
			
		||||
      context, patterns);
 | 
			
		||||
  PopulateForBroadcastingBinaryOp<ConvertTrivialNonBroadcastBinaryOp>(
 | 
			
		||||
      context, patterns, 10);
 | 
			
		||||
  PopulateForBroadcastingBinaryOp<ConvertRankedDynamicBroadcastBinaryOp>(
 | 
			
		||||
      context, patterns, 5);
 | 
			
		||||
 | 
			
		||||
  // Other patterns.
 | 
			
		||||
  patterns->insert<ConvertConstantLikeOp>(context);
 | 
			
		||||
 | 
			
		||||
@ -16,7 +16,9 @@ limitations under the License.
 | 
			
		||||
 | 
			
		||||
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
 | 
			
		||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
 | 
			
		||||
#include "mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h"
 | 
			
		||||
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
 | 
			
		||||
#include "mlir/Dialect/SCF/SCF.h"
 | 
			
		||||
#include "mlir/Dialect/Shape/IR/Shape.h"
 | 
			
		||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
 | 
			
		||||
#include "mlir/IR/Function.h"
 | 
			
		||||
@ -126,6 +128,291 @@ struct ElementwiseOpConversion : public OpRewritePattern<OpTy> {
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// Converts a broadcasting binary operation with a scalar operand and an
 | 
			
		||||
// unranked operand to a ranked broadcasting operation by dynamically reshaping
 | 
			
		||||
// the unranked operand to a 1D tensor. This will always be safe because
 | 
			
		||||
// broadcasting from a scalar to another shape always works.
 | 
			
		||||
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
 | 
			
		||||
struct ConvertUnrankedScalarDynamicBroadcastBinaryOp
 | 
			
		||||
    : public OpConversionPattern<ChloOpTy> {
 | 
			
		||||
  using OpConversionPattern<ChloOpTy>::OpConversionPattern;
 | 
			
		||||
  LogicalResult matchAndRewrite(
 | 
			
		||||
      ChloOpTy op, ArrayRef<Value> operands,
 | 
			
		||||
      ConversionPatternRewriter &rewriter) const override {
 | 
			
		||||
    auto loc = op.getLoc();
 | 
			
		||||
    typename ChloOpTy::Adaptor transformed(operands);
 | 
			
		||||
    Value lhs = transformed.lhs();
 | 
			
		||||
    Value rhs = transformed.rhs();
 | 
			
		||||
 | 
			
		||||
    auto lhs_ranked_type = lhs.getType().dyn_cast<RankedTensorType>();
 | 
			
		||||
    auto lhs_unranked_type = lhs.getType().dyn_cast<UnrankedTensorType>();
 | 
			
		||||
 | 
			
		||||
    auto rhs_ranked_type = rhs.getType().dyn_cast<RankedTensorType>();
 | 
			
		||||
    auto rhs_unranked_type = rhs.getType().dyn_cast<UnrankedTensorType>();
 | 
			
		||||
 | 
			
		||||
    bool lhs_is_scalar = lhs_ranked_type &&
 | 
			
		||||
                         lhs_ranked_type.getShape().empty() &&
 | 
			
		||||
                         rhs_unranked_type;
 | 
			
		||||
    bool rhs_is_scalar = rhs_ranked_type &&
 | 
			
		||||
                         rhs_ranked_type.getShape().empty() &&
 | 
			
		||||
                         lhs_unranked_type;
 | 
			
		||||
 | 
			
		||||
    // Only support the case where exactly one operand is scalar and the other
 | 
			
		||||
    // is unranked. Other patterns in chlo-to-hlo legalization will create more
 | 
			
		||||
    // efficient lowerings for cases where both ranks are known or will handle
 | 
			
		||||
    // the more generic case of both inputs being unranked.
 | 
			
		||||
    if (!(lhs_is_scalar ^ rhs_is_scalar)) return failure();
 | 
			
		||||
 | 
			
		||||
    auto result_type = op.getResult().getType().template dyn_cast<TensorType>();
 | 
			
		||||
 | 
			
		||||
    // Reshape the non-scalar value into a dynamically sized, rank-1 tensor
 | 
			
		||||
    Value shape =
 | 
			
		||||
        rewriter.create<shape::ShapeOfOp>(loc, lhs_is_scalar ? rhs : lhs);
 | 
			
		||||
    Value num_elements = rewriter.create<shape::NumElementsOp>(loc, shape);
 | 
			
		||||
    Value size_tensor =
 | 
			
		||||
        rewriter.create<TensorFromElementsOp>(loc, num_elements);
 | 
			
		||||
    Value reshaped = rewriter.create<mhlo::DynamicReshapeOp>(
 | 
			
		||||
        loc, RankedTensorType::get({-1}, result_type.getElementType()),
 | 
			
		||||
        lhs_is_scalar ? rhs : lhs, size_tensor);
 | 
			
		||||
 | 
			
		||||
    // Create a new ranked Chlo op that will be further lowered by other
 | 
			
		||||
    // patterns into Mhlo.
 | 
			
		||||
    SmallVector<Value, 2> new_operands{lhs_is_scalar ? lhs : reshaped,
 | 
			
		||||
                                       rhs_is_scalar ? rhs : reshaped};
 | 
			
		||||
    Value computed =
 | 
			
		||||
        rewriter.create<ChloOpTy>(loc, SmallVector<Type, 1>{reshaped.getType()},
 | 
			
		||||
                                  new_operands, op.getAttrs());
 | 
			
		||||
 | 
			
		||||
    // Reshape the result back into an unranked tensor.
 | 
			
		||||
    rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(op, result_type,
 | 
			
		||||
                                                        computed, shape);
 | 
			
		||||
 | 
			
		||||
    return success();
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// Handles lowering of the following pattern to patterns that will be further
 | 
			
		||||
// matched by other patterns until they result in LHLO:
 | 
			
		||||
//   %result = "chlo.op"(%lhs, %rhs) : (<*xTy>, <*xTy>) -> <*xTy>
 | 
			
		||||
//
 | 
			
		||||
// The sequence of specializations this handles is:
 | 
			
		||||
//   - Either operand being scalar
 | 
			
		||||
//   - Operands having equal shapes
 | 
			
		||||
//   - The resulting value being any of ranks [2,6]
 | 
			
		||||
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
 | 
			
		||||
struct ConvertUnrankedDynamicBroadcastBinaryOp
 | 
			
		||||
    : public OpConversionPattern<ChloOpTy> {
 | 
			
		||||
  using OpConversionPattern<ChloOpTy>::OpConversionPattern;
 | 
			
		||||
 | 
			
		||||
  LogicalResult matchAndRewrite(
 | 
			
		||||
      ChloOpTy op, ArrayRef<Value> operands,
 | 
			
		||||
      ConversionPatternRewriter &rewriter) const override {
 | 
			
		||||
    auto loc = op.getLoc();
 | 
			
		||||
    typename ChloOpTy::Adaptor transformed(operands);
 | 
			
		||||
    Value lhs = transformed.lhs();
 | 
			
		||||
    Value rhs = transformed.rhs();
 | 
			
		||||
    auto lhs_type = lhs.getType().dyn_cast<UnrankedTensorType>();
 | 
			
		||||
    auto rhs_type = rhs.getType().dyn_cast<UnrankedTensorType>();
 | 
			
		||||
    auto result_type = op.getResult().getType().template dyn_cast<TensorType>();
 | 
			
		||||
 | 
			
		||||
    // Only support unranked operands. If either operand is ranked, another
 | 
			
		||||
    // pattern will handle the lowering.
 | 
			
		||||
    if (!lhs_type || !rhs_type) return failure();
 | 
			
		||||
 | 
			
		||||
    // If lhs is scalar
 | 
			
		||||
    auto if_op = rewriter.create<scf::IfOp>(
 | 
			
		||||
        loc, result_type, IsScalarTensor(rewriter, op, lhs), true);
 | 
			
		||||
    OpBuilder if_lhs_scalar_builder =
 | 
			
		||||
        if_op.getThenBodyBuilder(rewriter.getListener());
 | 
			
		||||
    Value reshaped_lhs = if_lhs_scalar_builder.create<TensorCastOp>(
 | 
			
		||||
        loc, RankedTensorType::get({}, lhs_type.getElementType()), lhs);
 | 
			
		||||
    Value if_lhs_scalar_result = if_lhs_scalar_builder.create<ChloOpTy>(
 | 
			
		||||
        loc, ArrayRef<Type>{result_type}, ArrayRef<Value>{reshaped_lhs, rhs},
 | 
			
		||||
        op.getAttrs());
 | 
			
		||||
    if_lhs_scalar_builder.create<scf::YieldOp>(loc, if_lhs_scalar_result);
 | 
			
		||||
 | 
			
		||||
    // If lhs is NOT scalar
 | 
			
		||||
    //
 | 
			
		||||
    // See if rhs is scalar
 | 
			
		||||
    OpBuilder else_lhs_scalar_builder =
 | 
			
		||||
        if_op.getElseBodyBuilder(rewriter.getListener());
 | 
			
		||||
    auto if_rhs_scalar_op = else_lhs_scalar_builder.create<scf::IfOp>(
 | 
			
		||||
        loc, result_type, IsScalarTensor(else_lhs_scalar_builder, op, rhs),
 | 
			
		||||
        true);
 | 
			
		||||
    else_lhs_scalar_builder.create<scf::YieldOp>(loc,
 | 
			
		||||
                                                 if_rhs_scalar_op.getResult(0));
 | 
			
		||||
    OpBuilder if_rhs_scalar_builder =
 | 
			
		||||
        if_rhs_scalar_op.getThenBodyBuilder(rewriter.getListener());
 | 
			
		||||
    Value reshaped_rhs = if_rhs_scalar_builder.create<TensorCastOp>(
 | 
			
		||||
        loc, RankedTensorType::get({}, lhs_type.getElementType()), rhs);
 | 
			
		||||
    Value if_rhs_scalar_result = if_rhs_scalar_builder.create<ChloOpTy>(
 | 
			
		||||
        loc, ArrayRef<Type>{result_type}, ArrayRef<Value>{lhs, reshaped_rhs},
 | 
			
		||||
        op.getAttrs());
 | 
			
		||||
    if_rhs_scalar_builder.create<scf::YieldOp>(loc, if_rhs_scalar_result);
 | 
			
		||||
 | 
			
		||||
    // If NEITHER shape is scalar
 | 
			
		||||
    //
 | 
			
		||||
    // See if shapes are equal.
 | 
			
		||||
    OpBuilder else_no_scalars_builder =
 | 
			
		||||
        if_rhs_scalar_op.getElseBodyBuilder(rewriter.getListener());
 | 
			
		||||
    Value shape_of_lhs =
 | 
			
		||||
        else_no_scalars_builder.create<shape::ShapeOfOp>(loc, lhs);
 | 
			
		||||
    Value shape_of_rhs =
 | 
			
		||||
        else_no_scalars_builder.create<shape::ShapeOfOp>(loc, rhs);
 | 
			
		||||
    Value equal_shapes = else_no_scalars_builder.create<shape::ShapeEqOp>(
 | 
			
		||||
        loc, shape_of_lhs, shape_of_rhs);
 | 
			
		||||
 | 
			
		||||
    auto if_eq_shapes_op = else_no_scalars_builder.create<scf::IfOp>(
 | 
			
		||||
        loc, result_type, equal_shapes, true);
 | 
			
		||||
    else_no_scalars_builder.create<scf::YieldOp>(loc,
 | 
			
		||||
                                                 if_eq_shapes_op.getResult(0));
 | 
			
		||||
 | 
			
		||||
    OpBuilder if_eq_shapes_builder =
 | 
			
		||||
        if_eq_shapes_op.getThenBodyBuilder(rewriter.getListener());
 | 
			
		||||
    Value non_broadcast_op =
 | 
			
		||||
        Adaptor::CreateOp(op, result_type, lhs, rhs, if_eq_shapes_builder);
 | 
			
		||||
    if_eq_shapes_builder.create<scf::YieldOp>(loc, non_broadcast_op);
 | 
			
		||||
 | 
			
		||||
    // If shapes are not scalar, nor equal
 | 
			
		||||
    //
 | 
			
		||||
    // See if values are of a rank that we support.
 | 
			
		||||
    OpBuilder if_neq_shapes_builder =
 | 
			
		||||
        if_eq_shapes_op.getElseBodyBuilder(rewriter.getListener());
 | 
			
		||||
    if_neq_shapes_builder.create<scf::YieldOp>(
 | 
			
		||||
        loc, HandleBroadcastAndOp(if_neq_shapes_builder, op, lhs, rhs));
 | 
			
		||||
 | 
			
		||||
    rewriter.replaceOp(op, {if_op.getResult(0)});
 | 
			
		||||
    return success();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
 private:
 | 
			
		||||
  // Returns the dyanamic result of checking the given value is a scalar
 | 
			
		||||
  // tensor.
 | 
			
		||||
  Value IsScalarTensor(OpBuilder &rewriter, ChloOpTy op, Value tensor) const {
 | 
			
		||||
    auto loc = op.getLoc();
 | 
			
		||||
 | 
			
		||||
    Value shape_of_tensor = rewriter.create<shape::ShapeOfOp>(loc, tensor);
 | 
			
		||||
    Value rank_tensor = rewriter.create<shape::RankOp>(
 | 
			
		||||
        loc, rewriter.getIndexType(), shape_of_tensor);
 | 
			
		||||
    return rewriter.create<CmpIOp>(loc, rewriter.getI1Type(), CmpIPredicate::eq,
 | 
			
		||||
                                   rank_tensor,
 | 
			
		||||
                                   rewriter.create<ConstantIndexOp>(loc, 0));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Create the if statement and code for a broadcasting op with a result of a
 | 
			
		||||
  // given rank.
 | 
			
		||||
  scf::IfOp createRankSpecializedBroadcastAndOp(OpBuilder &builder, ChloOpTy op,
 | 
			
		||||
                                                Value lhs, Value rhs,
 | 
			
		||||
                                                Value actual_rank,
 | 
			
		||||
                                                int targeted_rank) const {
 | 
			
		||||
    auto loc = op.getLoc();
 | 
			
		||||
 | 
			
		||||
    // Create the if block to place the current specialized logic in.
 | 
			
		||||
    Value greater_rank_is_n = builder.create<CmpIOp>(
 | 
			
		||||
        loc, CmpIPredicate::eq, actual_rank,
 | 
			
		||||
        builder.create<ConstantIndexOp>(loc, targeted_rank));
 | 
			
		||||
    auto if_op =
 | 
			
		||||
        builder.create<scf::IfOp>(loc, lhs.getType(), greater_rank_is_n, true);
 | 
			
		||||
    OpBuilder if_builder = if_op.getThenBodyBuilder(builder.getListener());
 | 
			
		||||
 | 
			
		||||
    // Handle shape broadcasting and inferrence.
 | 
			
		||||
    Value lhs_shape = if_builder.create<shape::ShapeOfOp>(loc, lhs);
 | 
			
		||||
    Value rhs_shape = if_builder.create<shape::ShapeOfOp>(loc, rhs);
 | 
			
		||||
    SmallVector<int64_t, 6> ranked_shape(targeted_rank, 1);
 | 
			
		||||
    auto unknown_rank_extent_tensor_type = RankedTensorType::get(
 | 
			
		||||
        {RankedTensorType::kDynamicSize}, builder.getIndexType());
 | 
			
		||||
    auto known_rank_extent_tensor_type =
 | 
			
		||||
        RankedTensorType::get({targeted_rank}, builder.getIndexType());
 | 
			
		||||
    auto reshaped_type = RankedTensorType::get(
 | 
			
		||||
        llvm::SmallVector<int64_t, 6>(targeted_rank,
 | 
			
		||||
                                      RankedTensorType::kDynamicSize),
 | 
			
		||||
        lhs.getType().template dyn_cast<TensorType>().getElementType());
 | 
			
		||||
    Value ranked_shape_val = if_builder.create<shape::ConstShapeOp>(
 | 
			
		||||
        loc, known_rank_extent_tensor_type,
 | 
			
		||||
        mlir::DenseIntElementsAttr::get(known_rank_extent_tensor_type,
 | 
			
		||||
                                        ranked_shape));
 | 
			
		||||
    Value extended_lhs = if_builder.create<shape::BroadcastOp>(
 | 
			
		||||
        loc, unknown_rank_extent_tensor_type, lhs_shape, ranked_shape_val,
 | 
			
		||||
        nullptr);
 | 
			
		||||
    Value extended_lhs_casted = if_builder.create<TensorCastOp>(
 | 
			
		||||
        loc, known_rank_extent_tensor_type, extended_lhs);
 | 
			
		||||
    Value extended_rhs = if_builder.create<shape::BroadcastOp>(
 | 
			
		||||
        loc, unknown_rank_extent_tensor_type, rhs_shape, ranked_shape_val,
 | 
			
		||||
        nullptr);
 | 
			
		||||
    Value extended_rhs_casted = if_builder.create<TensorCastOp>(
 | 
			
		||||
        loc, known_rank_extent_tensor_type, extended_rhs);
 | 
			
		||||
 | 
			
		||||
    // 1. Reshape operands to the given rank (with the same number of elements)
 | 
			
		||||
    // 2. Compute the ranked-broadcasted ChloOp (which will assert that the ops
 | 
			
		||||
    //    can be broadcasted and do the actual broadcasting)
 | 
			
		||||
    // 3. Type erase the output back to unranked
 | 
			
		||||
    Value reshaped_lhs = if_builder.create<mhlo::DynamicReshapeOp>(
 | 
			
		||||
        loc, reshaped_type, lhs, extended_lhs_casted);
 | 
			
		||||
    Value reshaped_rhs = if_builder.create<mhlo::DynamicReshapeOp>(
 | 
			
		||||
        loc, reshaped_type, rhs, extended_rhs_casted);
 | 
			
		||||
    Value result = if_builder.create<ChloOpTy>(
 | 
			
		||||
        loc, ArrayRef<Type>{reshaped_type},
 | 
			
		||||
        ArrayRef<Value>{reshaped_lhs, reshaped_rhs}, op.getAttrs());
 | 
			
		||||
    Value reshaped_result = if_builder.create<TensorCastOp>(
 | 
			
		||||
        loc, UnrankedTensorType::get(reshaped_type.getElementType()), result);
 | 
			
		||||
    if_builder.create<scf::YieldOp>(loc, reshaped_result);
 | 
			
		||||
 | 
			
		||||
    // Return the if_op, so the result can be used and the else block can be
 | 
			
		||||
    // used for the next rank specialized step.
 | 
			
		||||
    return if_op;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Iterates over the desired ranks to be specialized and generates the code
 | 
			
		||||
  // snippet for each case.
 | 
			
		||||
  Value HandleBroadcastAndOp(OpBuilder &rewriter, ChloOpTy op, Value lhs,
 | 
			
		||||
                             Value rhs) const {
 | 
			
		||||
    constexpr int max_rank_specialization = 7;
 | 
			
		||||
    auto loc = op.getLoc();
 | 
			
		||||
 | 
			
		||||
    // Find the larger rank of the 2 operands.
 | 
			
		||||
    auto extent_tensor_type = RankedTensorType::get({ShapedType::kDynamicSize},
 | 
			
		||||
                                                    rewriter.getIndexType());
 | 
			
		||||
    Value lhs_shape =
 | 
			
		||||
        rewriter.create<shape::ShapeOfOp>(loc, extent_tensor_type, lhs);
 | 
			
		||||
    Value rhs_shape =
 | 
			
		||||
        rewriter.create<shape::ShapeOfOp>(loc, extent_tensor_type, rhs);
 | 
			
		||||
    Value lhs_rank =
 | 
			
		||||
        rewriter.create<RankOp>(loc, rewriter.getIndexType(), lhs_shape);
 | 
			
		||||
    Value rhs_rank =
 | 
			
		||||
        rewriter.create<RankOp>(loc, rewriter.getIndexType(), rhs_shape);
 | 
			
		||||
    Value greater_rank_lhs =
 | 
			
		||||
        rewriter.create<CmpIOp>(loc, CmpIPredicate::sgt, lhs_rank, rhs_rank);
 | 
			
		||||
    Value greater_rank =
 | 
			
		||||
        rewriter.create<SelectOp>(loc, greater_rank_lhs, lhs_rank, rhs_rank);
 | 
			
		||||
 | 
			
		||||
    // Generate a list of nested if/else statements to handle rank
 | 
			
		||||
    // specializations from 2-6.
 | 
			
		||||
    scf::IfOp if_op = createRankSpecializedBroadcastAndOp(rewriter, op, lhs,
 | 
			
		||||
                                                          rhs, greater_rank, 2);
 | 
			
		||||
 | 
			
		||||
    // Put each subsequent rank specialization inside the else statement of the
 | 
			
		||||
    // previous one.
 | 
			
		||||
    OpBuilder else_builder = if_op.getElseBodyBuilder(rewriter.getListener());
 | 
			
		||||
    for (int i = 3; i < max_rank_specialization; i++) {
 | 
			
		||||
      auto inner_if = createRankSpecializedBroadcastAndOp(else_builder, op, lhs,
 | 
			
		||||
                                                          rhs, greater_rank, i);
 | 
			
		||||
 | 
			
		||||
      else_builder.create<scf::YieldOp>(loc, inner_if.getResult(0));
 | 
			
		||||
      else_builder = inner_if.getElseBodyBuilder(rewriter.getListener());
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // Fire an assertion if none of the rank specializations applied (one of the
 | 
			
		||||
    // ranks was greater than 6).
 | 
			
		||||
    else_builder.create<AssertOp>(
 | 
			
		||||
        loc, else_builder.create<ConstantIntOp>(loc, 0, 1),
 | 
			
		||||
        "Input for dynamic binary op lowering was of a rank greater than 6");
 | 
			
		||||
    else_builder.create<scf::YieldOp>(loc, lhs);
 | 
			
		||||
 | 
			
		||||
    // Return the result of the outermost if statement.
 | 
			
		||||
    return if_op.getResult(0);
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
struct TransformUnrankedHloPass
 | 
			
		||||
    : public PassWrapper<TransformUnrankedHloPass, FunctionPass> {
 | 
			
		||||
  void getDependentDialects(DialectRegistry ®istry) const override {
 | 
			
		||||
@ -137,7 +424,7 @@ struct TransformUnrankedHloPass
 | 
			
		||||
    MLIRContext &ctx = getContext();
 | 
			
		||||
    ConversionTarget target(ctx);
 | 
			
		||||
    target.addLegalDialect<mhlo::MhloDialect, StandardOpsDialect,
 | 
			
		||||
                           shape::ShapeDialect>();
 | 
			
		||||
                           shape::ShapeDialect, scf::SCFDialect>();
 | 
			
		||||
    target.addLegalOp<FuncOp>();
 | 
			
		||||
#define ADD_LEGAL_MHLO(op) AddLegalOpOnRankedTensor<mhlo::op>(&target)
 | 
			
		||||
#define ADD_LEGAL_CHLO(op) AddLegalOpOnRankedTensor<chlo::op>(&target)
 | 
			
		||||
@ -148,6 +435,12 @@ struct TransformUnrankedHloPass
 | 
			
		||||
#undef ADD_LEGAL_CHLO
 | 
			
		||||
    AddLegalOpOnRankedTensor<mhlo::CompareOp>(&target);
 | 
			
		||||
    AddLegalOpOnRankedTensor<mhlo::SelectOp>(&target);
 | 
			
		||||
    target.addDynamicallyLegalDialect<chlo::HloClientDialect>(
 | 
			
		||||
        [](Operation *op) {
 | 
			
		||||
          return !llvm::any_of(op->getOperandTypes(), [](Type type) {
 | 
			
		||||
            return type.isa<UnrankedTensorType>();
 | 
			
		||||
          });
 | 
			
		||||
        });
 | 
			
		||||
 | 
			
		||||
    // Populate rewrite patterns.
 | 
			
		||||
    OwningRewritePatternList patterns;
 | 
			
		||||
@ -180,6 +473,10 @@ void PopulateTransformUnrankedHloPatterns(MLIRContext *context,
 | 
			
		||||
#undef MAP_BINARY
 | 
			
		||||
#undef MAP_CHLO_UNARY
 | 
			
		||||
#undef COMMA
 | 
			
		||||
  chlo::PopulateForBroadcastingBinaryOp<
 | 
			
		||||
      ConvertUnrankedDynamicBroadcastBinaryOp>(context, patterns);
 | 
			
		||||
  chlo::PopulateForBroadcastingBinaryOp<
 | 
			
		||||
      ConvertUnrankedScalarDynamicBroadcastBinaryOp>(context, patterns);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::unique_ptr<FunctionPass> createTransformUnrankedHloPass() {
 | 
			
		||||
 | 
			
		||||
@ -237,209 +237,3 @@ func @xorWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4x
 | 
			
		||||
  %0 = chlo.broadcast_xor %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1>
 | 
			
		||||
  return %0 : tensor<4xi1>
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
func @addScalarUnranked(%arg0: tensor<f32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
 | 
			
		||||
  %0 = chlo.broadcast_add %arg0, %arg1 : (tensor<f32>, tensor<*xf32>)
 | 
			
		||||
                                         -> tensor<*xf32>
 | 
			
		||||
  return %0 : tensor<*xf32>
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// CHECK-LABEL:   func @addScalarUnranked(
 | 
			
		||||
// CHECK-SAME:                            %[[ARG_0:.*]]: tensor<f32>,
 | 
			
		||||
// CHECK-SAME:                            %[[ARG_1:.*]]: tensor<*xf32>
 | 
			
		||||
// CHECK-SAME:                            ) -> tensor<*xf32> {
 | 
			
		||||
//                  First handle the dynamic reshaping of the unranked operand
 | 
			
		||||
//                  to a 1D tensor.
 | 
			
		||||
// CHECK:           %[[SHAPE_1:.*]] = shape.shape_of %[[ARG_1]] : tensor<*xf32>
 | 
			
		||||
// CHECK:           %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE_1]] : tensor<?xindex> -> index
 | 
			
		||||
// CHECK:           %[[SIZE_TENSOR:.*]] = tensor_from_elements %[[NUM_ELEMENTS]] : tensor<1xindex>
 | 
			
		||||
// CHECK:           %[[RESHAPED:.*]] = "mhlo.dynamic_reshape"(%[[ARG_1]], %[[SIZE_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
 | 
			
		||||
//                  The assuming region is part of the second stage of lowering
 | 
			
		||||
//                  with ranked broadcasting logic.
 | 
			
		||||
// CHECK:           %[[SHAPE_0:.*]] = shape.shape_of %[[ARG_0]] : tensor<f32>
 | 
			
		||||
// CHECK:           %[[SHAPE_RESHAPED:.*]] = shape.shape_of %[[RESHAPED]] : tensor<?xf32>
 | 
			
		||||
// CHECK:           %[[WITNESS:.*]] = shape.cstr_broadcastable %[[SHAPE_0]], %[[SHAPE_RESHAPED]]
 | 
			
		||||
// CHECK:           %[[ASSUMING_RESULT:.*]] = shape.assuming %[[WITNESS]] -> (tensor<?xf32>) {
 | 
			
		||||
// CHECK:             %[[SCALAR_SHAPE:.*]] = shape.const_shape []
 | 
			
		||||
// CHECK:             %[[BROADCASTED_SHAPE:.*]] = shape.broadcast %[[SCALAR_SHAPE]], %[[SHAPE_RESHAPED]]
 | 
			
		||||
// CHECK:             %[[SHAPE_TENSOR:.*]] = tensor_cast %[[BROADCASTED_SHAPE]] : tensor<?xindex> to tensor<1xindex>
 | 
			
		||||
// CHECK:             %[[BROADCASTED_LHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG_0]], %[[SHAPE_TENSOR]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<1xindex>) -> tensor<?xf32>
 | 
			
		||||
// CHECK:             %[[BROADCASTED_RHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[RESHAPED]], %[[SHAPE_TENSOR]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xf32>, tensor<1xindex>) -> tensor<?xf32>
 | 
			
		||||
// CHECK:             %[[BROADCASTED_RESULT:.*]] = mhlo.add %[[BROADCASTED_LHS]], %[[BROADCASTED_RHS]] : tensor<?xf32>
 | 
			
		||||
// CHECK:             shape.assuming_yield %[[BROADCASTED_RESULT]] : tensor<?xf32>
 | 
			
		||||
// CHECK:           }
 | 
			
		||||
//                  As part of the unranked logic, the result is reshaped back
 | 
			
		||||
//                  to an unranked tensor.
 | 
			
		||||
// CHECK:           %[[RESHAPED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[ASSUMING_RESULT:.*]], %[[SHAPE_1]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
 | 
			
		||||
// CHECK:           return %[[RESHAPED_RESULT]] : tensor<*xf32>
 | 
			
		||||
// CHECK:         }
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
func @addUnrankedScalar(%arg0: tensor<*xf32>, %arg1: tensor<f32>) -> tensor<*xf32> {
 | 
			
		||||
  %0 = chlo.broadcast_add %arg0, %arg1 : (tensor<*xf32>, tensor<f32>)
 | 
			
		||||
                                         -> tensor<*xf32>
 | 
			
		||||
  return %0 : tensor<*xf32>
 | 
			
		||||
}
 | 
			
		||||
// CHECK-LABEL:   func @addUnrankedScalar(
 | 
			
		||||
// CHECK-SAME:                            %[[ARG_0:.*]]: tensor<*xf32>,
 | 
			
		||||
// CHECK-SAME:                            %[[ARG_1:.*]]: tensor<f32>) -> tensor<*xf32> {
 | 
			
		||||
//                  First handle the dynamic reshaping of the unranked operand
 | 
			
		||||
//                  to a 1D tensor.
 | 
			
		||||
// CHECK:           %[[SHAPE_0:.*]] = shape.shape_of %[[ARG_0]] : tensor<*xf32>
 | 
			
		||||
// CHECK:           %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE_0]] : tensor<?xindex> -> index
 | 
			
		||||
// CHECK:           %[[SIZE_TENSOR:.*]] = tensor_from_elements %[[NUM_ELEMENTS]] : tensor<1xindex>
 | 
			
		||||
// CHECK:           %[[RESHAPED:.*]] = "mhlo.dynamic_reshape"(%[[ARG_0]], %[[SIZE_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
 | 
			
		||||
//                  The assuming region is part of the second stage of lowering
 | 
			
		||||
//                  with ranked broadcasting logic.
 | 
			
		||||
// CHECK:           %[[SHAPE_RESHAPED:.*]] = shape.shape_of %[[RESHAPED]] : tensor<?xf32>
 | 
			
		||||
// CHECK:           %[[SHAPE_1:.*]] = shape.shape_of %[[ARG_1]] : tensor<f32>
 | 
			
		||||
// CHECK:           %[[WITNESS:.*]] = shape.cstr_broadcastable %[[SHAPE_RESHAPED]], %[[SHAPE_1]]
 | 
			
		||||
// CHECK:           %[[ASSUMING_RESULT:.*]] = shape.assuming %[[WITNESS]] -> (tensor<?xf32>) {
 | 
			
		||||
// CHECK:             %[[ASTENSOR:.*]] = tensor_cast %[[SHAPE_RESHAPED]]
 | 
			
		||||
// CHECK:             %[[BROADCASTED_LHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[RESHAPED]], %[[ASTENSOR]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xf32>, tensor<1xindex>) -> tensor<?xf32>
 | 
			
		||||
// CHECK:             %[[BROADCASTED_RHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG_1]], %[[ASTENSOR]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<1xindex>) -> tensor<?xf32>
 | 
			
		||||
// CHECK:             %[[BROADCASTED_RESULT:.*]] = mhlo.add %[[BROADCASTED_LHS]], %[[BROADCASTED_RHS]] : tensor<?xf32>
 | 
			
		||||
// CHECK:             shape.assuming_yield %[[BROADCASTED_RESULT]] : tensor<?xf32>
 | 
			
		||||
// CHECK:           }
 | 
			
		||||
//                  As part of the unranked logic, the result is reshaped back
 | 
			
		||||
//                  to an unranked tensor.
 | 
			
		||||
// CHECK:           %[[RESHAPED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[ASSUMING_RESULT:.*]], %[[SHAPE_0]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
 | 
			
		||||
// CHECK:           return %[[RESHAPED_RESULT]] : tensor<*xf32>
 | 
			
		||||
// CHECK:         }
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
func @addUnrankedUnranked(
 | 
			
		||||
      %arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
 | 
			
		||||
  %0 = chlo.broadcast_add %arg0, %arg1 : (tensor<*xf32>, tensor<*xf32>)
 | 
			
		||||
                                         -> tensor<*xf32>
 | 
			
		||||
  return %0 : tensor<*xf32>
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// CHECK-LABEL:   func @addUnrankedUnranked(
 | 
			
		||||
// CHECK-SAME:          %[[LHS:.*]]: tensor<*xf32>,
 | 
			
		||||
// CHECK-SAME:          %[[RHS:.*]]: tensor<*xf32>) -> tensor<*xf32> {
 | 
			
		||||
// CHECK:           %[[LHS_SHAPE:.*]] = shape.shape_of %[[LHS]] : tensor<*xf32> -> tensor<?xindex>
 | 
			
		||||
// CHECK:           %[[RANK_LHS:.*]] = shape.rank %[[LHS_SHAPE]] : tensor<?xindex> -> index
 | 
			
		||||
// CHECK:           %[[C0:.*]] = constant 0 : index
 | 
			
		||||
// CHECK:           %[[LHS_IS_SCALAR:.*]] = cmpi "eq", %[[RANK_LHS]], %[[C0]] : index
 | 
			
		||||
//                  Handle scalar LHS case
 | 
			
		||||
// CHECK:           %[[VAL_8:.*]] = scf.if %[[LHS_IS_SCALAR]] -> (tensor<*xf32>) {
 | 
			
		||||
// CHECK:             %[[SCALAR_LHS:.*]] = tensor_cast %[[LHS]] : tensor<*xf32> to tensor<f32>
 | 
			
		||||
// CHECK:             %[[VAL_10:.*]] = chlo.broadcast_add %[[SCALAR_LHS]], %[[RHS]] : (tensor<f32>, tensor<*xf32>) -> tensor<*xf32>
 | 
			
		||||
// CHECK:             scf.yield %[[VAL_10]] : tensor<*xf32>
 | 
			
		||||
// CHECK:           } else {
 | 
			
		||||
// CHECK:             %[[RHS_SHAPE:.*]] = shape.shape_of %[[RHS]] : tensor<*xf32> -> tensor<?xindex>
 | 
			
		||||
// CHECK:             %[[RANK_RHS:.*]] = shape.rank %[[RHS_SHAPE]] : tensor<?xindex> -> index
 | 
			
		||||
// CHECK:             %[[RHS_IS_SCALAR:.*]] = cmpi "eq", %[[RANK_RHS]], %[[C0]] : index
 | 
			
		||||
  //                  Handle scalar RHS case
 | 
			
		||||
// CHECK:             %[[VAL_14:.*]] = scf.if %[[RHS_IS_SCALAR]] -> (tensor<*xf32>) {
 | 
			
		||||
// CHECK:               %[[SCALAR_RHS:.*]] = tensor_cast %[[RHS]] : tensor<*xf32> to tensor<f32>
 | 
			
		||||
// CHECK:               %[[VAL_16:.*]] = chlo.broadcast_add %[[LHS]], %[[SCALAR_RHS]] : (tensor<*xf32>, tensor<f32>) -> tensor<*xf32>
 | 
			
		||||
// CHECK:               scf.yield %[[VAL_16]] : tensor<*xf32>
 | 
			
		||||
// CHECK:             } else {
 | 
			
		||||
// CHECK:               %[[SHAPES_EQ:.*]] = shape.shape_eq %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex>
 | 
			
		||||
  //                    Handle scalar RHS case
 | 
			
		||||
// CHECK:               %[[VAL_18:.*]] = scf.if %[[SHAPES_EQ]] -> (tensor<*xf32>) {
 | 
			
		||||
// CHECK:                 %[[VAL_19:.*]] = mhlo.add %[[LHS]], %[[RHS]] : tensor<*xf32>
 | 
			
		||||
// CHECK:                 scf.yield %[[VAL_19]] : tensor<*xf32>
 | 
			
		||||
// CHECK:               } else {
 | 
			
		||||
// CHECK:                 %[[LHS_RANK:.*]] = rank %[[LHS_SHAPE]] : tensor<?xindex>
 | 
			
		||||
// CHECK:                 %[[RHS_RANK:.*]] = rank %[[RHS_SHAPE]] : tensor<?xindex>
 | 
			
		||||
// CHECK:                 %[[LHS_RANK_GREATER:.*]] = cmpi "sgt", %[[LHS_RANK]], %[[RHS_RANK]] : index
 | 
			
		||||
// CHECK:                 %[[GREATEST_RANK:.*]] = select %[[LHS_RANK_GREATER]], %[[LHS_RANK]], %[[RHS_RANK]] : index
 | 
			
		||||
// CHECK:                 %[[C2:.*]] = constant 2 : index
 | 
			
		||||
// CHECK:                 %[[GREATEST_RANK_IS_2:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C2]] : index
 | 
			
		||||
//                        Handle rank 2 specialization
 | 
			
		||||
// CHECK:                 %[[VAL_26:.*]] = scf.if %[[GREATEST_RANK_IS_2]] -> (tensor<*xf32>) {
 | 
			
		||||
// CHECK:                   %[[CONST_SHAPE_2:.*]] = shape.const_shape [1, 1]
 | 
			
		||||
// CHECK:                   %[[BROADCASTED_LHS_2:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_2]] : tensor<?xindex>, tensor<2xindex> -> tensor<?xindex>
 | 
			
		||||
// CHECK:                   %[[CASTED_LHS_2:.*]] = tensor_cast %[[BROADCASTED_LHS_2]] : tensor<?xindex> to tensor<2xindex>
 | 
			
		||||
// CHECK:                   %[[BROADCASTED_RHS_2:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_2]] : tensor<?xindex>, tensor<2xindex> -> tensor<?xindex>
 | 
			
		||||
// CHECK:                   %[[CASTED_RHS_2:.*]] = tensor_cast %[[BROADCASTED_RHS_2]] : tensor<?xindex> to tensor<2xindex>
 | 
			
		||||
// CHECK:                   %[[RESHAPED_LHS_2:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor<?x?xf32>
 | 
			
		||||
// CHECK:                   %[[RESHAPED_RHS_2:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor<?x?xf32>
 | 
			
		||||
// CHECK:                   %[[RESULT_RANK_2:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_2]], %[[RESHAPED_RHS_2]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
 | 
			
		||||
// CHECK:                   %[[RESULT_2:.*]] = tensor_cast %[[RESULT_RANK_2]] : tensor<?x?xf32> to tensor<*xf32>
 | 
			
		||||
// CHECK:                   scf.yield %[[RESULT_2]] : tensor<*xf32>
 | 
			
		||||
// CHECK:                 } else {
 | 
			
		||||
// CHECK:                   %[[C3:.*]] = constant 3 : index
 | 
			
		||||
// CHECK:                   %[[GREATEST_RANK_IS_3:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C3]] : index
 | 
			
		||||
//                          Handle rank 3 specialization
 | 
			
		||||
// CHECK:                   %[[VAL_34:.*]] = scf.if %[[GREATEST_RANK_IS_3]] -> (tensor<*xf32>) {
 | 
			
		||||
// CHECK:                     %[[CONST_SHAPE_3:.*]] = shape.const_shape [1, 1, 1]
 | 
			
		||||
// CHECK:                     %[[BROADCASTED_LHS_3:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_3]] : tensor<?xindex>, tensor<3xindex> -> tensor<?xindex>
 | 
			
		||||
// CHECK:                     %[[CASTED_LHS_3:.*]] = tensor_cast %[[BROADCASTED_LHS_3]] : tensor<?xindex> to tensor<3xindex>
 | 
			
		||||
// CHECK:                     %[[BROADCASTED_RHS_3:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_3]] : tensor<?xindex>, tensor<3xindex> -> tensor<?xindex>
 | 
			
		||||
// CHECK:                     %[[CASTED_RHS_3:.*]] = tensor_cast %[[BROADCASTED_RHS_3]] : tensor<?xindex> to tensor<3xindex>
 | 
			
		||||
// CHECK:                     %[[RESHAPED_LHS_3:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
 | 
			
		||||
// CHECK:                     %[[RESHAPED_RHS_3:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
 | 
			
		||||
// CHECK:                     %[[RESULT_RANK_3:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_3]], %[[RESHAPED_RHS_3]] : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
 | 
			
		||||
// CHECK:                     %[[RESULT_3:.*]] = tensor_cast %[[RESULT_RANK_3]] : tensor<?x?x?xf32> to tensor<*xf32>
 | 
			
		||||
// CHECK:                     scf.yield %[[RESULT_3]] : tensor<*xf32>
 | 
			
		||||
// CHECK:                   } else {
 | 
			
		||||
// CHECK:                     %[[C4:.*]] = constant 4 : index
 | 
			
		||||
// CHECK:                     %[[GREATEST_RANK_IS_4:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C4]] : index
 | 
			
		||||
//                            Handle rank 4 specialization
 | 
			
		||||
// CHECK:                     %[[VAL_42:.*]] = scf.if %[[GREATEST_RANK_IS_4]] -> (tensor<*xf32>) {
 | 
			
		||||
// CHECK:                       %[[CONST_SHAPE_4:.*]] = shape.const_shape [1, 1, 1, 1]
 | 
			
		||||
// CHECK:                       %[[BROADCASTED_LHS_4:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_4]] : tensor<?xindex>, tensor<4xindex> -> tensor<?xindex>
 | 
			
		||||
// CHECK:                       %[[CASTED_LHS_4:.*]] = tensor_cast %[[BROADCASTED_LHS_4]] : tensor<?xindex> to tensor<4xindex>
 | 
			
		||||
// CHECK:                       %[[BROADCASTED_RHS_4:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_4]] : tensor<?xindex>, tensor<4xindex> -> tensor<?xindex>
 | 
			
		||||
// CHECK:                       %[[CASTED_RHS_4:.*]] = tensor_cast %[[BROADCASTED_RHS_4]] : tensor<?xindex> to tensor<4xindex>
 | 
			
		||||
// CHECK:                       %[[RESHAPED_LHS_4:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
 | 
			
		||||
// CHECK:                       %[[RESHAPED_RHS_4:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
 | 
			
		||||
// CHECK:                       %[[RESULT_RANK_4:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_4]], %[[RESHAPED_RHS_4]] : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
 | 
			
		||||
// CHECK:                       %[[RESULT_4:.*]] = tensor_cast %[[RESULT_RANK_4]] : tensor<?x?x?x?xf32> to tensor<*xf32>
 | 
			
		||||
// CHECK:                       scf.yield %[[RESULT_4]] : tensor<*xf32>
 | 
			
		||||
// CHECK:                     } else {
 | 
			
		||||
// CHECK:                       %[[C5:.*]] = constant 5 : index
 | 
			
		||||
// CHECK:                       %[[GREATEST_RANK_IS_5:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C5]] : index
 | 
			
		||||
//                              Handle rank 5 specialization
 | 
			
		||||
// CHECK:                       %[[VAL_50:.*]] = scf.if %[[GREATEST_RANK_IS_5]] -> (tensor<*xf32>) {
 | 
			
		||||
// CHECK:                         %[[CONST_SHAPE_5:.*]] = shape.const_shape [1, 1, 1, 1, 1]
 | 
			
		||||
// CHECK:                         %[[BROADCASTED_LHS_5:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_5]] : tensor<?xindex>, tensor<5xindex> -> tensor<?xindex>
 | 
			
		||||
// CHECK:                         %[[CASTED_LHS_5:.*]] = tensor_cast %[[BROADCASTED_LHS_5]] : tensor<?xindex> to tensor<5xindex>
 | 
			
		||||
// CHECK:                         %[[BROADCASTED_RHS_5:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_5]] : tensor<?xindex>, tensor<5xindex> -> tensor<?xindex>
 | 
			
		||||
// CHECK:                         %[[CASTED_RHS_5:.*]] = tensor_cast %[[BROADCASTED_RHS_5]] : tensor<?xindex> to tensor<5xindex>
 | 
			
		||||
// CHECK:                         %[[RESHAPED_LHS_5:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor<?x?x?x?x?xf32>
 | 
			
		||||
// CHECK:                         %[[RESHAPED_RHS_5:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor<?x?x?x?x?xf32>
 | 
			
		||||
// CHECK:                         %[[RESULT_RANK_5:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_5]], %[[RESHAPED_RHS_5]] : (tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
 | 
			
		||||
// CHECK:                         %[[RESULT_5:.*]] = tensor_cast %[[RESULT_RANK_5]] : tensor<?x?x?x?x?xf32> to tensor<*xf32>
 | 
			
		||||
// CHECK:                         scf.yield %[[RESULT_5]] : tensor<*xf32>
 | 
			
		||||
// CHECK:                       } else {
 | 
			
		||||
// CHECK:                         %[[C6:.*]] = constant 6 : index
 | 
			
		||||
// CHECK:                         %[[GREATEST_RANK_IS_6:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C6]] : index
 | 
			
		||||
//                                Handle rank 6 specialization
 | 
			
		||||
// CHECK:                         %[[VAL_58:.*]] = scf.if %[[GREATEST_RANK_IS_6]] -> (tensor<*xf32>) {
 | 
			
		||||
// CHECK:                           %[[CONST_SHAPE_6:.*]] = shape.const_shape [1, 1, 1, 1, 1, 1]
 | 
			
		||||
// CHECK:                           %[[BROADCASTED_LHS_6:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_6]] : tensor<?xindex>, tensor<6xindex> -> tensor<?xindex>
 | 
			
		||||
// CHECK:                           %[[CASTED_LHS_6:.*]] = tensor_cast %[[BROADCASTED_LHS_6]] : tensor<?xindex> to tensor<6xindex>
 | 
			
		||||
// CHECK:                           %[[BROADCASTED_RHS_6:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_6]] : tensor<?xindex>, tensor<6xindex> -> tensor<?xindex>
 | 
			
		||||
// CHECK:                           %[[CASTED_RHS_6:.*]] = tensor_cast %[[BROADCASTED_RHS_6]] : tensor<?xindex> to tensor<6xindex>
 | 
			
		||||
// CHECK:                           %[[RESHAPED_LHS_6:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_6]]) : (tensor<*xf32>, tensor<6xindex>) -> tensor<?x?x?x?x?x?xf32>
 | 
			
		||||
// CHECK:                           %[[RESHAPED_RHS_6:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_6]]) : (tensor<*xf32>, tensor<6xindex>) -> tensor<?x?x?x?x?x?xf32>
 | 
			
		||||
// CHECK:                           %[[RESULT_RANK_6:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_6]], %[[RESHAPED_RHS_6]] : (tensor<?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?xf32>) -> tensor<?x?x?x?x?x?xf32>
 | 
			
		||||
// CHECK:                           %[[RESULT_6:.*]] = tensor_cast %[[RESULT_RANK_6]] : tensor<?x?x?x?x?x?xf32> to tensor<*xf32>
 | 
			
		||||
// CHECK:                           scf.yield %[[RESULT_6]] : tensor<*xf32>
 | 
			
		||||
// CHECK:                         } else {
 | 
			
		||||
// CHECK:                           %false = constant false
 | 
			
		||||
// CHECK:                           assert %false
 | 
			
		||||
// CHECK:                           scf.yield %[[LHS]] : tensor<*xf32>
 | 
			
		||||
// CHECK:                         }
 | 
			
		||||
// CHECK:                         scf.yield %[[VAL_64:.*]] : tensor<*xf32>
 | 
			
		||||
// CHECK:                       }
 | 
			
		||||
// CHECK:                       scf.yield %[[VAL_65:.*]] : tensor<*xf32>
 | 
			
		||||
// CHECK:                     }
 | 
			
		||||
// CHECK:                     scf.yield %[[VAL_66:.*]] : tensor<*xf32>
 | 
			
		||||
// CHECK:                   }
 | 
			
		||||
// CHECK:                   scf.yield %[[VAL_67:.*]] : tensor<*xf32>
 | 
			
		||||
// CHECK:                 }
 | 
			
		||||
// CHECK:                 scf.yield %[[VAL_68:.*]] : tensor<*xf32>
 | 
			
		||||
// CHECK:               }
 | 
			
		||||
// CHECK:               scf.yield %[[VAL_69:.*]] : tensor<*xf32>
 | 
			
		||||
// CHECK:             }
 | 
			
		||||
// CHECK:             scf.yield %[[VAL_70:.*]] : tensor<*xf32>
 | 
			
		||||
// CHECK:           }
 | 
			
		||||
// CHECK:           return %[[VAL_71:.*]] : tensor<*xf32>
 | 
			
		||||
// CHECK:         }
 | 
			
		||||
 | 
			
		||||
@ -1,4 +1,4 @@
 | 
			
		||||
// RUN: mlir-hlo-opt --transform-unranked-hlo --split-input-file %s | FileCheck %s
 | 
			
		||||
// RUN: mlir-hlo-opt --transform-unranked-hlo --cse --split-input-file %s | FileCheck %s
 | 
			
		||||
 | 
			
		||||
// Check the validity of expected IR.
 | 
			
		||||
// CHECK-LABEL: @sqr_transform_result
 | 
			
		||||
@ -96,3 +96,203 @@ func @tan(%a : tensor<*xf32>) -> tensor<*xf32> {
 | 
			
		||||
  %result = chlo.tan %a : tensor<*xf32>
 | 
			
		||||
  return %result : tensor<*xf32>
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
func @addScalarUnranked(%arg0: tensor<f32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
 | 
			
		||||
  %0 = chlo.broadcast_add %arg0, %arg1 : (tensor<f32>, tensor<*xf32>)
 | 
			
		||||
                                         -> tensor<*xf32>
 | 
			
		||||
  return %0 : tensor<*xf32>
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// CHECK-LABEL:   func @addScalarUnranked(
 | 
			
		||||
// CHECK-SAME:                            %[[ARG_0:.*]]: tensor<f32>,
 | 
			
		||||
// CHECK-SAME:                            %[[ARG_1:.*]]: tensor<*xf32>
 | 
			
		||||
// CHECK-SAME:                            ) -> tensor<*xf32> {
 | 
			
		||||
//                  First handle the dynamic reshaping of the unranked operand
 | 
			
		||||
//                  to a 1D tensor.
 | 
			
		||||
// CHECK-NEXT:           %[[SHAPE_1:.*]] = shape.shape_of %[[ARG_1]] : tensor<*xf32>
 | 
			
		||||
// CHECK-NEXT:           %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE_1]] : tensor<?xindex> -> index
 | 
			
		||||
// CHECK-NEXT:           %[[SIZE_TENSOR:.*]] = tensor_from_elements %[[NUM_ELEMENTS]] : tensor<1xindex>
 | 
			
		||||
// CHECK-NEXT:           %[[RESHAPED:.*]] = "mhlo.dynamic_reshape"(%[[ARG_1]], %[[SIZE_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
 | 
			
		||||
// CHECK-NEXT:           %[[BROADCASTED_RESULT:.*]] = chlo.broadcast_add %[[ARG_0]], %[[RESHAPED]] : (tensor<f32>, tensor<?xf32>) -> tensor<?xf32>
 | 
			
		||||
//                  As part of the unranked logic, the result is reshaped back
 | 
			
		||||
//                  to an unranked tensor.
 | 
			
		||||
// CHECK-NEXT:           %[[RESHAPED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[BROADCASTED_RESULT:.*]], %[[SHAPE_1]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
 | 
			
		||||
// CHECK-NEXT:           return %[[RESHAPED_RESULT]] : tensor<*xf32>
 | 
			
		||||
// CHECK-NEXT:         }
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
func @addUnrankedScalar(%arg0: tensor<*xf32>, %arg1: tensor<f32>) -> tensor<*xf32> {
 | 
			
		||||
  %0 = chlo.broadcast_add %arg0, %arg1 : (tensor<*xf32>, tensor<f32>)
 | 
			
		||||
                                         -> tensor<*xf32>
 | 
			
		||||
  return %0 : tensor<*xf32>
 | 
			
		||||
}
 | 
			
		||||
// CHECK-LABEL:   func @addUnrankedScalar(
 | 
			
		||||
// CHECK-SAME:                            %[[ARG_0:.*]]: tensor<*xf32>,
 | 
			
		||||
// CHECK-SAME:                            %[[ARG_1:.*]]: tensor<f32>) -> tensor<*xf32> {
 | 
			
		||||
//                  First handle the dynamic reshaping of the unranked operand
 | 
			
		||||
//                  to a 1D tensor.
 | 
			
		||||
// CHECK-NEXT:           %[[SHAPE_0:.*]] = shape.shape_of %[[ARG_0]] : tensor<*xf32>
 | 
			
		||||
// CHECK-NEXT:           %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE_0]] : tensor<?xindex> -> index
 | 
			
		||||
// CHECK-NEXT:           %[[SIZE_TENSOR:.*]] = tensor_from_elements %[[NUM_ELEMENTS]] : tensor<1xindex>
 | 
			
		||||
// CHECK-NEXT:           %[[RESHAPED:.*]] = "mhlo.dynamic_reshape"(%[[ARG_0]], %[[SIZE_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
 | 
			
		||||
//                  The assuming region is part of the second stage of lowering
 | 
			
		||||
//                  with ranked broadcasting logic.
 | 
			
		||||
// CHECK-NEXT:           %[[BROADCASTED_RESULT:.*]] = chlo.broadcast_add %[[RESHAPED]], %[[ARG_1]] : (tensor<?xf32>, tensor<f32>)  -> tensor<?xf32>
 | 
			
		||||
//                  As part of the unranked logic, the result is reshaped back
 | 
			
		||||
//                  to an unranked tensor.
 | 
			
		||||
// CHECK-NEXT:           %[[RESHAPED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[BROADCASTED_RESULT:.*]], %[[SHAPE_0]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
 | 
			
		||||
// CHECK-NEXT:           return %[[RESHAPED_RESULT]] : tensor<*xf32>
 | 
			
		||||
// CHECK-NEXT:         }
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
func @addUnrankedUnranked(
 | 
			
		||||
      %arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
 | 
			
		||||
  %0 = chlo.broadcast_add %arg0, %arg1 : (tensor<*xf32>, tensor<*xf32>)
 | 
			
		||||
                                         -> tensor<*xf32>
 | 
			
		||||
  return %0 : tensor<*xf32>
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// CHECK-LABEL:   func @addUnrankedUnranked(
 | 
			
		||||
// CHECK-SAME:          %[[LHS:.*]]: tensor<*xf32>,
 | 
			
		||||
// CHECK-SAME:          %[[RHS:.*]]: tensor<*xf32>) -> tensor<*xf32> {
 | 
			
		||||
// CHECK-NEXT:           %[[LHS_SHAPE:.*]] = shape.shape_of %[[LHS]] : tensor<*xf32> -> tensor<?xindex>
 | 
			
		||||
// CHECK-NEXT:           %[[RANK_LHS:.*]] = shape.rank %[[LHS_SHAPE]] : tensor<?xindex> -> index
 | 
			
		||||
// CHECK-NEXT:           %[[C0:.*]] = constant 0 : index
 | 
			
		||||
// CHECK-NEXT:           %[[LHS_IS_SCALAR:.*]] = cmpi "eq", %[[RANK_LHS]], %[[C0]] : index
 | 
			
		||||
//                       Handle scalar LHS case
 | 
			
		||||
// CHECK-NEXT:           %[[VAL_8:.*]] = scf.if %[[LHS_IS_SCALAR]] -> (tensor<*xf32>) {
 | 
			
		||||
// CHECK-NEXT:             %[[SCALAR_LHS:.*]] = tensor_cast %[[LHS]] : tensor<*xf32> to tensor<f32>
 | 
			
		||||
// CHECK-NEXT:             %[[RHS_SHAPE_1:.*]] = shape.shape_of %[[RHS]] : tensor<*xf32> -> tensor<?xindex>
 | 
			
		||||
// CHECK-NEXT:             %[[NUM_RHS:.*]] = shape.num_elements %[[RHS_SHAPE_1]] : tensor<?xindex> -> index
 | 
			
		||||
// CHECK-NEXT:             %[[NUM_TENS_RHS:.*]] = tensor_from_elements %[[NUM_RHS]] : tensor<1xindex>
 | 
			
		||||
// CHECK-NEXT:             %[[RESHAPED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[NUM_TENS_RHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
 | 
			
		||||
// CHECK-NEXT:             %[[LHS_SCALAR_RESULT:.*]] = chlo.broadcast_add %[[SCALAR_LHS]], %[[RESHAPED_RHS]] : (tensor<f32>, tensor<?xf32>) -> tensor<?xf32>
 | 
			
		||||
// CHECK-NEXT:             %[[RESHAPED_LHS_SCALAR_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[LHS_SCALAR_RESULT]], %[[RHS_SHAPE_1]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
 | 
			
		||||
// CHECK-NEXT:             scf.yield %[[RESHAPED_LHS_SCALAR_RESULT]] : tensor<*xf32>
 | 
			
		||||
// CHECK-NEXT:           } else {
 | 
			
		||||
// CHECK-NEXT:             %[[RHS_SHAPE:.*]] = shape.shape_of %[[RHS]] : tensor<*xf32> -> tensor<?xindex>
 | 
			
		||||
// CHECK-NEXT:             %[[RANK_RHS:.*]] = shape.rank %[[RHS_SHAPE]] : tensor<?xindex> -> index
 | 
			
		||||
// CHECK-NEXT:             %[[RHS_IS_SCALAR:.*]] = cmpi "eq", %[[RANK_RHS]], %[[C0]] : index
 | 
			
		||||
//                         Handle scalar RHS case
 | 
			
		||||
// CHECK-NEXT:             %[[VAL_14:.*]] = scf.if %[[RHS_IS_SCALAR]] -> (tensor<*xf32>) {
 | 
			
		||||
// CHECK-NEXT:               %[[SCALAR_RHS:.*]] = tensor_cast %[[RHS]] : tensor<*xf32> to tensor<f32>
 | 
			
		||||
// CHECK-NEXT:               %[[NUM_LHS:.*]] = shape.num_elements %[[LHS_SHAPE]] : tensor<?xindex> -> index
 | 
			
		||||
// CHECK-NEXT:               %[[NUM_TENS_LHS:.*]] = tensor_from_elements %[[NUM_LHS]] : tensor<1xindex>
 | 
			
		||||
// CHECK-NEXT:               %[[RESHAPED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[NUM_TENS_LHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
 | 
			
		||||
// CHECK-NEXT:               %[[RHS_SCALAR_RESULT:.*]] = chlo.broadcast_add %[[RESHAPED_LHS]], %[[SCALAR_RHS]] : (tensor<?xf32>, tensor<f32>) -> tensor<?xf32>
 | 
			
		||||
// CHECK-NEXT:               %[[RESHAPED_RHS_SCALAR_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[RHS_SCALAR_RESULT:.*]], %[[LHS_SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
 | 
			
		||||
// CHECK-NEXT:               scf.yield %[[RESHAPED_RHS_SCALAR_RESULT]] : tensor<*xf32>
 | 
			
		||||
// CHECK-NEXT:             } else {
 | 
			
		||||
// CHECK-NEXT:               %[[SHAPES_EQ:.*]] = shape.shape_eq %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex>
 | 
			
		||||
//                           Handle equal shapes case
 | 
			
		||||
// CHECK-NEXT:               %[[VAL_18:.*]] = scf.if %[[SHAPES_EQ]] -> (tensor<*xf32>) {
 | 
			
		||||
// CHECK-NEXT:                 %[[ANY_SHAPE:.*]] = shape.any %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
 | 
			
		||||
// CHECK-NEXT:                 %[[ANY_NUM:.*]] = shape.num_elements %[[ANY_SHAPE]] : tensor<?xindex> -> index
 | 
			
		||||
// CHECK-NEXT:                 %[[ANY_TENSOR:.*]] = tensor_from_elements %[[ANY_NUM]] : tensor<1xindex>
 | 
			
		||||
// CHECK-NEXT:                 %[[FLATTENED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[ANY_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
 | 
			
		||||
// CHECK-NEXT:                 %[[FLATTENED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[ANY_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
 | 
			
		||||
// CHECK-NEXT:                 %[[FLATTENED_RESULT:.*]] = mhlo.add %[[FLATTENED_LHS]], %[[FLATTENED_RHS]] : tensor<?xf32>
 | 
			
		||||
// CHECK-NEXT:                 %[[RESHAPED_SAME_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[FLATTENED_RESULT]], %[[ANY_SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
 | 
			
		||||
// CHECK-NEXT:                 scf.yield %[[RESHAPED_SAME_RESULT]] : tensor<*xf32>
 | 
			
		||||
// CHECK-NEXT:               } else {
 | 
			
		||||
// CHECK-NEXT:                 %[[LHS_RANK:.*]] = rank %[[LHS_SHAPE]] : tensor<?xindex>
 | 
			
		||||
// CHECK-NEXT:                 %[[RHS_RANK:.*]] = rank %[[RHS_SHAPE]] : tensor<?xindex>
 | 
			
		||||
// CHECK-NEXT:                 %[[LHS_RANK_GREATER:.*]] = cmpi "sgt", %[[LHS_RANK]], %[[RHS_RANK]] : index
 | 
			
		||||
// CHECK-NEXT:                 %[[GREATEST_RANK:.*]] = select %[[LHS_RANK_GREATER]], %[[LHS_RANK]], %[[RHS_RANK]] : index
 | 
			
		||||
// CHECK-NEXT:                 %[[C2:.*]] = constant 2 : index
 | 
			
		||||
// CHECK-NEXT:                 %[[GREATEST_RANK_IS_2:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C2]] : index
 | 
			
		||||
//                             Handle rank 2 specialization
 | 
			
		||||
// CHECK-NEXT:                 %[[VAL_26:.*]] = scf.if %[[GREATEST_RANK_IS_2]] -> (tensor<*xf32>) {
 | 
			
		||||
// CHECK-NEXT:                   %[[CONST_SHAPE_2:.*]] = shape.const_shape [1, 1]
 | 
			
		||||
// CHECK-NEXT:                   %[[BROADCASTED_LHS_2:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_2]] : tensor<?xindex>, tensor<2xindex> -> tensor<?xindex>
 | 
			
		||||
// CHECK-NEXT:                   %[[CASTED_LHS_2:.*]] = tensor_cast %[[BROADCASTED_LHS_2]] : tensor<?xindex> to tensor<2xindex>
 | 
			
		||||
// CHECK-NEXT:                   %[[BROADCASTED_RHS_2:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_2]] : tensor<?xindex>, tensor<2xindex> -> tensor<?xindex>
 | 
			
		||||
// CHECK-NEXT:                   %[[CASTED_RHS_2:.*]] = tensor_cast %[[BROADCASTED_RHS_2]] : tensor<?xindex> to tensor<2xindex>
 | 
			
		||||
// CHECK-NEXT:                   %[[RESHAPED_LHS_2:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor<?x?xf32>
 | 
			
		||||
// CHECK-NEXT:                   %[[RESHAPED_RHS_2:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor<?x?xf32>
 | 
			
		||||
// CHECK-NEXT:                   %[[RESULT_RANK_2:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_2]], %[[RESHAPED_RHS_2]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
 | 
			
		||||
// CHECK-NEXT:                   %[[RESULT_2:.*]] = tensor_cast %[[RESULT_RANK_2]] : tensor<?x?xf32> to tensor<*xf32>
 | 
			
		||||
// CHECK-NEXT:                   scf.yield %[[RESULT_2]] : tensor<*xf32>
 | 
			
		||||
// CHECK-NEXT:                 } else {
 | 
			
		||||
// CHECK-NEXT:                   %[[C3:.*]] = constant 3 : index
 | 
			
		||||
// CHECK-NEXT:                   %[[GREATEST_RANK_IS_3:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C3]] : index
 | 
			
		||||
//                               Handle rank 3 specialization
 | 
			
		||||
// CHECK-NEXT:                   %[[VAL_34:.*]] = scf.if %[[GREATEST_RANK_IS_3]] -> (tensor<*xf32>) {
 | 
			
		||||
// CHECK-NEXT:                     %[[CONST_SHAPE_3:.*]] = shape.const_shape [1, 1, 1]
 | 
			
		||||
// CHECK-NEXT:                     %[[BROADCASTED_LHS_3:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_3]] : tensor<?xindex>, tensor<3xindex> -> tensor<?xindex>
 | 
			
		||||
// CHECK-NEXT:                     %[[CASTED_LHS_3:.*]] = tensor_cast %[[BROADCASTED_LHS_3]] : tensor<?xindex> to tensor<3xindex>
 | 
			
		||||
// CHECK-NEXT:                     %[[BROADCASTED_RHS_3:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_3]] : tensor<?xindex>, tensor<3xindex> -> tensor<?xindex>
 | 
			
		||||
// CHECK-NEXT:                     %[[CASTED_RHS_3:.*]] = tensor_cast %[[BROADCASTED_RHS_3]] : tensor<?xindex> to tensor<3xindex>
 | 
			
		||||
// CHECK-NEXT:                     %[[RESHAPED_LHS_3:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
 | 
			
		||||
// CHECK-NEXT:                     %[[RESHAPED_RHS_3:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
 | 
			
		||||
// CHECK-NEXT:                     %[[RESULT_RANK_3:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_3]], %[[RESHAPED_RHS_3]] : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
 | 
			
		||||
// CHECK-NEXT:                     %[[RESULT_3:.*]] = tensor_cast %[[RESULT_RANK_3]] : tensor<?x?x?xf32> to tensor<*xf32>
 | 
			
		||||
// CHECK-NEXT:                     scf.yield %[[RESULT_3]] : tensor<*xf32>
 | 
			
		||||
// CHECK-NEXT:                   } else {
 | 
			
		||||
// CHECK-NEXT:                     %[[C4:.*]] = constant 4 : index
 | 
			
		||||
// CHECK-NEXT:                     %[[GREATEST_RANK_IS_4:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C4]] : index
 | 
			
		||||
//                                 Handle rank 4 specialization
 | 
			
		||||
// CHECK-NEXT:                     %[[VAL_42:.*]] = scf.if %[[GREATEST_RANK_IS_4]] -> (tensor<*xf32>) {
 | 
			
		||||
// CHECK-NEXT:                       %[[CONST_SHAPE_4:.*]] = shape.const_shape [1, 1, 1, 1]
 | 
			
		||||
// CHECK-NEXT:                       %[[BROADCASTED_LHS_4:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_4]] : tensor<?xindex>, tensor<4xindex> -> tensor<?xindex>
 | 
			
		||||
// CHECK-NEXT:                       %[[CASTED_LHS_4:.*]] = tensor_cast %[[BROADCASTED_LHS_4]] : tensor<?xindex> to tensor<4xindex>
 | 
			
		||||
// CHECK-NEXT:                       %[[BROADCASTED_RHS_4:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_4]] : tensor<?xindex>, tensor<4xindex> -> tensor<?xindex>
 | 
			
		||||
// CHECK-NEXT:                       %[[CASTED_RHS_4:.*]] = tensor_cast %[[BROADCASTED_RHS_4]] : tensor<?xindex> to tensor<4xindex>
 | 
			
		||||
// CHECK-NEXT:                       %[[RESHAPED_LHS_4:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
 | 
			
		||||
// CHECK-NEXT:                       %[[RESHAPED_RHS_4:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
 | 
			
		||||
// CHECK-NEXT:                       %[[RESULT_RANK_4:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_4]], %[[RESHAPED_RHS_4]] : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
 | 
			
		||||
// CHECK-NEXT:                       %[[RESULT_4:.*]] = tensor_cast %[[RESULT_RANK_4]] : tensor<?x?x?x?xf32> to tensor<*xf32>
 | 
			
		||||
// CHECK-NEXT:                       scf.yield %[[RESULT_4]] : tensor<*xf32>
 | 
			
		||||
// CHECK-NEXT:                     } else {
 | 
			
		||||
// CHECK-NEXT:                       %[[C5:.*]] = constant 5 : index
 | 
			
		||||
// CHECK-NEXT:                       %[[GREATEST_RANK_IS_5:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C5]] : index
 | 
			
		||||
//                                   Handle rank 5 specialization
 | 
			
		||||
// CHECK-NEXT:                       %[[VAL_50:.*]] = scf.if %[[GREATEST_RANK_IS_5]] -> (tensor<*xf32>) {
 | 
			
		||||
// CHECK-NEXT:                         %[[CONST_SHAPE_5:.*]] = shape.const_shape [1, 1, 1, 1, 1]
 | 
			
		||||
// CHECK-NEXT:                         %[[BROADCASTED_LHS_5:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_5]] : tensor<?xindex>, tensor<5xindex> -> tensor<?xindex>
 | 
			
		||||
// CHECK-NEXT:                         %[[CASTED_LHS_5:.*]] = tensor_cast %[[BROADCASTED_LHS_5]] : tensor<?xindex> to tensor<5xindex>
 | 
			
		||||
// CHECK-NEXT:                         %[[BROADCASTED_RHS_5:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_5]] : tensor<?xindex>, tensor<5xindex> -> tensor<?xindex>
 | 
			
		||||
// CHECK-NEXT:                         %[[CASTED_RHS_5:.*]] = tensor_cast %[[BROADCASTED_RHS_5]] : tensor<?xindex> to tensor<5xindex>
 | 
			
		||||
// CHECK-NEXT:                         %[[RESHAPED_LHS_5:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor<?x?x?x?x?xf32>
 | 
			
		||||
// CHECK-NEXT:                         %[[RESHAPED_RHS_5:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor<?x?x?x?x?xf32>
 | 
			
		||||
// CHECK-NEXT:                         %[[RESULT_RANK_5:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_5]], %[[RESHAPED_RHS_5]] : (tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
 | 
			
		||||
// CHECK-NEXT:                         %[[RESULT_5:.*]] = tensor_cast %[[RESULT_RANK_5]] : tensor<?x?x?x?x?xf32> to tensor<*xf32>
 | 
			
		||||
// CHECK-NEXT:                         scf.yield %[[RESULT_5]] : tensor<*xf32>
 | 
			
		||||
// CHECK-NEXT:                       } else {
 | 
			
		||||
// CHECK-NEXT:                         %[[C6:.*]] = constant 6 : index
 | 
			
		||||
// CHECK-NEXT:                         %[[GREATEST_RANK_IS_6:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C6]] : index
 | 
			
		||||
//                                     Handle rank 6 specialization
 | 
			
		||||
// CHECK-NEXT:                         %[[VAL_58:.*]] = scf.if %[[GREATEST_RANK_IS_6]] -> (tensor<*xf32>) {
 | 
			
		||||
// CHECK-NEXT:                           %[[CONST_SHAPE_6:.*]] = shape.const_shape [1, 1, 1, 1, 1, 1]
 | 
			
		||||
// CHECK-NEXT:                           %[[BROADCASTED_LHS_6:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_6]] : tensor<?xindex>, tensor<6xindex> -> tensor<?xindex>
 | 
			
		||||
// CHECK-NEXT:                           %[[CASTED_LHS_6:.*]] = tensor_cast %[[BROADCASTED_LHS_6]] : tensor<?xindex> to tensor<6xindex>
 | 
			
		||||
// CHECK-NEXT:                           %[[BROADCASTED_RHS_6:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_6]] : tensor<?xindex>, tensor<6xindex> -> tensor<?xindex>
 | 
			
		||||
// CHECK-NEXT:                           %[[CASTED_RHS_6:.*]] = tensor_cast %[[BROADCASTED_RHS_6]] : tensor<?xindex> to tensor<6xindex>
 | 
			
		||||
// CHECK-NEXT:                           %[[RESHAPED_LHS_6:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_6]]) : (tensor<*xf32>, tensor<6xindex>) -> tensor<?x?x?x?x?x?xf32>
 | 
			
		||||
// CHECK-NEXT:                           %[[RESHAPED_RHS_6:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_6]]) : (tensor<*xf32>, tensor<6xindex>) -> tensor<?x?x?x?x?x?xf32>
 | 
			
		||||
// CHECK-NEXT:                           %[[RESULT_RANK_6:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_6]], %[[RESHAPED_RHS_6]] : (tensor<?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?xf32>) -> tensor<?x?x?x?x?x?xf32>
 | 
			
		||||
// CHECK-NEXT:                           %[[RESULT_6:.*]] = tensor_cast %[[RESULT_RANK_6]] : tensor<?x?x?x?x?x?xf32> to tensor<*xf32>
 | 
			
		||||
// CHECK-NEXT:                           scf.yield %[[RESULT_6]] : tensor<*xf32>
 | 
			
		||||
// CHECK-NEXT:                         } else {
 | 
			
		||||
// CHECK-NEXT:                           %false = constant false
 | 
			
		||||
// CHECK-NEXT:                           assert %false
 | 
			
		||||
// CHECK-NEXT:                           scf.yield %[[LHS]] : tensor<*xf32>
 | 
			
		||||
// CHECK-NEXT:                         }
 | 
			
		||||
// CHECK-NEXT:                         scf.yield %[[VAL_64:.*]] : tensor<*xf32>
 | 
			
		||||
// CHECK-NEXT:                       }
 | 
			
		||||
// CHECK-NEXT:                       scf.yield %[[VAL_65:.*]] : tensor<*xf32>
 | 
			
		||||
// CHECK-NEXT:                     }
 | 
			
		||||
// CHECK-NEXT:                     scf.yield %[[VAL_66:.*]] : tensor<*xf32>
 | 
			
		||||
// CHECK-NEXT:                   }
 | 
			
		||||
// CHECK-NEXT:                   scf.yield %[[VAL_67:.*]] : tensor<*xf32>
 | 
			
		||||
// CHECK-NEXT:                 }
 | 
			
		||||
// CHECK-NEXT:                 scf.yield %[[VAL_68:.*]] : tensor<*xf32>
 | 
			
		||||
// CHECK-NEXT:               }
 | 
			
		||||
// CHECK-NEXT:               scf.yield %[[VAL_69:.*]] : tensor<*xf32>
 | 
			
		||||
// CHECK-NEXT:             }
 | 
			
		||||
// CHECK-NEXT:             scf.yield %[[VAL_70:.*]] : tensor<*xf32>
 | 
			
		||||
// CHECK-NEXT:           }
 | 
			
		||||
// CHECK-NEXT:           return %[[VAL_71:.*]] : tensor<*xf32>
 | 
			
		||||
// CHECK-NEXT:         }
 | 
			
		||||
 | 
			
		||||
@ -306,6 +306,14 @@ inline bool IsF32ShapedType(Type t) {
 | 
			
		||||
  return false;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Returns true if it is a shaped type of bf16 elements.
 | 
			
		||||
inline bool IsBF16ShapedType(Type t) {
 | 
			
		||||
  if (auto shaped_type = t.dyn_cast_or_null<ShapedType>()) {
 | 
			
		||||
    return shaped_type.getElementType().isBF16();
 | 
			
		||||
  }
 | 
			
		||||
  return false;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Performs const folding `calculate` with broadcast behavior on the two
 | 
			
		||||
// attributes `operand1` and `operand2` and returns the result if possible.
 | 
			
		||||
// The two operands are expected to both be scalar values.
 | 
			
		||||
@ -498,7 +506,7 @@ Attribute ConstFoldBinaryOp(
 | 
			
		||||
/// "tfl.logical_not".
 | 
			
		||||
Attribute ConstFoldUnaryOp(Type result_type, Attribute operand,
 | 
			
		||||
                           llvm::function_ref<APFloat(APFloat)> calculate) {
 | 
			
		||||
  assert(IsF32ShapedType(result_type));
 | 
			
		||||
  assert(IsF32ShapedType(result_type) || IsBF16ShapedType(result_type));
 | 
			
		||||
  auto result_shape_type = result_type.cast<ShapedType>();
 | 
			
		||||
 | 
			
		||||
  if (auto dense_elements = operand.dyn_cast_or_null<DenseElementsAttr>()) {
 | 
			
		||||
@ -1911,13 +1919,20 @@ OpFoldResult SqrtOp::fold(ArrayRef<Attribute> operands) {
 | 
			
		||||
 | 
			
		||||
OpFoldResult RsqrtOp::fold(ArrayRef<Attribute> operands) {
 | 
			
		||||
  Type result_type = getType();
 | 
			
		||||
  // Only constant fold for tensor of f32 is implemented.
 | 
			
		||||
  if (!IsF32ShapedType(result_type)) return nullptr;
 | 
			
		||||
  // Only constant fold for tensor of f32/bf16 is implemented.
 | 
			
		||||
  if (!IsF32ShapedType(result_type) && !IsBF16ShapedType(result_type))
 | 
			
		||||
    return nullptr;
 | 
			
		||||
 | 
			
		||||
  auto compute = [](APFloat value) -> APFloat {
 | 
			
		||||
    bool loseInfo;
 | 
			
		||||
    const llvm::fltSemantics &original_float_semantics = value.getSemantics();
 | 
			
		||||
    value.convert(APFloat::IEEEsingle(), APFloat::rmNearestTiesToEven,
 | 
			
		||||
                  &loseInfo);
 | 
			
		||||
    float f = value.convertToFloat();
 | 
			
		||||
    float result = 1.f / std::sqrt(f);
 | 
			
		||||
    return APFloat(result);
 | 
			
		||||
    APFloat result(1.f / std::sqrt(f));
 | 
			
		||||
    result.convert(original_float_semantics, APFloat::rmNearestTiesToEven,
 | 
			
		||||
                   &loseInfo);
 | 
			
		||||
    return result;
 | 
			
		||||
  };
 | 
			
		||||
  return ConstFoldUnaryOp(result_type, operands[0], compute);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -577,3 +577,13 @@ func @div_dense_different_rank() -> tensor<1x2x2xf32> {
 | 
			
		||||
// CHECK: %[[CST:.*]] = constant dense<[{{\[}}{{\[}}5.000000e-01, 0.333333343], [1.000000e+00, 0.666666686]]]> : tensor<1x2x2xf32>
 | 
			
		||||
// CHECK:  return %[[CST]]
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// CHECK-LABEL: @rsqrt_bf16
 | 
			
		||||
func @rsqrt_bf16() -> tensor<bf16> {
 | 
			
		||||
  %cst = constant dense<4.0> : tensor<bf16>
 | 
			
		||||
  %0 = "tfl.rsqrt"(%cst) : (tensor<bf16>) -> tensor<bf16>
 | 
			
		||||
  return %0 : tensor<bf16>
 | 
			
		||||
 | 
			
		||||
// CHECK: %[[CST:.*]] = constant dense<5.000000e-01> : tensor<bf16>
 | 
			
		||||
// CHECK:  return %[[CST]]
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -1358,3 +1358,27 @@ func @fuseScalarAddIntoConv2d(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16x3
 | 
			
		||||
  // CHECK: %cst = constant dense<[2.500000e+00, 3.500000e+00, 4.500000e+00, 5.500000e+00, 6.500000e+00, 7.500000e+00, 8.500000e+00, 9.500000e+00, 1.050000e+01, 1.150000e+01, 1.250000e+01, 1.350000e+01, 1.450000e+01, 1.550000e+01, 1.650000e+01, 1.750000e+01]> : tensor<16xf32>
 | 
			
		||||
  // CHECK: %0 = "tfl.conv_2d"(%arg0, %arg1, %cst)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// CHECK-LABEL: fuseScalarAddIntoConv2dBf16
 | 
			
		||||
func @fuseScalarAddIntoConv2dBf16(%arg0: tensor<256x32x32x3xbf16>, %arg1: tensor<16x3x3x3xbf16>) -> tensor<256x30x30x16xbf16> {
 | 
			
		||||
  %cst = constant dense<1.5> : tensor<bf16>
 | 
			
		||||
  %cst_0 = constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : tensor<16xbf16>
 | 
			
		||||
  %0 = "tfl.conv_2d"(%arg0, %arg1, %cst_0) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xbf16>, tensor<16x3x3x3xbf16>, tensor<16xbf16>) -> tensor<256x30x30x16xbf16>
 | 
			
		||||
  %1 = "tfl.add"(%0, %cst) {fused_activation_function = "NONE"} : (tensor<256x30x30x16xbf16>, tensor<bf16>) -> tensor<256x30x30x16xbf16>
 | 
			
		||||
  return %1 : tensor<256x30x30x16xbf16>
 | 
			
		||||
 | 
			
		||||
  // CHECK: %cst = constant dense<[2.500000e+00, 3.500000e+00, 4.500000e+00, 5.500000e+00, 6.500000e+00, 7.500000e+00, 8.500000e+00, 9.500000e+00, 1.050000e+01, 1.150000e+01, 1.250000e+01, 1.350000e+01, 1.450000e+01, 1.550000e+01, 1.650000e+01, 1.750000e+01]> : tensor<16xbf16>
 | 
			
		||||
  // CHECK: %0 = "tfl.conv_2d"(%arg0, %arg1, %cst)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// CHECK-LABEL: fuseScalarAddIntoConv2dHalf
 | 
			
		||||
func @fuseScalarAddIntoConv2dHalf(%arg0: tensor<256x32x32x3xf16>, %arg1: tensor<16x3x3x3xf16>) -> tensor<256x30x30x16xf16> {
 | 
			
		||||
  %cst = constant dense<1.5> : tensor<f16>
 | 
			
		||||
  %cst_0 = constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : tensor<16xf16>
 | 
			
		||||
  %0 = "tfl.conv_2d"(%arg0, %arg1, %cst_0) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf16>, tensor<16x3x3x3xf16>, tensor<16xf16>) -> tensor<256x30x30x16xf16>
 | 
			
		||||
  %1 = "tfl.add"(%0, %cst) {fused_activation_function = "NONE"} : (tensor<256x30x30x16xf16>, tensor<f16>) -> tensor<256x30x30x16xf16>
 | 
			
		||||
  return %1 : tensor<256x30x30x16xf16>
 | 
			
		||||
 | 
			
		||||
  // CHECK: %cst = constant dense<[2.500000e+00, 3.500000e+00, 4.500000e+00, 5.500000e+00, 6.500000e+00, 7.500000e+00, 8.500000e+00, 9.500000e+00, 1.050000e+01, 1.150000e+01, 1.250000e+01, 1.350000e+01, 1.450000e+01, 1.550000e+01, 1.650000e+01, 1.750000e+01]> : tensor<16xf16>
 | 
			
		||||
  // CHECK: %0 = "tfl.conv_2d"(%arg0, %arg1, %cst)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -211,20 +211,21 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
 | 
			
		||||
    pass_manager->addPass(mlir::createSymbolDCEPass());
 | 
			
		||||
    pass_manager->addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
 | 
			
		||||
    pass_manager->addNestedPass<mlir::FuncOp>(mlir::createCSEPass());
 | 
			
		||||
    // This pass should be always at the end of the floating point model
 | 
			
		||||
    // conversion. Some TFL ops like unidirectional
 | 
			
		||||
    // sequence lstm will have stateful operands and some optimization passes
 | 
			
		||||
    // will merge those operands if they have identical values & types. However,
 | 
			
		||||
    // it's not desired by TFL. This pass serves as a "fix" pass to split the
 | 
			
		||||
    // merged inputs until we have 1st class variable support or reuse
 | 
			
		||||
    // tf.variable to model this.
 | 
			
		||||
    pass_manager->addPass(mlir::TFL::CreateSplitMergedOperandsPass());
 | 
			
		||||
 | 
			
		||||
    // Run quantization after all the floating point model conversion is
 | 
			
		||||
    // completed.
 | 
			
		||||
    if (pass_config.quant_specs.RunPropagationAndRewriteQuantizationPasses()) {
 | 
			
		||||
      AddQuantizationPasses(pass_config.quant_specs, pass_manager);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // This pass should be always at the end of the model
 | 
			
		||||
    // conversion (even after quantization). Some TFL ops like unidirectional
 | 
			
		||||
    // sequence lstm will have stateful operands and some optimization passes
 | 
			
		||||
    // will merge those operands if they have identical values & types. However,
 | 
			
		||||
    // it's not desired by TFL. This pass serves as a "fix" pass to split the
 | 
			
		||||
    // merged inputs until we have 1st class variable support or reuse
 | 
			
		||||
    // tf.variable to model this.
 | 
			
		||||
    pass_manager->addPass(mlir::TFL::CreateSplitMergedOperandsPass());
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -24,6 +24,11 @@ include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
 | 
			
		||||
// Checks if the param passed is a F32 ElementsAttr.
 | 
			
		||||
def F32ElementsAttr : ElementsAttrBase<
 | 
			
		||||
  CPred<"$_self.isa<ElementsAttr>() && $_self.cast<ElementsAttr>().getType().getElementType().isF32()">,
 | 
			
		||||
        "32 bit float constant tensor">;
 | 
			
		||||
 | 
			
		||||
// Checks if the param passed is a float ElementsAttr.
 | 
			
		||||
def FloatElementsAttr : ElementsAttrBase<
 | 
			
		||||
  CPred<"$_self.isa<ElementsAttr>() && $_self.cast<ElementsAttr>().getType().getElementType().isa<FloatType>()">,
 | 
			
		||||
        "float constant tensor">;
 | 
			
		||||
 | 
			
		||||
// Checks if the param passed is of NoneType.
 | 
			
		||||
@ -93,9 +98,9 @@ class CanFuseConvOrDepthwiseConv<string is_depthwise> : Constraint<
 | 
			
		||||
multiclass FuseBinaryOpToPrecedingAffine<dag binaryOp> {
 | 
			
		||||
  def FuseBinaryOpWithConv#binaryOp : Pat<
 | 
			
		||||
    (binaryOp (TFL_Conv2DOp:$output $input, $filter,
 | 
			
		||||
                (ConstantOp F32ElementsAttr:$bias), $h_factor, $w_factor,
 | 
			
		||||
                (ConstantOp FloatElementsAttr:$bias), $h_factor, $w_factor,
 | 
			
		||||
                TFL_AF_None, $padding, $stride_h, $stride_w),
 | 
			
		||||
              (ConstantOp F32ElementsAttr:$value), $act_fn),
 | 
			
		||||
              (ConstantOp FloatElementsAttr:$value), $act_fn),
 | 
			
		||||
    (TFL_Conv2DOp $input, $filter,
 | 
			
		||||
      (binaryOp (ConstantOp $bias),
 | 
			
		||||
         (ConstantOp $value), TFL_AF_None),
 | 
			
		||||
@ -104,10 +109,10 @@ multiclass FuseBinaryOpToPrecedingAffine<dag binaryOp> {
 | 
			
		||||
     (HasOneUse $output)]>;
 | 
			
		||||
  def FuseBinaryOpWithDepthwiseConv#binaryOp : Pat<
 | 
			
		||||
    (binaryOp (TFL_DepthwiseConv2DOp:$output $input, $filter,
 | 
			
		||||
                (ConstantOp F32ElementsAttr:$bias),
 | 
			
		||||
                (ConstantOp FloatElementsAttr:$bias),
 | 
			
		||||
                $h_factor, $w_factor, TFL_AF_None, $padding, $stride_h,
 | 
			
		||||
                $stride_w, $multiplier),
 | 
			
		||||
              (ConstantOp F32ElementsAttr:$value), $act_fn),
 | 
			
		||||
              (ConstantOp FloatElementsAttr:$value), $act_fn),
 | 
			
		||||
    (TFL_DepthwiseConv2DOp $input, $filter,
 | 
			
		||||
      (binaryOp (ConstantOp $bias), (ConstantOp $value), TFL_AF_None),
 | 
			
		||||
      $h_factor, $w_factor, $act_fn, $padding, $stride_h, $stride_w,
 | 
			
		||||
@ -116,9 +121,9 @@ multiclass FuseBinaryOpToPrecedingAffine<dag binaryOp> {
 | 
			
		||||
     (HasOneUse $output)]>;
 | 
			
		||||
   def FuseBinaryOpWithTransposeConv#binaryOp : Pat<
 | 
			
		||||
    (binaryOp (TFL_TransposeConvOp:$output $output_shape, $weights, $inputs,
 | 
			
		||||
                (ConstantOp F32ElementsAttr:$bias), $padding,
 | 
			
		||||
                (ConstantOp FloatElementsAttr:$bias), $padding,
 | 
			
		||||
                $stride_h, $stride_w),
 | 
			
		||||
              (ConstantOp F32ElementsAttr:$value), TFL_AF_None),
 | 
			
		||||
              (ConstantOp FloatElementsAttr:$value), TFL_AF_None),
 | 
			
		||||
    (TFL_TransposeConvOp $output_shape, $weights, $inputs,
 | 
			
		||||
      (binaryOp (ConstantOp $bias),
 | 
			
		||||
         (ConstantOp $value), TFL_AF_None),
 | 
			
		||||
@ -130,7 +135,7 @@ multiclass FuseBinaryOpToPrecedingAffine<dag binaryOp> {
 | 
			
		||||
    (binaryOp (TFL_TransposeConvOp:$output $output_shape, $weights, $inputs,
 | 
			
		||||
                (ConstantOp $bias), $padding,
 | 
			
		||||
                $stride_h, $stride_w),
 | 
			
		||||
              (ConstantOp F32ElementsAttr:$value), TFL_AF_None),
 | 
			
		||||
              (ConstantOp FloatElementsAttr:$value), TFL_AF_None),
 | 
			
		||||
    (TFL_TransposeConvOp $output_shape, $weights, $inputs,
 | 
			
		||||
      (ConstantOp $value),
 | 
			
		||||
      $padding, $stride_h, $stride_w),
 | 
			
		||||
@ -155,11 +160,11 @@ def ExpandTo4DForDepthwiseConv: NativeCodeCall<
 | 
			
		||||
multiclass FuseMulOrDivWithConv2dOrDepthwiseConv2d<dag BinaryOp> {
 | 
			
		||||
  def FuseMulOrDivWithDepthwiseConv#BinaryOp : Pat<
 | 
			
		||||
    (BinaryOp (TFL_DepthwiseConv2DOp:$output $input,
 | 
			
		||||
                (ConstantOp F32ElementsAttr:$filter),
 | 
			
		||||
                (ConstantOp F32ElementsAttr:$bias),
 | 
			
		||||
                (ConstantOp FloatElementsAttr:$filter),
 | 
			
		||||
                (ConstantOp FloatElementsAttr:$bias),
 | 
			
		||||
                $h_factor, $w_factor, TFL_AF_None, $padding, $stride_h,
 | 
			
		||||
                $stride_w, $multiplier),
 | 
			
		||||
              (ConstantOp F32ElementsAttr:$value), $act_fn),
 | 
			
		||||
              (ConstantOp FloatElementsAttr:$value), $act_fn),
 | 
			
		||||
    (TFL_DepthwiseConv2DOp $input,
 | 
			
		||||
      (BinaryOp
 | 
			
		||||
        (ConstantOp $filter),
 | 
			
		||||
@ -175,11 +180,11 @@ multiclass FuseMulOrDivWithConv2dOrDepthwiseConv2d<dag BinaryOp> {
 | 
			
		||||
     (HasOneUse $output)]>;
 | 
			
		||||
  def FuseMulOrDivWithConv#BinaryOp : Pat<
 | 
			
		||||
    (BinaryOp (TFL_Conv2DOp:$conv_output $input,
 | 
			
		||||
                (ConstantOp F32ElementsAttr:$filter),
 | 
			
		||||
                (ConstantOp F32ElementsAttr:$bias),
 | 
			
		||||
                (ConstantOp FloatElementsAttr:$filter),
 | 
			
		||||
                (ConstantOp FloatElementsAttr:$bias),
 | 
			
		||||
                $h_factor, $w_factor, TFL_AF_None,
 | 
			
		||||
                $padding, $stride_h, $stride_w),
 | 
			
		||||
              (ConstantOp F32ElementsAttr:$value), $act_fn),
 | 
			
		||||
              (ConstantOp FloatElementsAttr:$value), $act_fn),
 | 
			
		||||
    (TFL_Conv2DOp $input,
 | 
			
		||||
      (BinaryOp (ConstantOp $filter),
 | 
			
		||||
        (ConstantOp (ExpandTo4DForConv $value)),
 | 
			
		||||
@ -192,8 +197,8 @@ multiclass FuseMulOrDivWithConv2dOrDepthwiseConv2d<dag BinaryOp> {
 | 
			
		||||
     (HasOneUse $conv_output)]>;
 | 
			
		||||
  def FuseMulOrDivWithTransposeConv#BinaryOp : Pat<
 | 
			
		||||
    (BinaryOp (TFL_TransposeConvOp:$output $output_shape,
 | 
			
		||||
                (ConstantOp F32ElementsAttr:$weights), $input,
 | 
			
		||||
                (ConstantOp F32ElementsAttr:$bias),
 | 
			
		||||
                (ConstantOp FloatElementsAttr:$weights), $input,
 | 
			
		||||
                (ConstantOp FloatElementsAttr:$bias),
 | 
			
		||||
                $padding, $stride_h, $stride_w),
 | 
			
		||||
              (ConstantOp $value), TFL_AF_None),
 | 
			
		||||
    (TFL_TransposeConvOp $output_shape,
 | 
			
		||||
@ -209,7 +214,7 @@ multiclass FuseMulOrDivWithConv2dOrDepthwiseConv2d<dag BinaryOp> {
 | 
			
		||||
     (HasOneUse $output)]>;
 | 
			
		||||
  def FuseMulOrDivWithTransposeConvWithNoneBias#BinaryOp : Pat<
 | 
			
		||||
    (BinaryOp (TFL_TransposeConvOp:$output $output_shape,
 | 
			
		||||
                (ConstantOp F32ElementsAttr:$weights), $input,
 | 
			
		||||
                (ConstantOp FloatElementsAttr:$weights), $input,
 | 
			
		||||
                (ConstantOp $bias),
 | 
			
		||||
                $padding, $stride_h, $stride_w),
 | 
			
		||||
              (ConstantOp $value), TFL_AF_None),
 | 
			
		||||
 | 
			
		||||
@ -1651,10 +1651,12 @@ Mutually reduces multiple tensors of identical type and shape.
 | 
			
		||||
    TF_Int32Tensor:$group_size,
 | 
			
		||||
    TF_Int32Tensor:$group_key,
 | 
			
		||||
    TF_Int32Tensor:$instance_key,
 | 
			
		||||
    Variadic<TF_ResourceTensor>:$ordering_token,
 | 
			
		||||
 | 
			
		||||
    TF_AnyStrAttrOf<["Min", "Max", "Mul", "Add"]>:$merge_op,
 | 
			
		||||
    TF_AnyStrAttrOf<["Id", "Div"]>:$final_op,
 | 
			
		||||
    DefaultValuedAttr<StrAttr, "auto">:$communication_hint
 | 
			
		||||
    DefaultValuedAttr<StrAttr, "auto">:$communication_hint,
 | 
			
		||||
    DefaultValuedAttr<F32Attr, "0.0f">:$timeout_seconds
 | 
			
		||||
  );
 | 
			
		||||
 | 
			
		||||
  let results = (outs
 | 
			
		||||
@ -1662,6 +1664,7 @@ Mutually reduces multiple tensors of identical type and shape.
 | 
			
		||||
  );
 | 
			
		||||
 | 
			
		||||
  TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
 | 
			
		||||
  TF_DerivedOperandSizeAttr Nordering_token = TF_DerivedOperandSizeAttr<4>;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
def TF_ComplexOp : TF_Op<"Complex", [NoSideEffect, ResultsBroadcastableShape]> {
 | 
			
		||||
 | 
			
		||||
@ -77,66 +77,6 @@ namespace TF {
 | 
			
		||||
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
// Returns true of the given function has a single uses (within the scope
 | 
			
		||||
// of the module containing it and all parent modules).
 | 
			
		||||
bool HasSingleUse(FuncOp func) {
 | 
			
		||||
  // Public function can have any number of external uses.
 | 
			
		||||
  if (func.isPublic()) return false;
 | 
			
		||||
 | 
			
		||||
  // Return false if unexpected IR structure seen.
 | 
			
		||||
  ModuleOp module = func.getParentOfType<ModuleOp>();
 | 
			
		||||
  if (!module) return false;
 | 
			
		||||
 | 
			
		||||
  // Inspect function uses in the containing module and all parent
 | 
			
		||||
  // modules.
 | 
			
		||||
  bool use_seen = false;
 | 
			
		||||
  for (; module; module = func.isPrivate()
 | 
			
		||||
                              ? nullptr
 | 
			
		||||
                              : module.getParentOfType<ModuleOp>()) {
 | 
			
		||||
    auto func_uses_optional =
 | 
			
		||||
        SymbolTable::getSymbolUses(func, &module.getBodyRegion());
 | 
			
		||||
    // Found an unknown use.
 | 
			
		||||
    if (!func_uses_optional) return false;
 | 
			
		||||
 | 
			
		||||
    // If no uses in this scope, continue looking in parent module
 | 
			
		||||
    SymbolTable::UseRange func_uses = func_uses_optional.getValue();
 | 
			
		||||
    if (func_uses.empty()) continue;
 | 
			
		||||
 | 
			
		||||
    // Check if multiple uses at this scope or another use already seen.
 | 
			
		||||
    if (!llvm::hasSingleElement(func_uses) || use_seen) return false;
 | 
			
		||||
 | 
			
		||||
    // This is the first use seen.
 | 
			
		||||
    use_seen = true;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // No multiple uses seen.
 | 
			
		||||
  return true;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Returns true if the caller ops can be inlined.
 | 
			
		||||
bool HasInlinableUsers(FuncOp func) {
 | 
			
		||||
  // Return false if unexpected IR structure seen.
 | 
			
		||||
  ModuleOp module = func.getParentOfType<ModuleOp>();
 | 
			
		||||
  if (!module) return false;
 | 
			
		||||
 | 
			
		||||
  // Inspect function uses in the containing module and all parent
 | 
			
		||||
  // modules.
 | 
			
		||||
  for (; module; module = func.isPrivate()
 | 
			
		||||
                              ? nullptr
 | 
			
		||||
                              : module.getParentOfType<ModuleOp>()) {
 | 
			
		||||
    auto func_uses_optional =
 | 
			
		||||
        SymbolTable::getSymbolUses(func, &module.getBodyRegion());
 | 
			
		||||
    // Found an unknown use.
 | 
			
		||||
    if (!func_uses_optional) return false;
 | 
			
		||||
 | 
			
		||||
    for (auto &use : func_uses_optional.getValue())
 | 
			
		||||
      if (isa<TPUPartitionedCallOp>(use.getUser())) return false;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // All caller ops that can be inlined.
 | 
			
		||||
  return true;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
struct TFConstantFoldInterface : public DialectFoldInterface {
 | 
			
		||||
  TFConstantFoldInterface(Dialect *dialect) : DialectFoldInterface(dialect) {}
 | 
			
		||||
  LogicalResult fold(Operation *op, ArrayRef<Attribute> operands,
 | 
			
		||||
@ -160,10 +100,12 @@ struct TFInlinerInterface : public DialectInlinerInterface {
 | 
			
		||||
  // Analysis Hooks
 | 
			
		||||
  //===--------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
  // Allow all call operations to be inlined.
 | 
			
		||||
  // Returns if it's legal to inline 'callable' into the 'call', where 'call' is
 | 
			
		||||
  // a TF operation.
 | 
			
		||||
  bool isLegalToInline(Operation *call, Operation *callable,
 | 
			
		||||
                       bool wouldBeCloned) const final {
 | 
			
		||||
    return true;
 | 
			
		||||
    // Check that the TF call operation is one that is legal to inline.
 | 
			
		||||
    return !isa<TPUPartitionedCallOp>(call);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Returns if its legal to inline 'src' region into the 'dest' region
 | 
			
		||||
@ -186,10 +128,7 @@ struct TFInlinerInterface : public DialectInlinerInterface {
 | 
			
		||||
    //     post inlining, the function will be dead and eliminated from the IR.
 | 
			
		||||
    //     So there won't be any code duplication.
 | 
			
		||||
    // plus the function caller op can be replaced by inlined ops.
 | 
			
		||||
    FuncOp func = op->getParentOfType<FuncOp>();
 | 
			
		||||
    if (!func) return true;
 | 
			
		||||
    if (!HasInlinableUsers(func)) return false;
 | 
			
		||||
    return TensorFlowDialect::CanDuplicate(op) || HasSingleUse(func);
 | 
			
		||||
    return !wouldBeCloned || TensorFlowDialect::CanDuplicate(op);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  //===--------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
@ -361,8 +361,7 @@ func @send_recv(%arg0: tensor<2x!tf.string>) {
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
// Tests functional control flow functions with replica variant ops reachable
 | 
			
		||||
// from a replicate region is cloned and remapped. Only the branches reachable
 | 
			
		||||
// with replica variant ops are cloned.
 | 
			
		||||
// from a replicate region is cloned and remapped.
 | 
			
		||||
 | 
			
		||||
// CHECK-LABEL: func @control_flow_with_replicate_variant_ops
 | 
			
		||||
func @control_flow_with_replicate_variant_ops(%arg0: tensor<i1>, %arg1: tensor<f32>, %arg2: tensor<f32>, %arg3: tensor<2x!tf.string>) {
 | 
			
		||||
@ -380,30 +379,32 @@ func @control_flow_with_replicate_variant_ops(%arg0: tensor<i1>, %arg1: tensor<f
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// CHECK: "tf.If"
 | 
			
		||||
// CHECK-SAME: else_branch = @cond_false
 | 
			
		||||
// CHECK-SAME: else_branch = [[COND_FALSE_REPLICA_0:@[a-z0-9_]+]]
 | 
			
		||||
// CHECK-SAME: then_branch = [[COND_TRUE_REPLICA_0:@[a-z0-9_]+]]
 | 
			
		||||
// CHECK: "tf.If"
 | 
			
		||||
// CHECK-SAME: else_branch = @cond_false
 | 
			
		||||
// CHECK-SAME: else_branch = [[COND_FALSE_REPLICA_1:@[a-z0-9_]+]]
 | 
			
		||||
// CHECK-SAME: then_branch = [[COND_TRUE_REPLICA_1:@[a-z0-9_]+]]
 | 
			
		||||
 | 
			
		||||
func @cond_false(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<2x!tf.string>) -> tensor<f32> {
 | 
			
		||||
  return %arg0 : tensor<f32>
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// CHECK-NOT: func @cond_false.+(
 | 
			
		||||
 | 
			
		||||
func @cond_true(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<2x!tf.string>) -> tensor<f32> {
 | 
			
		||||
  "tf._XlaSendFromHost"(%arg1, %arg2) {_xla_has_host_transfer = true, device_ordinal = 0 : i64, key = "host_compute_channel_recv_0"} : (tensor<f32>, tensor<2x!tf.string>) -> ()
 | 
			
		||||
  %0 = "tf._XlaRecvAtHost"(%arg2) {_xla_has_host_transfer = true, device_ordinal = 0 : i64, key = "host_compute_channel_send_0"} : (tensor<2x!tf.string>) -> tensor<f32>
 | 
			
		||||
  return %0 : tensor<f32>
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// CHECK: func [[COND_FALSE_REPLICA_0]]
 | 
			
		||||
 | 
			
		||||
// CHECK: func [[COND_TRUE_REPLICA_0]]
 | 
			
		||||
// CHECK: "tf._XlaSendFromHost"
 | 
			
		||||
// CHECK-SAME: device_ordinal = 1
 | 
			
		||||
// CHECK: "tf._XlaRecvAtHost"
 | 
			
		||||
// CHECK-SAME: device_ordinal = 1
 | 
			
		||||
 | 
			
		||||
// CHECK: func [[COND_FALSE_REPLICA_1]]
 | 
			
		||||
 | 
			
		||||
// CHECK: func [[COND_TRUE_REPLICA_1]]
 | 
			
		||||
// CHECK: "tf._XlaSendFromHost"
 | 
			
		||||
// CHECK-SAME: device_ordinal = 2
 | 
			
		||||
@ -413,7 +414,7 @@ func @cond_true(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<2x!tf.stri
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
// Tests function with no replica variant ops reachable from a replicate region
 | 
			
		||||
// is not cloned.
 | 
			
		||||
// is cloned.
 | 
			
		||||
 | 
			
		||||
// CHECK-LABEL: func @no_replicate_variant_ops
 | 
			
		||||
func @no_replicate_variant_ops(%arg0: tensor<f32>, %arg1: tensor<2x!tf.string>) {
 | 
			
		||||
@ -431,11 +432,17 @@ func @no_replicate_variant_ops(%arg0: tensor<f32>, %arg1: tensor<2x!tf.string>)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// CHECK: "tf.StatefulPartitionedCall"
 | 
			
		||||
// CHECK-SAME: f = @send_recv
 | 
			
		||||
// CHECK-SAME: f = [[CALLEE_REPLICA_0:@[a-z0-9_]+]]
 | 
			
		||||
// CHECK: "tf.StatefulPartitionedCall"
 | 
			
		||||
// CHECK-SAME: f = [[CALLEE_REPLICA_1:@[a-z0-9_]+]]
 | 
			
		||||
 | 
			
		||||
func @send_recv(%arg0: tensor<2x!tf.string>) {
 | 
			
		||||
  "tf.NoOp"() : () -> ()
 | 
			
		||||
  return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// CHECK-NOT: @send_recv.+(
 | 
			
		||||
// CHECK: func [[CALLEE_REPLICA_0]]
 | 
			
		||||
// CHECK: "tf.NoOp"
 | 
			
		||||
 | 
			
		||||
// CHECK: func [[CALLEE_REPLICA_1]]
 | 
			
		||||
// CHECK: "tf.NoOp"
 | 
			
		||||
 | 
			
		||||
@ -53,7 +53,6 @@ void AddGraphExportLoweringPasses(OpPassManager &pm) {
 | 
			
		||||
  add_pass(TFDevice::CreateParallelExecuteToIslandsPass());
 | 
			
		||||
  pm.addPass(TFDevice::CreateLaunchToDeviceAttributePass());
 | 
			
		||||
  pm.addPass(CreateBreakUpIslandsPass());
 | 
			
		||||
  pm.addNestedPass<FuncOp>(CreateTPUDevicePropagationPass());
 | 
			
		||||
  pm.addPass(createSymbolDCEPass());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -120,46 +120,6 @@ llvm::SmallPtrSet<FuncOp, 4> GetReachableFunctionsFromRegion(ModuleOp module,
 | 
			
		||||
  return visited_functions;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Collects all functions and transitive functions reachable from region that
 | 
			
		||||
// contain replicate variant ops.
 | 
			
		||||
llvm::SmallDenseMap<llvm::StringRef, FuncOp> GetReachableFunctionsToClone(
 | 
			
		||||
    ModuleOp module, Region& region,
 | 
			
		||||
    const llvm::Optional<DictionaryAttr>& devices) {
 | 
			
		||||
  llvm::SmallPtrSet<FuncOp, 4> reachable_functions =
 | 
			
		||||
      GetReachableFunctionsFromRegion(module, region);
 | 
			
		||||
 | 
			
		||||
  llvm::SmallDenseMap<llvm::StringRef, FuncOp> functions_to_clone;
 | 
			
		||||
  llvm::SmallVector<FuncOp, 4> functions_to_visit;
 | 
			
		||||
  for (FuncOp func : reachable_functions) {
 | 
			
		||||
    if (!func.getCallableRegion()) continue;
 | 
			
		||||
    if (HasReplicaVariantOps(*func.getCallableRegion(), devices)) {
 | 
			
		||||
      functions_to_clone.insert({func.getName(), func});
 | 
			
		||||
      functions_to_visit.push_back(func);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  while (!functions_to_visit.empty()) {
 | 
			
		||||
    llvm::SmallVector<FuncOp, 4> new_functions_to_visit;
 | 
			
		||||
 | 
			
		||||
    for (FuncOp func_to_visit : functions_to_visit) {
 | 
			
		||||
      auto func_uses = func_to_visit.getSymbolUses(module);
 | 
			
		||||
      if (!func_uses) continue;
 | 
			
		||||
      for (auto use : *func_uses) {
 | 
			
		||||
        auto parent_func = use.getUser()->getParentOfType<FuncOp>();
 | 
			
		||||
        if (!parent_func || !reachable_functions.contains(parent_func) ||
 | 
			
		||||
            !functions_to_clone.insert({parent_func.getName(), parent_func})
 | 
			
		||||
                 .second)
 | 
			
		||||
          continue;
 | 
			
		||||
        new_functions_to_visit.push_back(parent_func);
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    functions_to_visit.swap(new_functions_to_visit);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  return functions_to_clone;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
struct FuncOldNameAndClone {
 | 
			
		||||
  StringRef old_name;
 | 
			
		||||
  FuncOp clone;
 | 
			
		||||
@ -276,20 +236,19 @@ LogicalResult ExpandReplicateIntoReplicas(
 | 
			
		||||
  terminator.erase();
 | 
			
		||||
 | 
			
		||||
  auto funcs_to_clone =
 | 
			
		||||
      GetReachableFunctionsToClone(module, replicate_op.body(), devices);
 | 
			
		||||
      GetReachableFunctionsFromRegion(module, replicate_op.body());
 | 
			
		||||
  SymbolTable symbol_table(module);
 | 
			
		||||
 | 
			
		||||
  builder.setInsertionPoint(island_op);
 | 
			
		||||
  BlockAndValueMapping mapping;
 | 
			
		||||
  for (int i : llvm::seq<int>(0, num_replicas)) {
 | 
			
		||||
    // Clone reachable functions with replica variant ops.
 | 
			
		||||
    // Clone reachable functions from region.
 | 
			
		||||
    llvm::SmallVector<FuncOldNameAndClone, 4> cloned_functions;
 | 
			
		||||
    cloned_functions.reserve(funcs_to_clone.size());
 | 
			
		||||
    for (auto& func_to_clone : funcs_to_clone) {
 | 
			
		||||
      auto cloned_function = func_to_clone.getSecond().clone();
 | 
			
		||||
    for (FuncOp func_to_clone : funcs_to_clone) {
 | 
			
		||||
      auto cloned_function = func_to_clone.clone();
 | 
			
		||||
      symbol_table.insert(cloned_function, module.end());
 | 
			
		||||
      cloned_functions.push_back(
 | 
			
		||||
          {func_to_clone.getSecond().getName(), cloned_function});
 | 
			
		||||
      cloned_functions.push_back({func_to_clone.getName(), cloned_function});
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // Create new island for replica.
 | 
			
		||||
 | 
			
		||||
@ -31,7 +31,6 @@ int main(int argc, char **argv) {
 | 
			
		||||
  mlir::registerAllPasses();
 | 
			
		||||
  mlir::mhlo::registerAllMhloPasses();
 | 
			
		||||
  mlir::lmhlo::registerAllLmhloPasses();
 | 
			
		||||
  mlir::mhlo::registerAllMhloPasses();
 | 
			
		||||
 | 
			
		||||
  mlir::DialectRegistry registry;
 | 
			
		||||
  mlir::registerAllDialects(registry);
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										158
									
								
								tensorflow/compiler/mlir/tfr/README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										158
									
								
								tensorflow/compiler/mlir/tfr/README.md
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,158 @@
 | 
			
		||||
# Composable Tensorflow
 | 
			
		||||
 | 
			
		||||
## Composable Tensorflow
 | 
			
		||||
 | 
			
		||||
Composable TensorFlow (TF) is the framework for defining portable TF ops with
 | 
			
		||||
composition in the authoring language.
 | 
			
		||||
 | 
			
		||||
The set of standard TF ops is currently open. New ops are defined for special
 | 
			
		||||
purposes but it is hard to make them work end-to-end: The op
 | 
			
		||||
needs to be handled separately by a several backends (tf2xla bridge, tflite
 | 
			
		||||
converter, CPU kernels, etc.). Writing shape functions and gradients for these
 | 
			
		||||
ops is extremely difficult. `tf.function` makes some parts of the implementation
 | 
			
		||||
simpler, but it introduces runtime overhead and it cannot easily be used to
 | 
			
		||||
apply dedicated optimizations to op kernels.
 | 
			
		||||
 | 
			
		||||
The composable TF framework allows the user to define portable TF ops as
 | 
			
		||||
ompositions of other TF ops. It translates a Python function used to define the
 | 
			
		||||
composition directly into a portable IR at build time, and uses it to expand the
 | 
			
		||||
composite op in the TF program during compilation / execution. By using this
 | 
			
		||||
expansion mechanism, new op are readily available on different platforms without
 | 
			
		||||
extra work. Moreover, since the expansion is optional, the backend can easily
 | 
			
		||||
treat it as a monolithic op when needed, for instance to apply optimizations or
 | 
			
		||||
associate it with a custom kernel.
 | 
			
		||||
 | 
			
		||||
### Benefits
 | 
			
		||||
 | 
			
		||||
Using the Composable TF API to define a new op and its composition can bring the
 | 
			
		||||
following benefits:
 | 
			
		||||
 | 
			
		||||
* *Automatic backend support*: As long as it is composed of ops supported by the
 | 
			
		||||
backend, the new op is automatcally supported (as a `tf.function` alternative);
 | 
			
		||||
* *Reduced tracing overhead*: Unlike `tf.function`, the composition function is
 | 
			
		||||
compiled at build time, hence TF only needs to trace a single op to build the
 | 
			
		||||
`graph`;
 | 
			
		||||
* *Easy fused op/kernel optimization*: Even if it has complex
 | 
			
		||||
semantics, the new op is presented as a single node in the graph, thus
 | 
			
		||||
optimization passes and kernels can easily be specialized to this op for better
 | 
			
		||||
performance.
 | 
			
		||||
* *Automatic shape/type inference support*: No shape functions are required for
 | 
			
		||||
the new op;
 | 
			
		||||
* *Automatic gradient support (WIP)*: The user doesn't need to author
 | 
			
		||||
gradient a function of the op for training.
 | 
			
		||||
 | 
			
		||||
### Use Cases
 | 
			
		||||
 | 
			
		||||
* (Portablity) User wants to add a new op and run this op on different
 | 
			
		||||
platforms (CPU, TPU, TFLite, etc.) to be portable.
 | 
			
		||||
 * *Solution*: The user should define the new op as a composition. The ops used
 | 
			
		||||
 inside the composition should have support for these platforms. These ops can
 | 
			
		||||
 also be composite ops.
 | 
			
		||||
 | 
			
		||||
* (Performance) User defines a custom kernel for a regular structure
 | 
			
		||||
(i.e. LSTM), but it is hard to add the logic to fuse the individual ops to
 | 
			
		||||
target this kernel in the inference graph.
 | 
			
		||||
 * *Solution*: The user should define a new TF op, which corresponds to the
 | 
			
		||||
 fused kernel, with composition, and use this op to build the model for both
 | 
			
		||||
 training and inference. For the platforms where a fused kernel is not
 | 
			
		||||
 available, the execution will use the composition instead.
 | 
			
		||||
 | 
			
		||||
## Gradient
 | 
			
		||||
(TODO)
 | 
			
		||||
 | 
			
		||||
## Authoring Op Composition in Python
 | 
			
		||||
 | 
			
		||||
The composable TF provides a single API to define a new op with its composition
 | 
			
		||||
at the same time. For example, the following code defines a new
 | 
			
		||||
`FusedFullyConnected` op, which have `MatMul`, `Add` and some
 | 
			
		||||
`activation function` (specified by an op attribute) fused.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
import tensorflow as tf
 | 
			
		||||
 | 
			
		||||
@Composite(
 | 
			
		||||
    'FusedFullyConnected',
 | 
			
		||||
    inputs=['input_: T', 'filter_: T', 'bias: T'],
 | 
			
		||||
    attrs=['act: {"", "RELU", "RELU6", "TANH"} = ""'],
 | 
			
		||||
    derived_attrs=['T: {float, int8}'],
 | 
			
		||||
    outputs=['o: T'])
 | 
			
		||||
def _composite_fully_connected(input_, filter_, bias, act):
 | 
			
		||||
  res = tf.raw_ops.MatMul(
 | 
			
		||||
      a=input_, b=filter_, transpose_a=False, transpose_b=True)
 | 
			
		||||
  res = tf.raw_ops.Add(x=res, y=bias)
 | 
			
		||||
  if act == 'RELU':
 | 
			
		||||
    return tf.raw_ops.Relu(features=res)
 | 
			
		||||
  elif act == 'RELU6':
 | 
			
		||||
    return tf.raw_ops.Relu6(features=res)
 | 
			
		||||
  elif act == 'TANH':
 | 
			
		||||
    return tf.raw_ops.Tanh(x=res)
 | 
			
		||||
  else:
 | 
			
		||||
    return res
 | 
			
		||||
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
Besides defining new ops, composition can be specified for an existing op
 | 
			
		||||
for portability. The following code defines the semantics of `AddNOp`:
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
@Composite('AddNOp')
 | 
			
		||||
def _my_op_c(ins):
 | 
			
		||||
  N = len(ins)
 | 
			
		||||
  if N == 1:
 | 
			
		||||
    return ins[0]
 | 
			
		||||
  sum = ins[0]
 | 
			
		||||
  for i in range(1, N):
 | 
			
		||||
    sum += ins[i]
 | 
			
		||||
  return sum
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
Utilities have been built to compile the Python composition functions down to
 | 
			
		||||
the backend IR. The project also provides a set of graph optimization passes to
 | 
			
		||||
expand the composite ops in the graph by using the input backend IR. These
 | 
			
		||||
passes have been added to the TF [common runtime]
 | 
			
		||||
(https://github.com/tensorflow/tensorflow/tree/master/tensorflow/core/common_runtime)
 | 
			
		||||
for graph execution and [eager_runtime]
 | 
			
		||||
(https://github.com/tensorflow/tensorflow/tree/master/tensorflow/core/common_runtime/eager)
 | 
			
		||||
for eager execution.
 | 
			
		||||
 | 
			
		||||
## Compiling Op Composition
 | 
			
		||||
 | 
			
		||||
### Ahead-Of-Time (AOT) mode
 | 
			
		||||
 | 
			
		||||
Like the op kernels, the op composition can be pre-compiled to the backend IR
 | 
			
		||||
so the decomposition can be invoked at runtime. A Python [define_op_template.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/tfr/define_op_template.py)
 | 
			
		||||
file is provided as an example to build composite ops in the users project
 | 
			
		||||
directory. All the targets required to build the new ops are created by the
 | 
			
		||||
following target:
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
```BUILD
 | 
			
		||||
load("//tensorflow/compiler/mlir/tfr:build_defs.bzl", "gen_op_libraries")
 | 
			
		||||
 | 
			
		||||
gen_op_libraries(
 | 
			
		||||
    name = "test_ops",
 | 
			
		||||
    src = "define_op_template.py",
 | 
			
		||||
    deps = [
 | 
			
		||||
        "//third_party/py/tensorflow",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
More composite op definitions and usages are here included in the
 | 
			
		||||
[examples](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/compiler/mlir/tfr/examples)
 | 
			
		||||
directory.
 | 
			
		||||
 | 
			
		||||
### Just-In-Time (JIT) mode
 | 
			
		||||
(TODO)
 | 
			
		||||
 | 
			
		||||
## Known Limitations
 | 
			
		||||
 | 
			
		||||
* `while` statement
 | 
			
		||||
* condition of `if` statement couldn't be a tensor
 | 
			
		||||
 | 
			
		||||
## Team
 | 
			
		||||
 | 
			
		||||
* Feng Liu
 | 
			
		||||
* Dan Moldovan
 | 
			
		||||
 | 
			
		||||
@ -44,6 +44,10 @@ extern "C" CUmodule mgpuModuleLoad(void *data) {
 | 
			
		||||
  return module;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
extern "C" void mgpuModuleUnload(CUmodule module) {
 | 
			
		||||
  CUDA_REPORT_IF_ERROR(cuModuleUnload(module));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
extern "C" CUfunction mgpuModuleGetFunction(CUmodule module, const char *name) {
 | 
			
		||||
  CUfunction function = nullptr;
 | 
			
		||||
  CUDA_REPORT_IF_ERROR(cuModuleGetFunction(&function, module, name));
 | 
			
		||||
@ -64,16 +68,15 @@ extern "C" void mgpuLaunchKernel(CUfunction function, intptr_t gridX,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
extern "C" CUstream mgpuStreamCreate() {
 | 
			
		||||
  static CUstream stream = []() {
 | 
			
		||||
    // TODO(b/170649852): This is neither thread-safe nor handles
 | 
			
		||||
    // creation/descruction of one stream per context.
 | 
			
		||||
    CUstream stream = nullptr;
 | 
			
		||||
    CUDA_REPORT_IF_ERROR(cuStreamCreate(&stream, CU_STREAM_NON_BLOCKING));
 | 
			
		||||
    return stream;
 | 
			
		||||
  }();
 | 
			
		||||
  CUstream stream = nullptr;
 | 
			
		||||
  CUDA_REPORT_IF_ERROR(cuStreamCreate(&stream, CU_STREAM_NON_BLOCKING));
 | 
			
		||||
  return stream;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
extern "C" void mgpuStreamDestroy(CUstream stream) {
 | 
			
		||||
  CUDA_REPORT_IF_ERROR(cuStreamDestroy(stream));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
extern "C" void mgpuStreamSynchronize(CUstream stream) {
 | 
			
		||||
  CUDA_REPORT_IF_ERROR(cuStreamSynchronize(stream));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -19,6 +19,7 @@ limitations under the License.
 | 
			
		||||
#include <memory>
 | 
			
		||||
 | 
			
		||||
#include "mlir/Dialect/SCF/SCF.h"  // from @llvm-project
 | 
			
		||||
#include "mlir/Dialect/SCF/Transforms.h"  // from @llvm-project
 | 
			
		||||
#include "mlir/Dialect/Shape/IR/Shape.h"  // from @llvm-project
 | 
			
		||||
#include "mlir/Dialect/Shape/Transforms/Passes.h"  // from @llvm-project
 | 
			
		||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
 | 
			
		||||
@ -107,6 +108,8 @@ struct BufferizePass : public BufferizePassBase<BufferizePass> {
 | 
			
		||||
    populateStandardBufferizePattern(&context, &converter, &patterns);
 | 
			
		||||
    populateShapeStructuralTypeConversionsAndLegality(&context, converter,
 | 
			
		||||
                                                      patterns, target);
 | 
			
		||||
    scf::populateSCFStructuralTypeConversionsAndLegality(&context, converter,
 | 
			
		||||
                                                         patterns, target);
 | 
			
		||||
    patterns.insert<UnrankedTensorStoreTestOnlyPattern>(&context);
 | 
			
		||||
 | 
			
		||||
    auto module = getOperation();
 | 
			
		||||
 | 
			
		||||
@ -218,7 +218,7 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstructionImpl(
 | 
			
		||||
      auto attr = CreateDenseElementsAttrFromLiteral(literal, *builder_);
 | 
			
		||||
      if (!attr.ok()) return attr.status();
 | 
			
		||||
      mlir::Operation* new_operation =
 | 
			
		||||
          func_builder->create<mlir::ConstantOp>(loc, attr.ValueOrDie());
 | 
			
		||||
          func_builder->create<mlir::mhlo::ConstOp>(loc, attr.ValueOrDie());
 | 
			
		||||
      for (auto attr : attributes) {
 | 
			
		||||
        new_operation->setAttr(attr.first, attr.second);
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
@ -12,7 +12,11 @@ glob_lit_tests(
 | 
			
		||||
    data = [":test_utilities"],
 | 
			
		||||
    driver = "@llvm-project//mlir:run_lit.sh",
 | 
			
		||||
    tags_override = {
 | 
			
		||||
        "hlo_to_lhlo_with_xla/gpu_ops.mlir": tf_cuda_tests_tags() + ["noasan"],  # TODO(b/171751580)
 | 
			
		||||
        "hlo_to_lhlo_with_xla/gpu_ops.mlir": tf_cuda_tests_tags() + [
 | 
			
		||||
            "noasan",
 | 
			
		||||
            "nomsan",
 | 
			
		||||
            "noubsan",
 | 
			
		||||
        ],  # b/171751580
 | 
			
		||||
    },
 | 
			
		||||
    test_file_exts = [
 | 
			
		||||
        "mlir",
 | 
			
		||||
 | 
			
		||||
@ -26,10 +26,10 @@ ENTRY %indexed_conditional () -> f32[] {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// CHECK-LABEL: func @main() -> tensor<f32>
 | 
			
		||||
// CHECK: %[[INDEX:.*]] = constant dense<1> : tensor<i32>
 | 
			
		||||
// CHECK: %[[OPERAND_1:.*]] = constant dense<5.600000e+01> : tensor<f32>
 | 
			
		||||
// CHECK: %[[OPERAND_2:.*]] = constant dense<1.200000e+01> : tensor<f32>
 | 
			
		||||
// CHECK: %[[OPERAND_3:.*]] = constant dense<1.300000e+01> : tensor<f32>
 | 
			
		||||
// CHECK: %[[INDEX:.*]] = mhlo.constant dense<1> : tensor<i32>
 | 
			
		||||
// CHECK: %[[OPERAND_1:.*]] = mhlo.constant dense<5.600000e+01> : tensor<f32>
 | 
			
		||||
// CHECK: %[[OPERAND_2:.*]] = mhlo.constant dense<1.200000e+01> : tensor<f32>
 | 
			
		||||
// CHECK: %[[OPERAND_3:.*]] = mhlo.constant dense<1.300000e+01> : tensor<f32>
 | 
			
		||||
// CHECK: %[[RESULT:.*]] = "mhlo.case"(%[[INDEX]], %[[OPERAND_1]], %[[OPERAND_2]], %[[OPERAND_3]]) ( {
 | 
			
		||||
// CHECK:   ^bb0(%[[ARG_1:.*]]: tensor<f32>):
 | 
			
		||||
// CHECK:     %[[RES_1:.*]] = "mhlo.negate"(%[[ARG_1]]) : (tensor<f32>) -> tensor<f32>
 | 
			
		||||
 | 
			
		||||
@ -4,100 +4,102 @@
 | 
			
		||||
 | 
			
		||||
HloModule tfcompile.48
 | 
			
		||||
 | 
			
		||||
// CHECK-LABEL: func @main(%arg0: tensor<1x300xf32>, %arg1: tensor<1x300x3x1xf32>) -> tuple<tensor<300x1x5xf32>> {
 | 
			
		||||
// CHECK-LABEL:   func @main(
 | 
			
		||||
// CHECK-SAME:               %[[VAL_0:.*]]: tensor<1x300xf32>,
 | 
			
		||||
// CHECK-SAME:               %[[VAL_1:.*]]: tensor<1x300x3x1xf32>) -> tuple<tensor<300x1x5xf32>> {
 | 
			
		||||
ENTRY %tfcompile.48 {
 | 
			
		||||
  %arg0.1 = f32[1,300] parameter(0)
 | 
			
		||||
  %arg1.2 = f32[1,300,3,1] parameter(1)
 | 
			
		||||
 | 
			
		||||
  // CHECK-NEXT: %0 = "mhlo.reshape"(%arg0) : (tensor<1x300xf32>) -> tensor<1x300xf32>
 | 
			
		||||
  // CHECK-NEXT: %[[VAL_2:.*]] = "mhlo.reshape"(%[[VAL_0]]) : (tensor<1x300xf32>) -> tensor<1x300xf32>
 | 
			
		||||
  %reshape.3 = f32[1,300] reshape(%arg0.1)
 | 
			
		||||
 | 
			
		||||
  // CHECK-NEXT: %1 = "mhlo.transpose"(%0) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<1x300xf32>) -> tensor<300x1xf32>
 | 
			
		||||
  // CHECK-NEXT: %[[VAL_3:.*]] = "mhlo.transpose"(%[[VAL_2]]) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<1x300xf32>) -> tensor<300x1xf32>
 | 
			
		||||
  %transpose.27 = f32[300,1] transpose(%reshape.3), dimensions={1,0}
 | 
			
		||||
 | 
			
		||||
  // CHECK-NEXT: %2 = "mhlo.reshape"(%1) : (tensor<300x1xf32>) -> tensor<300x1x1xf32>
 | 
			
		||||
  // CHECK-NEXT: %[[VAL_4:.*]] = "mhlo.reshape"(%[[VAL_3]]) : (tensor<300x1xf32>) -> tensor<300x1x1xf32>
 | 
			
		||||
  %reshape.28 = f32[300,1,1] reshape(%transpose.27)
 | 
			
		||||
 | 
			
		||||
  // CHECK-NEXT: %3 = "mhlo.reshape"(%2) : (tensor<300x1x1xf32>) -> tensor<300x1xf32>
 | 
			
		||||
  // CHECK-NEXT: %[[VAL_5:.*]] = "mhlo.reshape"(%[[VAL_4]]) : (tensor<300x1x1xf32>) -> tensor<300x1xf32>
 | 
			
		||||
  %reshape.29 = f32[300,1] reshape(%reshape.28)
 | 
			
		||||
 | 
			
		||||
  // CHECK-NEXT: %4 = "mhlo.broadcast_in_dim"(%3) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<300x1xf32>) -> tensor<300x1x5xf32>
 | 
			
		||||
  // CHECK-NEXT: %[[VAL_6:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_5]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<300x1xf32>) -> tensor<300x1x5xf32>
 | 
			
		||||
  %broadcast.30 = f32[300,1,5] broadcast(%reshape.29), dimensions={0,1}
 | 
			
		||||
 | 
			
		||||
  // CHECK-NEXT: %cst = constant  dense<1.000000e+00> : tensor<f32>
 | 
			
		||||
  // CHECK-NEXT: %[[VAL_7:.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32>
 | 
			
		||||
  %constant.8 = f32[] constant(1)
 | 
			
		||||
 | 
			
		||||
  // CHECK-NEXT: %5 = "mhlo.broadcast_in_dim"(%cst) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<300x1x5xf32>
 | 
			
		||||
  // CHECK-NEXT: %[[VAL_8:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_7]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<300x1x5xf32>
 | 
			
		||||
  %broadcast.9 = f32[300,1,5] broadcast(%constant.8), dimensions={}
 | 
			
		||||
 | 
			
		||||
  // CHECK-NEXT: %6 = mhlo.multiply %4, %5 : tensor<300x1x5xf32>
 | 
			
		||||
  // CHECK-NEXT: %[[VAL_9:.*]] = mhlo.multiply %[[VAL_6]], %[[VAL_8]] : tensor<300x1x5xf32>
 | 
			
		||||
  %multiply.31 = f32[300,1,5] multiply(%broadcast.30, %broadcast.9)
 | 
			
		||||
 | 
			
		||||
  // CHECK-NEXT: %cst_0 = constant  dense<0.000000e+00> : tensor<f32>
 | 
			
		||||
  // CHECK-NEXT: %[[VAL_10:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
 | 
			
		||||
  %constant.32 = f32[] constant(0)
 | 
			
		||||
 | 
			
		||||
  // CHECK-NEXT: %7 = "mhlo.broadcast_in_dim"(%cst_0) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<300x1x5xf32>
 | 
			
		||||
  // CHECK-NEXT: %[[VAL_11:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_10]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<300x1x5xf32>
 | 
			
		||||
  %broadcast.33 = f32[300,1,5] broadcast(%constant.32), dimensions={}
 | 
			
		||||
 | 
			
		||||
  // CHECK-NEXT: %8 = "mhlo.compare"(%6, %7) {comparison_direction = "GT"} : (tensor<300x1x5xf32>, tensor<300x1x5xf32>) -> tensor<300x1x5xi1>
 | 
			
		||||
  // CHECK-NEXT: %[[VAL_12:.*]] = "mhlo.compare"(%[[VAL_9]], %[[VAL_11]]) {comparison_direction = "GT"} : (tensor<300x1x5xf32>, tensor<300x1x5xf32>) -> tensor<300x1x5xi1>
 | 
			
		||||
  %compare.34 = pred[300,1,5] compare(%multiply.31, %broadcast.33), direction=GT
 | 
			
		||||
 | 
			
		||||
  // CHECK-NEXT: %cst_1 = constant  dense<0.000000e+00> : tensor<f32>
 | 
			
		||||
  // CHECK-NEXT: %[[VAL_13:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
 | 
			
		||||
  %constant.10 = f32[] constant(0)
 | 
			
		||||
 | 
			
		||||
  // CHECK-NEXT: %9 = "mhlo.broadcast_in_dim"(%cst_1) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<300x1x5xf32>
 | 
			
		||||
  // CHECK-NEXT: %[[VAL_14:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_13]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<300x1x5xf32>
 | 
			
		||||
  %broadcast.11 = f32[300,1,5] broadcast(%constant.10), dimensions={}
 | 
			
		||||
 | 
			
		||||
  // CHECK-NEXT: %cst_2 = constant  dense<0.000000e+00> : tensor<f32>
 | 
			
		||||
  // CHECK-NEXT: %[[VAL_15:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
 | 
			
		||||
  %constant.40 = f32[] constant(0)
 | 
			
		||||
 | 
			
		||||
  // CHECK-NEXT: %10 = "mhlo.broadcast_in_dim"(%cst_2) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<300x5xf32>
 | 
			
		||||
  // CHECK-NEXT: %[[VAL_16:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_15]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<300x5xf32>
 | 
			
		||||
  %broadcast.41 = f32[300,5] broadcast(%constant.40), dimensions={}
 | 
			
		||||
 | 
			
		||||
  // CHECK-NEXT: %11 = "mhlo.copy"(%arg1) : (tensor<1x300x3x1xf32>) -> tensor<1x300x3x1xf32>
 | 
			
		||||
  // CHECK-NEXT: %[[VAL_17:.*]] = "mhlo.copy"(%[[VAL_1]]) : (tensor<1x300x3x1xf32>) -> tensor<1x300x3x1xf32>
 | 
			
		||||
  %copy.1 = f32[1,300,3,1] copy(%arg1.2)
 | 
			
		||||
 | 
			
		||||
  // CHECK-NEXT: %12 = "mhlo.reshape"(%11) : (tensor<1x300x3x1xf32>) -> tensor<1x300x3x1xf32>
 | 
			
		||||
  // CHECK-NEXT: %[[VAL_18:.*]] = "mhlo.reshape"(%[[VAL_17]]) : (tensor<1x300x3x1xf32>) -> tensor<1x300x3x1xf32>
 | 
			
		||||
  %reshape.4 = f32[1,300,3,1] reshape(%copy.1)
 | 
			
		||||
 | 
			
		||||
  // CHECK-NEXT: %13 = "mhlo.reshape"(%12) : (tensor<1x300x3x1xf32>) -> tensor<1x300x3xf32>
 | 
			
		||||
  // CHECK-NEXT: %[[VAL_19:.*]] = "mhlo.reshape"(%[[VAL_18]]) : (tensor<1x300x3x1xf32>) -> tensor<1x300x3xf32>
 | 
			
		||||
  %reshape.24 = f32[1,300,3] reshape(%reshape.4)
 | 
			
		||||
 | 
			
		||||
  // CHECK-NEXT: %14 = "mhlo.transpose"(%13) {permutation = dense<[1, 0, 2]> : tensor<3xi64>} : (tensor<1x300x3xf32>) -> tensor<300x1x3xf32>
 | 
			
		||||
  // CHECK-NEXT: %[[VAL_20:.*]] = "mhlo.transpose"(%[[VAL_19]]) {permutation = dense<[1, 0, 2]> : tensor<3xi64>} : (tensor<1x300x3xf32>) -> tensor<300x1x3xf32>
 | 
			
		||||
  %transpose.25 = f32[300,1,3] transpose(%reshape.24), dimensions={1,0,2}
 | 
			
		||||
 | 
			
		||||
  // CHECK-NEXT: %15 = "mhlo.reshape"(%14) : (tensor<300x1x3xf32>) -> tensor<300x3xf32>
 | 
			
		||||
  // CHECK-NEXT: %[[VAL_21:.*]] = "mhlo.reshape"(%[[VAL_20]]) : (tensor<300x1x3xf32>) -> tensor<300x3xf32>
 | 
			
		||||
  %reshape.26 = f32[300,3] reshape(%transpose.25)
 | 
			
		||||
 | 
			
		||||
  // CHECK-NEXT: %cst_3 = constant  dense<{{\[\[}}-1.060230e-01, 1.215050e-01, 8.002390e-01, -7.688850e-01, 0.0966112986], [6.890140e-01, -4.070560e-01, -0.797852993, 3.789250e-03, -2.088810e-01], [-6.085290e-01, 2.766170e-02, 2.685570e-01, 5.774010e-01, -4.284370e-01]]> : tensor<3x5xf32>
 | 
			
		||||
  // CHECK-NEXT: %[[VAL_22:.*]] = mhlo.constant dense<{{\[\[}}-1.060230e-01, 1.215050e-01, 8.002390e-01, -7.688850e-01, 0.0966112986], [6.890140e-01, -4.070560e-01, -0.797852993, 3.789250e-03, -2.088810e-01], [-6.085290e-01, 2.766170e-02, 2.685570e-01, 5.774010e-01, -4.284370e-01]]> : tensor<3x5xf32>
 | 
			
		||||
  %constant.35 = f32[3,5] constant({ { -0.106023, 0.121505, 0.800239, -0.768885, 0.0966113 }, { 0.689014, -0.407056, -0.797853, 0.00378925, -0.208881 }, { -0.608529, 0.0276617, 0.268557, 0.577401, -0.428437 } })
 | 
			
		||||
 | 
			
		||||
  // TODO(b/129709049) consider making this default precision config implied.
 | 
			
		||||
  // CHECK-NEXT: %16 = "mhlo.dot"(%15, %cst_3) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<300x3xf32>, tensor<3x5xf32>) -> tensor<300x5xf32>
 | 
			
		||||
  // CHECK-NEXT: %[[VAL_23:.*]] = "mhlo.dot"(%[[VAL_21]], %[[VAL_22]]) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<300x3xf32>, tensor<3x5xf32>) -> tensor<300x5xf32>
 | 
			
		||||
  %dot.36 = f32[300,5] dot(%reshape.26, %constant.35), lhs_contracting_dims={1}, rhs_contracting_dims={0}
 | 
			
		||||
 | 
			
		||||
  // CHECK-NEXT: %cst_4 = constant  dense<0.000000e+00> : tensor<5xf32>
 | 
			
		||||
  // CHECK-NEXT: %[[VAL_24:.*]] = mhlo.constant dense<0.000000e+00> : tensor<5xf32>
 | 
			
		||||
  %constant.37 = f32[5]{0} constant({0, 0, 0, 0, 0})
 | 
			
		||||
 | 
			
		||||
  // CHECK-NEXT: %17 = "mhlo.broadcast_in_dim"(%cst_4) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<5xf32>) -> tensor<300x5xf32>
 | 
			
		||||
  // CHECK-NEXT: %[[VAL_25:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_24]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<5xf32>) -> tensor<300x5xf32>
 | 
			
		||||
  %broadcast.38 = f32[300,5] broadcast(%constant.37), dimensions={1}
 | 
			
		||||
 | 
			
		||||
  // CHECK-NEXT: %18 = mhlo.add %16, %17 : tensor<300x5xf32>
 | 
			
		||||
  // CHECK-NEXT: %[[VAL_26:.*]] = mhlo.add %[[VAL_23]], %[[VAL_25]] : tensor<300x5xf32>
 | 
			
		||||
  %add.39 = f32[300,5] add(%dot.36, %broadcast.38)
 | 
			
		||||
 | 
			
		||||
  // CHECK-NEXT: %19 = mhlo.maximum %10, %18 : tensor<300x5xf32>
 | 
			
		||||
  // CHECK-NEXT: %[[VAL_27:.*]] = mhlo.maximum %[[VAL_16]], %[[VAL_26]] : tensor<300x5xf32>
 | 
			
		||||
  %maximum.42 = f32[300,5] maximum(%broadcast.41, %add.39)
 | 
			
		||||
 | 
			
		||||
  // CHECK-NEXT: %20 = "mhlo.reshape"(%19) : (tensor<300x5xf32>) -> tensor<300x1x5xf32>
 | 
			
		||||
  // CHECK-NEXT: %[[VAL_28:.*]] = "mhlo.reshape"(%[[VAL_27]]) : (tensor<300x5xf32>) -> tensor<300x1x5xf32>
 | 
			
		||||
  %reshape.44 = f32[300,1,5] reshape(%maximum.42)
 | 
			
		||||
 | 
			
		||||
  // CHECK-NEXT: %21 = "mhlo.select"(%8, %9, %20) : (tensor<300x1x5xi1>, tensor<300x1x5xf32>, tensor<300x1x5xf32>) -> tensor<300x1x5xf32>
 | 
			
		||||
  // CHECK-NEXT: %[[VAL_29:.*]] = "mhlo.select"(%[[VAL_12]], %[[VAL_14]], %[[VAL_28]]) : (tensor<300x1x5xi1>, tensor<300x1x5xf32>, tensor<300x1x5xf32>) -> tensor<300x1x5xf32>
 | 
			
		||||
  %select.45 = f32[300,1,5] select(%compare.34, %broadcast.11, %reshape.44)
 | 
			
		||||
 | 
			
		||||
  // CHECK-NEXT: %22 = "mhlo.reshape"(%21) : (tensor<300x1x5xf32>) -> tensor<300x1x5xf32>
 | 
			
		||||
  // CHECK-NEXT: %[[VAL_30:.*]] = "mhlo.reshape"(%[[VAL_29]]) : (tensor<300x1x5xf32>) -> tensor<300x1x5xf32>
 | 
			
		||||
  %reshape.46 = f32[300,1,5] reshape(%select.45)
 | 
			
		||||
 | 
			
		||||
  // CHECK-NEXT: %23 = "mhlo.tuple"(%22) : (tensor<300x1x5xf32>) -> tuple<tensor<300x1x5xf32>>
 | 
			
		||||
  // CHECK-NEXT: return %23 : tuple<tensor<300x1x5xf32>>
 | 
			
		||||
  // CHECK-NEXT: %[[VAL_31:.*]] = "mhlo.tuple"(%[[VAL_30]]) : (tensor<300x1x5xf32>) -> tuple<tensor<300x1x5xf32>>
 | 
			
		||||
  // CHECK-NEXT: return %[[VAL_31]] : tuple<tensor<300x1x5xf32>>
 | 
			
		||||
  ROOT %tuple.47 = (f32[300,1,5]) tuple(%reshape.46)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -20,7 +20,7 @@ HloModule tfcompile.20
 | 
			
		||||
ENTRY %tfcompile.20 {
 | 
			
		||||
  %arg0.1 = f32[] parameter(0), parameter_replication={false}, metadata={op_name="XLA_Args"}
 | 
			
		||||
 | 
			
		||||
  // CHECK: [[C0:%.+]] = constant
 | 
			
		||||
  // CHECK: [[C0:%.+]] = mhlo.constant
 | 
			
		||||
  %constant.3 = f32[] constant(10), metadata={op_type="Less" op_name="Less"}
 | 
			
		||||
 | 
			
		||||
  // CHECK: [[R1:%.+]] = "mhlo.compare"([[A0]], [[C0]])
 | 
			
		||||
 | 
			
		||||
@ -176,48 +176,49 @@ add {
 | 
			
		||||
%test_constant {
 | 
			
		||||
 | 
			
		||||
  // Scalar/0D tensor constant
 | 
			
		||||
  // CHECK-NEXT:  %cst = constant dense<1> : tensor<i64>
 | 
			
		||||
  // CHECK-NEXT: %[[VAL_0:.*]] = mhlo.constant dense<1> : tensor<i64>
 | 
			
		||||
  %constant.0 = s64[] constant(1)
 | 
			
		||||
 | 
			
		||||
  // Note that double brackets "[[" have to be escaped as they denote variables
 | 
			
		||||
  // in FileCheck. The only way to do so is to drop into regex with "{{"
 | 
			
		||||
  // CHECK-NEXT:  constant  dense<{{\[\[\[\[}}1.000000e+00]], {{\[\[}}2.000000e+00]]], {{\[\[\[}}3.000000e+00]], {{\[\[}}4.000000e+00]]]]> : tensor<2x2x1x1xf32>
 | 
			
		||||
  // CHECK-NEXT: %[[VAL_1:.*]] = mhlo.constant dense<{{\[\[}}{{\[\[}}1.000000e+00]], {{\[\[}}2.000000e+00]]], {{\[\[}}[3.000000e+00]], {{\[\[}}4.000000e+00]]]]> : tensor<2x2x1x1xf32>
 | 
			
		||||
  %constant.1 = f32[2,2,1,1]{3,2,1,0} constant({{{{1.0}},{{2.0}}},{{{3.0}},{{4.0}}}}), metadata={op_type="Conv2D" op_name="embedded_inference/conv_model/conv_0/Conv2D"}
 | 
			
		||||
 | 
			
		||||
  // CHECK: dense<[1, 2, 4, 8]> : tensor<4xui64>
 | 
			
		||||
  // CHECK: %[[VAL_2:.*]] = mhlo.constant dense<[1, 2, 4, 8]> : tensor<4xui64>
 | 
			
		||||
  %constant.2 = u64[4] constant({ 1, 2, 4, 8 })
 | 
			
		||||
 | 
			
		||||
  // CHECK: dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xbf16>
 | 
			
		||||
  // CHECK: %[[VAL_3:.*]] = mhlo.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xbf16>
 | 
			
		||||
  %constant.3 = bf16[4] constant({1, 2, 3, 4})
 | 
			
		||||
 | 
			
		||||
  // CHECK: dense<(1.000000e+00,0.000000e+00)> : tensor<complex<f32>>
 | 
			
		||||
  // CHECK: %[[VAL_4:.*]] = mhlo.constant dense<(1.000000e+00,0.000000e+00)> : tensor<complex<f32>>
 | 
			
		||||
  %constant.4 = c64[] constant((1, 0))
 | 
			
		||||
 | 
			
		||||
  // CHECK: dense<(1.000000e+00,0.000000e+00)> : tensor<complex<f64>>
 | 
			
		||||
  // CHECK: %[[VAL_5:.*]] = mhlo.constant dense<(1.000000e+00,0.000000e+00)> : tensor<complex<f64>>
 | 
			
		||||
  %constant.5 = c128[] constant((1, 0))
 | 
			
		||||
 | 
			
		||||
  // CHECK: dense<[1.000000e+00, -4.000000e+00, -6.550400e+04, 1.562500e-02]> : tensor<4xf16>
 | 
			
		||||
  // CHECK: %[[VAL_6:.*]] = mhlo.constant dense<[1.000000e+00, -4.000000e+00, -6.550400e+04, 1.562500e-02]> : tensor<4xf16>
 | 
			
		||||
  ROOT %constant.6 = f16[4] constant({1, -4, -65504, 0.015625})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// TODO(b/129422361) Potentially update when copy, reshape, and conv have actual
 | 
			
		||||
// implementations with attributes, etc.
 | 
			
		||||
// CHECK-LABEL:  func @test_conv(%arg0: tensor<256x32x32x6xf32>) -> tuple<tensor<256x30x30x16xf32>>
 | 
			
		||||
// CHECK-LABEL: func @test_conv(
 | 
			
		||||
// CHECK-SAME:                  %[[VAL_0:.*]]: tensor<256x32x32x6xf32>) -> tuple<tensor<256x30x30x16xf32>> attributes {sym_visibility = "private"} {
 | 
			
		||||
%test_conv {
 | 
			
		||||
  %arg0.1 = f32[256,32,32,6]{3,2,1,0} parameter(0), metadata={op_name="HLO_Args"}
 | 
			
		||||
 | 
			
		||||
  // CHECK-NEXT:  %0 = "mhlo.copy"(%arg0) {minor_to_major = dense<[2, 1, 3, 0]> : tensor<4xindex>} : (tensor<256x32x32x6xf32>) -> tensor<256x32x32x6xf32>
 | 
			
		||||
  // CHECK-NEXT: %[[VAL_1:.*]] = "mhlo.copy"(%[[VAL_0]]) {minor_to_major = dense<[2, 1, 3, 0]> : tensor<4xindex>} : (tensor<256x32x32x6xf32>) -> tensor<256x32x32x6xf32>
 | 
			
		||||
  %copy.1 = f32[256,32,32,6]{2,1,3,0} copy(%arg0.1), metadata={op_name="HLO_Args"}
 | 
			
		||||
 | 
			
		||||
  // CHECK-NEXT:  %1 = "mhlo.reshape"(%0) {minor_to_major = dense<[2, 1, 3, 0]> : tensor<4xindex>} : (tensor<256x32x32x6xf32>) -> tensor<256x32x32x6xf32>
 | 
			
		||||
  // CHECK-NEXT: %[[VAL_2:.*]] = "mhlo.reshape"(%[[VAL_1]]) {minor_to_major = dense<[2, 1, 3, 0]> : tensor<4xindex>} : (tensor<256x32x32x6xf32>) -> tensor<256x32x32x6xf32>
 | 
			
		||||
  %reshape.2 = f32[256,32,32,6]{2,1,3,0} reshape(%copy.1)
 | 
			
		||||
 | 
			
		||||
  // Note that double brackets "[[" have to be escaped as they denote variables
 | 
			
		||||
  // in FileCheck. The only way to do so is to drop into regex with "{{"
 | 
			
		||||
  // CHECK-NEXT:  %cst = constant  dense<{{\[\[\[\[}}5.000000e-01]], {{\[\[}}-6.000000e-01]]], {{\[\[\[}}3.000000e-01]], {{\[\[}}-1.000000e-01]]]]> : tensor<2x2x1x1xf32>
 | 
			
		||||
  // CHECK-NEXT: %[[VAL_3:.*]] = mhlo.constant dense<{{\[\[}}{{\[\[}}5.000000e-01]], {{\[\[}}-6.000000e-01]]], {{\[\[}}[3.000000e-01]], {{\[\[}}-1.000000e-01]]]]> : tensor<2x2x1x1xf32>
 | 
			
		||||
  %constant.3 = f32[2,2,1,1]{3,2,1,0} constant({{{{0.5}}, {{-0.6}}}, {{{0.3}}, {{-0.1}}}}), metadata={op_type="Conv2D" op_name="embedded_inference/conv_model/conv_0/Conv2D"}
 | 
			
		||||
 | 
			
		||||
  // CHECK-NEXT:  %2 = "mhlo.convolution"(%1, %cst) {
 | 
			
		||||
  // CHECK-NEXT: %[[VAL_4:.*]] = "mhlo.convolution"(%[[VAL_2]], %[[VAL_3]]) {
 | 
			
		||||
  // CHECK-SAME:     batch_group_count = 1 : i64
 | 
			
		||||
  // CHECK-SAME:     dimension_numbers = {
 | 
			
		||||
  // CHECK-SAME:       input_batch_dimension = 0 : i64
 | 
			
		||||
@ -241,10 +242,10 @@ add {
 | 
			
		||||
 | 
			
		||||
  %convolution.4 = f32[16,30,30,256]{2,1,3,0} convolution(%reshape.2, %constant.3), window={size=3x3 stride=4x5 pad=44_45x60_60 rhs_dilate=2x3}, dim_labels=b01f_01io->f01b, metadata={op_type="Conv2D" op_name="embedded_inference/conv_model/conv_0/Conv2D"}
 | 
			
		||||
 | 
			
		||||
  // CHECK-NEXT:  %3 = "mhlo.reshape"(%2) : (tensor<16x30x30x256xf32>) -> tensor<256x30x30x16xf32>
 | 
			
		||||
  // CHECK-NEXT: %[[VAL_5:.*]] = "mhlo.reshape"(%[[VAL_4]]) : (tensor<16x30x30x256xf32>) -> tensor<256x30x30x16xf32>
 | 
			
		||||
  %reshape.5 = f32[256,30,30,16]{3,2,1,0} reshape(%convolution.4), metadata={op_name="HLO_Retvals"}
 | 
			
		||||
 | 
			
		||||
  // CHECK-NEXT:  "mhlo.tuple"(%3) : (tensor<256x30x30x16xf32>) -> tuple<tensor<256x30x30x16xf32>>
 | 
			
		||||
  // CHECK-NEXT: %[[VAL_6:.*]] = "mhlo.tuple"(%[[VAL_5]]) : (tensor<256x30x30x16xf32>) -> tuple<tensor<256x30x30x16xf32>>
 | 
			
		||||
  ROOT %tuple.6 = (f32[256,30,30,16]{3,2,1,0}) tuple(%reshape.5), metadata={op_name="HLO_Retvals"}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -4,25 +4,25 @@ HloModule tfcompile.1
 | 
			
		||||
 | 
			
		||||
// CHECK-LABEL: func @main() -> tensor<i1> {
 | 
			
		||||
ENTRY %tfcompile.1 {
 | 
			
		||||
  // CHECK-NEXT: %cst = constant dense<1.000000e+00> : tensor<f32>
 | 
			
		||||
  // CHECK-NEXT: %[[VAL_0:.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32>
 | 
			
		||||
  %constant.0 = f32[] constant(1)
 | 
			
		||||
 | 
			
		||||
  // CHECK-NEXT: %cst_0 = constant dense<1.000000e+00> : tensor<f64>
 | 
			
		||||
  // CHECK-NEXT: %[[VAL_1:.*]] = mhlo.constant dense<1.000000e+00> : tensor<f64>
 | 
			
		||||
  %constant.1 = f64[] constant(1)
 | 
			
		||||
 | 
			
		||||
  // CHECK-NEXT: %cst_1 = constant dense<1> : tensor<i8>
 | 
			
		||||
  // CHECK-NEXT: %[[VAL_2:.*]] = mhlo.constant dense<1> : tensor<i8>
 | 
			
		||||
  %constant.2 = s8[] constant(1)
 | 
			
		||||
 | 
			
		||||
  // CHECK-NEXT: %cst_2 = constant dense<1> : tensor<i16>
 | 
			
		||||
  // CHECK-NEXT: %[[VAL_3:.*]] = mhlo.constant dense<1> : tensor<i16>
 | 
			
		||||
  %constant.3 = s16[] constant(1)
 | 
			
		||||
 | 
			
		||||
  // CHECK-NEXT: %cst_3 = constant dense<1> : tensor<i32>
 | 
			
		||||
  // CHECK-NEXT: %[[VAL_4:.*]] = mhlo.constant dense<1> : tensor<i32>
 | 
			
		||||
  %constant.4 = s32[] constant(1)
 | 
			
		||||
 | 
			
		||||
  // CHECK-NEXT: %cst_4 = constant dense<1> : tensor<i64>
 | 
			
		||||
  // CHECK-NEXT: %[[VAL_5:.*]] = mhlo.constant dense<1> : tensor<i64>
 | 
			
		||||
  %constant.5 = s64[] constant(1)
 | 
			
		||||
 | 
			
		||||
  // CHECK-NEXT: %cst_5 = constant dense<true> : tensor<i1>
 | 
			
		||||
  // CHECK-NEXT: return %cst_5 : tensor<i1>
 | 
			
		||||
  // CHECK-NEXT: %[[VAL_6:.*]] = mhlo.constant dense<true> : tensor<i1>
 | 
			
		||||
  // CHECK-NEXT: return %[[VAL_6]] : tensor<i1>
 | 
			
		||||
  ROOT %constant.6 = pred[] constant(1)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -351,6 +351,7 @@ cc_library(
 | 
			
		||||
        ":xla_op_registry",
 | 
			
		||||
        ":xla_resource",
 | 
			
		||||
        "@com_google_absl//absl/algorithm:container",
 | 
			
		||||
        "@com_google_absl//absl/container:flat_hash_map",
 | 
			
		||||
        "@com_google_absl//absl/memory",
 | 
			
		||||
        "@com_google_absl//absl/types:span",
 | 
			
		||||
        "@com_google_absl//absl/types:variant",
 | 
			
		||||
 | 
			
		||||
@ -91,15 +91,15 @@ class DataFormatVecPermuteOp : public XlaOpKernel {
 | 
			
		||||
      : XlaOpKernel(ctx) {
 | 
			
		||||
    OP_REQUIRES_OK(ctx, ctx->GetAttr("src_format", &src_format_));
 | 
			
		||||
    OP_REQUIRES(
 | 
			
		||||
        ctx, src_format_.size() == 4,
 | 
			
		||||
        errors::InvalidArgument("Data format should have 4 characters"));
 | 
			
		||||
        ctx, src_format_.size() == 4 || src_format_.size() == 5,
 | 
			
		||||
        errors::InvalidArgument("Data format should have 4 or 5 characters"));
 | 
			
		||||
    TensorFormat data_format;
 | 
			
		||||
    OP_REQUIRES(ctx, FormatFromString(src_format_, &data_format),
 | 
			
		||||
                errors::InvalidArgument("Invalid data format"));
 | 
			
		||||
    OP_REQUIRES_OK(ctx, ctx->GetAttr("dst_format", &dst_format_));
 | 
			
		||||
    OP_REQUIRES(
 | 
			
		||||
        ctx, dst_format_.size() == 4,
 | 
			
		||||
        errors::InvalidArgument("Data format should have 4 characters"));
 | 
			
		||||
        ctx, dst_format_.size() == 4 || dst_format_.size() == 5,
 | 
			
		||||
        errors::InvalidArgument("Data format should have 4 or 5 characters"));
 | 
			
		||||
    OP_REQUIRES(ctx, FormatFromString(dst_format_, &data_format),
 | 
			
		||||
                errors::InvalidArgument("Invalid data format"));
 | 
			
		||||
  }
 | 
			
		||||
@ -113,9 +113,10 @@ class DataFormatVecPermuteOp : public XlaOpKernel {
 | 
			
		||||
                    input_tensor_shape.DebugString()));
 | 
			
		||||
    const int dim0 = input_tensor_shape.dim_size(0);
 | 
			
		||||
    OP_REQUIRES(
 | 
			
		||||
        ctx, dim0 == 2 || dim0 == 4,
 | 
			
		||||
        ctx, dim0 == 2 || dim0 == 4 || dim0 == 5,
 | 
			
		||||
        errors::InvalidArgument(
 | 
			
		||||
            "First dimension of input must be of size 4, but got shape ",
 | 
			
		||||
            "First dimension of input must be of size 2, 4 or 5, but got "
 | 
			
		||||
            "shape ",
 | 
			
		||||
            input_tensor_shape.DebugString()));
 | 
			
		||||
    if (input_rank == 2) {
 | 
			
		||||
      OP_REQUIRES(
 | 
			
		||||
 | 
			
		||||
@ -18,6 +18,7 @@ limitations under the License.
 | 
			
		||||
#include <numeric>
 | 
			
		||||
#include <vector>
 | 
			
		||||
 | 
			
		||||
#include "absl/container/flat_hash_map.h"
 | 
			
		||||
#include "absl/memory/memory.h"
 | 
			
		||||
#include "absl/types/variant.h"
 | 
			
		||||
#include "tensorflow/compiler/jit/defs.h"
 | 
			
		||||
@ -675,6 +676,38 @@ std::unique_ptr<Graph> XlaCompiler::GetGraph(const FunctionBody* fbody) {
 | 
			
		||||
  return graph;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Collects all control rets from `orig_control_ret_nodes` that are still valid,
 | 
			
		||||
// keeping the same order.
 | 
			
		||||
std::vector<std::string> GetValidControlRets(
 | 
			
		||||
    absl::Span<Node* const> orig_control_ret_nodes, const Graph& graph) {
 | 
			
		||||
  // Build map from control ret node to index.
 | 
			
		||||
  absl::flat_hash_map<const Node*, int> control_ret_nodes_map;
 | 
			
		||||
  for (int i = 0; i < orig_control_ret_nodes.size(); ++i) {
 | 
			
		||||
    const Node* n = orig_control_ret_nodes[i];
 | 
			
		||||
    control_ret_nodes_map[n] = i;
 | 
			
		||||
  }
 | 
			
		||||
  // Check which control rets are still valid.
 | 
			
		||||
  std::vector<bool> is_valid_control_ret(orig_control_ret_nodes.size(), false);
 | 
			
		||||
  int num_valid_control_rets = 0;
 | 
			
		||||
  for (const Node* n : graph.nodes()) {
 | 
			
		||||
    auto iter = control_ret_nodes_map.find(n);
 | 
			
		||||
    if (iter != control_ret_nodes_map.end()) {
 | 
			
		||||
      ++num_valid_control_rets;
 | 
			
		||||
      is_valid_control_ret[iter->second] = true;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  // Return valid control rets in same order as they appear in
 | 
			
		||||
  // `orig_control_ret_nodes`.
 | 
			
		||||
  std::vector<std::string> valid_control_rets;
 | 
			
		||||
  valid_control_rets.reserve(num_valid_control_rets);
 | 
			
		||||
  for (int i = 0; i < orig_control_ret_nodes.size(); ++i) {
 | 
			
		||||
    if (is_valid_control_ret[i]) {
 | 
			
		||||
      valid_control_rets.push_back(orig_control_ret_nodes[i]->name());
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  return valid_control_rets;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Status XlaCompiler::CompileFunction(
 | 
			
		||||
    const XlaCompiler::CompileOptions& options,
 | 
			
		||||
    const NameAttrList& fn_name_attrs,
 | 
			
		||||
@ -765,15 +798,15 @@ Status XlaCompiler::CompileFunction(
 | 
			
		||||
      ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED) {
 | 
			
		||||
    VLOG(1) << "Using MLIR bridge";
 | 
			
		||||
    GraphDebugInfo debug_info;
 | 
			
		||||
    std::vector<std::string> control_rets;
 | 
			
		||||
    for (const auto* control_ret_node : fbody->control_ret_nodes) {
 | 
			
		||||
      control_rets.push_back(control_ret_node->name());
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    std::vector<std::string> valid_control_rets =
 | 
			
		||||
        GetValidControlRets(fbody->control_ret_nodes, *graph);
 | 
			
		||||
 | 
			
		||||
    TF_RETURN_IF_ERROR(CompileGraphToXlaHlo(
 | 
			
		||||
        std::move(*graph), mlir::SpanToArrayRef<XlaCompiler::Argument>(args),
 | 
			
		||||
        control_rets, options_.device_type.type_string(), options.use_tuple_arg,
 | 
			
		||||
        *options_.flib_def, debug_info, options_.shape_representation_fn,
 | 
			
		||||
        result));
 | 
			
		||||
        valid_control_rets, options_.device_type.type_string(),
 | 
			
		||||
        options.use_tuple_arg, *options_.flib_def, debug_info,
 | 
			
		||||
        options_.shape_representation_fn, result));
 | 
			
		||||
  } else {
 | 
			
		||||
    TF_RETURN_IF_ERROR(
 | 
			
		||||
        CompileGraph(options, function_id, std::move(graph), args, result));
 | 
			
		||||
 | 
			
		||||
@ -1994,7 +1994,7 @@ PjRtExecutable::ExecuteOnLocalDevices(
 | 
			
		||||
    if (!statusor.ok()) {
 | 
			
		||||
      return AppendStatus(
 | 
			
		||||
          statusor.status(),
 | 
			
		||||
          absl::StrFormat("while running replica %d and partition %d of a"
 | 
			
		||||
          absl::StrFormat("while running replica %d and partition %d of a "
 | 
			
		||||
                          "replicated computation (other "
 | 
			
		||||
                          "replicas may have failed as well).",
 | 
			
		||||
                          replica, partition));
 | 
			
		||||
 | 
			
		||||
@ -5220,13 +5220,31 @@ StatusOr<bool> AlgebraicSimplifierVisitor::SwapConvOperands(
 | 
			
		||||
  for (int64 spatial_dim = 0;
 | 
			
		||||
       spatial_dim < dnums.input_spatial_dimensions_size(); ++spatial_dim) {
 | 
			
		||||
    const int64 kernel_size = window_dims[spatial_dim].size();
 | 
			
		||||
    const int64 dilated_kernel_size =
 | 
			
		||||
        1 + (kernel_size - 1) * window_dims[spatial_dim].window_dilation();
 | 
			
		||||
 | 
			
		||||
    const bool can_be_group_or_contraction =
 | 
			
		||||
        !window_dims[spatial_dim].window_reversal() &&
 | 
			
		||||
        window_dims[spatial_dim].padding_low() == 0 &&
 | 
			
		||||
        window_dims[spatial_dim].padding_high() == 0 &&
 | 
			
		||||
        window_dims[spatial_dim].window_dilation() == 1;
 | 
			
		||||
    const bool is_group_dim =
 | 
			
		||||
        can_be_group_or_contraction &&
 | 
			
		||||
        window_dims[spatial_dim].base_dilation() == kernel_size &&
 | 
			
		||||
        window_dims[spatial_dim].stride() == kernel_size - 1;
 | 
			
		||||
    const int64 input_size =
 | 
			
		||||
        input->shape().dimensions(dnums.input_spatial_dimensions(spatial_dim));
 | 
			
		||||
    const bool is_pure_contraction_dim =
 | 
			
		||||
        kernel_size == input_size && can_be_group_or_contraction &&
 | 
			
		||||
        window_dims[spatial_dim].base_dilation() == 1 &&
 | 
			
		||||
        window_dims[spatial_dim].stride() == 1;
 | 
			
		||||
    if (is_group_dim || is_pure_contraction_dim) {
 | 
			
		||||
      *(swapped_window.add_dimensions()) = window_dims[spatial_dim];
 | 
			
		||||
      continue;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    const int64 dilated_kernel_size =
 | 
			
		||||
        1 + (kernel_size - 1) * window_dims[spatial_dim].window_dilation();
 | 
			
		||||
    const int64 dilated_input_size =
 | 
			
		||||
        1 + (input_size - 1) * window_dims[spatial_dim].base_dilation();
 | 
			
		||||
 | 
			
		||||
    // Don't decide to swap if the input size is one, since many convolution
 | 
			
		||||
    // implementations can easily hand that special case efficiently.
 | 
			
		||||
    kernel_product *= kernel_size;
 | 
			
		||||
 | 
			
		||||
@ -6654,6 +6654,32 @@ TEST_F(AlgebraicSimplifierTest, AbsEliminationMultiply) {
 | 
			
		||||
              GmockMatch(m::Multiply(m::Parameter(0), m::Parameter(0))));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(AlgebraicSimplifierTest, BroadcastCompareSimplification) {
 | 
			
		||||
  std::string module_string = R"(
 | 
			
		||||
    HloModule m
 | 
			
		||||
    test {
 | 
			
		||||
      a = s32[] parameter(0)
 | 
			
		||||
      b = s32[] parameter(1)
 | 
			
		||||
      x = s32[10]{0} parameter(2)
 | 
			
		||||
      broadcast_a = s32[10]{0} broadcast(a), dimensions={}
 | 
			
		||||
      broadcast_b = s32[10]{0} broadcast(b), dimensions={}
 | 
			
		||||
      add = s32[10]{0} add(broadcast_a, x)
 | 
			
		||||
      ROOT cmp = pred[10]{0} compare(add, broadcast_b), direction=EQ
 | 
			
		||||
    }
 | 
			
		||||
  )";
 | 
			
		||||
  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_string));
 | 
			
		||||
  ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
 | 
			
		||||
  EXPECT_THAT(m->entry_computation()->root_instruction(),
 | 
			
		||||
              GmockMatch(m::Compare(m::Parameter(2),
 | 
			
		||||
                                    m::Broadcast(m::Subtract(
 | 
			
		||||
                                        m::Parameter(1), m::Parameter(0))))));
 | 
			
		||||
 | 
			
		||||
  // Numerically unstable transformation shouldn't be applied to floating types.
 | 
			
		||||
  std::string module_string_f32 =
 | 
			
		||||
      absl::StrReplaceAll(module_string, {{"s32", "f32"}});
 | 
			
		||||
  ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(AlgebraicSimplifierTest, AbsEliminationPower2) {
 | 
			
		||||
  const char* kModuleStr = R"(
 | 
			
		||||
    HloModule m
 | 
			
		||||
 | 
			
		||||
@ -63,7 +63,6 @@ limitations under the License.
 | 
			
		||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
 | 
			
		||||
#include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h"
 | 
			
		||||
#include "tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h"
 | 
			
		||||
#include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
 | 
			
		||||
#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
 | 
			
		||||
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
 | 
			
		||||
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
 | 
			
		||||
@ -2200,20 +2199,21 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) {
 | 
			
		||||
  if (llvm_ir::CanEmitFusedDynamicUpdateSliceInPlace(fusion, assignment_)) {
 | 
			
		||||
    VLOG(3) << "HandleFusion FusedDynamicUpdateSliceInPlace";
 | 
			
		||||
    CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_);
 | 
			
		||||
    FusedIrEmitter fused_emitter(&elemental_emitter);
 | 
			
		||||
    BindFusionArguments(fusion, &fused_emitter);
 | 
			
		||||
 | 
			
		||||
    TF_RETURN_IF_ERROR(EmitTargetAddressForOp(fusion));
 | 
			
		||||
    // Delegate to common implementation of fused in-place dynamic-update-slice.
 | 
			
		||||
    return llvm_ir::EmitFusedDynamicUpdateSliceInPlace(
 | 
			
		||||
        fusion, GetGeneratorForOperandIrArrays(fusion), GetIrArrayFor(fusion),
 | 
			
		||||
        &elemental_emitter, &b_);
 | 
			
		||||
        fusion, GetIrArrayFor(fusion), &fused_emitter, &b_);
 | 
			
		||||
  } else if (fusion->IsLoopFusion()) {
 | 
			
		||||
    VLOG(3) << "HandleFusion kLoop";
 | 
			
		||||
    CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_);
 | 
			
		||||
    auto operands = GetIrArraysForOperandsOf(fusion);
 | 
			
		||||
    FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(fusion),
 | 
			
		||||
                                 &elemental_emitter);
 | 
			
		||||
    TF_RETURN_IF_ERROR(fusion->fused_expression_root()->Accept(&fused_emitter));
 | 
			
		||||
 | 
			
		||||
    return EmitTargetElementLoop(fusion, fused_emitter.GetRootGenerator());
 | 
			
		||||
    FusedIrEmitter fused_emitter(&elemental_emitter);
 | 
			
		||||
    BindFusionArguments(fusion, &fused_emitter);
 | 
			
		||||
    TF_ASSIGN_OR_RETURN(auto generator, fused_emitter.GetGenerator(
 | 
			
		||||
                                            fusion->fused_expression_root()));
 | 
			
		||||
    return EmitTargetElementLoop(fusion, generator);
 | 
			
		||||
  } else if (fusion->IsOutputFusion()) {
 | 
			
		||||
    VLOG(3) << "HandleFusion kOutput";
 | 
			
		||||
    int64 dot_op_index = root->operand(0)->opcode() == HloOpcode::kDot ? 0 : 1;
 | 
			
		||||
@ -3451,5 +3451,17 @@ llvm::Value* IrEmitter::GetBufferForGlobalCallReturnValue(
 | 
			
		||||
  return EmitBufferPointer(root_buffer, root_inst->shape());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void IrEmitter::BindFusionArguments(const HloInstruction* fusion,
 | 
			
		||||
                                    FusedIrEmitter* fused_emitter) {
 | 
			
		||||
  for (int i = 0; i < fusion->operand_count(); i++) {
 | 
			
		||||
    const HloInstruction* operand = fusion->operand(i);
 | 
			
		||||
    fused_emitter->BindGenerator(
 | 
			
		||||
        fusion->fused_parameter(i),
 | 
			
		||||
        [this, operand](llvm_ir::IrArray::Index index) {
 | 
			
		||||
          return GetIrArrayFor(operand).EmitReadArrayElement(index, &b_);
 | 
			
		||||
        });
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace cpu
 | 
			
		||||
}  // namespace xla
 | 
			
		||||
 | 
			
		||||
@ -43,6 +43,7 @@ limitations under the License.
 | 
			
		||||
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
 | 
			
		||||
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
 | 
			
		||||
#include "tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h"
 | 
			
		||||
#include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
 | 
			
		||||
#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
 | 
			
		||||
#include "tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h"
 | 
			
		||||
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
 | 
			
		||||
@ -234,10 +235,9 @@ class IrEmitter : public DfsHloVisitorWithDefault,
 | 
			
		||||
  std::vector<llvm_ir::IrArray> GetIrArraysForOperandsOf(
 | 
			
		||||
      const HloInstruction* hlo);
 | 
			
		||||
 | 
			
		||||
  GeneratorForOperandIrArrays GetGeneratorForOperandIrArrays(
 | 
			
		||||
      HloInstruction* unnested_hlo) {
 | 
			
		||||
    return [=]() { return GetIrArraysForOperandsOf(unnested_hlo); };
 | 
			
		||||
  }
 | 
			
		||||
  // Bind all argument IrArrays of `fusion` to `fused_emitter`.
 | 
			
		||||
  void BindFusionArguments(const HloInstruction* fusion,
 | 
			
		||||
                           FusedIrEmitter* fused_emitter);
 | 
			
		||||
 | 
			
		||||
  // Augments IrArray with aliasing information.
 | 
			
		||||
  void AddAliasingInformationToIrArray(const HloInstruction& hlo,
 | 
			
		||||
 | 
			
		||||
@ -38,14 +38,60 @@ FusionNodeIndexingEvaluation::FusionNodeIndexingEvaluation(
 | 
			
		||||
// a tradeoff between compilation time and runtime here.
 | 
			
		||||
const int64 FusionNodeIndexingEvaluation::kAllowedCodeDuplication = 15;
 | 
			
		||||
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
// Returns which ops invalidate the cache of emitted instructions by creating a
 | 
			
		||||
// new BasicBlock and setting the insertion point to the newly created
 | 
			
		||||
// BasicBlock. We can only reuse cached values if they were emitted in the same
 | 
			
		||||
// BasicBlock as the current BasicBlock.
 | 
			
		||||
bool OpInvalidatesCache(const HloInstruction* hlo) {
 | 
			
		||||
  switch (hlo->opcode()) {
 | 
			
		||||
    // This list of ops was created by inspecting the code. There is no
 | 
			
		||||
    // guarantee that it is complete.
 | 
			
		||||
    case HloOpcode::kConcatenate:
 | 
			
		||||
    case HloOpcode::kDot:
 | 
			
		||||
    case HloOpcode::kDynamicUpdateSlice:
 | 
			
		||||
    case HloOpcode::kPad:
 | 
			
		||||
    case HloOpcode::kReduce:
 | 
			
		||||
    case HloOpcode::kReduceWindow:
 | 
			
		||||
      return true;
 | 
			
		||||
    default:
 | 
			
		||||
      return false;
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Counts the number of "real" users of 'hlo'. When 'hlo' has a fusion node as
 | 
			
		||||
// user, we consider the users of the fusion parameter corresponding to 'hlo' as
 | 
			
		||||
// the real users.
 | 
			
		||||
int64 UserCount(const HloInstruction* hlo) {
 | 
			
		||||
  int64 cnt = 0;
 | 
			
		||||
  for (HloInstruction* user : hlo->users()) {
 | 
			
		||||
    if (user->opcode() == HloOpcode::kFusion) {
 | 
			
		||||
      // Count the number of users of the parameter corresponding to the fusion
 | 
			
		||||
      // operand.
 | 
			
		||||
      int64 operand_index = user->operand_index(hlo);
 | 
			
		||||
      cnt += user->fused_parameter(operand_index)->user_count();
 | 
			
		||||
    } else {
 | 
			
		||||
      ++cnt;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  return cnt;
 | 
			
		||||
}
 | 
			
		||||
}  // namespace
 | 
			
		||||
 | 
			
		||||
bool FusionNodeIndexingEvaluation::CodeDuplicationTooHigh(
 | 
			
		||||
    const HloInstruction* producer) const {
 | 
			
		||||
  return EvaluateEmittedInstructions(producer) > kAllowedCodeDuplication;
 | 
			
		||||
  int64 emitted_instructions = EvaluateEmittedInstructions(producer);
 | 
			
		||||
  return emitted_instructions > kAllowedCodeDuplication ||
 | 
			
		||||
         (OpInvalidatesCache(producer) &&
 | 
			
		||||
          (emitted_instructions > 1 || UserCount(producer) > 1));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool FusionNodeIndexingEvaluation::MaxCodeDuplicationTooHigh() const {
 | 
			
		||||
  for (const auto& entry : index_usage_count_) {
 | 
			
		||||
    if (entry.second > kAllowedCodeDuplication) {
 | 
			
		||||
    if (entry.second > kAllowedCodeDuplication ||
 | 
			
		||||
        (OpInvalidatesCache(entry.first) &&
 | 
			
		||||
         (entry.second > 1 || UserCount(entry.first) > 1))) {
 | 
			
		||||
      return true;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
@ -773,11 +773,11 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) {
 | 
			
		||||
  CHECK_EQ(HloInstruction::FusionKind::kLoop, fusion->fusion_kind());
 | 
			
		||||
  GpuElementalIrEmitter elemental_emitter(hlo_module_config_, module_, &b_,
 | 
			
		||||
                                          GetNestedComputer());
 | 
			
		||||
  FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(fusion),
 | 
			
		||||
                               &elemental_emitter);
 | 
			
		||||
  TF_RETURN_IF_ERROR(fusion->fused_expression_root()->Accept(&fused_emitter));
 | 
			
		||||
 | 
			
		||||
  return EmitTargetElementLoop(*fusion, fused_emitter.GetRootGenerator());
 | 
			
		||||
  FusedIrEmitter fused_emitter(&elemental_emitter);
 | 
			
		||||
  BindFusionArguments(fusion, &fused_emitter);
 | 
			
		||||
  TF_ASSIGN_OR_RETURN(auto generator, fused_emitter.GetGenerator(
 | 
			
		||||
                                          fusion->fused_expression_root()));
 | 
			
		||||
  return EmitTargetElementLoop(*fusion, generator);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Status IrEmitter::HandleCall(HloInstruction* call) {
 | 
			
		||||
@ -876,5 +876,17 @@ std::vector<llvm_ir::IrArray> IrEmitter::ConstructIrArrayForOutputs(
 | 
			
		||||
  return output_arrays;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void IrEmitter::BindFusionArguments(const HloInstruction* fusion,
 | 
			
		||||
                                    FusedIrEmitter* fused_emitter) {
 | 
			
		||||
  for (int i = 0; i < fusion->operand_count(); i++) {
 | 
			
		||||
    const HloInstruction* operand = fusion->operand(i);
 | 
			
		||||
    fused_emitter->BindGenerator(
 | 
			
		||||
        fusion->fused_parameter(i),
 | 
			
		||||
        [this, operand, fusion](llvm_ir::IrArray::Index index) {
 | 
			
		||||
          return GetIrArray(*operand, *fusion).EmitReadArrayElement(index, &b_);
 | 
			
		||||
        });
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace gpu
 | 
			
		||||
}  // namespace xla
 | 
			
		||||
 | 
			
		||||
@ -36,6 +36,7 @@ limitations under the License.
 | 
			
		||||
#include "tensorflow/compiler/xla/service/gpu/thunk.h"
 | 
			
		||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
 | 
			
		||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
 | 
			
		||||
#include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
 | 
			
		||||
#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
 | 
			
		||||
#include "tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h"
 | 
			
		||||
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
 | 
			
		||||
@ -182,18 +183,9 @@ class IrEmitter : public DfsHloVisitorWithDefault,
 | 
			
		||||
  const HloModuleConfig& hlo_module_config_;
 | 
			
		||||
 | 
			
		||||
 protected:
 | 
			
		||||
  GeneratorForOperandIrArrays GetGeneratorForOperandIrArrays(
 | 
			
		||||
      const HloInstruction* fusion) {
 | 
			
		||||
    return [=]() {
 | 
			
		||||
      std::vector<llvm_ir::IrArray> ir_arrays;
 | 
			
		||||
      ir_arrays.reserve(fusion->operand_count());
 | 
			
		||||
      absl::c_transform(fusion->operands(), std::back_inserter(ir_arrays),
 | 
			
		||||
                        [&](const HloInstruction* operand) {
 | 
			
		||||
                          return GetIrArray(*operand, *fusion);
 | 
			
		||||
                        });
 | 
			
		||||
      return ir_arrays;
 | 
			
		||||
    };
 | 
			
		||||
  }
 | 
			
		||||
  // Bind all argument IrArrays of `fusion` to `fused_emitter`.
 | 
			
		||||
  void BindFusionArguments(const HloInstruction* fusion,
 | 
			
		||||
                           FusedIrEmitter* fused_emitter);
 | 
			
		||||
 | 
			
		||||
 private:
 | 
			
		||||
  // A helper method for EmitAtomicOperationForNestedComputation. Certain
 | 
			
		||||
 | 
			
		||||
@ -960,19 +960,24 @@ Status IrEmitterUnnested::EmitLoopFusionFromMlir(MlirEmitterInput input,
 | 
			
		||||
 | 
			
		||||
  GpuElementalIrEmitter elemental_emitter(hlo_module_config_, module_, &b_,
 | 
			
		||||
                                          GetNestedComputer());
 | 
			
		||||
  FusedIrEmitter fused_emitter(&elemental_emitter);
 | 
			
		||||
 | 
			
		||||
  FusedIrEmitter fused_emitter(
 | 
			
		||||
      [&] {
 | 
			
		||||
        auto operand_ir_arrays =
 | 
			
		||||
            absl::MakeSpan(ir_arrays).subspan(0, fusion_operands.size());
 | 
			
		||||
        return std::vector<llvm_ir::IrArray>(operand_ir_arrays.begin(),
 | 
			
		||||
                                             operand_ir_arrays.end());
 | 
			
		||||
      },
 | 
			
		||||
      &elemental_emitter);
 | 
			
		||||
  TF_RETURN_IF_ERROR(
 | 
			
		||||
      fused_computation->root_instruction()->Accept(&fused_emitter));
 | 
			
		||||
  for (int i = 0; i < fusion_operands.size(); i++) {
 | 
			
		||||
    auto operand_ir_arrays =
 | 
			
		||||
        absl::MakeSpan(ir_arrays).subspan(0, fusion_operands.size());
 | 
			
		||||
 | 
			
		||||
    auto* builder = &b_;
 | 
			
		||||
    auto ir_array = operand_ir_arrays[i];
 | 
			
		||||
    fused_emitter.BindGenerator(
 | 
			
		||||
        fused_computation->parameter_instruction(i),
 | 
			
		||||
        [builder, ir_array](llvm_ir::IrArray::Index index) {
 | 
			
		||||
          return ir_array.EmitReadArrayElement(index, builder);
 | 
			
		||||
        });
 | 
			
		||||
  }
 | 
			
		||||
  TF_ASSIGN_OR_RETURN(
 | 
			
		||||
      auto element_generator,
 | 
			
		||||
      fused_emitter.GetGenerator(fused_computation->root_instruction()));
 | 
			
		||||
 | 
			
		||||
  auto element_generator = fused_emitter.GetRootGenerator();
 | 
			
		||||
  Shape element_shape = TypeToShape(fusion_outputs[0].getType());
 | 
			
		||||
  LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
 | 
			
		||||
      element_shape, ir_emitter_context_->gpu_device_info(), unroll_factor);
 | 
			
		||||
@ -1022,14 +1027,14 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
 | 
			
		||||
          GpuElementalIrEmitter operand_elemental_emitter(
 | 
			
		||||
              hlo_module_config_, ir_emitter_context_->llvm_module(), &b_,
 | 
			
		||||
              GetNestedComputer());
 | 
			
		||||
          FusedIrEmitter operand_fused_emitter(
 | 
			
		||||
              GetGeneratorForOperandIrArrays(fusion),
 | 
			
		||||
              &operand_elemental_emitter);
 | 
			
		||||
          TF_RETURN_IF_ERROR(
 | 
			
		||||
              root->mutable_operand(0)->Accept(&operand_fused_emitter));
 | 
			
		||||
          FusedIrEmitter operand_fused_emitter(&operand_elemental_emitter);
 | 
			
		||||
          BindFusionArguments(fusion, &operand_fused_emitter);
 | 
			
		||||
          TF_ASSIGN_OR_RETURN(
 | 
			
		||||
              auto generator,
 | 
			
		||||
              operand_fused_emitter.GetGenerator(root->operand(0)));
 | 
			
		||||
 | 
			
		||||
          TF_RETURN_IF_ERROR(EmitTargetElementLoopInThunk(
 | 
			
		||||
              *fusion, operand_fused_emitter.GetGenerator(root->operand(0)),
 | 
			
		||||
              *fusion, generator,
 | 
			
		||||
              static_cast<KernelThunk*>(thunks.back().get()),
 | 
			
		||||
              ComputeMaxUnrollFactor(fusion)));
 | 
			
		||||
        }
 | 
			
		||||
@ -1044,10 +1049,8 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
 | 
			
		||||
          GpuElementalIrEmitter scatter_elemental_emitter(
 | 
			
		||||
              hlo_module_config_, ir_emitter_context_->llvm_module(), &b_,
 | 
			
		||||
              GetNestedComputer());
 | 
			
		||||
          FusedIrEmitter scatter_fused_emitter(
 | 
			
		||||
              GetGeneratorForOperandIrArrays(fusion),
 | 
			
		||||
              &scatter_elemental_emitter);
 | 
			
		||||
          TF_RETURN_IF_ERROR(root->Accept(&scatter_fused_emitter));
 | 
			
		||||
          FusedIrEmitter scatter_fused_emitter(&scatter_elemental_emitter);
 | 
			
		||||
          BindFusionArguments(fusion, &scatter_fused_emitter);
 | 
			
		||||
          CHECK_EQ(root->parent()->FusionInstruction(), fusion);
 | 
			
		||||
 | 
			
		||||
          TF_ASSIGN_OR_RETURN(
 | 
			
		||||
@ -1063,10 +1066,12 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
 | 
			
		||||
          desc.unique_indices = root->unique_indices();
 | 
			
		||||
          desc.update_computation = root->called_computations()[0];
 | 
			
		||||
          desc.output = GetIrArray(*fusion, *fusion);
 | 
			
		||||
          desc.scatter_indices_gen =
 | 
			
		||||
              scatter_fused_emitter.GetGenerator(root->operand(1));
 | 
			
		||||
          desc.updates_gen =
 | 
			
		||||
              scatter_fused_emitter.GetGenerator(root->operand(2));
 | 
			
		||||
          TF_ASSIGN_OR_RETURN(
 | 
			
		||||
              desc.scatter_indices_gen,
 | 
			
		||||
              scatter_fused_emitter.GetGenerator(root->operand(1)));
 | 
			
		||||
          TF_ASSIGN_OR_RETURN(
 | 
			
		||||
              desc.updates_gen,
 | 
			
		||||
              scatter_fused_emitter.GetGenerator(root->operand(2)));
 | 
			
		||||
          desc.get_index_type = [&](int64 launch_size) {
 | 
			
		||||
            return GetIndexTypeForKernel(root, launch_size, &b_);
 | 
			
		||||
          };
 | 
			
		||||
@ -1133,9 +1138,11 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
 | 
			
		||||
                           ir_emitter_context_->llvm_module());
 | 
			
		||||
    AddThunkToThunkSequence(std::move(fusion_thunk));
 | 
			
		||||
 | 
			
		||||
    FusedIrEmitter fused_emitter(&elemental_emitter);
 | 
			
		||||
    BindFusionArguments(fusion, &fused_emitter);
 | 
			
		||||
 | 
			
		||||
    return llvm_ir::EmitParallelFusedDynamicUpdateSliceInPlace(
 | 
			
		||||
        fusion, GetGeneratorForOperandIrArrays(fusion), output_array,
 | 
			
		||||
        &elemental_emitter, launch_dimensions, &b_);
 | 
			
		||||
        fusion, output_array, &fused_emitter, launch_dimensions, &b_);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  CHECK_EQ(fusion->fusion_kind(), HloInstruction::FusionKind::kLoop)
 | 
			
		||||
@ -2596,14 +2603,14 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
 | 
			
		||||
                                            ir_emitter_context_->llvm_module(),
 | 
			
		||||
                                            &b_, GetNestedComputer());
 | 
			
		||||
 | 
			
		||||
    FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(hlo),
 | 
			
		||||
                                 &elemental_emitter);
 | 
			
		||||
    TF_RETURN_IF_ERROR(init_value_operand->Accept(&fused_emitter));
 | 
			
		||||
    TF_RETURN_IF_ERROR(
 | 
			
		||||
        ParallelLoopEmitter(fused_emitter.GetGenerator(init_value_operand),
 | 
			
		||||
                            GetIrArray(*hlo, *hlo, index), launch_dimensions,
 | 
			
		||||
                            &b_)
 | 
			
		||||
            .EmitLoop(IrName(hlo)));
 | 
			
		||||
    FusedIrEmitter fused_emitter(&elemental_emitter);
 | 
			
		||||
    BindFusionArguments(hlo, &fused_emitter);
 | 
			
		||||
    TF_ASSIGN_OR_RETURN(auto generator,
 | 
			
		||||
                        fused_emitter.GetGenerator(init_value_operand));
 | 
			
		||||
    TF_RETURN_IF_ERROR(ParallelLoopEmitter(generator,
 | 
			
		||||
                                           GetIrArray(*hlo, *hlo, index),
 | 
			
		||||
                                           launch_dimensions, &b_)
 | 
			
		||||
                           .EmitLoop(IrName(hlo)));
 | 
			
		||||
  } else {
 | 
			
		||||
    // In the unfused case the element is already there, just read from it.
 | 
			
		||||
    TF_RETURN_IF_ERROR(ParallelLoopEmitter(
 | 
			
		||||
@ -3097,15 +3104,35 @@ void IrEmitterUnnested::EmitTileElementForFusion(
 | 
			
		||||
  std::vector<IrArray> output_arrays = ConstructIrArrayForOutputs(*hlo);
 | 
			
		||||
  GpuElementalIrEmitter elem_emitter(hlo_module_config_, module_, &b_,
 | 
			
		||||
                                     GetNestedComputer());
 | 
			
		||||
  FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(hlo),
 | 
			
		||||
                               &elem_emitter, x_loc, y_loc,
 | 
			
		||||
                               param_shmem_buffers);
 | 
			
		||||
 | 
			
		||||
  TF_CHECK_OK(hlo->fused_expression_root()->Accept(&fused_emitter));
 | 
			
		||||
  FusedIrEmitter fused_emitter(&elem_emitter);
 | 
			
		||||
  for (int i = 0; i < hlo->operand_count(); i++) {
 | 
			
		||||
    llvm_ir::ElementGenerator gen;
 | 
			
		||||
    if (llvm::Value* param_tile_buffer = param_shmem_buffers[i]) {
 | 
			
		||||
      gen = [this, param_tile_buffer, x_loc,
 | 
			
		||||
             y_loc](llvm_ir::IrArray::Index index) {
 | 
			
		||||
        // TODO(jlebar): Add AA metadata to this load.  Tile buffers are
 | 
			
		||||
        // global variables, so LLVM's points-to analysis doesn't help us
 | 
			
		||||
        // much.  And we want the AA info to be present before address
 | 
			
		||||
        // spaces are inferred (which is pretty late in the pipeline), so
 | 
			
		||||
        // even if we had address-space-based AA in LLVM, it wouldn't help
 | 
			
		||||
        // us much here.
 | 
			
		||||
        return b_.CreateLoad(
 | 
			
		||||
            b_.CreateGEP(param_tile_buffer,
 | 
			
		||||
                         {index.GetConstantWithIndexType(0), x_loc, y_loc}),
 | 
			
		||||
            "tiled_buffer");
 | 
			
		||||
      };
 | 
			
		||||
    } else {
 | 
			
		||||
      const HloInstruction* operand = hlo->operand(i);
 | 
			
		||||
      gen = [this, operand, hlo](llvm_ir::IrArray::Index index) {
 | 
			
		||||
        return GetIrArray(*operand, *hlo).EmitReadArrayElement(index, &b_);
 | 
			
		||||
      };
 | 
			
		||||
    }
 | 
			
		||||
    fused_emitter.BindGenerator(hlo->fused_parameter(i), std::move(gen));
 | 
			
		||||
  }
 | 
			
		||||
  IrArray::Index untiled_index = GetUnnormalizedIndex(
 | 
			
		||||
      index, output_arrays[0].GetShape(), &b_, mapping_scheme);
 | 
			
		||||
  const llvm_ir::ElementGenerator& output_generator =
 | 
			
		||||
      fused_emitter.GetRootGenerator();
 | 
			
		||||
  llvm_ir::ElementGenerator output_generator =
 | 
			
		||||
      *fused_emitter.GetGenerator(hlo->fused_expression_root());
 | 
			
		||||
  llvm::Value* output_value = output_generator(untiled_index).ValueOrDie();
 | 
			
		||||
  if (hlo->IsMultiOutputFusion()) {
 | 
			
		||||
    DCHECK(output_value->getType()->isStructTy());
 | 
			
		||||
@ -3161,14 +3188,12 @@ void IrEmitterUnnested::EmitPrologueForReduction(
 | 
			
		||||
    llvm::Value* init_ir_value;
 | 
			
		||||
    const HloInstruction* init_value = reduce_inst->operand(1);
 | 
			
		||||
    if (unnested_hlo->opcode() == HloOpcode::kFusion) {
 | 
			
		||||
      FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(unnested_hlo),
 | 
			
		||||
                                   &elemental_emitter);
 | 
			
		||||
      FusedIrEmitter fused_emitter(&elemental_emitter);
 | 
			
		||||
      BindFusionArguments(unnested_hlo, &fused_emitter);
 | 
			
		||||
 | 
			
		||||
      TF_CHECK_OK(init_value->Accept(&fused_emitter));
 | 
			
		||||
      init_ir_value =
 | 
			
		||||
          fused_emitter
 | 
			
		||||
              .GetGenerator(init_value)(IrArray::Index(b_.getInt32Ty()))
 | 
			
		||||
              .ValueOrDie();
 | 
			
		||||
      init_ir_value = (*fused_emitter.GetGenerator(init_value))(
 | 
			
		||||
                          IrArray::Index(b_.getInt32Ty()))
 | 
			
		||||
                          .ValueOrDie();
 | 
			
		||||
    } else {
 | 
			
		||||
      init_ir_value =
 | 
			
		||||
          GetIrArray(*init_value, *unnested_hlo)
 | 
			
		||||
@ -3507,21 +3532,21 @@ void IrEmitterUnnested::EmitTileElementForReduction(
 | 
			
		||||
      extra_output_gens;
 | 
			
		||||
  GpuElementalIrEmitter elem_emitter(hlo_module_config_, module_, &b_,
 | 
			
		||||
                                     GetNestedComputer());
 | 
			
		||||
  FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(unnested_hlo),
 | 
			
		||||
                               &elem_emitter);
 | 
			
		||||
  FusedIrEmitter fused_emitter(&elem_emitter);
 | 
			
		||||
 | 
			
		||||
  // Construct the ElementGenerator for each reduction and extra output in the
 | 
			
		||||
  // the group of output instructions.
 | 
			
		||||
  if (unnested_hlo->opcode() == HloOpcode::kFusion) {
 | 
			
		||||
    TF_CHECK_OK(unnested_hlo->fused_expression_root()->Accept(&fused_emitter));
 | 
			
		||||
    BindFusionArguments(unnested_hlo, &fused_emitter);
 | 
			
		||||
 | 
			
		||||
    for (int i = 0, e = output_instructions.size(); i != e; ++i) {
 | 
			
		||||
      const HloInstruction* inst = output_instructions[i];
 | 
			
		||||
      ShapeIndex idx =
 | 
			
		||||
          CreateShapeIndexForOutputInstruction(*unnested_hlo, *inst);
 | 
			
		||||
      if (IsReductionFromOrToContiguousDimensions(*inst)) {
 | 
			
		||||
        input_gens.push_back(fused_emitter.GetGenerator(inst->operand(0)));
 | 
			
		||||
        input_gens.push_back(*fused_emitter.GetGenerator(inst->operand(0)));
 | 
			
		||||
      } else {
 | 
			
		||||
        extra_output_gens.emplace_back(fused_emitter.GetGenerator(inst),
 | 
			
		||||
        extra_output_gens.emplace_back(*fused_emitter.GetGenerator(inst),
 | 
			
		||||
                                       std::move(idx));
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
@ -4506,11 +4531,10 @@ void IrEmitterUnnested::EmitElementForInputFusibleSlices(
 | 
			
		||||
  std::vector<llvm::Value*> input_ir_values;
 | 
			
		||||
  GpuElementalIrEmitter elem_emitter(hlo_module_config_, module_, &b_,
 | 
			
		||||
                                     GetNestedComputer());
 | 
			
		||||
  FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(unnested_hlo),
 | 
			
		||||
                               &elem_emitter);
 | 
			
		||||
  TF_CHECK_OK(unnested_hlo->fused_expression_root()->Accept(&fused_emitter));
 | 
			
		||||
  FusedIrEmitter fused_emitter(&elem_emitter);
 | 
			
		||||
  BindFusionArguments(unnested_hlo, &fused_emitter);
 | 
			
		||||
  for (const HloInstruction* slice : slice_instructions) {
 | 
			
		||||
    auto input_generator = fused_emitter.GetGenerator(slice->operand(0));
 | 
			
		||||
    auto input_generator = *fused_emitter.GetGenerator(slice->operand(0));
 | 
			
		||||
    input_ir_values.push_back(input_generator(index).ValueOrDie());
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -17,7 +17,6 @@ limitations under the License.
 | 
			
		||||
 | 
			
		||||
#include "tensorflow/compiler/xla/service/gpu/launch_dimensions.h"
 | 
			
		||||
#include "tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h"
 | 
			
		||||
#include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
 | 
			
		||||
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
 | 
			
		||||
#include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
 | 
			
		||||
 | 
			
		||||
@ -190,9 +189,8 @@ Status EmitDynamicUpdateSliceInPlace(absl::Span<const IrArray> operand_arrays,
 | 
			
		||||
//
 | 
			
		||||
// Emits a sequential loop if launch_dimensions is null.
 | 
			
		||||
static Status EmitFusedDynamicUpdateSliceInPlaceImpl(
 | 
			
		||||
    HloInstruction* fusion,
 | 
			
		||||
    GeneratorForOperandIrArrays operand_arrays_generator,
 | 
			
		||||
    const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter,
 | 
			
		||||
    HloInstruction* fusion, const IrArray& fusion_output_array,
 | 
			
		||||
    FusedIrEmitter* fused_emitter,
 | 
			
		||||
    const gpu::LaunchDimensions* launch_dimensions, llvm::IRBuilder<>* b) {
 | 
			
		||||
  CHECK_EQ(fusion->opcode(), HloOpcode::kFusion);
 | 
			
		||||
  VLOG(2) << "EmitFusedDynamicUpdateSliceInPlace for "
 | 
			
		||||
@ -221,14 +219,14 @@ static Status EmitFusedDynamicUpdateSliceInPlaceImpl(
 | 
			
		||||
      LayoutUtil::CopyLayoutBetweenShapes(fusion->shape(), &update_shape));
 | 
			
		||||
 | 
			
		||||
  // Create element generators for update and start_indices.
 | 
			
		||||
  FusedIrEmitter fused_emitter(std::move(operand_arrays_generator),
 | 
			
		||||
                               elemental_emitter);
 | 
			
		||||
  TF_RETURN_IF_ERROR(dynamic_update_slice->Accept(&fused_emitter));
 | 
			
		||||
  ElementGenerator update_array_generator = fused_emitter.GetGenerator(update);
 | 
			
		||||
  TF_ASSIGN_OR_RETURN(ElementGenerator update_array_generator,
 | 
			
		||||
                      fused_emitter->GetGenerator(update));
 | 
			
		||||
 | 
			
		||||
  IndexGenerator start_indices_generator = [&](int64 index) {
 | 
			
		||||
    ElementGenerator element_generator =
 | 
			
		||||
        fused_emitter.GetGenerator(dynamic_update_slice->operand(2 + index));
 | 
			
		||||
  IndexGenerator start_indices_generator =
 | 
			
		||||
      [&](int64 index) -> StatusOr<llvm::Value*> {
 | 
			
		||||
    TF_ASSIGN_OR_RETURN(
 | 
			
		||||
        ElementGenerator element_generator,
 | 
			
		||||
        fused_emitter->GetGenerator(dynamic_update_slice->operand(2 + index)));
 | 
			
		||||
    return element_generator(IrArray::Index(b->getInt64Ty()));
 | 
			
		||||
  };
 | 
			
		||||
  bool is_signed = ShapeUtil::ElementIsSigned(start_indices->shape());
 | 
			
		||||
@ -237,25 +235,21 @@ static Status EmitFusedDynamicUpdateSliceInPlaceImpl(
 | 
			
		||||
      fusion_output_array, launch_dimensions, IrName(fusion), b);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Status EmitFusedDynamicUpdateSliceInPlace(
 | 
			
		||||
    HloInstruction* fusion,
 | 
			
		||||
    GeneratorForOperandIrArrays operand_arrays_generator,
 | 
			
		||||
    const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter,
 | 
			
		||||
    llvm::IRBuilder<>* b) {
 | 
			
		||||
Status EmitFusedDynamicUpdateSliceInPlace(HloInstruction* fusion,
 | 
			
		||||
                                          const IrArray& fusion_output_array,
 | 
			
		||||
                                          FusedIrEmitter* fused_emitter,
 | 
			
		||||
                                          llvm::IRBuilder<>* b) {
 | 
			
		||||
  return EmitFusedDynamicUpdateSliceInPlaceImpl(
 | 
			
		||||
      fusion, std::move(operand_arrays_generator), fusion_output_array,
 | 
			
		||||
      elemental_emitter,
 | 
			
		||||
      fusion, fusion_output_array, fused_emitter,
 | 
			
		||||
      /*launch_dimensions=*/nullptr, b);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Status EmitParallelFusedDynamicUpdateSliceInPlace(
 | 
			
		||||
    HloInstruction* fusion,
 | 
			
		||||
    GeneratorForOperandIrArrays operand_arrays_generator,
 | 
			
		||||
    const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter,
 | 
			
		||||
    HloInstruction* fusion, const IrArray& fusion_output_array,
 | 
			
		||||
    FusedIrEmitter* fused_emitter,
 | 
			
		||||
    const gpu::LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* b) {
 | 
			
		||||
  return EmitFusedDynamicUpdateSliceInPlaceImpl(
 | 
			
		||||
      fusion, std::move(operand_arrays_generator), fusion_output_array,
 | 
			
		||||
      elemental_emitter, &launch_dimensions, b);
 | 
			
		||||
      fusion, fusion_output_array, fused_emitter, &launch_dimensions, b);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace llvm_ir
 | 
			
		||||
 | 
			
		||||
@ -20,6 +20,7 @@ limitations under the License.
 | 
			
		||||
#include "tensorflow/compiler/xla/service/elemental_ir_emitter.h"
 | 
			
		||||
#include "tensorflow/compiler/xla/service/gpu/launch_dimensions.h"
 | 
			
		||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
 | 
			
		||||
#include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
 | 
			
		||||
#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
 | 
			
		||||
 | 
			
		||||
// Utilities related to emitting LLVM IR for various HLO ops.
 | 
			
		||||
@ -71,18 +72,16 @@ Status EmitDynamicUpdateSliceInPlace(absl::Span<const IrArray> operand_arrays,
 | 
			
		||||
// array-to-be-updated and output share the same buffer slice, emits
 | 
			
		||||
// (sequential) code for a fusion node that does the dynamic-update-slice in
 | 
			
		||||
// place.
 | 
			
		||||
Status EmitFusedDynamicUpdateSliceInPlace(
 | 
			
		||||
    HloInstruction* fusion,
 | 
			
		||||
    GeneratorForOperandIrArrays operand_arrays_generator,
 | 
			
		||||
    const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter,
 | 
			
		||||
    llvm::IRBuilder<>* b);
 | 
			
		||||
Status EmitFusedDynamicUpdateSliceInPlace(HloInstruction* fusion,
 | 
			
		||||
                                          const IrArray& fusion_output_array,
 | 
			
		||||
                                          FusedIrEmitter* fused_emitter,
 | 
			
		||||
                                          llvm::IRBuilder<>* b);
 | 
			
		||||
 | 
			
		||||
// Same as EmitFusedDynamicUpdateSliceInPlace, except emits a parallel loop with
 | 
			
		||||
// the given launch dimensions.
 | 
			
		||||
Status EmitParallelFusedDynamicUpdateSliceInPlace(
 | 
			
		||||
    HloInstruction* fusion,
 | 
			
		||||
    GeneratorForOperandIrArrays operand_arrays_generator,
 | 
			
		||||
    const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter,
 | 
			
		||||
    HloInstruction* fusion, const IrArray& fusion_output_array,
 | 
			
		||||
    FusedIrEmitter* fused_emitter,
 | 
			
		||||
    const gpu::LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* b);
 | 
			
		||||
 | 
			
		||||
}  // namespace llvm_ir
 | 
			
		||||
 | 
			
		||||
@ -114,25 +114,9 @@ Status FusedIrEmitter::HandleGetTupleElement(
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Status FusedIrEmitter::HandleParameter(const HloInstruction* parameter) {
 | 
			
		||||
  indexed_generators_[parameter] =
 | 
			
		||||
      [=](const IrArray::Index& index) -> llvm::Value* {
 | 
			
		||||
    int64 param_num = parameter->parameter_number();
 | 
			
		||||
    if (param_shmem_buffers_.size() > param_num) {
 | 
			
		||||
      if (llvm::Value* param_tile_buffer = param_shmem_buffers_[param_num]) {
 | 
			
		||||
        // TODO(jlebar): Add AA metadata to this load.  Tile buffers are global
 | 
			
		||||
        // variables, so LLVM's points-to analysis doesn't help us much.  And we
 | 
			
		||||
        // want the AA info to be present before address spaces are inferred
 | 
			
		||||
        // (which is pretty late in the pipeline), so even if we had
 | 
			
		||||
        // address-space-based AA in LLVM, it wouldn't help us much here.
 | 
			
		||||
        return b_->CreateLoad(
 | 
			
		||||
            b_->CreateGEP(param_tile_buffer, {index.GetConstantWithIndexType(0),
 | 
			
		||||
                                              thread_id_x_, thread_id_y_}),
 | 
			
		||||
            "tiled_buffer");
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
    return GetIrArrayForFusedParameter(param_num).EmitReadArrayElement(index,
 | 
			
		||||
                                                                       b_);
 | 
			
		||||
  };
 | 
			
		||||
  if (indexed_generators_.find(parameter) == indexed_generators_.end()) {
 | 
			
		||||
    return InvalidArgument("Unbound parameter: %s", parameter->ToString());
 | 
			
		||||
  }
 | 
			
		||||
  return Status::OK();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -157,22 +141,6 @@ Status FusedIrEmitter::HandleTuple(const HloInstruction* tuple) {
 | 
			
		||||
  return Status::OK();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Status FusedIrEmitter::FinishVisit(const HloInstruction* root) {
 | 
			
		||||
  fused_root_ = root;
 | 
			
		||||
  return Status::OK();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
FusedIrEmitter::IndexedGenerator FusedIrEmitter::GetRootGenerator() const {
 | 
			
		||||
  CHECK_NE(nullptr, fused_root_)
 | 
			
		||||
      << "GetRootGenerator should be called after Accept.";
 | 
			
		||||
  return indexed_generators_.at(fused_root_);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
FusedIrEmitter::IndexedGenerator FusedIrEmitter::GetGenerator(
 | 
			
		||||
    const HloInstruction* instruction) const {
 | 
			
		||||
  return indexed_generators_.at(instruction);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool FusedIrEmitter::IsFusedIrEmitterInefficient(
 | 
			
		||||
    const HloInstruction* consumer, const HloInstruction* producer) {
 | 
			
		||||
  if (consumer->opcode() != HloOpcode::kFusion) {
 | 
			
		||||
@ -189,4 +157,39 @@ bool FusedIrEmitter::IsFusedIrEmitterInefficient(
 | 
			
		||||
  return eval_producer.MaxCodeDuplicationTooHigh();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
StatusOr<FusedIrEmitter::IndexedGenerator> FusedIrEmitter::GetGenerator(
 | 
			
		||||
    const HloInstruction* instruction) {
 | 
			
		||||
  std::vector<const HloInstruction*> stack;
 | 
			
		||||
  stack.push_back(instruction);
 | 
			
		||||
  while (!stack.empty()) {
 | 
			
		||||
    const HloInstruction* instr = stack.back();
 | 
			
		||||
    stack.pop_back();
 | 
			
		||||
    if (indexed_generators_.count(instr)) {
 | 
			
		||||
      continue;
 | 
			
		||||
    }
 | 
			
		||||
    for (const HloInstruction* operand : instr->operands()) {
 | 
			
		||||
      stack.push_back(operand);
 | 
			
		||||
    }
 | 
			
		||||
    switch (instr->opcode()) {
 | 
			
		||||
      case HloOpcode::kConstant:
 | 
			
		||||
        TF_RETURN_IF_ERROR(HandleConstant(instr));
 | 
			
		||||
        break;
 | 
			
		||||
      case HloOpcode::kGetTupleElement:
 | 
			
		||||
        TF_RETURN_IF_ERROR(HandleGetTupleElement(instr));
 | 
			
		||||
        break;
 | 
			
		||||
      case HloOpcode::kParameter:
 | 
			
		||||
        TF_RETURN_IF_ERROR(HandleParameter(instr));
 | 
			
		||||
        break;
 | 
			
		||||
      case HloOpcode::kTuple:
 | 
			
		||||
        TF_RETURN_IF_ERROR(HandleTuple(instr));
 | 
			
		||||
        break;
 | 
			
		||||
      default:
 | 
			
		||||
        TF_RETURN_IF_ERROR(DefaultAction(instr));
 | 
			
		||||
        break;
 | 
			
		||||
    }
 | 
			
		||||
    CHECK(indexed_generators_.count(instr));
 | 
			
		||||
  }
 | 
			
		||||
  return indexed_generators_.at(instruction);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace xla
 | 
			
		||||
 | 
			
		||||
@ -24,7 +24,6 @@ limitations under the License.
 | 
			
		||||
#include "absl/types/span.h"
 | 
			
		||||
#include "llvm/IR/IRBuilder.h"
 | 
			
		||||
#include "llvm/IR/Value.h"
 | 
			
		||||
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
 | 
			
		||||
#include "tensorflow/compiler/xla/service/elemental_ir_emitter.h"
 | 
			
		||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
 | 
			
		||||
#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
 | 
			
		||||
@ -51,47 +50,22 @@ namespace xla {
 | 
			
		||||
// created produces an LLVM struct with N elements, one for each element of the
 | 
			
		||||
// arrays in the tuple.  It follows that the arrays in the tuple must have the
 | 
			
		||||
// same length.
 | 
			
		||||
class FusedIrEmitter : public ConstDfsHloVisitorWithDefault {
 | 
			
		||||
class FusedIrEmitter {
 | 
			
		||||
 public:
 | 
			
		||||
  using IndexedGenerator = llvm_ir::ElementGenerator;
 | 
			
		||||
  using NonIndexedGenerator = std::function<StatusOr<llvm::Value*>()>;
 | 
			
		||||
  using GeneratorForOperandIrArrays =
 | 
			
		||||
      std::function<std::vector<llvm_ir::IrArray>()>;
 | 
			
		||||
 | 
			
		||||
  FusedIrEmitter(GeneratorForOperandIrArrays operand_arrays_generator,
 | 
			
		||||
                 ElementalIrEmitter* elemental_emitter,
 | 
			
		||||
                 llvm::Value* thread_id_x = nullptr,
 | 
			
		||||
                 llvm::Value* thread_id_y = nullptr,
 | 
			
		||||
                 absl::Span<llvm::Value* const> param_shmem_buffers = {})
 | 
			
		||||
      : operand_arrays_(),
 | 
			
		||||
        operand_arrays_generator_(std::move(operand_arrays_generator)),
 | 
			
		||||
        thread_id_x_(thread_id_x),
 | 
			
		||||
        thread_id_y_(thread_id_y),
 | 
			
		||||
        param_shmem_buffers_(param_shmem_buffers.begin(),
 | 
			
		||||
                             param_shmem_buffers.end()),
 | 
			
		||||
        elemental_emitter_(elemental_emitter),
 | 
			
		||||
  explicit FusedIrEmitter(ElementalIrEmitter* elemental_emitter)
 | 
			
		||||
      : elemental_emitter_(elemental_emitter),
 | 
			
		||||
        b_(elemental_emitter->b()),
 | 
			
		||||
        module_(elemental_emitter->module()) {}
 | 
			
		||||
 | 
			
		||||
  Status DefaultAction(const HloInstruction* hlo) override;
 | 
			
		||||
 | 
			
		||||
  Status HandleConstant(const HloInstruction* constant) override;
 | 
			
		||||
 | 
			
		||||
  Status HandleGetTupleElement(
 | 
			
		||||
      const HloInstruction* get_tuple_element) override;
 | 
			
		||||
 | 
			
		||||
  Status HandleParameter(const HloInstruction* parameter) override;
 | 
			
		||||
 | 
			
		||||
  // Emits the ir value for each element in the tuple.
 | 
			
		||||
  Status HandleTuple(const HloInstruction* tuple) override;
 | 
			
		||||
 | 
			
		||||
  Status FinishVisit(const HloInstruction* root) override;
 | 
			
		||||
 | 
			
		||||
  // Returns the generator function for the root of the fused computation.
 | 
			
		||||
  IndexedGenerator GetRootGenerator() const;
 | 
			
		||||
  void BindGenerator(const HloInstruction* hlo,
 | 
			
		||||
                     llvm_ir::ElementGenerator generator) {
 | 
			
		||||
    indexed_generators_[hlo] = std::move(generator);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Returns the generator function for the given instruction.
 | 
			
		||||
  IndexedGenerator GetGenerator(const HloInstruction* instruction) const;
 | 
			
		||||
  StatusOr<IndexedGenerator> GetGenerator(const HloInstruction* instruction);
 | 
			
		||||
 | 
			
		||||
  // Evaluates whether fusing 'producer' into 'consumer' might cause exponential
 | 
			
		||||
  // behavior in FusedIrEmitter. We currently can have exponential time/memory
 | 
			
		||||
@ -101,40 +75,20 @@ class FusedIrEmitter : public ConstDfsHloVisitorWithDefault {
 | 
			
		||||
  static bool IsFusedIrEmitterInefficient(const HloInstruction* consumer,
 | 
			
		||||
                                          const HloInstruction* producer);
 | 
			
		||||
 | 
			
		||||
 protected:
 | 
			
		||||
  // Returns the IrArrays for the fusion instruction operands.
 | 
			
		||||
  llvm_ir::IrArray& GetIrArrayForFusedParameter(int64 parameter_number) {
 | 
			
		||||
    if (!operand_arrays_.has_value()) {
 | 
			
		||||
      operand_arrays_ = operand_arrays_generator_();
 | 
			
		||||
    }
 | 
			
		||||
    return operand_arrays_.value()[parameter_number];
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  llvm::Value* GetBasePointerForFusedParameter(int64 parameter_number) {
 | 
			
		||||
    return GetIrArrayForFusedParameter(parameter_number).GetBasePointer();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
 private:
 | 
			
		||||
  // IrArrays for the fusion instruction operands, whose base addresses are the
 | 
			
		||||
  // base address of the corresponding parameters in the fused computation.
 | 
			
		||||
  absl::optional<std::vector<llvm_ir::IrArray>> operand_arrays_;
 | 
			
		||||
  GeneratorForOperandIrArrays operand_arrays_generator_;
 | 
			
		||||
  Status DefaultAction(const HloInstruction* hlo);
 | 
			
		||||
 | 
			
		||||
  // The x coordinate within a tile.
 | 
			
		||||
  llvm::Value* thread_id_x_;
 | 
			
		||||
  Status HandleConstant(const HloInstruction* constant);
 | 
			
		||||
 | 
			
		||||
  // The y coordinate within a tile.
 | 
			
		||||
  llvm::Value* thread_id_y_;
 | 
			
		||||
  Status HandleGetTupleElement(const HloInstruction* get_tuple_element);
 | 
			
		||||
 | 
			
		||||
  // Param_buffers_[i] stores the tile buffer for the ith parameter or nullptr
 | 
			
		||||
  // if the parameter is not tiled.
 | 
			
		||||
  std::vector<llvm::Value*> param_shmem_buffers_;
 | 
			
		||||
  Status HandleParameter(const HloInstruction* parameter);
 | 
			
		||||
 | 
			
		||||
  // Emits the ir value for each element in the tuple.
 | 
			
		||||
  Status HandleTuple(const HloInstruction* tuple);
 | 
			
		||||
 | 
			
		||||
  ElementalIrEmitter* elemental_emitter_;
 | 
			
		||||
 | 
			
		||||
  // This member will be set by FinishVisit and used in GetRootGenerator.
 | 
			
		||||
  const HloInstruction* fused_root_ = nullptr;
 | 
			
		||||
 | 
			
		||||
  // Borrowed
 | 
			
		||||
  llvm::IRBuilder<>* b_;
 | 
			
		||||
  llvm::Module* module_;
 | 
			
		||||
@ -145,12 +99,6 @@ class FusedIrEmitter : public ConstDfsHloVisitorWithDefault {
 | 
			
		||||
  std::unordered_map<const HloInstruction*, IndexedGenerator>
 | 
			
		||||
      indexed_generators_;
 | 
			
		||||
 | 
			
		||||
  // Map from tuple-result-producing GetTupleELement instructions to functions
 | 
			
		||||
  // that generate the base pointers for the output elements. This is used to
 | 
			
		||||
  // support the translation of nested GetTupleElement instructions.
 | 
			
		||||
  std::unordered_map<const HloInstruction*, NonIndexedGenerator>
 | 
			
		||||
      non_indexed_generators_;
 | 
			
		||||
 | 
			
		||||
  // Cache of generated values, lest we regenerate an element of a node with
 | 
			
		||||
  // multiple outgoing edges
 | 
			
		||||
  absl::flat_hash_map<
 | 
			
		||||
 | 
			
		||||
@ -521,8 +521,7 @@ XLA_TEST_F(ConcatTest, ConcatDeeplyNested) {
 | 
			
		||||
  ComputeAndCompareR1<float>(&builder, expected, {a_data.get()});
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// TODO(b/169314478): Enable the test when the slow compilation is fixed.
 | 
			
		||||
XLA_TEST_F(ConcatTestHlo, DISABLED_ConcatWithBitcast) {
 | 
			
		||||
XLA_TEST_F(ConcatTestHlo, ConcatWithBitcast) {
 | 
			
		||||
  auto module = ParseAndReturnVerifiedModule(R"(
 | 
			
		||||
HloModule jit_broken.874
 | 
			
		||||
 | 
			
		||||
@ -762,7 +761,7 @@ ENTRY jit_broken.874 {
 | 
			
		||||
  auto input_array = absl::make_unique<Array2D<float>>(4, 2);
 | 
			
		||||
  input_array->FillUnique(1.0f);
 | 
			
		||||
  auto input = LiteralUtil::CreateR2FromArray2D<float>(*input_array);
 | 
			
		||||
  EXPECT_TRUE(RunAndCompare(std::move(module), {&input}, absl::nullopt));
 | 
			
		||||
  EXPECT_TRUE(RunAndCompare(std::move(module), {&input}, error_spec_));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Describes a binary rank-2 concatenation test.
 | 
			
		||||
 | 
			
		||||
@ -354,8 +354,11 @@ void BaseCollectiveExecutor::CompleteParamsAsync(
 | 
			
		||||
  // timeout callback executes, done_safe will become a no-op and the timeout
 | 
			
		||||
  // callback is responsible for invoking done() at the end.
 | 
			
		||||
  const auto is_callback_called = std::make_shared<std::atomic<bool>>(false);
 | 
			
		||||
  auto done_safe = [this, is_callback_called, cancel_mgr,
 | 
			
		||||
  auto trace_id =
 | 
			
		||||
      profiler::TraceMe::ActivityStart("CollectiveExecutor::CompleteParams");
 | 
			
		||||
  auto done_safe = [this, is_callback_called, cancel_mgr, trace_id,
 | 
			
		||||
                    done](const Status& s) {
 | 
			
		||||
    profiler::TraceMe::ActivityEnd(trace_id);
 | 
			
		||||
    bool called = is_callback_called->exchange(true);
 | 
			
		||||
    if (!called) {
 | 
			
		||||
      if (!s.ok() && !IsCancelled(cancel_mgr)) {
 | 
			
		||||
 | 
			
		||||
@ -2587,11 +2587,9 @@ TEST(DirectSessionTest,
 | 
			
		||||
 | 
			
		||||
// A simple benchmark for the overhead of `DirectSession::Run()` calls
 | 
			
		||||
// with varying numbers of feeds/fetches.
 | 
			
		||||
void FeedFetchBenchmarkHelper(int iters, int num_feeds, bool use_make_callable,
 | 
			
		||||
                              int inter_op_threads,
 | 
			
		||||
void FeedFetchBenchmarkHelper(::testing::benchmark::State& state, int num_feeds,
 | 
			
		||||
                              bool use_make_callable, int inter_op_threads,
 | 
			
		||||
                              bool use_single_threaded_executor) {
 | 
			
		||||
  testing::StopTiming();
 | 
			
		||||
 | 
			
		||||
  Tensor value(DT_FLOAT, TensorShape());
 | 
			
		||||
  value.flat<float>()(0) = 37.0;
 | 
			
		||||
 | 
			
		||||
@ -2643,13 +2641,11 @@ void FeedFetchBenchmarkHelper(int iters, int num_feeds, bool use_make_callable,
 | 
			
		||||
    }
 | 
			
		||||
    TF_CHECK_OK(session->MakeCallable(callable_options, &handle));
 | 
			
		||||
 | 
			
		||||
    testing::StartTiming();
 | 
			
		||||
    for (int i = 0; i < iters; ++i) {
 | 
			
		||||
    for (auto s : state) {
 | 
			
		||||
      std::vector<Tensor> output_values;
 | 
			
		||||
      TF_CHECK_OK(
 | 
			
		||||
          session->RunCallable(handle, input_tensors, &output_values, nullptr));
 | 
			
		||||
    }
 | 
			
		||||
    testing::StopTiming();
 | 
			
		||||
  } else {
 | 
			
		||||
    {
 | 
			
		||||
      // NOTE(mrry): Ignore the first run, which will incur the graph
 | 
			
		||||
@ -2661,32 +2657,40 @@ void FeedFetchBenchmarkHelper(int iters, int num_feeds, bool use_make_callable,
 | 
			
		||||
      std::vector<Tensor> output_values;
 | 
			
		||||
      TF_CHECK_OK(session->Run(inputs, outputs, {}, &output_values));
 | 
			
		||||
    }
 | 
			
		||||
    testing::StartTiming();
 | 
			
		||||
    for (int i = 0; i < iters; ++i) {
 | 
			
		||||
 | 
			
		||||
    for (auto s : state) {
 | 
			
		||||
      std::vector<Tensor> output_values;
 | 
			
		||||
      TF_CHECK_OK(session->Run(inputs, outputs, {}, &output_values));
 | 
			
		||||
    }
 | 
			
		||||
    testing::StopTiming();
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void BM_FeedFetch(int iters, int num_feeds) {
 | 
			
		||||
  FeedFetchBenchmarkHelper(iters, num_feeds, /* use_make_callable */ false,
 | 
			
		||||
void BM_FeedFetch(::testing::benchmark::State& state) {
 | 
			
		||||
  const int num_feeds = state.range(0);
 | 
			
		||||
 | 
			
		||||
  FeedFetchBenchmarkHelper(state, num_feeds, /* use_make_callable */ false,
 | 
			
		||||
                           /* inter_op_threads */ 0,
 | 
			
		||||
                           /* use_single_threaded_executor */ false);
 | 
			
		||||
}
 | 
			
		||||
void BM_FeedFetchCallable(int iters, int num_feeds) {
 | 
			
		||||
  FeedFetchBenchmarkHelper(iters, num_feeds, /* use_make_callable */ true,
 | 
			
		||||
void BM_FeedFetchCallable(::testing::benchmark::State& state) {
 | 
			
		||||
  const int num_feeds = state.range(0);
 | 
			
		||||
 | 
			
		||||
  FeedFetchBenchmarkHelper(state, num_feeds, /* use_make_callable */ true,
 | 
			
		||||
                           /* inter_op_threads */ 0,
 | 
			
		||||
                           /* use_single_threaded_executor */ false);
 | 
			
		||||
}
 | 
			
		||||
void BM_FeedFetchCallableSingleThread(int iters, int num_feeds) {
 | 
			
		||||
  FeedFetchBenchmarkHelper(iters, num_feeds, /* use_make_callable */ true,
 | 
			
		||||
void BM_FeedFetchCallableSingleThread(::testing::benchmark::State& state) {
 | 
			
		||||
  const int num_feeds = state.range(0);
 | 
			
		||||
 | 
			
		||||
  FeedFetchBenchmarkHelper(state, num_feeds, /* use_make_callable */ true,
 | 
			
		||||
                           /* inter_op_threads */ -1,
 | 
			
		||||
                           /* use_single_threaded_executor */ false);
 | 
			
		||||
}
 | 
			
		||||
void BM_FeedFetchCallableSingleThreadExecutor(int iters, int num_feeds) {
 | 
			
		||||
  FeedFetchBenchmarkHelper(iters, num_feeds, /* use_make_callable */ true,
 | 
			
		||||
void BM_FeedFetchCallableSingleThreadExecutor(
 | 
			
		||||
    ::testing::benchmark::State& state) {
 | 
			
		||||
  const int num_feeds = state.range(0);
 | 
			
		||||
 | 
			
		||||
  FeedFetchBenchmarkHelper(state, num_feeds, /* use_make_callable */ true,
 | 
			
		||||
                           /* inter_op_threads */ -1,
 | 
			
		||||
                           /* use_single_threaded_executor */ true);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -378,6 +378,12 @@ Status EagerOperation::InferInputListAttrs(int num_inputs) {
 | 
			
		||||
  } else if (!input_def.type_attr().empty() &&
 | 
			
		||||
             !input_def.number_attr().empty()) {
 | 
			
		||||
    InferSingleTypeInputListAttrs(input_def, inputs_[start]->dtype, num_inputs);
 | 
			
		||||
  } else if (!input_def.number_attr().empty()) {
 | 
			
		||||
    if (inference_attrs_.find(input_def.number_attr()) ==
 | 
			
		||||
        inference_attrs_.end()) {
 | 
			
		||||
      MutableAttrs()->Set(input_def.number_attr(), num_inputs);
 | 
			
		||||
      inference_attrs_.insert(input_def.number_attr());
 | 
			
		||||
    }
 | 
			
		||||
  } else {
 | 
			
		||||
    return errors::InvalidArgument("Invalid input list definition");
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
@ -69,8 +69,8 @@ class TestEnv {
 | 
			
		||||
  Device* cpu_device_;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
void BM_CreateGraph(int iters) {
 | 
			
		||||
  for (int i = 0; i < iters; ++i) {
 | 
			
		||||
void BM_CreateGraph(::testing::benchmark::State& state) {
 | 
			
		||||
  for (auto s : state) {
 | 
			
		||||
    Scope root = Scope::NewRootScope();
 | 
			
		||||
    auto C = ops::Const(root, {{1.0, 2.0}, {3.0, 4.0}});
 | 
			
		||||
    auto M = ops::MatMul(root, C, C);
 | 
			
		||||
@ -79,8 +79,7 @@ void BM_CreateGraph(int iters) {
 | 
			
		||||
}
 | 
			
		||||
BENCHMARK(BM_CreateGraph);
 | 
			
		||||
 | 
			
		||||
void BM_RunGraph(int iters) {
 | 
			
		||||
  tensorflow::testing::StopTiming();
 | 
			
		||||
void BM_RunGraph(::testing::benchmark::State& state) {
 | 
			
		||||
  Scope root = Scope::NewRootScope();
 | 
			
		||||
  auto C = ops::Const(root, {{1.0, 2.0}, {3.0, 4.0}});
 | 
			
		||||
  auto M = ops::MatMul(root, C, C);
 | 
			
		||||
@ -89,28 +88,24 @@ void BM_RunGraph(int iters) {
 | 
			
		||||
  opts.config.set_intra_op_parallelism_threads(1);
 | 
			
		||||
  ClientSession sess(root, opts);
 | 
			
		||||
  std::vector<Tensor> outputs;
 | 
			
		||||
  tensorflow::testing::StartTiming();
 | 
			
		||||
  for (int i = 0; i < iters; ++i) {
 | 
			
		||||
  for (auto s : state) {
 | 
			
		||||
    outputs.clear();
 | 
			
		||||
    TF_CHECK_OK(sess.Run({M}, &outputs));
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
BENCHMARK(BM_RunGraph);
 | 
			
		||||
 | 
			
		||||
void BM_CreateAndDestroySession(int iters) {
 | 
			
		||||
  tensorflow::testing::StopTiming();
 | 
			
		||||
void BM_CreateAndDestroySession(::testing::benchmark::State& state) {
 | 
			
		||||
  Scope root = Scope::NewRootScope();
 | 
			
		||||
  auto C = ops::Const(root, {{1.0, 2.0}, {3.0, 4.0}});
 | 
			
		||||
  auto M = ops::MatMul(root, C, C);
 | 
			
		||||
  tensorflow::testing::StartTiming();
 | 
			
		||||
  for (int i = 0; i < iters; ++i) {
 | 
			
		||||
  for (auto s : state) {
 | 
			
		||||
    ClientSession sess(root);
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
BENCHMARK(BM_CreateAndDestroySession);
 | 
			
		||||
 | 
			
		||||
void BM_KernelAndDeviceInit(int iters) {
 | 
			
		||||
  tensorflow::testing::StopTiming();
 | 
			
		||||
void BM_KernelAndDeviceInit(::testing::benchmark::State& state) {
 | 
			
		||||
  NodeDef ndef(AttrBuilder("MatMul")
 | 
			
		||||
                   .Set("T", DT_FLOAT)
 | 
			
		||||
                   .Set("transpose_a", false)
 | 
			
		||||
@ -120,15 +115,13 @@ void BM_KernelAndDeviceInit(int iters) {
 | 
			
		||||
  TestEnv env;
 | 
			
		||||
  KernelAndDeviceOp k(nullptr, false, env.function_library_runtime(), nullptr,
 | 
			
		||||
                      nullptr, env.cpu_device());
 | 
			
		||||
  tensorflow::testing::StartTiming();
 | 
			
		||||
  for (int i = 0; i < iters; ++i) {
 | 
			
		||||
  for (auto s : state) {
 | 
			
		||||
    TF_CHECK_OK(k.Init({}, ndef, nullptr));
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
BENCHMARK(BM_KernelAndDeviceInit);
 | 
			
		||||
 | 
			
		||||
void BM_KernelAndDeviceRun(int iters) {
 | 
			
		||||
  tensorflow::testing::StopTiming();
 | 
			
		||||
void BM_KernelAndDeviceRun(::testing::benchmark::State& state) {
 | 
			
		||||
  Tensor t(Input({{1.0f, 2.0f}, {3.0f, 4.0f}}).tensor());
 | 
			
		||||
  gtl::InlinedVector<TensorValue, 4> inputs;
 | 
			
		||||
  inputs.push_back(TensorValue(&t));
 | 
			
		||||
@ -145,8 +138,7 @@ void BM_KernelAndDeviceRun(int iters) {
 | 
			
		||||
                      nullptr, env.cpu_device());
 | 
			
		||||
  TF_CHECK_OK(k.Init({}, ndef, nullptr));
 | 
			
		||||
  const EagerKernelArgs args(std::move(inputs));
 | 
			
		||||
  tensorflow::testing::StartTiming();
 | 
			
		||||
  for (int i = 0; i < iters; ++i) {
 | 
			
		||||
  for (auto s : state) {
 | 
			
		||||
    TF_CHECK_OK(k.Run(nullptr, args, &outputs, nullptr, absl::nullopt));
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -433,11 +433,10 @@ TEST_F(ExecutorTest, NoInputTensors) {
 | 
			
		||||
// Create a graph that is 'depth' deep. At each level, fan-in and fan-out a
 | 
			
		||||
// maximum of 'width' nodes. All nodes are no-ops and all dependencies are
 | 
			
		||||
// control dependencies.
 | 
			
		||||
static void BM_executor(int iters, int width, int depth) {
 | 
			
		||||
  testing::StopTiming();
 | 
			
		||||
#ifdef PLATFORM_GOOGLE
 | 
			
		||||
  BenchmarkUseRealTime();
 | 
			
		||||
#endif  // PLATFORM_GOOGLE
 | 
			
		||||
static void BM_executor(::testing::benchmark::State& state) {
 | 
			
		||||
  const int width = state.range(0);
 | 
			
		||||
  const int depth = state.range(1);
 | 
			
		||||
 | 
			
		||||
  Graph* g = new Graph(OpRegistry::Global());
 | 
			
		||||
  random::PhiloxRandom philox(1729, 17);
 | 
			
		||||
  random::SimplePhilox rand(&philox);
 | 
			
		||||
@ -466,30 +465,29 @@ static void BM_executor(int iters, int width, int depth) {
 | 
			
		||||
      ++cur;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
#ifdef PLATFORM_GOOGLE
 | 
			
		||||
  SetBenchmarkLabel(strings::StrCat("Nodes = ", cur));
 | 
			
		||||
  SetBenchmarkItemsProcessed(cur * static_cast<int64>(iters));
 | 
			
		||||
#endif  // PLATFORM_GOOGLE
 | 
			
		||||
 | 
			
		||||
  FixupSourceAndSinkEdges(g);
 | 
			
		||||
  testing::StartTiming();
 | 
			
		||||
  test::Benchmark("cpu", g).Run(iters);
 | 
			
		||||
  test::Benchmark("cpu", g, /*old_benchmark_api=*/false).Run(state);
 | 
			
		||||
 | 
			
		||||
  state.SetLabel(strings::StrCat("Nodes = ", cur));
 | 
			
		||||
  state.SetItemsProcessed(cur * static_cast<int64>(state.iterations()));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Tall skinny graphs
 | 
			
		||||
BENCHMARK(BM_executor)->ArgPair(16, 1024);
 | 
			
		||||
BENCHMARK(BM_executor)->ArgPair(32, 8192);
 | 
			
		||||
BENCHMARK(BM_executor)->UseRealTime()->ArgPair(16, 1024);
 | 
			
		||||
BENCHMARK(BM_executor)->UseRealTime()->ArgPair(32, 8192);
 | 
			
		||||
 | 
			
		||||
// Short fat graphs
 | 
			
		||||
BENCHMARK(BM_executor)->ArgPair(1024, 16);
 | 
			
		||||
BENCHMARK(BM_executor)->ArgPair(8192, 32);
 | 
			
		||||
BENCHMARK(BM_executor)->UseRealTime()->ArgPair(1024, 16);
 | 
			
		||||
BENCHMARK(BM_executor)->UseRealTime()->ArgPair(8192, 32);
 | 
			
		||||
 | 
			
		||||
// Tall fat graph
 | 
			
		||||
BENCHMARK(BM_executor)->ArgPair(1024, 1024);
 | 
			
		||||
BENCHMARK(BM_executor)->UseRealTime()->ArgPair(1024, 1024);
 | 
			
		||||
 | 
			
		||||
static void BM_const_identity(::testing::benchmark::State& state) {
 | 
			
		||||
  const int width = state.range(0);
 | 
			
		||||
  const int outputs_per_const = state.range(1);
 | 
			
		||||
 | 
			
		||||
static void BM_const_identity(int iters, int width, int outputs_per_const) {
 | 
			
		||||
#ifdef PLATFORM_GOOGL
 | 
			
		||||
  BenchmarkUseRealTime();
 | 
			
		||||
#endif  // PLATFORM_GOOGLE
 | 
			
		||||
  Graph* g = new Graph(OpRegistry::Global());
 | 
			
		||||
  for (int i = 0; i < width; ++i) {
 | 
			
		||||
    Tensor i_t(i);
 | 
			
		||||
@ -499,23 +497,21 @@ static void BM_const_identity(int iters, int width, int outputs_per_const) {
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  FixupSourceAndSinkEdges(g);
 | 
			
		||||
#ifdef PLATFORM_GOOGLE
 | 
			
		||||
  SetBenchmarkLabel(
 | 
			
		||||
      strings::StrCat("Nodes = ", (1 + outputs_per_const) * width));
 | 
			
		||||
  SetBenchmarkItemsProcessed((1 + outputs_per_const) * width *
 | 
			
		||||
                             static_cast<int64>(iters));
 | 
			
		||||
#endif  // PLATFORM_GOOGLE
 | 
			
		||||
  test::Benchmark("cpu", g).Run(iters);
 | 
			
		||||
  test::Benchmark("cpu", g, /*old_benchmark_api=*/false).Run(state);
 | 
			
		||||
  state.SetLabel(strings::StrCat("Nodes = ", (1 + outputs_per_const) * width));
 | 
			
		||||
  state.SetItemsProcessed((1 + outputs_per_const) * width *
 | 
			
		||||
                          static_cast<int64>(state.iterations()));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Graph with actual op execution.
 | 
			
		||||
BENCHMARK(BM_const_identity)->ArgPair(1, 1);
 | 
			
		||||
BENCHMARK(BM_const_identity)->ArgPair(1, 100);
 | 
			
		||||
BENCHMARK(BM_const_identity)->ArgPair(100, 1);
 | 
			
		||||
BENCHMARK(BM_const_identity)->ArgPair(100, 100);
 | 
			
		||||
BENCHMARK(BM_const_identity)
 | 
			
		||||
    ->UseRealTime()
 | 
			
		||||
    ->ArgPair(1, 1)
 | 
			
		||||
    ->ArgPair(1, 100)
 | 
			
		||||
    ->ArgPair(100, 1)
 | 
			
		||||
    ->ArgPair(100, 100);
 | 
			
		||||
 | 
			
		||||
static void BM_FeedInputFetchOutput(int iters) {
 | 
			
		||||
  testing::StopTiming();
 | 
			
		||||
static void BM_FeedInputFetchOutput(::testing::benchmark::State& state) {
 | 
			
		||||
  Graph* g = new Graph(OpRegistry::Global());
 | 
			
		||||
  // z = x + y: x and y are provided as benchmark inputs.  z is the
 | 
			
		||||
  // output of the benchmark.  Conceptually, the caller is ALICE, the
 | 
			
		||||
@ -531,13 +527,10 @@ static void BM_FeedInputFetchOutput(int iters) {
 | 
			
		||||
 | 
			
		||||
  Tensor val(DT_FLOAT, TensorShape({}));
 | 
			
		||||
  val.scalar<float>()() = 3.14;
 | 
			
		||||
#ifdef PLATFORM_GOOGLE
 | 
			
		||||
  SetBenchmarkItemsProcessed(static_cast<int64>(iters));
 | 
			
		||||
#endif  // PLATFORM_GOOGLE
 | 
			
		||||
  FixupSourceAndSinkEdges(g);
 | 
			
		||||
  testing::StartTiming();
 | 
			
		||||
  test::Benchmark("cpu", g).RunWithRendezvousArgs({{x_key, val}, {y_key, val}},
 | 
			
		||||
                                                  {z_key}, iters);
 | 
			
		||||
  test::Benchmark("cpu", g, /*old_benchmark_api=*/false)
 | 
			
		||||
      .RunWithRendezvousArgs({{x_key, val}, {y_key, val}}, {z_key}, state);
 | 
			
		||||
  state.SetItemsProcessed(static_cast<int64>(state.iterations()));
 | 
			
		||||
}
 | 
			
		||||
BENCHMARK(BM_FeedInputFetchOutput);
 | 
			
		||||
 | 
			
		||||
@ -549,9 +542,8 @@ BENCHMARK(BM_FeedInputFetchOutput);
 | 
			
		||||
//
 | 
			
		||||
// ...using the functional `WhileOp` (if `lower` is false) or the
 | 
			
		||||
// `Switch`/`Merge`-style of control flow (if `lower` is true).
 | 
			
		||||
static void BM_WhileLoopHelper(int iters, int loop_iters, int loop_vars,
 | 
			
		||||
                               bool lower) {
 | 
			
		||||
  testing::StopTiming();
 | 
			
		||||
static void BM_WhileLoopHelper(::testing::benchmark::State& state,
 | 
			
		||||
                               int loop_iters, int loop_vars, bool lower) {
 | 
			
		||||
  std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
 | 
			
		||||
 | 
			
		||||
  // Add test functions for cond and body.
 | 
			
		||||
@ -661,12 +653,15 @@ static void BM_WhileLoopHelper(int iters, int loop_iters, int loop_vars,
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  FixupSourceAndSinkEdges(graph.get());
 | 
			
		||||
  testing::StartTiming();
 | 
			
		||||
  test::Benchmark("cpu", graph.release()).Run(iters);
 | 
			
		||||
  test::Benchmark("cpu", graph.release(), /*old_benchmark_api=*/false)
 | 
			
		||||
      .Run(state);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static void BM_LoweredWhileLoop(int iters, int loop_iters, int loop_vars) {
 | 
			
		||||
  BM_WhileLoopHelper(iters, loop_iters, loop_vars, /* lower= */ true);
 | 
			
		||||
static void BM_LoweredWhileLoop(::testing::benchmark::State& state) {
 | 
			
		||||
  const int loop_iters = state.range(0);
 | 
			
		||||
  const int loop_vars = state.range(1);
 | 
			
		||||
 | 
			
		||||
  BM_WhileLoopHelper(state, loop_iters, loop_vars, /* lower= */ true);
 | 
			
		||||
}
 | 
			
		||||
BENCHMARK(BM_LoweredWhileLoop)
 | 
			
		||||
    ->ArgPair(0, 1)
 | 
			
		||||
@ -680,8 +675,11 @@ BENCHMARK(BM_LoweredWhileLoop)
 | 
			
		||||
    ->ArgPair(100, 100)
 | 
			
		||||
    ->ArgPair(1000, 100);
 | 
			
		||||
 | 
			
		||||
static void BM_FunctionalWhileLoop(int iters, int loop_iters, int loop_vars) {
 | 
			
		||||
  BM_WhileLoopHelper(iters, loop_iters, loop_vars, /* lower= */ false);
 | 
			
		||||
static void BM_FunctionalWhileLoop(::testing::benchmark::State& state) {
 | 
			
		||||
  const int loop_iters = state.range(0);
 | 
			
		||||
  const int loop_vars = state.range(1);
 | 
			
		||||
 | 
			
		||||
  BM_WhileLoopHelper(state, loop_iters, loop_vars, /* lower= */ false);
 | 
			
		||||
}
 | 
			
		||||
BENCHMARK(BM_FunctionalWhileLoop)
 | 
			
		||||
    ->ArgPair(0, 1)
 | 
			
		||||
 | 
			
		||||
@ -931,7 +931,6 @@ TEST(SessionTest, InvalidOpInputName) {
 | 
			
		||||
      attr { key: 'T' value { type: DT_FLOAT } }
 | 
			
		||||
      attr { key: 'transpose_a' value { b: false } }
 | 
			
		||||
      attr { key: 'transpose_b' value { b: false } }
 | 
			
		||||
      attr { key: '_kernel' value { s: 'eigen' } }
 | 
			
		||||
    }
 | 
			
		||||
  )",
 | 
			
		||||
                     "Illegal op input name");
 | 
			
		||||
@ -950,7 +949,6 @@ TEST(SessionTest, InvalidOpInputName) {
 | 
			
		||||
      attr { key: 'T' value { type: DT_FLOAT } }
 | 
			
		||||
      attr { key: 'transpose_a' value { b: false } }
 | 
			
		||||
      attr { key: 'transpose_b' value { b: false } }
 | 
			
		||||
      attr { key: '_kernel' value { s: 'eigen' } }
 | 
			
		||||
    }
 | 
			
		||||
  )",
 | 
			
		||||
                     "Illegal op input name");
 | 
			
		||||
@ -969,7 +967,6 @@ TEST(SessionTest, InvalidOpInputName) {
 | 
			
		||||
      attr { key: 'T' value { type: DT_FLOAT } }
 | 
			
		||||
      attr { key: 'transpose_a' value { b: false } }
 | 
			
		||||
      attr { key: 'transpose_b' value { b: false } }
 | 
			
		||||
      attr { key: '_kernel' value { s: 'eigen' } }
 | 
			
		||||
    }
 | 
			
		||||
  )",
 | 
			
		||||
                     "Illegal op input name");
 | 
			
		||||
@ -988,7 +985,6 @@ TEST(SessionTest, InvalidOpInputName) {
 | 
			
		||||
      attr { key: 'T' value { type: DT_FLOAT } }
 | 
			
		||||
      attr { key: 'transpose_a' value { b: false } }
 | 
			
		||||
      attr { key: 'transpose_b' value { b: false } }
 | 
			
		||||
      attr { key: '_kernel' value { s: 'eigen' } }
 | 
			
		||||
    }
 | 
			
		||||
  )",
 | 
			
		||||
                     "Illegal op input name");
 | 
			
		||||
@ -1026,7 +1022,6 @@ TEST(SessionTest, ExtendValidation) {
 | 
			
		||||
      attr { key: 'T' value { type: DT_FLOAT } }
 | 
			
		||||
      attr { key: 'transpose_a' value { b: false } }
 | 
			
		||||
      attr { key: 'transpose_b' value { b: false } }
 | 
			
		||||
      attr { key: '_kernel' value { s: 'eigen' } }
 | 
			
		||||
    }
 | 
			
		||||
  )",
 | 
			
		||||
                                                  &extension);
 | 
			
		||||
@ -1043,7 +1038,6 @@ TEST(SessionTest, ExtendValidation) {
 | 
			
		||||
      attr { key: 'T' value { type: DT_FLOAT } }
 | 
			
		||||
      attr { key: 'transpose_a' value { b: false } }
 | 
			
		||||
      attr { key: 'transpose_b' value { b: false } }
 | 
			
		||||
      attr { key: '_kernel' value { s: 'eigen' } }
 | 
			
		||||
    }
 | 
			
		||||
  )",
 | 
			
		||||
                                                  &extension);
 | 
			
		||||
@ -1057,7 +1051,6 @@ TEST(SessionTest, ExtendValidation) {
 | 
			
		||||
      attr { key: 'T' value { type: DT_FLOAT } }
 | 
			
		||||
      attr { key: 'transpose_a' value { b: false } }
 | 
			
		||||
      attr { key: 'transpose_b' value { b: false } }
 | 
			
		||||
      attr { key: '_kernel' value { s: 'eigen' } }
 | 
			
		||||
    }
 | 
			
		||||
  )",
 | 
			
		||||
                                                  &extension);
 | 
			
		||||
 | 
			
		||||
@ -221,14 +221,16 @@ TEST(CustomAllocatorAttributes, TestSetterAndGetter) {
 | 
			
		||||
  EXPECT_FALSE(HasDeviceAllocatorAttribute(AllocatorAttributes()));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static void BM_Allocation(int iters, int arg) {
 | 
			
		||||
static void BM_Allocation(::testing::benchmark::State& state) {
 | 
			
		||||
  const int arg = state.range(0);
 | 
			
		||||
 | 
			
		||||
  Allocator* a = cpu_allocator();
 | 
			
		||||
  // Exercise a few different allocation sizes
 | 
			
		||||
  std::vector<int> sizes = {256, 4096, 16384, 524288, 512, 1048576};
 | 
			
		||||
  int size_index = 0;
 | 
			
		||||
 | 
			
		||||
  if (arg) EnableCPUAllocatorStats();
 | 
			
		||||
  while (--iters > 0) {
 | 
			
		||||
  for (auto s : state) {
 | 
			
		||||
    int bytes = sizes[size_index++ % sizes.size()];
 | 
			
		||||
    void* p = a->AllocateRaw(1, bytes);
 | 
			
		||||
    a->DeallocateRaw(p);
 | 
			
		||||
 | 
			
		||||
@ -39,60 +39,60 @@ TEST(Bfloat16Test, Conversion) {
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static void BM_FloatToBFloat16(int iters) {
 | 
			
		||||
  testing::StopTiming();
 | 
			
		||||
void BM_FloatToBFloat16(::testing::benchmark::State& state) {
 | 
			
		||||
  static const int N = 32 << 20;
 | 
			
		||||
  const int64 tot = static_cast<int64>(iters) * N;
 | 
			
		||||
  testing::ItemsProcessed(tot);
 | 
			
		||||
  testing::BytesProcessed(tot * (sizeof(float) + sizeof(bfloat16)));
 | 
			
		||||
 | 
			
		||||
  float* inp = new float[N];
 | 
			
		||||
  bfloat16* out = new bfloat16[N];
 | 
			
		||||
 | 
			
		||||
  testing::StartTiming();
 | 
			
		||||
  while (iters--) {
 | 
			
		||||
  for (auto s : state) {
 | 
			
		||||
    FloatToBFloat16(inp, out, N);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  const int64 tot = static_cast<int64>(state.iterations()) * N;
 | 
			
		||||
  state.SetItemsProcessed(tot);
 | 
			
		||||
  state.SetBytesProcessed(tot * (sizeof(float) + sizeof(bfloat16)));
 | 
			
		||||
 | 
			
		||||
  delete[] inp;
 | 
			
		||||
  delete[] out;
 | 
			
		||||
}
 | 
			
		||||
BENCHMARK(BM_FloatToBFloat16);
 | 
			
		||||
 | 
			
		||||
static void BM_RoundFloatToBFloat16(int iters) {
 | 
			
		||||
  testing::StopTiming();
 | 
			
		||||
void BM_RoundFloatToBFloat16(::testing::benchmark::State& state) {
 | 
			
		||||
  static const int N = 32 << 20;
 | 
			
		||||
  const int64 tot = static_cast<int64>(iters) * N;
 | 
			
		||||
  testing::ItemsProcessed(tot);
 | 
			
		||||
  testing::BytesProcessed(tot * (sizeof(float) + sizeof(bfloat16)));
 | 
			
		||||
 | 
			
		||||
  float* inp = new float[N];
 | 
			
		||||
  bfloat16* out = new bfloat16[N];
 | 
			
		||||
 | 
			
		||||
  testing::StartTiming();
 | 
			
		||||
  while (iters--) {
 | 
			
		||||
  for (auto s : state) {
 | 
			
		||||
    RoundFloatToBFloat16(inp, out, N);
 | 
			
		||||
    tensorflow::testing::DoNotOptimize(inp);
 | 
			
		||||
    tensorflow::testing::DoNotOptimize(out);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  const int64 tot = static_cast<int64>(state.iterations()) * N;
 | 
			
		||||
  state.SetItemsProcessed(tot);
 | 
			
		||||
  state.SetBytesProcessed(tot * (sizeof(float) + sizeof(bfloat16)));
 | 
			
		||||
 | 
			
		||||
  delete[] inp;
 | 
			
		||||
  delete[] out;
 | 
			
		||||
}
 | 
			
		||||
BENCHMARK(BM_RoundFloatToBFloat16);
 | 
			
		||||
 | 
			
		||||
static void BM_BFloat16ToFloat(int iters) {
 | 
			
		||||
  testing::StopTiming();
 | 
			
		||||
void BM_BFloat16ToFloat(::testing::benchmark::State& state) {
 | 
			
		||||
  static const int N = 32 << 20;
 | 
			
		||||
  const int64 tot = static_cast<int64>(iters) * N;
 | 
			
		||||
  testing::ItemsProcessed(tot);
 | 
			
		||||
  testing::BytesProcessed(tot * (sizeof(float) + sizeof(bfloat16)));
 | 
			
		||||
 | 
			
		||||
  bfloat16* inp = new bfloat16[N];
 | 
			
		||||
  float* out = new float[N];
 | 
			
		||||
 | 
			
		||||
  testing::StartTiming();
 | 
			
		||||
  while (iters--) {
 | 
			
		||||
  for (auto s : state) {
 | 
			
		||||
    BFloat16ToFloat(inp, out, N);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  const int64 tot = static_cast<int64>(state.iterations()) * N;
 | 
			
		||||
  state.SetItemsProcessed(tot);
 | 
			
		||||
  state.SetBytesProcessed(tot * (sizeof(float) + sizeof(bfloat16)));
 | 
			
		||||
 | 
			
		||||
  delete[] inp;
 | 
			
		||||
  delete[] out;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -406,7 +406,7 @@ XTimesTwo[T:{float, double, int32, int64}](x:T) -> (y:T) {
 | 
			
		||||
TEST(TFunc, WXPlusB) {
 | 
			
		||||
  auto expect = R"P(
 | 
			
		||||
WXPlusB[T:{float, double}](w:T, x:T, b:T) -> (y:T) {
 | 
			
		||||
  mm = MatMul[T=$T, _kernel="eigen", transpose_a=false, transpose_b=false](w, x)
 | 
			
		||||
  mm = MatMul[T=$T, transpose_a=false, transpose_b=false](w, x)
 | 
			
		||||
  y = Add[T=$T](mm:product:0, b)
 | 
			
		||||
  return y = y:z:0
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -346,10 +346,7 @@ FunctionDef WXPlusB() {
 | 
			
		||||
      {{{"mm"},
 | 
			
		||||
        "MatMul",
 | 
			
		||||
        {"w", "x"},
 | 
			
		||||
        {{"T", "$T"},
 | 
			
		||||
         {"transpose_a", false},
 | 
			
		||||
         {"transpose_b", false},
 | 
			
		||||
         {"_kernel", "eigen"}}},
 | 
			
		||||
        {{"T", "$T"}, {"transpose_a", false}, {"transpose_b", false}}},
 | 
			
		||||
       {{"y"}, "Add", {"mm", "b"}, {{"T", "$T"}}}});
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1002,9 +1002,9 @@ TEST_F(LabelTest, Duplicate) {
 | 
			
		||||
                error::INVALID_ARGUMENT);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void BM_InputRangeHelper(int iters, const NodeDef& node_def,
 | 
			
		||||
                         const char* input_name, int expected_start,
 | 
			
		||||
                         int expected_stop) {
 | 
			
		||||
void BM_InputRangeHelper(::testing::benchmark::State& state,
 | 
			
		||||
                         const NodeDef& node_def, const char* input_name,
 | 
			
		||||
                         int expected_start, int expected_stop) {
 | 
			
		||||
  Status status;
 | 
			
		||||
  auto device = absl::make_unique<DummyDevice>(Env::Default());
 | 
			
		||||
 | 
			
		||||
@ -1013,24 +1013,20 @@ void BM_InputRangeHelper(int iters, const NodeDef& node_def,
 | 
			
		||||
                                              TF_GRAPH_DEF_VERSION, &status));
 | 
			
		||||
  TF_CHECK_OK(status);
 | 
			
		||||
 | 
			
		||||
  testing::StartTiming();
 | 
			
		||||
  for (int i = 0; i < iters; ++i) {
 | 
			
		||||
  for (auto s : state) {
 | 
			
		||||
    int start;
 | 
			
		||||
    int stop;
 | 
			
		||||
    TF_CHECK_OK(op->InputRange(input_name, &start, &stop));
 | 
			
		||||
    EXPECT_EQ(expected_start, start);
 | 
			
		||||
    EXPECT_EQ(expected_stop, stop);
 | 
			
		||||
  }
 | 
			
		||||
  testing::StopTiming();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
REGISTER_KERNEL_BUILDER(Name("ConcatV2").Device(DEVICE_CPU), DummyKernel);
 | 
			
		||||
REGISTER_KERNEL_BUILDER(Name("Select").Device(DEVICE_CPU), DummyKernel);
 | 
			
		||||
REGISTER_KERNEL_BUILDER(Name("MatMul").Device(DEVICE_CPU), DummyKernel);
 | 
			
		||||
 | 
			
		||||
void BM_ConcatInputRange(int iters) {
 | 
			
		||||
  testing::StopTiming();
 | 
			
		||||
 | 
			
		||||
void BM_ConcatInputRange(::testing::benchmark::State& state) {
 | 
			
		||||
  // Create a ConcatV2 NodeDef with 4 inputs (plus the axis).
 | 
			
		||||
  NodeDef node_def;
 | 
			
		||||
  node_def.set_name("concat-op");
 | 
			
		||||
@ -1048,12 +1044,10 @@ void BM_ConcatInputRange(int iters) {
 | 
			
		||||
    node_def.add_input(strings::StrCat("a:", i));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  BM_InputRangeHelper(iters, node_def, "values", 0, 4);
 | 
			
		||||
  BM_InputRangeHelper(state, node_def, "values", 0, 4);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void BM_SelectInputRange(int iters) {
 | 
			
		||||
  testing::StopTiming();
 | 
			
		||||
 | 
			
		||||
void BM_SelectInputRange(::testing::benchmark::State& state) {
 | 
			
		||||
  // Create a Select NodeDef with 3 inputs.
 | 
			
		||||
  NodeDef node_def;
 | 
			
		||||
  node_def.set_name("select-op");
 | 
			
		||||
@ -1065,11 +1059,11 @@ void BM_SelectInputRange(int iters) {
 | 
			
		||||
    node_def.add_input(strings::StrCat("a:", i));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  BM_InputRangeHelper(iters, node_def, "condition", 0, 1);
 | 
			
		||||
  BM_InputRangeHelper(state, node_def, "condition", 0, 1);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void BM_TraceString(const int iters, const int verbose) {
 | 
			
		||||
  testing::StopTiming();
 | 
			
		||||
void BM_TraceString(::testing::benchmark::State& state) {
 | 
			
		||||
  const int verbose = state.range(0);
 | 
			
		||||
 | 
			
		||||
  // Create a MatMul NodeDef with 2 inputs.
 | 
			
		||||
  NodeDef node_def;
 | 
			
		||||
@ -1103,11 +1097,9 @@ void BM_TraceString(const int iters, const int verbose) {
 | 
			
		||||
  params.inputs = &inputs;
 | 
			
		||||
  auto ctx = absl::make_unique<OpKernelContext>(¶ms);
 | 
			
		||||
 | 
			
		||||
  testing::StartTiming();
 | 
			
		||||
  for (int i = 0; i < iters; ++i) {
 | 
			
		||||
  for (auto s : state) {
 | 
			
		||||
    auto trace = op->TraceString(*ctx, verbose);
 | 
			
		||||
  }
 | 
			
		||||
  testing::StopTiming();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
BENCHMARK(BM_ConcatInputRange);
 | 
			
		||||
 | 
			
		||||
@ -434,83 +434,89 @@ TEST_F(LocalRendezvousTest, TransferDummyDeviceContext) {
 | 
			
		||||
  args1.device_context->Unref();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void BM_SendRecv(int iters) {
 | 
			
		||||
void BM_SendRecv(::testing::benchmark::State& state) {
 | 
			
		||||
  Rendezvous* rendez = NewLocalRendezvous();
 | 
			
		||||
  Tensor orig = V("val");
 | 
			
		||||
  Tensor val(DT_STRING, TensorShape({}));
 | 
			
		||||
  bool is_dead = false;
 | 
			
		||||
  Rendezvous::Args args;
 | 
			
		||||
  if (iters > 0) {
 | 
			
		||||
    while (iters--) {
 | 
			
		||||
      TF_CHECK_OK(rendez->Send(KeyFoo(), args, orig, is_dead));
 | 
			
		||||
      TF_CHECK_OK(rendez->Recv(KeyFoo(), args, &val, &is_dead));
 | 
			
		||||
    }
 | 
			
		||||
    CHECK_EQ(V(val), V(orig));
 | 
			
		||||
 | 
			
		||||
  for (auto s : state) {
 | 
			
		||||
    TF_CHECK_OK(rendez->Send(KeyFoo(), args, orig, is_dead));
 | 
			
		||||
    TF_CHECK_OK(rendez->Recv(KeyFoo(), args, &val, &is_dead));
 | 
			
		||||
  }
 | 
			
		||||
  CHECK_EQ(V(val), V(orig));
 | 
			
		||||
 | 
			
		||||
  rendez->Unref();
 | 
			
		||||
}
 | 
			
		||||
BENCHMARK(BM_SendRecv);
 | 
			
		||||
 | 
			
		||||
void BM_RecvSend(int iters) {
 | 
			
		||||
void BM_RecvSend(::testing::benchmark::State& state) {
 | 
			
		||||
  Rendezvous* rendez = NewLocalRendezvous();
 | 
			
		||||
  Tensor orig = V("val");
 | 
			
		||||
  Tensor val(DT_STRING, TensorShape({}));
 | 
			
		||||
  bool is_dead = false;
 | 
			
		||||
  Rendezvous::Args args;
 | 
			
		||||
  if (iters > 0) {
 | 
			
		||||
    while (iters--) {
 | 
			
		||||
      bool received = false;
 | 
			
		||||
      rendez->RecvAsync(
 | 
			
		||||
          KeyFoo(), args,
 | 
			
		||||
          [&val, &received](const Status& s, const Rendezvous::Args& send_args,
 | 
			
		||||
                            const Rendezvous::Args& recv_args,
 | 
			
		||||
                            const Tensor& tensor, bool is_dead) {
 | 
			
		||||
            val = tensor;
 | 
			
		||||
            received = true;
 | 
			
		||||
          });
 | 
			
		||||
      TF_CHECK_OK(rendez->Send(KeyFoo(), args, orig, is_dead));
 | 
			
		||||
      CHECK(received);
 | 
			
		||||
    }
 | 
			
		||||
    CHECK_EQ(V(val), V(orig));
 | 
			
		||||
 | 
			
		||||
  for (auto s : state) {
 | 
			
		||||
    bool received = false;
 | 
			
		||||
    rendez->RecvAsync(
 | 
			
		||||
        KeyFoo(), args,
 | 
			
		||||
        [&val, &received](const Status& /*s*/,
 | 
			
		||||
                          const Rendezvous::Args& /*send_args*/,
 | 
			
		||||
                          const Rendezvous::Args& /*recv_args*/,
 | 
			
		||||
                          const Tensor& tensor, bool /*is_dead*/) {
 | 
			
		||||
          val = tensor;
 | 
			
		||||
          received = true;
 | 
			
		||||
        });
 | 
			
		||||
    TF_CHECK_OK(rendez->Send(KeyFoo(), args, orig, is_dead));
 | 
			
		||||
    CHECK(received);
 | 
			
		||||
  }
 | 
			
		||||
  CHECK_EQ(V(val), V(orig));
 | 
			
		||||
 | 
			
		||||
  rendez->Unref();
 | 
			
		||||
}
 | 
			
		||||
BENCHMARK(BM_RecvSend);
 | 
			
		||||
 | 
			
		||||
void BM_PingPong(int iters) {
 | 
			
		||||
  CHECK_GT(iters, 0);
 | 
			
		||||
void BM_PingPong(::testing::benchmark::State& state) {
 | 
			
		||||
  const int messages_count = state.range(0);
 | 
			
		||||
  auto* cm = new CancellationManager();
 | 
			
		||||
  thread::ThreadPool* pool = new thread::ThreadPool(Env::Default(), "test", 1);
 | 
			
		||||
 | 
			
		||||
  // The main thread sends "foo" for iters times and receives "bar"
 | 
			
		||||
  // for iters times.  The other thread sends "bar" for iters times
 | 
			
		||||
  // and receives "foo" for iters times.
 | 
			
		||||
  Rendezvous* rendez = NewLocalRendezvous();
 | 
			
		||||
  pool->Schedule([rendez, iters]() {
 | 
			
		||||
    Tensor bar = V("bar");
 | 
			
		||||
    Tensor foo(DT_STRING, TensorShape({}));
 | 
			
		||||
  // Benchmark loop
 | 
			
		||||
  // In each iteration:
 | 
			
		||||
  // The main thread sends "foo" for messages_count times and receives "bar"
 | 
			
		||||
  // for messages_count times.  The other thread sends "bar" for
 | 
			
		||||
  // messages_count times and receives "foo" for messages_count times.
 | 
			
		||||
  for (auto s : state) {
 | 
			
		||||
    Rendezvous* rendez = NewLocalRendezvous();
 | 
			
		||||
    pool->Schedule([rendez, messages_count]() {
 | 
			
		||||
      Tensor bar = V("bar");
 | 
			
		||||
      Tensor foo(DT_STRING, TensorShape({}));
 | 
			
		||||
      bool is_dead = false;
 | 
			
		||||
      Rendezvous::Args args;
 | 
			
		||||
      for (int i = 0; i < messages_count; ++i) {
 | 
			
		||||
        TF_CHECK_OK(rendez->Recv(KeyFoo(), args, &foo, &is_dead));
 | 
			
		||||
        TF_CHECK_OK(rendez->Send(KeyBar(), args, bar, is_dead));
 | 
			
		||||
      }
 | 
			
		||||
      CHECK_EQ("foo", V(foo));
 | 
			
		||||
    });
 | 
			
		||||
    Tensor foo = V("foo");
 | 
			
		||||
    Tensor bar(DT_STRING, TensorShape({}));
 | 
			
		||||
    bool is_dead = false;
 | 
			
		||||
    Rendezvous::Args args;
 | 
			
		||||
    for (int i = 0; i < iters; ++i) {
 | 
			
		||||
      TF_CHECK_OK(rendez->Recv(KeyFoo(), args, &foo, &is_dead));
 | 
			
		||||
      TF_CHECK_OK(rendez->Send(KeyBar(), args, bar, is_dead));
 | 
			
		||||
    args.cancellation_manager = cm;
 | 
			
		||||
    for (int i = 0; i < messages_count; ++i) {
 | 
			
		||||
      TF_CHECK_OK(rendez->Send(KeyFoo(), args, foo, is_dead));
 | 
			
		||||
      TF_CHECK_OK(rendez->Recv(KeyBar(), args, &bar, &is_dead));
 | 
			
		||||
    }
 | 
			
		||||
    CHECK_EQ("foo", V(foo));
 | 
			
		||||
  });
 | 
			
		||||
  Tensor foo = V("foo");
 | 
			
		||||
  Tensor bar(DT_STRING, TensorShape({}));
 | 
			
		||||
  bool is_dead = false;
 | 
			
		||||
  Rendezvous::Args args;
 | 
			
		||||
  args.cancellation_manager = cm;
 | 
			
		||||
  for (int i = 0; i < iters; ++i) {
 | 
			
		||||
    TF_CHECK_OK(rendez->Send(KeyFoo(), args, foo, is_dead));
 | 
			
		||||
    TF_CHECK_OK(rendez->Recv(KeyBar(), args, &bar, &is_dead));
 | 
			
		||||
    CHECK_EQ("bar", V(bar));
 | 
			
		||||
  }
 | 
			
		||||
  CHECK_EQ("bar", V(bar));
 | 
			
		||||
  state.SetItemsProcessed(messages_count * state.iterations());
 | 
			
		||||
  delete pool;
 | 
			
		||||
  delete cm;
 | 
			
		||||
}
 | 
			
		||||
BENCHMARK(BM_PingPong);
 | 
			
		||||
BENCHMARK(BM_PingPong)->Arg(100)->Arg(200)->Arg(300);
 | 
			
		||||
 | 
			
		||||
}  // namespace
 | 
			
		||||
}  // namespace tensorflow
 | 
			
		||||
 | 
			
		||||
@ -684,19 +684,24 @@ static std::vector<int64> MakeSizes(int arg) {
 | 
			
		||||
  return sizes;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static void BM_TensorShape_Init(int iters, int arg) {
 | 
			
		||||
void BM_TensorShape_Init(::testing::benchmark::State& state) {
 | 
			
		||||
  const int arg = state.range(0);
 | 
			
		||||
 | 
			
		||||
  auto sizes = MakeSizes(arg);
 | 
			
		||||
  while (--iters > 0) {
 | 
			
		||||
  for (auto s : state) {
 | 
			
		||||
    TensorShape shape(sizes);
 | 
			
		||||
    tensorflow::testing::DoNotOptimize(shape.num_elements());
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
BENCHMARK(BM_TensorShape_Init)->Arg(0)->Arg(1)->Arg(2)->Arg(3)->Arg(4);
 | 
			
		||||
 | 
			
		||||
static void BM_TensorShape_Assign(int iters, int arg) {
 | 
			
		||||
  TensorShape s(MakeSizes(arg));
 | 
			
		||||
  while (--iters > 0) {
 | 
			
		||||
    TensorShape s2 = s;
 | 
			
		||||
void BM_TensorShape_Assign(::testing::benchmark::State& state) {
 | 
			
		||||
  const int arg = state.range(0);
 | 
			
		||||
 | 
			
		||||
  TensorShape shape(MakeSizes(arg));
 | 
			
		||||
  for (auto s : state) {
 | 
			
		||||
    const TensorShape s2 = shape;
 | 
			
		||||
    tensorflow::testing::DoNotOptimize(s2);
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
BENCHMARK(BM_TensorShape_Assign)->Arg(0)->Arg(1)->Arg(2)->Arg(3)->Arg(4);
 | 
			
		||||
 | 
			
		||||
@ -1468,19 +1468,19 @@ TEST(SummarizeValue, STRING_PRINT_V2) {
 | 
			
		||||
            x.SummarizeValue(16, true));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void BM_CreateAndDestroy(int iters) {
 | 
			
		||||
void BM_CreateAndDestroy(::testing::benchmark::State& state) {
 | 
			
		||||
  TensorShape shape({10, 20});
 | 
			
		||||
  while (--iters) {
 | 
			
		||||
  for (auto s : state) {
 | 
			
		||||
    Tensor t(DT_FLOAT, shape);
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
BENCHMARK(BM_CreateAndDestroy);
 | 
			
		||||
 | 
			
		||||
void BM_Assign(int iters) {
 | 
			
		||||
void BM_Assign(::testing::benchmark::State& state) {
 | 
			
		||||
  Tensor a(DT_FLOAT, TensorShape({10, 20}));
 | 
			
		||||
  Tensor b(DT_FLOAT, TensorShape({10, 20}));
 | 
			
		||||
  bool a_to_b = true;
 | 
			
		||||
  while (--iters) {
 | 
			
		||||
  for (auto s : state) {
 | 
			
		||||
    if (a_to_b) {
 | 
			
		||||
      b = a;
 | 
			
		||||
    } else {
 | 
			
		||||
@ -1498,20 +1498,20 @@ TEST(Tensor, EmptyTensorData) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Benchmark create and destroy a tensor, with an allocated buffer.
 | 
			
		||||
void BM_CreateAndDestroyWithBuf(int iters) {
 | 
			
		||||
void BM_CreateAndDestroyWithBuf(::testing::benchmark::State& state) {
 | 
			
		||||
  TensorShape shape({10, 20});
 | 
			
		||||
  Allocator* allocator = cpu_allocator();
 | 
			
		||||
  while (--iters) {
 | 
			
		||||
  for (auto s : state) {
 | 
			
		||||
    Tensor a(allocator, DT_FLOAT, shape);
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
BENCHMARK(BM_CreateAndDestroyWithBuf);
 | 
			
		||||
 | 
			
		||||
// Benchmark create+copy a tensor, with an allocated buffer.
 | 
			
		||||
void BM_CreateAndCopyCtrWithBuf(int iters) {
 | 
			
		||||
void BM_CreateAndCopyCtrWithBuf(::testing::benchmark::State& state) {
 | 
			
		||||
  TensorShape shape({10, 20});
 | 
			
		||||
  Allocator* allocator = cpu_allocator();
 | 
			
		||||
  while (--iters) {
 | 
			
		||||
  for (auto s : state) {
 | 
			
		||||
    Tensor a(allocator, DT_FLOAT, shape);
 | 
			
		||||
    Tensor b(a);
 | 
			
		||||
  }
 | 
			
		||||
@ -1519,10 +1519,10 @@ void BM_CreateAndCopyCtrWithBuf(int iters) {
 | 
			
		||||
BENCHMARK(BM_CreateAndCopyCtrWithBuf);
 | 
			
		||||
 | 
			
		||||
// Benchmark create+move a tensor, with an allocated buffer.
 | 
			
		||||
void BM_CreateAndMoveCtrWithBuf(int iters) {
 | 
			
		||||
void BM_CreateAndMoveCtrWithBuf(::testing::benchmark::State& state) {
 | 
			
		||||
  TensorShape shape({10, 20});
 | 
			
		||||
  Allocator* allocator = cpu_allocator();
 | 
			
		||||
  while (--iters) {
 | 
			
		||||
  for (auto s : state) {
 | 
			
		||||
    Tensor a(allocator, DT_FLOAT, shape);
 | 
			
		||||
    Tensor b(std::move(a));
 | 
			
		||||
  }
 | 
			
		||||
@ -1531,10 +1531,11 @@ BENCHMARK(BM_CreateAndMoveCtrWithBuf);
 | 
			
		||||
 | 
			
		||||
// Benchmark creating and destroy a host-scalar tensor, using the allocator
 | 
			
		||||
// interface.
 | 
			
		||||
void BM_CreateAndDestroyHostScalarNonOptimized(int iters) {
 | 
			
		||||
void BM_CreateAndDestroyHostScalarNonOptimized(
 | 
			
		||||
    ::testing::benchmark::State& state) {
 | 
			
		||||
  TensorShape shape({});
 | 
			
		||||
  Allocator* allocator = cpu_allocator();
 | 
			
		||||
  while (--iters) {
 | 
			
		||||
  for (auto s : state) {
 | 
			
		||||
    Tensor a(allocator, DT_FLOAT, shape);
 | 
			
		||||
    a.scalar<float>()() = 37.0;
 | 
			
		||||
  }
 | 
			
		||||
@ -1543,32 +1544,33 @@ BENCHMARK(BM_CreateAndDestroyHostScalarNonOptimized);
 | 
			
		||||
 | 
			
		||||
// Benchmark creating and destroy a host-scalar tensor, using the specialized
 | 
			
		||||
// constructor.
 | 
			
		||||
void BM_CreateAndDestroyHostScalarOptimized(int iters) {
 | 
			
		||||
  while (--iters) {
 | 
			
		||||
void BM_CreateAndDestroyHostScalarOptimized(
 | 
			
		||||
    ::testing::benchmark::State& state) {
 | 
			
		||||
  for (auto s : state) {
 | 
			
		||||
    Tensor a(37.0);
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
BENCHMARK(BM_CreateAndDestroyHostScalarOptimized);
 | 
			
		||||
 | 
			
		||||
static void BM_FromProto(int iters, int size) {
 | 
			
		||||
  testing::StopTiming();
 | 
			
		||||
void BM_FromProto(::testing::benchmark::State& state) {
 | 
			
		||||
  const int size = state.range(0);
 | 
			
		||||
 | 
			
		||||
  TensorShape shape({size});
 | 
			
		||||
  Allocator* allocator = cpu_allocator();
 | 
			
		||||
  Tensor a(allocator, DT_FLOAT, shape);
 | 
			
		||||
  std::fill_n(a.flat<float>().data(), size, 42.0);
 | 
			
		||||
  TensorProto p;
 | 
			
		||||
  a.AsProtoField(&p);
 | 
			
		||||
  testing::StartTiming();
 | 
			
		||||
  while (--iters) {
 | 
			
		||||
  for (auto s : state) {
 | 
			
		||||
    Tensor b;
 | 
			
		||||
    ASSERT_TRUE(b.FromProto(p));
 | 
			
		||||
  }
 | 
			
		||||
  testing::StopTiming();
 | 
			
		||||
}
 | 
			
		||||
BENCHMARK(BM_FromProto)->Range(1, 1 << 20);
 | 
			
		||||
 | 
			
		||||
static void BM_FromProtoCompressed(int iters, int size) {
 | 
			
		||||
  testing::StopTiming();
 | 
			
		||||
void BM_FromProtoCompressed(::testing::benchmark::State& state) {
 | 
			
		||||
  const int size = state.range(0);
 | 
			
		||||
 | 
			
		||||
  TensorShape shape({size});
 | 
			
		||||
  Allocator* allocator = cpu_allocator();
 | 
			
		||||
  Tensor a(allocator, DT_FLOAT, shape);
 | 
			
		||||
@ -1576,17 +1578,16 @@ static void BM_FromProtoCompressed(int iters, int size) {
 | 
			
		||||
  TensorProto p;
 | 
			
		||||
  a.AsProtoField(&p);
 | 
			
		||||
  tensor::CompressTensorProtoInPlace(&p);
 | 
			
		||||
  testing::StartTiming();
 | 
			
		||||
  while (--iters) {
 | 
			
		||||
  for (auto s : state) {
 | 
			
		||||
    Tensor b;
 | 
			
		||||
    ASSERT_TRUE(b.FromProto(p));
 | 
			
		||||
  }
 | 
			
		||||
  testing::StopTiming();
 | 
			
		||||
}
 | 
			
		||||
BENCHMARK(BM_FromProtoCompressed)->Range(1, 1 << 20);
 | 
			
		||||
 | 
			
		||||
static void BM_FromProtoCompressedZero(int iters, int size) {
 | 
			
		||||
  testing::StopTiming();
 | 
			
		||||
void BM_FromProtoCompressedZero(::testing::benchmark::State& state) {
 | 
			
		||||
  const int size = state.range(0);
 | 
			
		||||
 | 
			
		||||
  TensorShape shape({size});
 | 
			
		||||
  Allocator* allocator = cpu_allocator();
 | 
			
		||||
  Tensor a(allocator, DT_FLOAT, shape);
 | 
			
		||||
@ -1595,12 +1596,10 @@ static void BM_FromProtoCompressedZero(int iters, int size) {
 | 
			
		||||
  TensorProto p;
 | 
			
		||||
  a.AsProtoField(&p);
 | 
			
		||||
  tensor::CompressTensorProtoInPlace(&p);
 | 
			
		||||
  testing::StartTiming();
 | 
			
		||||
  while (--iters) {
 | 
			
		||||
  for (auto s : state) {
 | 
			
		||||
    Tensor b;
 | 
			
		||||
    ASSERT_TRUE(b.FromProto(p));
 | 
			
		||||
  }
 | 
			
		||||
  testing::StopTiming();
 | 
			
		||||
}
 | 
			
		||||
BENCHMARK(BM_FromProtoCompressedZero)->Range(1, 1 << 20);
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -767,16 +767,49 @@ Status AvgPoolGradTransposer::TransposeNode(TransposeContext* context,
 | 
			
		||||
  return context->graph_view->GetMutationBuilder()->Apply();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Status BiasAddGradTransposer::TransposeNode(TransposeContext* context,
 | 
			
		||||
                                            utils::MutableNodeView* node) {
 | 
			
		||||
  DCHECK(IsBiasAddGrad(*node->node()));
 | 
			
		||||
  if (!ShouldProcess(*context, *node) || !IsFaninPortRankN(*node, 0, 4)) {
 | 
			
		||||
Status BiasAddTransposer::TransposeNode(TransposeContext* context,
 | 
			
		||||
                                        utils::MutableNodeView* node) {
 | 
			
		||||
  DCHECK(IsBiasAdd(*node->node()));
 | 
			
		||||
  const int rank = GetFanoutPortRank(*node, 0);
 | 
			
		||||
  if (rank != 4 && rank != 5) {
 | 
			
		||||
    return Status::OK();
 | 
			
		||||
  }
 | 
			
		||||
  if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, rank)) {
 | 
			
		||||
    return Status::OK();
 | 
			
		||||
  }
 | 
			
		||||
  VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
 | 
			
		||||
          << "' with op '" << node->GetOp() << "' from data format '"
 | 
			
		||||
          << context->src_format << "' to '" << context->dst_format << "'";
 | 
			
		||||
  // BiasAdd itself only needs NCHW/NHWC to determine whether C dim is the
 | 
			
		||||
  // second or the last dim. Therefore, we use the original 4D data format in
 | 
			
		||||
  // the context to update the node. For the input/output tensor, the
 | 
			
		||||
  // corresponding 4D or 5D data format is needed.
 | 
			
		||||
  TF_RETURN_IF_ERROR(UpdateNode(context, node));
 | 
			
		||||
  ScopedDataFormatUpgrader data_format_upgrader(context, rank);
 | 
			
		||||
  TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
 | 
			
		||||
  TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
 | 
			
		||||
  return context->graph_view->GetMutationBuilder()->Apply();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Status BiasAddGradTransposer::TransposeNode(TransposeContext* context,
 | 
			
		||||
                                            utils::MutableNodeView* node) {
 | 
			
		||||
  DCHECK(IsBiasAddGrad(*node->node()));
 | 
			
		||||
  const int rank = GetFaninPortRank(*node, 0);
 | 
			
		||||
  if (rank != 4 && rank != 5) {
 | 
			
		||||
    return Status::OK();
 | 
			
		||||
  }
 | 
			
		||||
  if (!ShouldProcess(*context, *node)) {
 | 
			
		||||
    return Status::OK();
 | 
			
		||||
  }
 | 
			
		||||
  VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
 | 
			
		||||
          << "' with op '" << node->GetOp() << "' from data format '"
 | 
			
		||||
          << context->src_format << "' to '" << context->dst_format << "'";
 | 
			
		||||
  // BiasAddGrad itself only needs NCHW/NHWC to determine whether C dim is the
 | 
			
		||||
  // second or the last dim. Therefore, we use the original 4D data format in
 | 
			
		||||
  // the context to update the node. For the input tensor, the corresponding 4D
 | 
			
		||||
  // or 5D data format is needed.
 | 
			
		||||
  TF_RETURN_IF_ERROR(UpdateNode(context, node));
 | 
			
		||||
  ScopedDataFormatUpgrader data_format_upgrader(context, rank);
 | 
			
		||||
  TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
 | 
			
		||||
  // No need to update output shape, as it is always of shape 1-D with size the
 | 
			
		||||
  // feature dimension of `out_backprop`, regardless of whether NCHW or NHWC is
 | 
			
		||||
@ -839,7 +872,12 @@ Status Conv2DBackpropInputTransposer::TransposeNode(
 | 
			
		||||
Status Conv3DTransposer::TransposeNode(TransposeContext* context,
 | 
			
		||||
                                       utils::MutableNodeView* node) {
 | 
			
		||||
  DCHECK(IsConv3D(*node->node()));
 | 
			
		||||
  if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 5)) {
 | 
			
		||||
  const int rank = GetFanoutPortRank(*node, 0);
 | 
			
		||||
  if (rank != 5) {
 | 
			
		||||
    return Status::OK();
 | 
			
		||||
  }
 | 
			
		||||
  ScopedDataFormatUpgrader data_format_upgrader(context, rank);
 | 
			
		||||
  if (!ShouldProcess(*context, *node)) {
 | 
			
		||||
    return Status::OK();
 | 
			
		||||
  }
 | 
			
		||||
  VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
 | 
			
		||||
@ -854,7 +892,12 @@ Status Conv3DTransposer::TransposeNode(TransposeContext* context,
 | 
			
		||||
Status Conv3DBackpropFilterTransposer::TransposeNode(
 | 
			
		||||
    TransposeContext* context, utils::MutableNodeView* node) {
 | 
			
		||||
  DCHECK(IsConv3DBackpropFilterV2(*node->node()));
 | 
			
		||||
  if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 5)) {
 | 
			
		||||
  const int rank = GetFanoutPortRank(*node, 0);
 | 
			
		||||
  if (rank != 5) {
 | 
			
		||||
    return Status::OK();
 | 
			
		||||
  }
 | 
			
		||||
  ScopedDataFormatUpgrader data_format_upgrader(context, rank);
 | 
			
		||||
  if (!ShouldProcess(*context, *node)) {
 | 
			
		||||
    return Status::OK();
 | 
			
		||||
  }
 | 
			
		||||
  VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
 | 
			
		||||
@ -872,7 +915,12 @@ Status Conv3DBackpropFilterTransposer::TransposeNode(
 | 
			
		||||
Status Conv3DBackpropInputTransposer::TransposeNode(
 | 
			
		||||
    TransposeContext* context, utils::MutableNodeView* node) {
 | 
			
		||||
  DCHECK(IsConv3DBackpropInputV2(*node->node()));
 | 
			
		||||
  if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 5)) {
 | 
			
		||||
  const int rank = GetFanoutPortRank(*node, 0);
 | 
			
		||||
  if (rank != 5) {
 | 
			
		||||
    return Status::OK();
 | 
			
		||||
  }
 | 
			
		||||
  ScopedDataFormatUpgrader data_format_upgrader(context, rank);
 | 
			
		||||
  if (!ShouldProcess(*context, *node)) {
 | 
			
		||||
    return Status::OK();
 | 
			
		||||
  }
 | 
			
		||||
  VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
 | 
			
		||||
@ -1081,8 +1129,9 @@ bool LayoutAgnosticOpTransposer::IsAfterDstToSrcTransform(
 | 
			
		||||
  return false;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::vector<int> LayoutAgnosticOpTransposer::GetVariadic4DFaninPorts(
 | 
			
		||||
    const TransposeContext& context, const utils::MutableNodeView& node) const {
 | 
			
		||||
std::vector<int> LayoutAgnosticOpTransposer::GetVariadicNDFaninPorts(
 | 
			
		||||
    const TransposeContext& context, const utils::MutableNodeView& node,
 | 
			
		||||
    int rank) const {
 | 
			
		||||
  std::vector<int> ports;
 | 
			
		||||
  const int num_regular_fanins = node.NumRegularFanins();
 | 
			
		||||
  ports.reserve(num_regular_fanins);
 | 
			
		||||
@ -1090,7 +1139,7 @@ std::vector<int> LayoutAgnosticOpTransposer::GetVariadic4DFaninPorts(
 | 
			
		||||
    const auto& regular_fanin = node.GetRegularFanin(i);
 | 
			
		||||
    auto* regular_fanin_node = regular_fanin.node_view();
 | 
			
		||||
    int regular_fanin_port = regular_fanin.index();
 | 
			
		||||
    if (IsFanoutPortRankN(*regular_fanin_node, regular_fanin_port, 4) &&
 | 
			
		||||
    if ((IsFanoutPortRankN(*regular_fanin_node, regular_fanin_port, rank)) &&
 | 
			
		||||
        ((IsAfterDstToSrcTransform(context, *regular_fanin_node) &&
 | 
			
		||||
          IsLayoutAgnosticOp(*regular_fanin_node->node())) ||
 | 
			
		||||
         IsLayoutOptimizerAddedDstToSrcTranspose(context,
 | 
			
		||||
@ -1124,10 +1173,18 @@ Status DefaultLayoutAgnosticOpTransposer::TransposeNode(
 | 
			
		||||
Status AddNTransposer::TransposeNode(TransposeContext* context,
 | 
			
		||||
                                     utils::MutableNodeView* node) {
 | 
			
		||||
  DCHECK(IsAddN(*node->node()));
 | 
			
		||||
  if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4) ||
 | 
			
		||||
  const int rank = GetFanoutPortRank(*node, 0);
 | 
			
		||||
  if (rank != 4 && rank != 5) {
 | 
			
		||||
    return Status::OK();
 | 
			
		||||
  }
 | 
			
		||||
  ScopedDataFormatUpgrader data_format_upgrader(context, rank);
 | 
			
		||||
  if (!ShouldProcess(*context, *node) ||
 | 
			
		||||
      !IsAfterDstToSrcTransform(*context, *node)) {
 | 
			
		||||
    return Status::OK();
 | 
			
		||||
  }
 | 
			
		||||
  VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
 | 
			
		||||
          << "' with op '" << node->GetOp() << "' from data format '"
 | 
			
		||||
          << context->src_format << "' to '" << context->dst_format << "'";
 | 
			
		||||
  TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, GetDataFaninPorts(*node),
 | 
			
		||||
                                            node, kOpTranspose));
 | 
			
		||||
  TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
 | 
			
		||||
@ -1284,7 +1341,12 @@ Status BinaryOpTransposer::TransposeNode(TransposeContext* context,
 | 
			
		||||
Status ConcatOpTransposer::TransposeNode(TransposeContext* context,
 | 
			
		||||
                                         utils::MutableNodeView* node) {
 | 
			
		||||
  DCHECK(IsConcat(*node->node()));
 | 
			
		||||
  if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4) ||
 | 
			
		||||
  const int rank = GetFanoutPortRank(*node, 0);
 | 
			
		||||
  if (rank != 4 && rank != 5) {
 | 
			
		||||
    return Status::OK();
 | 
			
		||||
  }
 | 
			
		||||
  ScopedDataFormatUpgrader data_format_upgrader(context, rank);
 | 
			
		||||
  if (!ShouldProcess(*context, *node) ||
 | 
			
		||||
      !IsAfterDstToSrcTransform(*context, *node)) {
 | 
			
		||||
    return Status::OK();
 | 
			
		||||
  }
 | 
			
		||||
@ -1297,6 +1359,9 @@ Status ConcatOpTransposer::TransposeNode(TransposeContext* context,
 | 
			
		||||
      axis_node = n_attr->i();
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
 | 
			
		||||
          << "' with op '" << node->GetOp() << "' from data format '"
 | 
			
		||||
          << context->src_format << "' to '" << context->dst_format << "'";
 | 
			
		||||
  TF_RETURN_IF_ERROR(
 | 
			
		||||
      UpdateFaninEdgesWithOp(context, {axis_node}, node, kOpDataFormatDimMap));
 | 
			
		||||
  TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
 | 
			
		||||
@ -1320,14 +1385,33 @@ Status FillOpTransposer::TransposeNode(TransposeContext* context,
 | 
			
		||||
Status IdentityNTransposer::TransposeNode(TransposeContext* context,
 | 
			
		||||
                                          utils::MutableNodeView* node) {
 | 
			
		||||
  DCHECK(IsIdentityN(*node->node()));
 | 
			
		||||
  const auto ports = GetVariadic4DFaninPorts(*context, *node);
 | 
			
		||||
  if (!ShouldProcess(*context, *node) || ports.empty()) {
 | 
			
		||||
  const auto ports_4d = GetVariadicNDFaninPorts(*context, *node, 4);
 | 
			
		||||
 | 
			
		||||
  // Temporarily upgrade the context to obtain the number of 5D fanin ports.
 | 
			
		||||
  std::vector<int> ports_5d;
 | 
			
		||||
  {
 | 
			
		||||
    ScopedDataFormatUpgrader data_format_upgrader(context, 5);
 | 
			
		||||
    ports_5d = GetVariadicNDFaninPorts(*context, *node, 5);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  if (!ShouldProcess(*context, *node)) {
 | 
			
		||||
    return Status::OK();
 | 
			
		||||
  }
 | 
			
		||||
  TF_RETURN_IF_ERROR(
 | 
			
		||||
      UpdateFaninEdgesWithOp(context, ports, node, kOpTranspose));
 | 
			
		||||
  TF_RETURN_IF_ERROR(
 | 
			
		||||
      UpdateFanoutEdgesWithOp(context, ports, node, kOpTranspose));
 | 
			
		||||
 | 
			
		||||
  if (!ports_4d.empty()) {
 | 
			
		||||
    TF_RETURN_IF_ERROR(
 | 
			
		||||
        UpdateFaninEdgesWithOp(context, ports_4d, node, kOpTranspose));
 | 
			
		||||
    TF_RETURN_IF_ERROR(
 | 
			
		||||
        UpdateFanoutEdgesWithOp(context, ports_4d, node, kOpTranspose));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  if (!ports_5d.empty()) {
 | 
			
		||||
    ScopedDataFormatUpgrader data_format_upgrader(context, 5);
 | 
			
		||||
    TF_RETURN_IF_ERROR(
 | 
			
		||||
        UpdateFaninEdgesWithOp(context, ports_5d, node, kOpTranspose));
 | 
			
		||||
    TF_RETURN_IF_ERROR(
 | 
			
		||||
        UpdateFanoutEdgesWithOp(context, ports_5d, node, kOpTranspose));
 | 
			
		||||
  }
 | 
			
		||||
  return context->graph_view->GetMutationBuilder()->Apply();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -1519,10 +1603,18 @@ Status SelectTransposer::TransposeNode(TransposeContext* context,
 | 
			
		||||
Status ShapeTransposer::TransposeNode(TransposeContext* context,
 | 
			
		||||
                                      utils::MutableNodeView* node) {
 | 
			
		||||
  DCHECK(IsShape(*node->node()));
 | 
			
		||||
  if (!ShouldProcess(*context, *node) || !IsFaninPortRankN(*node, 0, 4) ||
 | 
			
		||||
  const int rank = GetFaninPortRank(*node, 0);
 | 
			
		||||
  if (rank != 4 && rank != 5) {
 | 
			
		||||
    return Status::OK();
 | 
			
		||||
  }
 | 
			
		||||
  ScopedDataFormatUpgrader data_format_upgrader(context, rank);
 | 
			
		||||
  if (!ShouldProcess(*context, *node) ||
 | 
			
		||||
      !IsAfterDstToSrcTransform(*context, *node)) {
 | 
			
		||||
    return Status::OK();
 | 
			
		||||
  }
 | 
			
		||||
  VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
 | 
			
		||||
          << "' with op '" << node->GetOp() << "' from data format '"
 | 
			
		||||
          << context->src_format << "' to '" << context->dst_format << "'";
 | 
			
		||||
  TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
 | 
			
		||||
  TF_RETURN_IF_ERROR(
 | 
			
		||||
      UpdateFanoutEdgesWithOp(context, {0}, node, kOpDataFormatVecPermute));
 | 
			
		||||
@ -1532,10 +1624,20 @@ Status ShapeTransposer::TransposeNode(TransposeContext* context,
 | 
			
		||||
Status ShapeNTransposer::TransposeNode(TransposeContext* context,
 | 
			
		||||
                                       utils::MutableNodeView* node) {
 | 
			
		||||
  DCHECK(IsShapeN(*node->node()));
 | 
			
		||||
  const auto ports = GetVariadic4DFaninPorts(*context, *node);
 | 
			
		||||
  // ShapeN requires all input tensors to have the same dimensions. Therefore,
 | 
			
		||||
  // we simply use the 0th fanin port.
 | 
			
		||||
  const int rank = GetFaninPortRank(*node, 0);
 | 
			
		||||
  if (rank != 4 && rank != 5) {
 | 
			
		||||
    return Status::OK();
 | 
			
		||||
  }
 | 
			
		||||
  ScopedDataFormatUpgrader data_format_upgrader(context, rank);
 | 
			
		||||
  const auto ports = GetVariadicNDFaninPorts(*context, *node, rank);
 | 
			
		||||
  if (!ShouldProcess(*context, *node) || ports.empty()) {
 | 
			
		||||
    return Status::OK();
 | 
			
		||||
  }
 | 
			
		||||
  VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
 | 
			
		||||
          << "' with op '" << node->GetOp() << "' from data format '"
 | 
			
		||||
          << context->src_format << "' to '" << context->dst_format << "'";
 | 
			
		||||
  TF_RETURN_IF_ERROR(
 | 
			
		||||
      UpdateFaninEdgesWithOp(context, ports, node, kOpTranspose));
 | 
			
		||||
  TF_RETURN_IF_ERROR(
 | 
			
		||||
@ -1546,11 +1648,19 @@ Status ShapeNTransposer::TransposeNode(TransposeContext* context,
 | 
			
		||||
Status SliceTransposer::TransposeNode(TransposeContext* context,
 | 
			
		||||
                                      utils::MutableNodeView* node) {
 | 
			
		||||
  DCHECK(IsSlice(*node->node()));
 | 
			
		||||
  if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4) ||
 | 
			
		||||
      !IsFaninPortsDimsNIfConst(*node, {1, 2}, {4}) ||
 | 
			
		||||
  const int rank = GetFanoutPortRank(*node, 0);
 | 
			
		||||
  if (rank != 4 && rank != 5) {
 | 
			
		||||
    return Status::OK();
 | 
			
		||||
  }
 | 
			
		||||
  ScopedDataFormatUpgrader data_format_upgrader(context, rank);
 | 
			
		||||
  if (!ShouldProcess(*context, *node) ||
 | 
			
		||||
      !IsFaninPortsDimsNIfConst(*node, {1, 2}, {rank}) ||
 | 
			
		||||
      !IsAfterDstToSrcTransform(*context, *node)) {
 | 
			
		||||
    return Status::OK();
 | 
			
		||||
  }
 | 
			
		||||
  VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
 | 
			
		||||
          << "' with op '" << node->GetOp() << "' from data format '"
 | 
			
		||||
          << context->src_format << "' to '" << context->dst_format << "'";
 | 
			
		||||
  TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
 | 
			
		||||
  TF_RETURN_IF_ERROR(
 | 
			
		||||
      UpdateFaninEdgesWithOp(context, {1, 2}, node, kOpDataFormatVecPermute));
 | 
			
		||||
@ -1839,18 +1949,17 @@ string GetDeviceName(const VirtualPlacer* virtual_placer, const NodeDef& node) {
 | 
			
		||||
bool IsDefaultLayoutSensitiveOp(const NodeDef& node) {
 | 
			
		||||
  static absl::flat_hash_set<string>* default_layout_sensitive_ops =
 | 
			
		||||
      new absl::flat_hash_set<std::string>(
 | 
			
		||||
          {"AvgPool", "BiasAdd", "Conv2D", "DepthwiseConv2dNative",
 | 
			
		||||
           "DepthToSpace", "FusedBatchNorm", "FusedBatchNormV2",
 | 
			
		||||
           "FusedBatchNormV3", "FusedConv2DBiasActivation", "MaxPool",
 | 
			
		||||
           "SpaceToDepth"});
 | 
			
		||||
          {"AvgPool", "Conv2D", "DepthwiseConv2dNative", "DepthToSpace",
 | 
			
		||||
           "FusedBatchNorm", "FusedBatchNormV2", "FusedBatchNormV3",
 | 
			
		||||
           "FusedConv2DBiasActivation", "MaxPool", "SpaceToDepth"});
 | 
			
		||||
  return default_layout_sensitive_ops->find(node.op()) !=
 | 
			
		||||
         default_layout_sensitive_ops->end();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool IsLayoutSensitiveOp(const NodeDef& node) {
 | 
			
		||||
  return IsDefaultLayoutSensitiveOp(node) || IsAvgPoolGrad(node) ||
 | 
			
		||||
         IsBiasAddGrad(node) || IsConv2DBackpropFilter(node) ||
 | 
			
		||||
         IsConv2DBackpropInput(node) ||
 | 
			
		||||
         IsBiasAdd(node) || IsBiasAddGrad(node) ||
 | 
			
		||||
         IsConv2DBackpropFilter(node) || IsConv2DBackpropInput(node) ||
 | 
			
		||||
         IsDepthwiseConv2dNativeBackpropFilter(node) ||
 | 
			
		||||
         IsDepthwiseConv2dNativeBackpropInput(node) ||
 | 
			
		||||
         IsFusedBatchNormEx(node) || IsFusedBatchNormGrad(node) ||
 | 
			
		||||
 | 
			
		||||
@ -210,6 +210,14 @@ class DefaultLayoutSensitiveOpTransposer : public LayoutSensitiveOpTransposer {
 | 
			
		||||
                       utils::MutableNodeView* node) override;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
class BiasAddTransposer : public LayoutSensitiveOpTransposer {
 | 
			
		||||
 public:
 | 
			
		||||
  explicit BiasAddTransposer() : LayoutSensitiveOpTransposer() {}
 | 
			
		||||
 | 
			
		||||
  Status TransposeNode(TransposeContext* context,
 | 
			
		||||
                       utils::MutableNodeView* node) override;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
class AvgPoolGradTransposer : public LayoutSensitiveOpTransposer {
 | 
			
		||||
 public:
 | 
			
		||||
  explicit AvgPoolGradTransposer() : LayoutSensitiveOpTransposer() {}
 | 
			
		||||
@ -319,9 +327,9 @@ class LayoutAgnosticOpTransposer : public Transposer {
 | 
			
		||||
  bool IsAfterDstToSrcTransform(const TransposeContext& context,
 | 
			
		||||
                                const utils::MutableNodeView& node) const;
 | 
			
		||||
 | 
			
		||||
  std::vector<int> GetVariadic4DFaninPorts(
 | 
			
		||||
      const TransposeContext& context,
 | 
			
		||||
      const utils::MutableNodeView& node) const;
 | 
			
		||||
  std::vector<int> GetVariadicNDFaninPorts(const TransposeContext& context,
 | 
			
		||||
                                           const utils::MutableNodeView& node,
 | 
			
		||||
                                           int rank) const;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
class DefaultLayoutAgnosticOpTransposer : public LayoutAgnosticOpTransposer {
 | 
			
		||||
 | 
			
		||||
@ -27,6 +27,9 @@ std::shared_ptr<Transposer> TransposerFactory::GetTransposer(
 | 
			
		||||
    return GetOrCreateIfNotFound<DefaultLayoutSensitiveOpTransposer>(
 | 
			
		||||
        "DefaultLayoutSensitiveOp");
 | 
			
		||||
  }
 | 
			
		||||
  if (IsBiasAdd(node)) {
 | 
			
		||||
    return GetOrCreateIfNotFound<BiasAddTransposer>("BiasAdd");
 | 
			
		||||
  }
 | 
			
		||||
  if (IsAvgPoolGrad(node)) {
 | 
			
		||||
    return GetOrCreateIfNotFound<AvgPoolGradTransposer>("AvgPoolGrad");
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
@ -48,7 +48,6 @@ load(
 | 
			
		||||
)
 | 
			
		||||
load(
 | 
			
		||||
    "//third_party/mkl:build_defs.bzl",
 | 
			
		||||
    "if_mkl_ml",
 | 
			
		||||
    "mkl_deps",
 | 
			
		||||
)
 | 
			
		||||
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
 | 
			
		||||
@ -3241,7 +3240,6 @@ cc_library(
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":aggregate_ops",
 | 
			
		||||
        ":argmax_op",
 | 
			
		||||
        ":batch_matmul_op",
 | 
			
		||||
        ":betainc_op",
 | 
			
		||||
        ":bincount_op",
 | 
			
		||||
        ":bucketize_op",
 | 
			
		||||
@ -3337,14 +3335,27 @@ tf_kernel_library(
 | 
			
		||||
 | 
			
		||||
tf_kernel_library(
 | 
			
		||||
    name = "batch_matmul_op",
 | 
			
		||||
    deps = [":matmul_op"],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
tf_kernel_library(
 | 
			
		||||
    name = "matmul_op",
 | 
			
		||||
    # <prefix>*impl.h are excluded by default from the CPU build, add explicitly.
 | 
			
		||||
    hdrs = ["batch_matmul_op_impl.h"],
 | 
			
		||||
    prefix = "batch_matmul_op",
 | 
			
		||||
    deps = MATH_DEPS + [":eigen_contraction_kernel"] + if_mkl_ml([
 | 
			
		||||
        "//third_party/mkl:intel_binary_blob",
 | 
			
		||||
    ]) + if_cuda_or_rocm([
 | 
			
		||||
        "//tensorflow/core/kernels:gpu_utils",
 | 
			
		||||
    ]),
 | 
			
		||||
    hdrs = ["matmul_op_impl.h"],
 | 
			
		||||
    defines = select({
 | 
			
		||||
        ":xsmm": ["TENSORFLOW_USE_LIBXSMM"],
 | 
			
		||||
        "//conditions:default": [],
 | 
			
		||||
    }),
 | 
			
		||||
    prefix = "matmul_op",
 | 
			
		||||
    deps = MATH_DEPS + [
 | 
			
		||||
        ":eigen_contraction_kernel",
 | 
			
		||||
        ":fused_eigen_output_kernels",
 | 
			
		||||
    ] + select({
 | 
			
		||||
        ":xsmm": ["@libxsmm_archive//:xsmm_avx"],
 | 
			
		||||
        "//conditions:default": [],
 | 
			
		||||
    }) + mkl_deps() + if_cuda([
 | 
			
		||||
        "//tensorflow/core/platform/default/build_config:cublas_plugin",
 | 
			
		||||
    ]) + if_cuda_or_rocm([":gpu_utils"]),
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
tf_kernel_library(
 | 
			
		||||
@ -3406,28 +3417,6 @@ tf_kernel_library(
 | 
			
		||||
    ]),
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
tf_kernel_library(
 | 
			
		||||
    name = "matmul_op",
 | 
			
		||||
    srcs = [
 | 
			
		||||
        "matmul_op.cc",
 | 
			
		||||
        "matmul_op_fused.cc",
 | 
			
		||||
    ],
 | 
			
		||||
    hdrs = ["matmul_op.h"],
 | 
			
		||||
    defines = select({
 | 
			
		||||
        ":xsmm": ["TENSORFLOW_USE_LIBXSMM"],
 | 
			
		||||
        "//conditions:default": [],
 | 
			
		||||
    }),
 | 
			
		||||
    deps = MATH_DEPS + [
 | 
			
		||||
        ":eigen_contraction_kernel",
 | 
			
		||||
        ":fused_eigen_output_kernels",
 | 
			
		||||
    ] + select({
 | 
			
		||||
        ":xsmm": ["@libxsmm_archive//:xsmm_avx"],
 | 
			
		||||
        "//conditions:default": [],
 | 
			
		||||
    }) + mkl_deps() + if_cuda([
 | 
			
		||||
        "//tensorflow/core/platform/default/build_config:cublas_plugin",
 | 
			
		||||
    ]) + if_cuda_or_rocm([":gpu_utils"]),
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
tf_kernel_library(
 | 
			
		||||
    name = "reduction_ops",
 | 
			
		||||
    gpu_srcs = ["reduction_gpu_kernels.cu.h"],
 | 
			
		||||
@ -3620,25 +3609,6 @@ tf_cuda_cc_test(
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
tf_cuda_cc_test(
 | 
			
		||||
    name = "batch_matmul_op_test",
 | 
			
		||||
    size = "small",
 | 
			
		||||
    srcs = ["batch_matmul_op_test.cc"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":batch_matmul_op",
 | 
			
		||||
        ":broadcast_to_op",
 | 
			
		||||
        ":ops_testutil",
 | 
			
		||||
        ":ops_util",
 | 
			
		||||
        "//tensorflow/core:core_cpu",
 | 
			
		||||
        "//tensorflow/core:framework",
 | 
			
		||||
        "//tensorflow/core:lib",
 | 
			
		||||
        "//tensorflow/core:protos_all_cc",
 | 
			
		||||
        "//tensorflow/core:test",
 | 
			
		||||
        "//tensorflow/core:test_main",
 | 
			
		||||
        "//tensorflow/core:testlib",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
tf_cuda_cc_test(
 | 
			
		||||
    name = "scan_ops_test",
 | 
			
		||||
    size = "small",
 | 
			
		||||
@ -5868,8 +5838,8 @@ filegroup(
 | 
			
		||||
        "identity_op.h",
 | 
			
		||||
        "immutable_constant_op.cc",
 | 
			
		||||
        "immutable_constant_op.h",
 | 
			
		||||
        "matmul_op.cc",
 | 
			
		||||
        "matmul_op.h",
 | 
			
		||||
        "matmul_op_impl.h",
 | 
			
		||||
        "matmul_op_real.cc",
 | 
			
		||||
        "no_op.cc",
 | 
			
		||||
        "no_op.h",
 | 
			
		||||
        "one_hot_op.cc",
 | 
			
		||||
@ -5948,7 +5918,6 @@ filegroup(
 | 
			
		||||
    srcs = [
 | 
			
		||||
        "argmax_op.h",
 | 
			
		||||
        "avgpooling_op.h",
 | 
			
		||||
        "batch_matmul_op_impl.h",
 | 
			
		||||
        "batch_norm_op.h",
 | 
			
		||||
        "bincount_op.h",
 | 
			
		||||
        "broadcast_to_op.h",
 | 
			
		||||
@ -6039,7 +6008,6 @@ filegroup(
 | 
			
		||||
        ":android_extended_ops_headers",
 | 
			
		||||
        "argmax_op.cc",
 | 
			
		||||
        "avgpooling_op.cc",
 | 
			
		||||
        "batch_matmul_op_real.cc",
 | 
			
		||||
        "batch_norm_op.cc",
 | 
			
		||||
        "bcast_ops.cc",
 | 
			
		||||
        "check_numerics_op.cc",
 | 
			
		||||
@ -6480,6 +6448,7 @@ cc_library(
 | 
			
		||||
        "//tensorflow/core/platform:strong_hash",
 | 
			
		||||
        "//third_party/eigen3",
 | 
			
		||||
        "//third_party/fft2d:fft2d_headers",
 | 
			
		||||
        "//third_party/icu/data:conversion_data",
 | 
			
		||||
        "@com_google_absl//absl/base",
 | 
			
		||||
        "@com_google_protobuf//:protobuf",
 | 
			
		||||
        "@fft2d",
 | 
			
		||||
@ -7431,7 +7400,6 @@ test_suite(
 | 
			
		||||
        "manual",  # Avoid redundancy when using wildcard test patterns.
 | 
			
		||||
    ],
 | 
			
		||||
    tests = [
 | 
			
		||||
        ":batch_matmul_op_test",
 | 
			
		||||
        ":batch_norm_op_test",
 | 
			
		||||
        ":broadcast_to_op_test",
 | 
			
		||||
        ":cast_op_test",
 | 
			
		||||
 | 
			
		||||
@ -1,257 +0,0 @@
 | 
			
		||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
 | 
			
		||||
#include "tensorflow/core/framework/op.h"
 | 
			
		||||
#include "tensorflow/core/framework/tensor.h"
 | 
			
		||||
#include "tensorflow/core/framework/tensor_shape.h"
 | 
			
		||||
#include "tensorflow/core/framework/types.pb.h"
 | 
			
		||||
#include "tensorflow/core/graph/graph.h"
 | 
			
		||||
#include "tensorflow/core/graph/node_builder.h"
 | 
			
		||||
#include "tensorflow/core/graph/testlib.h"
 | 
			
		||||
#include "tensorflow/core/kernels/broadcast_to_op.h"
 | 
			
		||||
#include "tensorflow/core/lib/core/status.h"
 | 
			
		||||
#include "tensorflow/core/platform/test.h"
 | 
			
		||||
#include "tensorflow/core/platform/test_benchmark.h"
 | 
			
		||||
 | 
			
		||||
namespace tensorflow {
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
Node* BroadcastTo(Graph* g, Node* input, Node* shape) {
 | 
			
		||||
  Node* ret;
 | 
			
		||||
  TF_CHECK_OK(NodeBuilder(g->NewName("n"), "BroadcastTo")
 | 
			
		||||
                  .Input(input)
 | 
			
		||||
                  .Input(shape)
 | 
			
		||||
                  .Finalize(g, &ret));
 | 
			
		||||
  return ret;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Node* BatchMatmulV2(Graph* g, Node* in0, Node* in1, bool adj_x, bool adj_y) {
 | 
			
		||||
  Node* ret;
 | 
			
		||||
  TF_CHECK_OK(NodeBuilder(g->NewName("n"), "BatchMatMulV2")
 | 
			
		||||
                  .Input(in0)
 | 
			
		||||
                  .Input(in1)
 | 
			
		||||
                  .Attr("adj_x", adj_x)
 | 
			
		||||
                  .Attr("adj_y", adj_y)
 | 
			
		||||
                  .Finalize(g, &ret));
 | 
			
		||||
  return ret;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename T>
 | 
			
		||||
static Graph* BatchMatmul(int b, int m, int k, int n, bool adjoint_a,
 | 
			
		||||
                          bool adjoint_b, DataType type) {
 | 
			
		||||
  Graph* g = new Graph(OpRegistry::Global());
 | 
			
		||||
  Tensor in0(type, adjoint_a ? TensorShape({b, k, m}) : TensorShape({b, m, k}));
 | 
			
		||||
  in0.flat<T>().setRandom();
 | 
			
		||||
  Tensor in1(type, adjoint_b ? TensorShape({b, n, k}) : TensorShape({b, k, n}));
 | 
			
		||||
  in1.flat<T>().setRandom();
 | 
			
		||||
  test::graph::BatchMatmul(g, test::graph::Constant(g, in0),
 | 
			
		||||
                           test::graph::Constant(g, in1), adjoint_a, adjoint_b);
 | 
			
		||||
  return g;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename T>
 | 
			
		||||
static Graph* BatchMatmulWithBroadcast(int b0, int b1, int m, int k, int n,
 | 
			
		||||
                                       bool manual_broadcast, DataType type) {
 | 
			
		||||
  Graph* g = new Graph(OpRegistry::Global());
 | 
			
		||||
  Tensor in0(type, TensorShape({b0, m, k}));
 | 
			
		||||
  in0.flat<T>().setRandom();
 | 
			
		||||
  Tensor in1(type, TensorShape({b1, k, n}));
 | 
			
		||||
  in1.flat<T>().setRandom();
 | 
			
		||||
 | 
			
		||||
  Tensor broadcasted_in0_shape(DT_INT64, TensorShape({3}));
 | 
			
		||||
  Tensor broadcasted_in1_shape(DT_INT64, TensorShape({3}));
 | 
			
		||||
 | 
			
		||||
  Node* in0_node = nullptr;
 | 
			
		||||
  Node* in1_node = nullptr;
 | 
			
		||||
  if (manual_broadcast) {
 | 
			
		||||
    for (int i = 0; i < 3; ++i) {
 | 
			
		||||
      auto vec0 = broadcasted_in0_shape.vec<int64>();
 | 
			
		||||
      auto vec1 = broadcasted_in1_shape.vec<int64>();
 | 
			
		||||
      vec0(i) = (i == 0 ? std::max(b0, b1) : in0.shape().dim_size(i));
 | 
			
		||||
      vec1(i) = (i == 0 ? std::max(b0, b1) : in1.shape().dim_size(i));
 | 
			
		||||
    }
 | 
			
		||||
    in0_node = BroadcastTo(g, test::graph::Constant(g, in0),
 | 
			
		||||
                           test::graph::Constant(g, broadcasted_in0_shape));
 | 
			
		||||
    in1_node = BroadcastTo(g, test::graph::Constant(g, in1),
 | 
			
		||||
                           test::graph::Constant(g, broadcasted_in1_shape));
 | 
			
		||||
  } else {
 | 
			
		||||
    in0_node = test::graph::Constant(g, in0);
 | 
			
		||||
    in1_node = test::graph::Constant(g, in1);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  BatchMatmulV2(g, in0_node, in1_node, false, false);
 | 
			
		||||
  return g;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#define BM_BatchMatmulDev(B, M, K, N, TA, TB, T, TFTYPE, DEVICE)                  \
 | 
			
		||||
  static void                                                                     \
 | 
			
		||||
      BM_BatchMatmul##_##B##_##M##_##K##_##N##_##TA##_##TB##_##TFTYPE##_##DEVICE( \
 | 
			
		||||
          int iters) {                                                            \
 | 
			
		||||
    testing::UseRealTime();                                                       \
 | 
			
		||||
    testing::ItemsProcessed(static_cast<int64>(iters) * B * M * K * N * 2);       \
 | 
			
		||||
    test::Benchmark(#DEVICE, BatchMatmul<T>(B, M, K, N, TA, TB, TFTYPE))          \
 | 
			
		||||
        .Run(iters);                                                              \
 | 
			
		||||
  }                                                                               \
 | 
			
		||||
  BENCHMARK(                                                                      \
 | 
			
		||||
      BM_BatchMatmul##_##B##_##M##_##K##_##N##_##TA##_##TB##_##TFTYPE##_##DEVICE);
 | 
			
		||||
 | 
			
		||||
#define BM_BatchMatmul(B, M, K, N, TA, TB) \
 | 
			
		||||
  BM_BatchMatmulDev(B, M, K, N, TA, TB, float, DT_FLOAT, cpu);
 | 
			
		||||
// BM_BatchMatmulDev(B, M, K, N, TA, TB, std::complex<float>, DT_COMPLEX64,
 | 
			
		||||
// cpu);
 | 
			
		||||
//  BM_BatchMatmulDev(B, M, K, N, TA, TB, float, DT_FLOAT, gpu);
 | 
			
		||||
/* Uncomment to enable benchmarks for double & complex types: */
 | 
			
		||||
// BM_BatchMatmulDev(B, M, K, N, TA, TB, std::complex<float>, DT_COMPLEX64,
 | 
			
		||||
// gpu);
 | 
			
		||||
// BM_BatchMatmulDev(M, K, N, TA, TB, double, DT_DOUBLE, cpu); \
 | 
			
		||||
// BM_BatchMatmulDev(M, K, N, TA, TB, std::complex<double>, DT_COMPLEX128, cpu);
 | 
			
		||||
// \
 | 
			
		||||
// BM_BatchMatmulDev(M, K, N, TA, TB, double, DT_DOUBLE, gpu); \
 | 
			
		||||
// BM_BatchMatmulDev(M, K, N, TA, TB, std::complex<double>, DT_COMPLEX128, gpu);
 | 
			
		||||
 | 
			
		||||
// Macro arguments names: --------------------------------------------------- //
 | 
			
		||||
//   B1: batch size of LHS
 | 
			
		||||
//   B2: batch size of RHS
 | 
			
		||||
//    M: outer dimension of LHS
 | 
			
		||||
//    K: inner dimensions of LHS and RHS
 | 
			
		||||
//    N: outer dimension of RHS
 | 
			
		||||
//   MB: boolean indicating whether to use manual broadcasting
 | 
			
		||||
//    T: C++ type of scalars (e.g. float, std::complex)
 | 
			
		||||
//   TT: TensorFlow type of scalars (e.g. DT_FLOAT, DT_COMPLEX128
 | 
			
		||||
//    D: Device (e.g. cpu, gpu)
 | 
			
		||||
#define BM_BatchMatmulBCastDev(B1, B2, M, K, N, MB, T, TT, D)                  \
 | 
			
		||||
  static void                                                                  \
 | 
			
		||||
      BM_BatchMatmulBCast##_##B1##_##B2##_##M##_##K##_##N##_##MB##_##TT##_##D( \
 | 
			
		||||
          int iters) {                                                         \
 | 
			
		||||
    testing::UseRealTime();                                                    \
 | 
			
		||||
    testing::ItemsProcessed(static_cast<int64>(iters) * std::max(B1, B2) * M * \
 | 
			
		||||
                            K * N * 2);                                        \
 | 
			
		||||
    test::Benchmark(#D, BatchMatmulWithBroadcast<T>(B1, B2, M, K, N, MB, TT))  \
 | 
			
		||||
        .Run(iters);                                                           \
 | 
			
		||||
  }                                                                            \
 | 
			
		||||
  BENCHMARK(                                                                   \
 | 
			
		||||
      BM_BatchMatmulBCast##_##B1##_##B2##_##M##_##K##_##N##_##MB##_##TT##_##D);
 | 
			
		||||
 | 
			
		||||
#define BM_BatchMatmulBCast(B1, B2, M, K, N, MB) \
 | 
			
		||||
  BM_BatchMatmulBCastDev(B1, B2, M, K, N, MB, float, DT_FLOAT, cpu);
 | 
			
		||||
 | 
			
		||||
// Typical fully connected layers
 | 
			
		||||
BM_BatchMatmulBCast(1, 128, 1, 1024, 1024, true);
 | 
			
		||||
BM_BatchMatmulBCast(1, 128, 1, 1024, 1024, false);
 | 
			
		||||
BM_BatchMatmulBCast(128, 1, 1, 1024, 1024, true);
 | 
			
		||||
BM_BatchMatmulBCast(128, 1, 1, 1024, 1024, false);
 | 
			
		||||
BM_BatchMatmulBCast(1, 128, 128, 1024, 1024, true);
 | 
			
		||||
BM_BatchMatmulBCast(1, 128, 128, 1024, 1024, false);
 | 
			
		||||
BM_BatchMatmulBCast(128, 1, 128, 1024, 1024, true);
 | 
			
		||||
BM_BatchMatmulBCast(128, 1, 128, 1024, 1024, false);
 | 
			
		||||
 | 
			
		||||
// Square matmul.
 | 
			
		||||
BM_BatchMatmulBCast(1, 128, 512, 512, 512, true);
 | 
			
		||||
BM_BatchMatmulBCast(1, 128, 512, 512, 512, false);
 | 
			
		||||
BM_BatchMatmulBCast(128, 1, 512, 512, 512, true);
 | 
			
		||||
BM_BatchMatmulBCast(128, 1, 512, 512, 512, false);
 | 
			
		||||
BM_BatchMatmulBCast(1, 128, 1024, 1024, 1024, true);
 | 
			
		||||
BM_BatchMatmulBCast(1, 128, 1024, 1024, 1024, false);
 | 
			
		||||
BM_BatchMatmulBCast(128, 1, 1024, 1024, 1024, true);
 | 
			
		||||
BM_BatchMatmulBCast(128, 1, 1024, 1024, 1024, false);
 | 
			
		||||
 | 
			
		||||
// Matrix-vector multiplies.
 | 
			
		||||
BM_BatchMatmulBCast(1, 128, 10000, 200, 1, true);
 | 
			
		||||
BM_BatchMatmulBCast(1, 128, 10000, 200, 1, false);
 | 
			
		||||
BM_BatchMatmulBCast(128, 1, 10000, 200, 1, true);
 | 
			
		||||
BM_BatchMatmulBCast(128, 1, 10000, 200, 1, false);
 | 
			
		||||
 | 
			
		||||
// Vector-matrix multiplies.
 | 
			
		||||
BM_BatchMatmulBCast(1, 128, 1, 200, 10000, true);
 | 
			
		||||
BM_BatchMatmulBCast(1, 128, 1, 200, 10000, false);
 | 
			
		||||
BM_BatchMatmulBCast(128, 1, 1, 200, 10000, true);
 | 
			
		||||
BM_BatchMatmulBCast(128, 1, 1, 200, 10000, false);
 | 
			
		||||
 | 
			
		||||
// Typical fully connected layers
 | 
			
		||||
BM_BatchMatmul(1, 1, 1024, 1024, false, false);
 | 
			
		||||
BM_BatchMatmul(1, 8, 1024, 1024, false, false);
 | 
			
		||||
BM_BatchMatmul(1, 16, 1024, 1024, false, false);
 | 
			
		||||
BM_BatchMatmul(1, 128, 1024, 1024, false, false);
 | 
			
		||||
BM_BatchMatmul(2, 1, 1024, 1024, false, false);
 | 
			
		||||
BM_BatchMatmul(2, 8, 1024, 1024, false, false);
 | 
			
		||||
BM_BatchMatmul(2, 16, 1024, 1024, false, false);
 | 
			
		||||
BM_BatchMatmul(2, 128, 1024, 1024, false, false);
 | 
			
		||||
BM_BatchMatmul(8, 1, 1024, 1024, false, false);
 | 
			
		||||
BM_BatchMatmul(8, 8, 1024, 1024, false, false);
 | 
			
		||||
BM_BatchMatmul(8, 16, 1024, 1024, false, false);
 | 
			
		||||
BM_BatchMatmul(8, 128, 1024, 1024, false, false);
 | 
			
		||||
BM_BatchMatmul(32, 1, 1024, 1024, false, false);
 | 
			
		||||
BM_BatchMatmul(32, 8, 1024, 1024, false, false);
 | 
			
		||||
BM_BatchMatmul(32, 16, 1024, 1024, false, false);
 | 
			
		||||
BM_BatchMatmul(32, 128, 1024, 1024, false, false);
 | 
			
		||||
 | 
			
		||||
// Square matmul.
 | 
			
		||||
BM_BatchMatmul(1, 32, 32, 32, false, false);
 | 
			
		||||
BM_BatchMatmul(1, 128, 128, 128, false, false);
 | 
			
		||||
BM_BatchMatmul(1, 256, 256, 256, false, false);
 | 
			
		||||
BM_BatchMatmul(1, 1024, 1024, 1024, false, false);
 | 
			
		||||
BM_BatchMatmul(1, 2048, 2048, 2048, false, false);
 | 
			
		||||
BM_BatchMatmul(2, 32, 32, 32, false, false);
 | 
			
		||||
BM_BatchMatmul(2, 128, 128, 128, false, false);
 | 
			
		||||
BM_BatchMatmul(2, 256, 256, 256, false, false);
 | 
			
		||||
BM_BatchMatmul(2, 1024, 1024, 1024, false, false);
 | 
			
		||||
BM_BatchMatmul(2, 2048, 2048, 2048, false, false);
 | 
			
		||||
BM_BatchMatmul(4, 32, 32, 32, false, false);
 | 
			
		||||
BM_BatchMatmul(4, 128, 128, 128, false, false);
 | 
			
		||||
BM_BatchMatmul(4, 256, 256, 256, false, false);
 | 
			
		||||
BM_BatchMatmul(4, 1024, 1024, 1024, false, false);
 | 
			
		||||
BM_BatchMatmul(4, 2048, 2048, 2048, false, false);
 | 
			
		||||
BM_BatchMatmul(8, 32, 32, 32, false, false);
 | 
			
		||||
BM_BatchMatmul(8, 128, 128, 128, false, false);
 | 
			
		||||
BM_BatchMatmul(8, 256, 256, 256, false, false);
 | 
			
		||||
BM_BatchMatmul(8, 1024, 1024, 1024, false, false);
 | 
			
		||||
BM_BatchMatmul(8, 2048, 2048, 2048, false, false);
 | 
			
		||||
BM_BatchMatmul(32, 32, 32, 32, false, false);
 | 
			
		||||
BM_BatchMatmul(32, 128, 128, 128, false, false);
 | 
			
		||||
BM_BatchMatmul(32, 256, 256, 256, false, false);
 | 
			
		||||
BM_BatchMatmul(32, 1024, 1024, 1024, false, false);
 | 
			
		||||
BM_BatchMatmul(32, 2048, 2048, 2048, false, false);
 | 
			
		||||
 | 
			
		||||
// Matrix-vector multiplies.
 | 
			
		||||
BM_BatchMatmul(1, 10000, 200, 1, false, false);
 | 
			
		||||
BM_BatchMatmul(8, 10000, 200, 1, false, false);
 | 
			
		||||
BM_BatchMatmul(32, 10000, 200, 1, false, false);
 | 
			
		||||
BM_BatchMatmul(1, 10000, 200, 1, true, false);
 | 
			
		||||
BM_BatchMatmul(8, 10000, 200, 1, true, false);
 | 
			
		||||
BM_BatchMatmul(32, 10000, 200, 1, true, false);
 | 
			
		||||
BM_BatchMatmul(1, 10000, 200, 1, false, true);
 | 
			
		||||
BM_BatchMatmul(8, 10000, 200, 1, false, true);
 | 
			
		||||
BM_BatchMatmul(32, 10000, 200, 1, false, true);
 | 
			
		||||
BM_BatchMatmul(1, 10000, 200, 1, true, true);
 | 
			
		||||
BM_BatchMatmul(8, 10000, 200, 1, true, true);
 | 
			
		||||
BM_BatchMatmul(32, 10000, 200, 1, true, true);
 | 
			
		||||
 | 
			
		||||
// Vector-matrix multiplies.
 | 
			
		||||
BM_BatchMatmul(1, 1, 200, 10000, false, false);
 | 
			
		||||
BM_BatchMatmul(8, 1, 200, 10000, false, false);
 | 
			
		||||
BM_BatchMatmul(32, 1, 200, 10000, false, false);
 | 
			
		||||
BM_BatchMatmul(1, 1, 200, 10000, true, false);
 | 
			
		||||
BM_BatchMatmul(8, 1, 200, 10000, true, false);
 | 
			
		||||
BM_BatchMatmul(32, 1, 200, 10000, true, false);
 | 
			
		||||
BM_BatchMatmul(1, 1, 200, 10000, false, true);
 | 
			
		||||
BM_BatchMatmul(8, 1, 200, 10000, false, true);
 | 
			
		||||
BM_BatchMatmul(32, 1, 200, 10000, false, true);
 | 
			
		||||
BM_BatchMatmul(1, 1, 200, 10000, true, true);
 | 
			
		||||
BM_BatchMatmul(8, 1, 200, 10000, true, true);
 | 
			
		||||
BM_BatchMatmul(32, 1, 200, 10000, true, true);
 | 
			
		||||
 | 
			
		||||
}  // namespace
 | 
			
		||||
}  // namespace tensorflow
 | 
			
		||||
@ -38,6 +38,8 @@ namespace data {
 | 
			
		||||
/* static */ constexpr const char* const CacheDatasetOp::kOutputTypes;
 | 
			
		||||
/* static */ constexpr const char* const CacheDatasetOp::kOutputShapes;
 | 
			
		||||
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
constexpr char kKeyStrFormat[] = "%%%zuzu_%%%zuzu";
 | 
			
		||||
constexpr char kPaddingSizeStrFormat[] = "%zu";
 | 
			
		||||
constexpr char kFileDatasetPrefix[] = "File";
 | 
			
		||||
@ -57,6 +59,14 @@ constexpr char kCacheCompleted[] = "cache_completed";
 | 
			
		||||
constexpr char kIndex[] = "index";
 | 
			
		||||
constexpr char kImpl[] = "Impl";
 | 
			
		||||
constexpr char kCacheDataset[] = "CacheDataset";
 | 
			
		||||
constexpr char kIncompleteCacheErrorMessage[] =
 | 
			
		||||
    "The calling iterator did not fully read the dataset being cached. In "
 | 
			
		||||
    "order to avoid unexpected truncation of the dataset, the partially cached "
 | 
			
		||||
    "contents of the dataset  will be discarded. This can happen if you have "
 | 
			
		||||
    "an input pipeline similar to `dataset.cache().take(k).repeat()`. You "
 | 
			
		||||
    "should use `dataset.take(k).cache().repeat()` instead.";
 | 
			
		||||
 | 
			
		||||
}  // namespace
 | 
			
		||||
 | 
			
		||||
class CacheDatasetOp::FileDatasetBase : public DatasetBase {
 | 
			
		||||
 public:
 | 
			
		||||
@ -220,6 +230,7 @@ class CacheDatasetOp::FileDatasetBase : public DatasetBase {
 | 
			
		||||
 | 
			
		||||
      ~FileWriterIterator() override {
 | 
			
		||||
        if (!dataset()->env_->FileExists(MetaFilename(filename_)).ok()) {
 | 
			
		||||
          LOG(WARNING) << kIncompleteCacheErrorMessage;
 | 
			
		||||
          std::vector<string> cache_files;
 | 
			
		||||
          Status s = dataset()->env_->GetMatchingPaths(
 | 
			
		||||
              strings::StrCat(filename_, "*"), &cache_files);
 | 
			
		||||
@ -754,13 +765,7 @@ class CacheDatasetOp::MemoryDatasetBase : public DatasetBase {
 | 
			
		||||
      ~MemoryWriterIterator() override {
 | 
			
		||||
        mutex_lock l(mu_);
 | 
			
		||||
        if (!temp_cache_.empty() && !cache_->IsCompleted()) {
 | 
			
		||||
          LOG(WARNING)
 | 
			
		||||
              << "The calling iterator did not fully read the dataset being "
 | 
			
		||||
                 "cached. In order to avoid unexpected truncation of the "
 | 
			
		||||
                 "dataset, the partially cached contents of the dataset "
 | 
			
		||||
                 "will be discarded. This can happen if you have an input "
 | 
			
		||||
                 "pipeline similar to `dataset.cache().take(k).repeat()`. "
 | 
			
		||||
                 "You should use `dataset.take(k).cache().repeat()` instead.";
 | 
			
		||||
          LOG(WARNING) << kIncompleteCacheErrorMessage;
 | 
			
		||||
          cache_->Reset();
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
@ -482,7 +482,10 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
 | 
			
		||||
          VLOG(1) << "Failed to get element from worker "
 | 
			
		||||
                  << task_to_process->address << ": " << s;
 | 
			
		||||
          task_to_process->in_use = false;
 | 
			
		||||
          status_ = s;
 | 
			
		||||
          status_ = Status(
 | 
			
		||||
              s.code(),
 | 
			
		||||
              absl::StrCat("Failed to get element from worker ",
 | 
			
		||||
                           task_to_process->address, ": ", s.error_message()));
 | 
			
		||||
          get_next_cv_.notify_all();
 | 
			
		||||
          return;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
@ -188,15 +188,19 @@ static Graph* DynamicPartition(int num_partitions, int dim) {
 | 
			
		||||
  return g;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#define BM_DYNAMIC_PARTITION(DEVICE, T, num)                            \
 | 
			
		||||
  static void BM_##DEVICE##_dynpart_##T##_##num(int iters, int dim) {   \
 | 
			
		||||
    const int64 items = ((128 << 20) / sizeof(T));                      \
 | 
			
		||||
    const int64 tot = static_cast<int64>(iters) * items;                \
 | 
			
		||||
    testing::ItemsProcessed(tot);                                       \
 | 
			
		||||
    testing::UseRealTime();                                             \
 | 
			
		||||
    test::Benchmark(#DEVICE, DynamicPartition<T>(num, dim)).Run(iters); \
 | 
			
		||||
  }                                                                     \
 | 
			
		||||
  BENCHMARK(BM_##DEVICE##_dynpart_##T##_##num)->Arg(1)->Arg(256)
 | 
			
		||||
#define BM_DYNAMIC_PARTITION(DEVICE, T, num)                          \
 | 
			
		||||
  static void BM_##DEVICE##_dynpart_##T##_##num(                      \
 | 
			
		||||
      ::testing::benchmark::State& state) {                           \
 | 
			
		||||
    const int dim = state.range(0);                                   \
 | 
			
		||||
                                                                      \
 | 
			
		||||
    const int64 items = ((128 << 20) / sizeof(T));                    \
 | 
			
		||||
    test::Benchmark(#DEVICE, DynamicPartition<T>(num, dim),           \
 | 
			
		||||
                    /*old_benchmark_api=*/false)                      \
 | 
			
		||||
        .Run(state);                                                  \
 | 
			
		||||
    const int64 tot = static_cast<int64>(state.iterations()) * items; \
 | 
			
		||||
    state.SetItemsProcessed(tot);                                     \
 | 
			
		||||
  }                                                                   \
 | 
			
		||||
  BENCHMARK(BM_##DEVICE##_dynpart_##T##_##num)->UseRealTime()->Arg(1)->Arg(256)
 | 
			
		||||
 | 
			
		||||
BM_DYNAMIC_PARTITION(cpu, float, 2);
 | 
			
		||||
BM_DYNAMIC_PARTITION(cpu, float, 100);
 | 
			
		||||
 | 
			
		||||
@ -1376,7 +1376,7 @@ TEST(EigenSpatialConvolutionsTest, SpatialConvContractionMapper) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename T>
 | 
			
		||||
static void PackRhsHelper(int iters,
 | 
			
		||||
static void PackRhsHelper(::testing::benchmark::State& state,
 | 
			
		||||
                          /* Input dimensions: */
 | 
			
		||||
                          int input_batches, int input_cols, int input_rows,
 | 
			
		||||
                          int input_depth,
 | 
			
		||||
@ -1393,9 +1393,6 @@ static void PackRhsHelper(int iters,
 | 
			
		||||
  // Set random seed for benchmark repeatability.
 | 
			
		||||
  srand(12345);
 | 
			
		||||
 | 
			
		||||
  tensorflow::testing::UseRealTime();
 | 
			
		||||
  tensorflow::testing::StopTiming();
 | 
			
		||||
 | 
			
		||||
  using Dimensions = Eigen::DSizes<Eigen::Index, 4>;
 | 
			
		||||
 | 
			
		||||
  // Default Eigen::Tensor layout is column major, so we configure dimensions
 | 
			
		||||
@ -1547,8 +1544,7 @@ static void PackRhsHelper(int iters,
 | 
			
		||||
    return (idx / packet_size) * packet_size;
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  tensorflow::testing::StartTiming();
 | 
			
		||||
  for (int i = 0; i < iters; ++i) {
 | 
			
		||||
  for (auto s : state) {
 | 
			
		||||
    int input_idx =
 | 
			
		||||
        num_inputs == 1 ? 1 : internal::random<int>(0, num_inputs - 1);
 | 
			
		||||
 | 
			
		||||
@ -1571,15 +1567,15 @@ static void PackRhsHelper(int iters,
 | 
			
		||||
        input_mappers[input_idx].getSubMapper(depth_offset, col_offset);
 | 
			
		||||
    pack_rhs(packed.data() + packed_offset, sub_mapper, depth, cols);
 | 
			
		||||
  }
 | 
			
		||||
  tensorflow::testing::StopTiming();
 | 
			
		||||
  tensorflow::testing::SetLabel(
 | 
			
		||||
 | 
			
		||||
  state.SetLabel(
 | 
			
		||||
      absl::StrCat("patch: ", patch_rows, "x", patch_cols, " D", patch_depth,
 | 
			
		||||
                   "; num_patches=", num_patches, " patch_size=", patch_size,
 | 
			
		||||
                   " num_inputs=", num_inputs, " padding=", padding));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename T>
 | 
			
		||||
static void PackLhsHelper(int iters,
 | 
			
		||||
static void PackLhsHelper(::testing::benchmark::State& state,
 | 
			
		||||
                          /* Input dimensions: */
 | 
			
		||||
                          int input_depth,
 | 
			
		||||
                          /* Filter (kernel) dimensions: */
 | 
			
		||||
@ -1592,9 +1588,6 @@ static void PackLhsHelper(int iters,
 | 
			
		||||
  eigen_assert(block_rows <= filter_count);
 | 
			
		||||
  eigen_assert(block_cols <= input_depth * filter_rows * filter_cols);
 | 
			
		||||
 | 
			
		||||
  tensorflow::testing::UseRealTime();
 | 
			
		||||
  tensorflow::testing::StopTiming();
 | 
			
		||||
 | 
			
		||||
  using Dimensions = Eigen::DSizes<Eigen::Index, 4>;
 | 
			
		||||
 | 
			
		||||
  // Default Eigen::Tensor layout is column major, so we configure dimensions
 | 
			
		||||
@ -1716,8 +1709,7 @@ static void PackLhsHelper(int iters,
 | 
			
		||||
  const Index max_row = filter_count;
 | 
			
		||||
  const Index max_col = filter_rows * filter_cols * input_depth;
 | 
			
		||||
 | 
			
		||||
  tensorflow::testing::StartTiming();
 | 
			
		||||
  for (int i = 0; i < iters; ++i) {
 | 
			
		||||
  for (auto s : state) {
 | 
			
		||||
    int filter_idx =
 | 
			
		||||
        num_filters == 1 ? 1 : internal::random<int>(0, num_filters - 1);
 | 
			
		||||
 | 
			
		||||
@ -1743,8 +1735,7 @@ static void PackLhsHelper(int iters,
 | 
			
		||||
    pack_lhs(packed.data() + packed_offset, sub_mapper, cols, rows);
 | 
			
		||||
#endif
 | 
			
		||||
  }
 | 
			
		||||
  tensorflow::testing::StopTiming();
 | 
			
		||||
  tensorflow::testing::SetLabel(absl::StrCat(
 | 
			
		||||
  state.SetLabel(absl::StrCat(
 | 
			
		||||
      "filter: count=", filter_count, " dims=", filter_rows, "x", filter_cols,
 | 
			
		||||
      "; input: depth=", input_depth, "; num_filers=", num_filters));
 | 
			
		||||
}
 | 
			
		||||
@ -1777,12 +1768,14 @@ static void PackLhsHelper(int iters,
 | 
			
		||||
 | 
			
		||||
#define BM_PackRhs(T, N, H, W, C, FC, FH, FW, PAD, SH, SW, ISH, ISW, BR, BC)  \
 | 
			
		||||
  static void BM_RHS_NAME(PackRhs, T, N, H, W, C, FC, FH, FW, PAD, SH, SW,    \
 | 
			
		||||
                          ISH, ISW, BR, BC)(int iters) {                      \
 | 
			
		||||
    PackRhsHelper<T>(iters, N, H, W, C, FC, FH, FW, PADDING_##PAD, SH, SW,    \
 | 
			
		||||
                          ISH, ISW, BR,                                       \
 | 
			
		||||
                          BC)(::testing::benchmark::State & state) {          \
 | 
			
		||||
    PackRhsHelper<T>(state, N, H, W, C, FC, FH, FW, PADDING_##PAD, SH, SW,    \
 | 
			
		||||
                     ISH, ISW, BR, BC);                                       \
 | 
			
		||||
  }                                                                           \
 | 
			
		||||
  BENCHMARK(BM_RHS_NAME(PackRhs, T, N, H, W, C, FC, FH, FW, PAD, SH, SW, ISH, \
 | 
			
		||||
                        ISW, BR, BC))
 | 
			
		||||
                        ISW, BR, BC))                                         \
 | 
			
		||||
      ->UseRealTime()
 | 
			
		||||
 | 
			
		||||
// Number of input channel (input depth) it equal to the number of patch
 | 
			
		||||
// channels (patch depth).
 | 
			
		||||
@ -2019,11 +2012,12 @@ BM_PackRhs(/*type*/ qint8,                 //
 | 
			
		||||
#define BM_LHS_NAME(prefix, T, C, FC, FH, FW, BR, BC) \
 | 
			
		||||
  BM_CONCAT(BM_##prefix##_##T##_##C##_FC##FC##_##FH##x##FW, _B##BR##x##BC)
 | 
			
		||||
 | 
			
		||||
#define BM_PackLhs(T, C, FC, FH, FW, BR, BC)                              \
 | 
			
		||||
  static void BM_LHS_NAME(PackLhs, T, C, FC, FH, FW, BR, BC)(int iters) { \
 | 
			
		||||
    PackLhsHelper<T>(iters, C, FC, FH, FW, BR, BC);                       \
 | 
			
		||||
  }                                                                       \
 | 
			
		||||
  BENCHMARK(BM_LHS_NAME(PackLhs, T, C, FC, FH, FW, BR, BC))
 | 
			
		||||
#define BM_PackLhs(T, C, FC, FH, FW, BR, BC)                         \
 | 
			
		||||
  static void BM_LHS_NAME(PackLhs, T, C, FC, FH, FW, BR,             \
 | 
			
		||||
                          BC)(::testing::benchmark::State & state) { \
 | 
			
		||||
    PackLhsHelper<T>(state, C, FC, FH, FW, BR, BC);                  \
 | 
			
		||||
  }                                                                  \
 | 
			
		||||
  BENCHMARK(BM_LHS_NAME(PackLhs, T, C, FC, FH, FW, BR, BC))->UseRealTime()
 | 
			
		||||
 | 
			
		||||
// Number of input channel (input depth) it equal to the number of patch
 | 
			
		||||
// channels (patch depth).
 | 
			
		||||
 | 
			
		||||
@ -349,15 +349,16 @@ typedef BenchmarkOptions<ExampleStore<FloatFiller>, kRagged> RaggedFloat;
 | 
			
		||||
// B == batch_size, K == num_keys. F == feature_size.
 | 
			
		||||
// K must be one of 10, 100, 1000
 | 
			
		||||
#define BM_ParseExample(TYPE, B, K, F)                                    \
 | 
			
		||||
  static void BM_ParseExample##_##TYPE##_##B##_##K##_##F(int iters) {     \
 | 
			
		||||
  static void BM_ParseExample##_##TYPE##_##B##_##K##_##F(                 \
 | 
			
		||||
      ::testing::benchmark::State& state) {                               \
 | 
			
		||||
    int64 items_per_iter = static_cast<int64>(B) * K * F;                 \
 | 
			
		||||
    testing::UseRealTime();                                               \
 | 
			
		||||
    testing::ItemsProcessed(static_cast<int64>(iters) * items_per_iter);  \
 | 
			
		||||
    test::Benchmark("cpu", ParseExample<TYPE>(B, K, F), nullptr, nullptr, \
 | 
			
		||||
                    nullptr, "SINGLE_THREADED_EXECUTOR")                  \
 | 
			
		||||
        .Run(iters);                                                      \
 | 
			
		||||
                    nullptr, "SINGLE_THREADED_EXECUTOR", false)           \
 | 
			
		||||
        .Run(state);                                                      \
 | 
			
		||||
    state.SetItemsProcessed(static_cast<int64>(state.iterations()) *      \
 | 
			
		||||
                            items_per_iter);                              \
 | 
			
		||||
  }                                                                       \
 | 
			
		||||
  BENCHMARK(BM_ParseExample##_##TYPE##_##B##_##K##_##F);
 | 
			
		||||
  BENCHMARK(BM_ParseExample##_##TYPE##_##B##_##K##_##F)->UseRealTime();
 | 
			
		||||
 | 
			
		||||
#define BM_AllParseExample(Type)       \
 | 
			
		||||
  BM_ParseExample(Type, 1, 10, 1);     \
 | 
			
		||||
@ -385,15 +386,17 @@ BM_AllParseExample(VarLenDenseFloat);
 | 
			
		||||
// K must be one of 10, 100, 1000
 | 
			
		||||
// B=0 indicates that a scalar input should be used (instead of a vector).
 | 
			
		||||
#define BM_ParseExampleV2(TYPE, B, K, F)                                    \
 | 
			
		||||
  static void BM_ParseExampleV2##_##TYPE##_##B##_##K##_##F(int iters) {     \
 | 
			
		||||
  static void BM_ParseExampleV2##_##TYPE##_##B##_##K##_##F(                 \
 | 
			
		||||
      ::testing::benchmark::State& state) {                                 \
 | 
			
		||||
    int64 items_per_iter = static_cast<int64>(std::max(B, 1)) * K * F;      \
 | 
			
		||||
    testing::UseRealTime();                                                 \
 | 
			
		||||
    testing::ItemsProcessed(static_cast<int64>(iters) * items_per_iter);    \
 | 
			
		||||
    test::Benchmark("cpu", ParseExampleV2<TYPE>(B, K, F), nullptr, nullptr, \
 | 
			
		||||
                    nullptr, "SINGLE_THREADED_EXECUTOR")                    \
 | 
			
		||||
        .Run(iters);                                                        \
 | 
			
		||||
                    nullptr, "SINGLE_THREADED_EXECUTOR",                    \
 | 
			
		||||
                    /*old_benchmark_api=*/false)                            \
 | 
			
		||||
        .Run(state);                                                        \
 | 
			
		||||
    state.SetItemsProcessed(static_cast<int64>(state.iterations()) *        \
 | 
			
		||||
                            items_per_iter);                                \
 | 
			
		||||
  }                                                                         \
 | 
			
		||||
  BENCHMARK(BM_ParseExampleV2##_##TYPE##_##B##_##K##_##F);
 | 
			
		||||
  BENCHMARK(BM_ParseExampleV2##_##TYPE##_##B##_##K##_##F)->UseRealTime();
 | 
			
		||||
 | 
			
		||||
#define BM_AllParseExampleV2(Type)        \
 | 
			
		||||
  /* Vector Inputs */                     \
 | 
			
		||||
@ -437,15 +440,17 @@ BM_AllParseExampleV2(RaggedFloat);
 | 
			
		||||
// K == num_keys. F == feature_size.
 | 
			
		||||
// K must be one of 10, 100, 1000
 | 
			
		||||
#define BM_ParseSingleExample(TYPE, K, F)                                    \
 | 
			
		||||
  static void BM_ParseSingleExample##_##TYPE##_1_##K##_##F(int iters) {      \
 | 
			
		||||
  void BM_ParseSingleExample##_##TYPE##_1_##K##_##F(                         \
 | 
			
		||||
      ::testing::benchmark::State& state) {                                  \
 | 
			
		||||
    int64 items_per_iter = K * F;                                            \
 | 
			
		||||
    testing::UseRealTime();                                                  \
 | 
			
		||||
    testing::ItemsProcessed(static_cast<int64>(iters) * items_per_iter);     \
 | 
			
		||||
    test::Benchmark("cpu", ParseSingleExample<TYPE>(K, F), nullptr, nullptr, \
 | 
			
		||||
                    nullptr, "SINGLE_THREADED_EXECUTOR")                     \
 | 
			
		||||
        .Run(iters);                                                         \
 | 
			
		||||
                    nullptr, "SINGLE_THREADED_EXECUTOR",                     \
 | 
			
		||||
                    /*old_benchmark_api=*/false)                             \
 | 
			
		||||
        .Run(state);                                                         \
 | 
			
		||||
    state.SetItemsProcessed(static_cast<int64>(state.iterations()) *         \
 | 
			
		||||
                            items_per_iter);                                 \
 | 
			
		||||
  }                                                                          \
 | 
			
		||||
  BENCHMARK(BM_ParseSingleExample##_##TYPE##_1_##K##_##F);
 | 
			
		||||
  BENCHMARK(BM_ParseSingleExample##_##TYPE##_1_##K##_##F)->UseRealTime();
 | 
			
		||||
 | 
			
		||||
#define BM_AllParseSingleExample(Type)     \
 | 
			
		||||
  BM_ParseSingleExample(Type, 10, 1);      \
 | 
			
		||||
 | 
			
		||||
@ -132,18 +132,22 @@ static Graph* GatherNd(int dim) {
 | 
			
		||||
  return g;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#define BM_GATHER_ND(DEVICE, INDEX)                                 \
 | 
			
		||||
  static void BM_##DEVICE##_gather_nd_##INDEX(int iters, int dim) { \
 | 
			
		||||
    const int64 tot = static_cast<int64>(iters) * kLookups * 4;     \
 | 
			
		||||
    testing::ItemsProcessed(tot);                                   \
 | 
			
		||||
    testing::BytesProcessed(tot * sizeof(float));                   \
 | 
			
		||||
    testing::UseRealTime();                                         \
 | 
			
		||||
    test::Benchmark(#DEVICE, GatherNd<INDEX>(dim)).Run(iters);      \
 | 
			
		||||
  }                                                                 \
 | 
			
		||||
  BENCHMARK(BM_##DEVICE##_gather_nd_##INDEX)                        \
 | 
			
		||||
      ->Arg(10)                                                     \
 | 
			
		||||
      ->Arg(100)                                                    \
 | 
			
		||||
      ->Arg(1000)                                                   \
 | 
			
		||||
#define BM_GATHER_ND(DEVICE, INDEX)                                          \
 | 
			
		||||
  static void BM_##DEVICE##_gather_nd_##INDEX(                               \
 | 
			
		||||
      ::testing::benchmark::State& state) {                                  \
 | 
			
		||||
    const int dim = state.range(0);                                          \
 | 
			
		||||
    test::Benchmark(#DEVICE, GatherNd<INDEX>(dim),                           \
 | 
			
		||||
                    /*old_benchmark_api=*/false)                             \
 | 
			
		||||
        .Run(state);                                                         \
 | 
			
		||||
    const int64 tot = static_cast<int64>(state.iterations()) * kLookups * 4; \
 | 
			
		||||
    state.SetItemsProcessed(tot);                                            \
 | 
			
		||||
    state.SetBytesProcessed(tot * sizeof(float));                            \
 | 
			
		||||
  }                                                                          \
 | 
			
		||||
  BENCHMARK(BM_##DEVICE##_gather_nd_##INDEX)                                 \
 | 
			
		||||
      ->UseRealTime()                                                        \
 | 
			
		||||
      ->Arg(10)                                                              \
 | 
			
		||||
      ->Arg(100)                                                             \
 | 
			
		||||
      ->Arg(1000)                                                            \
 | 
			
		||||
      ->Arg(10000)
 | 
			
		||||
 | 
			
		||||
BM_GATHER_ND(cpu, int32);
 | 
			
		||||
 | 
			
		||||
@ -222,21 +222,24 @@ static Graph* Gather(int dim) {
 | 
			
		||||
  return g;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#define BM_GATHER(DEVICE, INDEX)                                  \
 | 
			
		||||
  static void BM_##DEVICE##_gather_##INDEX(int iters, int dim) {  \
 | 
			
		||||
    const int64 tot = static_cast<int64>(iters) * kLookups * dim; \
 | 
			
		||||
    testing::ItemsProcessed(tot);                                 \
 | 
			
		||||
    testing::BytesProcessed(tot * sizeof(float));                 \
 | 
			
		||||
    testing::UseRealTime();                                       \
 | 
			
		||||
    test::Benchmark(#DEVICE, Gather<INDEX>(dim)).Run(iters);      \
 | 
			
		||||
  }                                                               \
 | 
			
		||||
  BENCHMARK(BM_##DEVICE##_gather_##INDEX)                         \
 | 
			
		||||
      ->Arg(1)                                                    \
 | 
			
		||||
      ->Arg(10)                                                   \
 | 
			
		||||
      ->Arg(20)                                                   \
 | 
			
		||||
      ->Arg(64)                                                   \
 | 
			
		||||
      ->Arg(100)                                                  \
 | 
			
		||||
      ->Arg(200)                                                  \
 | 
			
		||||
#define BM_GATHER(DEVICE, INDEX)                                               \
 | 
			
		||||
  static void BM_##DEVICE##_gather_##INDEX(                                    \
 | 
			
		||||
      ::testing::benchmark::State& state) {                                    \
 | 
			
		||||
    const int dim = state.range(0);                                            \
 | 
			
		||||
    test::Benchmark(#DEVICE, Gather<INDEX>(dim), /*old_benchmark_api=*/false)  \
 | 
			
		||||
        .Run(state);                                                           \
 | 
			
		||||
    const int64 tot = static_cast<int64>(state.iterations()) * kLookups * dim; \
 | 
			
		||||
    state.SetItemsProcessed(tot);                                              \
 | 
			
		||||
    state.SetBytesProcessed(tot * sizeof(float));                              \
 | 
			
		||||
  }                                                                            \
 | 
			
		||||
  BENCHMARK(BM_##DEVICE##_gather_##INDEX)                                      \
 | 
			
		||||
      ->UseRealTime()                                                          \
 | 
			
		||||
      ->Arg(1)                                                                 \
 | 
			
		||||
      ->Arg(10)                                                                \
 | 
			
		||||
      ->Arg(20)                                                                \
 | 
			
		||||
      ->Arg(64)                                                                \
 | 
			
		||||
      ->Arg(100)                                                               \
 | 
			
		||||
      ->Arg(200)                                                               \
 | 
			
		||||
      ->Arg(1000)
 | 
			
		||||
 | 
			
		||||
BM_GATHER(cpu, int32);
 | 
			
		||||
 | 
			
		||||
@ -65,13 +65,16 @@ static Graph* InTopK(int num_targets, int num_classes, T top_k) {
 | 
			
		||||
#define BM_NAME(T, TARGETS, CLASSES, K, DEVICE) \
 | 
			
		||||
  BM_InTopK##_##T##_##TARGETS##_##CLASSES##_##K##_##DEVICE
 | 
			
		||||
 | 
			
		||||
#define BM_InTopK(T, TARGETS, CLASSES, K, DEVICE)                           \
 | 
			
		||||
  static void BM_NAME(T, TARGETS, CLASSES, K, DEVICE)(int iters) {          \
 | 
			
		||||
    testing::UseRealTime();                                                 \
 | 
			
		||||
    testing::ItemsProcessed(static_cast<int64>(iters) * TARGETS * CLASSES); \
 | 
			
		||||
    test::Benchmark(#DEVICE, InTopK<T>(TARGETS, CLASSES, K)).Run(iters);    \
 | 
			
		||||
  }                                                                         \
 | 
			
		||||
  BENCHMARK(BM_NAME(T, TARGETS, CLASSES, K, DEVICE));
 | 
			
		||||
#define BM_InTopK(T, TARGETS, CLASSES, K, DEVICE)                              \
 | 
			
		||||
  static void BM_NAME(T, TARGETS, CLASSES, K,                                  \
 | 
			
		||||
                      DEVICE)(::testing::benchmark::State & state) {           \
 | 
			
		||||
    test::Benchmark(#DEVICE, InTopK<T>(TARGETS, CLASSES, K),                   \
 | 
			
		||||
                    /*old_benchmark_api=*/false)                               \
 | 
			
		||||
        .Run(state);                                                           \
 | 
			
		||||
    state.SetItemsProcessed(static_cast<int64>(state.iterations()) * TARGETS * \
 | 
			
		||||
                            CLASSES);                                          \
 | 
			
		||||
  }                                                                            \
 | 
			
		||||
  BENCHMARK(BM_NAME(T, TARGETS, CLASSES, K, DEVICE))->UseRealTime();
 | 
			
		||||
 | 
			
		||||
BM_InTopK(int64, 64, 1000, 10, cpu);
 | 
			
		||||
BM_InTopK(int64, 64, 10000, 10, cpu);
 | 
			
		||||
 | 
			
		||||
@ -30,9 +30,9 @@ limitations under the License.
 | 
			
		||||
#include "tensorflow/core/framework/tensor.h"
 | 
			
		||||
#include "tensorflow/core/framework/tensor_shape.h"
 | 
			
		||||
#include "tensorflow/core/framework/tensor_types.h"
 | 
			
		||||
#include "tensorflow/core/kernels/batch_matmul_op_impl.h"
 | 
			
		||||
#include "tensorflow/core/kernels/fill_functor.h"
 | 
			
		||||
#include "tensorflow/core/kernels/linalg/einsum_op.h"
 | 
			
		||||
#include "tensorflow/core/kernels/matmul_op_impl.h"
 | 
			
		||||
#include "tensorflow/core/kernels/reduction_ops_common.h"
 | 
			
		||||
#include "tensorflow/core/kernels/transpose_functor.h"
 | 
			
		||||
#include "tensorflow/core/lib/core/errors.h"
 | 
			
		||||
 | 
			
		||||
@ -199,7 +199,7 @@ TCASE(T3, 128,   4,     3,            2.0f, 1.0f,  1.0f)
 | 
			
		||||
 | 
			
		||||
#undef TCASE
 | 
			
		||||
 | 
			
		||||
static Graph* BM_LRNGrad(int batches, int rows, int cols, int depth,
 | 
			
		||||
static Graph* MakeRNGrad(int batches, int rows, int cols, int depth,
 | 
			
		||||
                         int depth_radius) {
 | 
			
		||||
  Graph* g = new Graph(OpRegistry::Global());
 | 
			
		||||
  Tensor grads(DT_FLOAT, TensorShape({batches, rows, cols, depth}));
 | 
			
		||||
@ -223,12 +223,15 @@ static Graph* BM_LRNGrad(int batches, int rows, int cols, int depth,
 | 
			
		||||
  return g;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#define BM_LRNGradDev(DEVICE, B, R, C, D, DR)                                 \
 | 
			
		||||
  static void BM_LRNGrad_##DEVICE##_##B##_##R##_##C##_##D##_##DR(int iters) { \
 | 
			
		||||
    testing::ItemsProcessed(static_cast<int64>(iters) * B * R * C * D * DR *  \
 | 
			
		||||
                            4);                                               \
 | 
			
		||||
    test::Benchmark(#DEVICE, BM_LRNGrad(B, R, C, D, DR)).Run(iters);          \
 | 
			
		||||
  }                                                                           \
 | 
			
		||||
#define BM_LRNGradDev(DEVICE, B, R, C, D, DR)                                \
 | 
			
		||||
  static void BM_LRNGrad_##DEVICE##_##B##_##R##_##C##_##D##_##DR(            \
 | 
			
		||||
      ::testing::benchmark::State& state) {                                  \
 | 
			
		||||
    test::Benchmark(#DEVICE, MakeRNGrad(B, R, C, D, DR),                     \
 | 
			
		||||
                    /*old_benchmark_api*/ false)                             \
 | 
			
		||||
        .Run(state);                                                         \
 | 
			
		||||
    state.SetItemsProcessed(static_cast<int64>(state.iterations()) * B * R * \
 | 
			
		||||
                            C * D * DR * 4);                                 \
 | 
			
		||||
  }                                                                          \
 | 
			
		||||
  BENCHMARK(BM_LRNGrad_##DEVICE##_##B##_##R##_##C##_##D##_##DR)
 | 
			
		||||
 | 
			
		||||
BM_LRNGradDev(cpu, 128, 12, 12, 64, 4);
 | 
			
		||||
 | 
			
		||||
@ -1,567 +0,0 @@
 | 
			
		||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
// See docs in ../ops/math_ops.cc.
 | 
			
		||||
 | 
			
		||||
#define EIGEN_USE_THREADS
 | 
			
		||||
 | 
			
		||||
#include "tensorflow/core/kernels/matmul_op.h"
 | 
			
		||||
 | 
			
		||||
#include "tensorflow/core/framework/op.h"
 | 
			
		||||
#include "tensorflow/core/framework/op_kernel.h"
 | 
			
		||||
#include "tensorflow/core/framework/register_types.h"
 | 
			
		||||
#include "tensorflow/core/kernels/fill_functor.h"
 | 
			
		||||
#include "tensorflow/core/util/matmul_autotune.h"
 | 
			
		||||
#if GOOGLE_CUDA
 | 
			
		||||
#include "third_party/gpus/cuda/include/cuda.h"
 | 
			
		||||
#endif
 | 
			
		||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 | 
			
		||||
#include "tensorflow/core/kernels/gpu_utils.h"
 | 
			
		||||
#include "tensorflow/core/platform/stream_executor.h"
 | 
			
		||||
#endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 | 
			
		||||
 | 
			
		||||
namespace tensorflow {
 | 
			
		||||
 | 
			
		||||
typedef Eigen::ThreadPoolDevice CPUDevice;
 | 
			
		||||
typedef Eigen::GpuDevice GPUDevice;
 | 
			
		||||
 | 
			
		||||
template <typename Device, typename T, bool USE_CUBLAS>
 | 
			
		||||
struct LaunchMatMul;
 | 
			
		||||
 | 
			
		||||
namespace {
 | 
			
		||||
// Converts a TensorFlow Tensor to an Eigen Matrix.
 | 
			
		||||
template <typename T>
 | 
			
		||||
Eigen::Map<
 | 
			
		||||
    const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
 | 
			
		||||
ToEigenMatrix(const Tensor& tensor) {
 | 
			
		||||
  auto matrix = tensor.matrix<T>();
 | 
			
		||||
  return Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>::Map(
 | 
			
		||||
      matrix.data(), matrix.dimension(0), matrix.dimension(1));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Converts a TensorFlow Tensor to an Eigen Vector.
 | 
			
		||||
template <typename T>
 | 
			
		||||
Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, 1>> ToEigenVector(Tensor* tensor) {
 | 
			
		||||
  auto v = tensor->flat<T>();
 | 
			
		||||
  return Eigen::Matrix<T, Eigen::Dynamic, 1>::Map(v.data(), v.dimension(0));
 | 
			
		||||
}
 | 
			
		||||
template <typename T>
 | 
			
		||||
Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, 1>> ToEigenVector(
 | 
			
		||||
    const Tensor& tensor) {
 | 
			
		||||
  auto v = tensor.flat<T>();
 | 
			
		||||
  return Eigen::Matrix<T, Eigen::Dynamic, 1>::Map(v.data(), v.dimension(0));
 | 
			
		||||
}
 | 
			
		||||
}  // namespace
 | 
			
		||||
 | 
			
		||||
// If either side can be represented as a vector, do an explicit vector
 | 
			
		||||
// matrix multiply and return true; else return false.
 | 
			
		||||
//
 | 
			
		||||
// Note: this uses plain Eigen and not Eigen Tensor because it is more
 | 
			
		||||
// efficient.
 | 
			
		||||
template <typename T>
 | 
			
		||||
bool ExplicitVectorMatrixOptimization(
 | 
			
		||||
    const Tensor& a, const Tensor& b,
 | 
			
		||||
    const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair,
 | 
			
		||||
    Tensor* out) {
 | 
			
		||||
  if (out->dim_size(0) == 1) {
 | 
			
		||||
    if (dim_pair[0].second == 0) {
 | 
			
		||||
      // Note: this case is optimized in Eigen Tensors.
 | 
			
		||||
      return false;
 | 
			
		||||
    } else {
 | 
			
		||||
      auto out_v = ToEigenVector<T>(out);
 | 
			
		||||
      auto a_v = ToEigenVector<T>(a);
 | 
			
		||||
      auto b_m = ToEigenMatrix<T>(b);
 | 
			
		||||
      out_v.noalias() = b_m * a_v;
 | 
			
		||||
    }
 | 
			
		||||
    return true;
 | 
			
		||||
  } else if (out->dim_size(1) == 1) {
 | 
			
		||||
    auto out_v = ToEigenVector<T>(out);
 | 
			
		||||
    auto a_m = ToEigenMatrix<T>(a);
 | 
			
		||||
    auto b_v = ToEigenVector<T>(b);
 | 
			
		||||
    if (dim_pair[0].first == 0) {
 | 
			
		||||
      out_v.noalias() = a_m.transpose() * b_v;
 | 
			
		||||
    } else {
 | 
			
		||||
      out_v.noalias() = a_m * b_v;
 | 
			
		||||
    }
 | 
			
		||||
    return true;
 | 
			
		||||
  }
 | 
			
		||||
  return false;
 | 
			
		||||
}
 | 
			
		||||
// Half is not supported.
 | 
			
		||||
template <>
 | 
			
		||||
bool ExplicitVectorMatrixOptimization<Eigen::half>(
 | 
			
		||||
    const Tensor& a, const Tensor& b,
 | 
			
		||||
    const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair,
 | 
			
		||||
    Tensor* out) {
 | 
			
		||||
  return false;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename Device, typename T>
 | 
			
		||||
struct LaunchMatMulBase {
 | 
			
		||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 | 
			
		||||
  typedef se::blas::AlgorithmType AlgorithmType;
 | 
			
		||||
#else
 | 
			
		||||
  typedef int64 AlgorithmType;
 | 
			
		||||
#endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 | 
			
		||||
 | 
			
		||||
  static void launch(
 | 
			
		||||
      OpKernelContext* ctx, const Tensor& a, const Tensor& b,
 | 
			
		||||
      const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair,
 | 
			
		||||
      std::vector<AlgorithmType>* algorithms, bool use_autotune, Tensor* out) {
 | 
			
		||||
    // An explicit vector-matrix multiply is much better optimized than an
 | 
			
		||||
    // implicit one and this is a bottleneck during non-batched inference.
 | 
			
		||||
    bool was_vector = ExplicitVectorMatrixOptimization<T>(a, b, dim_pair, out);
 | 
			
		||||
    if (!was_vector) {
 | 
			
		||||
      functor::MatMulFunctor<Device, T>()(ctx->eigen_device<Device>(),
 | 
			
		||||
                                          out->matrix<T>(), a.matrix<T>(),
 | 
			
		||||
                                          b.matrix<T>(), dim_pair);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  static void GetBlasGemmAlgorithm(OpKernelConstruction* ctx,
 | 
			
		||||
                                   std::vector<int64>* algorithms,
 | 
			
		||||
                                   bool* algorithm_set_flag) {}
 | 
			
		||||
};
 | 
			
		||||
// On CPUs, we ignore USE_CUBLAS
 | 
			
		||||
template <typename T>
 | 
			
		||||
struct LaunchMatMulCPU : LaunchMatMulBase<CPUDevice, T> {};
 | 
			
		||||
 | 
			
		||||
template <typename T, bool USE_CUBLAS>
 | 
			
		||||
struct LaunchMatMul<CPUDevice, T, USE_CUBLAS> : public LaunchMatMulCPU<T> {};
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 | 
			
		||||
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
template <typename T>
 | 
			
		||||
struct LaunchBlasGemv {
 | 
			
		||||
  static void Compute(OpKernelContext* ctx, se::Stream* stream, bool trans,
 | 
			
		||||
                      uint64 m, uint64 n, const se::DeviceMemory<T>& a,
 | 
			
		||||
                      const se::DeviceMemory<T>& b, se::DeviceMemory<T>* c,
 | 
			
		||||
                      se::blas::ProfileResult* output_profile) {
 | 
			
		||||
    const auto blas_trans = trans ? se::blas::Transpose::kTranspose
 | 
			
		||||
                                  : se::blas::Transpose::kNoTranspose;
 | 
			
		||||
    if (output_profile == nullptr) {
 | 
			
		||||
      bool blas_launch_status =
 | 
			
		||||
          stream
 | 
			
		||||
              ->ThenBlasGemv(blas_trans, m, n, static_cast<T>(1.0), a, m, b, 1,
 | 
			
		||||
                             static_cast<T>(0.0), c, 1)
 | 
			
		||||
              .ok();
 | 
			
		||||
      if (!blas_launch_status) {
 | 
			
		||||
        ctx->SetStatus(
 | 
			
		||||
            errors::Internal("Blas GEMV launch failed:  m=", m, ", n=", n));
 | 
			
		||||
      }
 | 
			
		||||
    } else {
 | 
			
		||||
      bool blas_launch_status =
 | 
			
		||||
          stream
 | 
			
		||||
              ->ThenBlasGemvWithProfiling(blas_trans, m, n, static_cast<T>(1.0),
 | 
			
		||||
                                          a, m, b, 1, static_cast<T>(0.0), c, 1,
 | 
			
		||||
                                          output_profile)
 | 
			
		||||
              .ok();
 | 
			
		||||
      if (!blas_launch_status) {
 | 
			
		||||
        ctx->SetStatus(errors::Internal(
 | 
			
		||||
            "Blas GEMV with profiling launch failed:  m=", m, ", n=", n));
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  static bool IsSupported() { return true; }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
void LaunchBlasGemv<Eigen::half>::Compute(
 | 
			
		||||
    OpKernelContext* ctx, se::Stream* stream, bool trans, uint64 m, uint64 n,
 | 
			
		||||
    const se::DeviceMemory<Eigen::half>& a,
 | 
			
		||||
    const se::DeviceMemory<Eigen::half>& b, se::DeviceMemory<Eigen::half>* c,
 | 
			
		||||
    se::blas::ProfileResult* output_profile) {
 | 
			
		||||
  ctx->SetStatus(errors::Internal(
 | 
			
		||||
      "Blas GEMV launch failed: GEMV is not implemented for float16."));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
bool LaunchBlasGemv<Eigen::half>::IsSupported() {
 | 
			
		||||
  return false;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename T>
 | 
			
		||||
bool ShouldUseGemv(uint64 n) {
 | 
			
		||||
  return (LaunchBlasGemv<T>::IsSupported() && n == 1);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace
 | 
			
		||||
 | 
			
		||||
bool GetCublasAutotuneComputationType(const DataType& dtype,
 | 
			
		||||
                                      se::blas::ComputationType* compute_type) {
 | 
			
		||||
  using se::blas::ComputationType;
 | 
			
		||||
  switch (dtype) {
 | 
			
		||||
    case DT_HALF:
 | 
			
		||||
    case DT_BFLOAT16:
 | 
			
		||||
      static bool use_f32_for_f16_computation =
 | 
			
		||||
          MatmulDoFP32ComputationFP16Input();
 | 
			
		||||
      if (use_f32_for_f16_computation) {
 | 
			
		||||
        *compute_type = ComputationType::kF32;
 | 
			
		||||
      } else {
 | 
			
		||||
        *compute_type = ComputationType::kF16;
 | 
			
		||||
      }
 | 
			
		||||
      return false;
 | 
			
		||||
    case DT_FLOAT:
 | 
			
		||||
      *compute_type = ComputationType::kF32;
 | 
			
		||||
      return true;
 | 
			
		||||
    case DT_DOUBLE:
 | 
			
		||||
      *compute_type = ComputationType::kF64;
 | 
			
		||||
      return true;
 | 
			
		||||
    default:
 | 
			
		||||
      // Unsupported compute_type, return false.
 | 
			
		||||
      return false;
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// A dummy type to group matmul autotune results together.
 | 
			
		||||
struct MatmulAutoTuneGroup {
 | 
			
		||||
  static string name() { return "Matmul"; }
 | 
			
		||||
};
 | 
			
		||||
typedef AutoTuneSingleton<MatmulAutoTuneGroup, MatmulParameters,
 | 
			
		||||
                          se::blas::AlgorithmConfig>
 | 
			
		||||
    AutoTuneMatmul;
 | 
			
		||||
 | 
			
		||||
template <typename T>
 | 
			
		||||
struct LaunchMatMul<GPUDevice, T, true /* USE_CUBLAS */> {
 | 
			
		||||
  static void launch(
 | 
			
		||||
      OpKernelContext* ctx, const Tensor& a, const Tensor& b,
 | 
			
		||||
      const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair,
 | 
			
		||||
      std::vector<int64>* algorithms, bool use_autotune, Tensor* out) {
 | 
			
		||||
    using se::blas::AlgorithmConfig;
 | 
			
		||||
    using se::blas::ComputationType;
 | 
			
		||||
    using se::blas::kDefaultAlgorithm;
 | 
			
		||||
    using se::blas::kDefaultBlasGemm;
 | 
			
		||||
    using se::blas::kDefaultBlasGemv;
 | 
			
		||||
    using se::blas::kNoAlgorithm;
 | 
			
		||||
    using se::blas::ProfileResult;
 | 
			
		||||
    using se::blas::Transpose;
 | 
			
		||||
    Transpose trans[] = {Transpose::kNoTranspose, Transpose::kTranspose};
 | 
			
		||||
    const uint64 m = a.dim_size(1 - dim_pair[0].first);
 | 
			
		||||
    const uint64 k = a.dim_size(dim_pair[0].first);
 | 
			
		||||
    const uint64 n = b.dim_size(1 - dim_pair[0].second);
 | 
			
		||||
    bool transpose_a = dim_pair[0].first == 0;
 | 
			
		||||
    bool transpose_b = dim_pair[0].second == 1;
 | 
			
		||||
    auto blas_transpose_a = trans[transpose_a];
 | 
			
		||||
    auto blas_transpose_b = trans[transpose_b];
 | 
			
		||||
 | 
			
		||||
    auto* stream = ctx->op_device_context()->stream();
 | 
			
		||||
    OP_REQUIRES(ctx, stream, errors::Internal("No GPU stream available."));
 | 
			
		||||
 | 
			
		||||
    auto a_ptr = AsDeviceMemory(a.template flat<T>().data(),
 | 
			
		||||
                                a.template flat<T>().size());
 | 
			
		||||
    auto b_ptr = AsDeviceMemory(b.template flat<T>().data(),
 | 
			
		||||
                                b.template flat<T>().size());
 | 
			
		||||
    auto c_ptr = AsDeviceMemory(out->template flat<T>().data(),
 | 
			
		||||
                                out->template flat<T>().size());
 | 
			
		||||
    auto alpha = static_cast<T>(1.0);
 | 
			
		||||
    auto beta = static_cast<T>(0.0);
 | 
			
		||||
 | 
			
		||||
    int device_id = stream->parent()->device_ordinal();
 | 
			
		||||
    DataType dtype = a.dtype();
 | 
			
		||||
    MatmulParameters matmul_parameters = {
 | 
			
		||||
        transpose_a, transpose_b, m, n, k, dtype, device_id,
 | 
			
		||||
    };
 | 
			
		||||
    AlgorithmConfig algorithm_config(kNoAlgorithm);
 | 
			
		||||
 | 
			
		||||
    ComputationType computation_type;
 | 
			
		||||
    bool compute_type_supported =
 | 
			
		||||
        GetCublasAutotuneComputationType(dtype, &computation_type);
 | 
			
		||||
    if (use_autotune && compute_type_supported && !algorithms->empty()) {
 | 
			
		||||
      ProfileResult best_result;
 | 
			
		||||
      // TODO(yangzihao): Unify this code with conv autotuning.
 | 
			
		||||
      if (!AutoTuneMatmul::GetInstance()->Find(matmul_parameters,
 | 
			
		||||
                                               &algorithm_config)) {
 | 
			
		||||
        ProfileResult profile_result;
 | 
			
		||||
        for (auto profile_algorithm : (*algorithms)) {
 | 
			
		||||
          // Cublas does
 | 
			
		||||
          // C = A x B
 | 
			
		||||
          // where A, B and C are assumed to be in column major.
 | 
			
		||||
          // We want the output to be in row-major, so we can compute
 | 
			
		||||
          // C' = B' x A' (' stands for transpose)
 | 
			
		||||
          bool cublas_launch_status =
 | 
			
		||||
              stream
 | 
			
		||||
                  ->ThenBlasGemmWithAlgorithm(
 | 
			
		||||
                      blas_transpose_b, blas_transpose_a, n, m, k, alpha, b_ptr,
 | 
			
		||||
                      transpose_b ? k : n, a_ptr, transpose_a ? m : k, beta,
 | 
			
		||||
                      &c_ptr, n, computation_type, profile_algorithm,
 | 
			
		||||
                      &profile_result)
 | 
			
		||||
                  .ok();
 | 
			
		||||
          if (cublas_launch_status) {
 | 
			
		||||
            if (profile_result.is_valid()) {
 | 
			
		||||
              if (profile_result.elapsed_time_in_ms() <
 | 
			
		||||
                  best_result.elapsed_time_in_ms()) {
 | 
			
		||||
                best_result = profile_result;
 | 
			
		||||
              }
 | 
			
		||||
            }
 | 
			
		||||
          }
 | 
			
		||||
        }
 | 
			
		||||
        // Try BlasGemmWithProfiling
 | 
			
		||||
        bool cublas_launch_status =
 | 
			
		||||
            stream
 | 
			
		||||
                ->ThenBlasGemmWithProfiling(
 | 
			
		||||
                    blas_transpose_b, blas_transpose_a, n, m, k, 1.0, b_ptr,
 | 
			
		||||
                    transpose_b ? k : n, a_ptr, transpose_a ? m : k, 0.0,
 | 
			
		||||
                    &c_ptr, n, &profile_result)
 | 
			
		||||
                .ok();
 | 
			
		||||
        if (cublas_launch_status) {
 | 
			
		||||
          if (profile_result.is_valid()) {
 | 
			
		||||
            if (profile_result.elapsed_time_in_ms() <
 | 
			
		||||
                best_result.elapsed_time_in_ms()) {
 | 
			
		||||
              best_result = profile_result;
 | 
			
		||||
            }
 | 
			
		||||
          }
 | 
			
		||||
        }
 | 
			
		||||
        // Try BlasGemvWithProfiling
 | 
			
		||||
        if (ShouldUseGemv<T>(n)) {
 | 
			
		||||
          LaunchBlasGemv<T>::Compute(ctx, stream, !transpose_a,
 | 
			
		||||
                                     transpose_a ? m : k, transpose_a ? k : m,
 | 
			
		||||
                                     a_ptr, b_ptr, &c_ptr, &profile_result);
 | 
			
		||||
          if (profile_result.is_valid()) {
 | 
			
		||||
            if (profile_result.elapsed_time_in_ms() <
 | 
			
		||||
                best_result.elapsed_time_in_ms()) {
 | 
			
		||||
              best_result = profile_result;
 | 
			
		||||
            }
 | 
			
		||||
          }
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
      // We make sure that each matmul parameter set only gets one pass of
 | 
			
		||||
      // autotune. If the best result is found, assign it to algorithm_type
 | 
			
		||||
      // and insert it to autotune map. If all internal kernels of
 | 
			
		||||
      // cublasGemmEx() returns invalid results, we add kNoAlgorithm to the
 | 
			
		||||
      // autotune map.
 | 
			
		||||
      if (best_result.is_valid()) {
 | 
			
		||||
        algorithm_config.set_algorithm(best_result.algorithm());
 | 
			
		||||
      }
 | 
			
		||||
      AutoTuneMatmul::GetInstance()->Insert(matmul_parameters,
 | 
			
		||||
                                            algorithm_config);
 | 
			
		||||
      if (algorithm_config.algorithm() != kNoAlgorithm &&
 | 
			
		||||
          algorithm_config.algorithm() != kDefaultBlasGemm &&
 | 
			
		||||
          algorithm_config.algorithm() != kDefaultBlasGemv) {
 | 
			
		||||
        bool cublas_launch_status =
 | 
			
		||||
            stream
 | 
			
		||||
                ->ThenBlasGemmWithAlgorithm(
 | 
			
		||||
                    blas_transpose_b, blas_transpose_a, n, m, k, alpha, b_ptr,
 | 
			
		||||
                    transpose_b ? k : n, a_ptr, transpose_a ? m : k, beta,
 | 
			
		||||
                    &c_ptr, n, computation_type, algorithm_config.algorithm(),
 | 
			
		||||
                    nullptr)
 | 
			
		||||
                .ok();
 | 
			
		||||
        if (!cublas_launch_status) {
 | 
			
		||||
          ctx->SetStatus(errors::Internal(
 | 
			
		||||
              "Blas GEMM with algorithm launch failed : a.shape=(",
 | 
			
		||||
              a.dim_size(0), ", ", a.dim_size(1), "), b.shape=(", b.dim_size(0),
 | 
			
		||||
              ", ", b.dim_size(1), "), m=", m, ", n=", n, ", k=", k));
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
    // For the following case, we use normal BlasGemm():
 | 
			
		||||
    //  1) We didn't set the use_autotune flag;
 | 
			
		||||
    //  2) compute type does not support autotune;
 | 
			
		||||
    //  3) no algorithm is found;
 | 
			
		||||
    //  4) all internal kernels in autotune return invalid results.
 | 
			
		||||
    //  For the following case, we use normal BlasGemv():
 | 
			
		||||
    //  1) We didn't set the use_autotune flag but LaunchBlasGemv is supported
 | 
			
		||||
    //     and n == 1.
 | 
			
		||||
    //  2) We set the use_autotune flag and it picked up BlasGemv() and set the
 | 
			
		||||
    //     algorithm_config.algorithm() to be kDefaultBlasGemv.
 | 
			
		||||
    if (!use_autotune || !compute_type_supported || algorithms->empty() ||
 | 
			
		||||
        algorithm_config.algorithm() == kNoAlgorithm ||
 | 
			
		||||
        algorithm_config.algorithm() == kDefaultBlasGemm ||
 | 
			
		||||
        algorithm_config.algorithm() == kDefaultBlasGemv) {
 | 
			
		||||
      if (algorithm_config.algorithm() == kDefaultBlasGemv ||
 | 
			
		||||
          ShouldUseGemv<T>(n)) {
 | 
			
		||||
        // This is a matrix*vector multiply so use GEMV to compute A * b.
 | 
			
		||||
        // Here we are multiplying in the natural order, so we have to flip
 | 
			
		||||
        // the transposition flag to compensate for the tensor being stored
 | 
			
		||||
        // row-major.
 | 
			
		||||
        // TODO(yangzihao): Add Gemv as an autotuning option too.
 | 
			
		||||
        LaunchBlasGemv<T>::Compute(ctx, stream, !transpose_a,
 | 
			
		||||
                                   transpose_a ? m : k, transpose_a ? k : m,
 | 
			
		||||
                                   a_ptr, b_ptr, &c_ptr, nullptr);
 | 
			
		||||
      } else {
 | 
			
		||||
        // Use C' = B' x A' (' stands for transpose)
 | 
			
		||||
        bool blas_launch_status =
 | 
			
		||||
            stream
 | 
			
		||||
                ->ThenBlasGemm(blas_transpose_b, blas_transpose_a, n, m, k,
 | 
			
		||||
                               1.0f, b_ptr, transpose_b ? k : n, a_ptr,
 | 
			
		||||
                               transpose_a ? m : k, 0.0f, &c_ptr, n)
 | 
			
		||||
                .ok();
 | 
			
		||||
        if (!blas_launch_status) {
 | 
			
		||||
          ctx->SetStatus(errors::Internal(
 | 
			
		||||
              "Blas GEMM launch failed : a.shape=(", a.dim_size(0), ", ",
 | 
			
		||||
              a.dim_size(1), "), b.shape=(", b.dim_size(0), ", ", b.dim_size(1),
 | 
			
		||||
              "), m=", m, ", n=", n, ", k=", k));
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  static void GetBlasGemmAlgorithm(OpKernelConstruction* ctx,
 | 
			
		||||
                                   std::vector<int64>* algorithms,
 | 
			
		||||
                                   bool* algorithm_set_flag) {
 | 
			
		||||
    if (*algorithm_set_flag == false) {
 | 
			
		||||
      auto* stream = ctx->device()->tensorflow_gpu_device_info()->stream;
 | 
			
		||||
      stream->parent()->GetBlasGemmAlgorithms(algorithms);
 | 
			
		||||
      *algorithm_set_flag = true;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
#endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 | 
			
		||||
 | 
			
		||||
template <typename Device, typename T, bool USE_CUBLAS>
 | 
			
		||||
class MatMulOp : public OpKernel {
 | 
			
		||||
 public:
 | 
			
		||||
  explicit MatMulOp(OpKernelConstruction* ctx)
 | 
			
		||||
      : OpKernel(ctx), algorithms_set_already_(false) {
 | 
			
		||||
    OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_a", &transpose_a_));
 | 
			
		||||
    OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_b", &transpose_b_));
 | 
			
		||||
 | 
			
		||||
    LaunchMatMul<Device, T, USE_CUBLAS>::GetBlasGemmAlgorithm(
 | 
			
		||||
        ctx, &algorithms_, &algorithms_set_already_);
 | 
			
		||||
    use_autotune_ = MatmulAutotuneEnable();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  void Compute(OpKernelContext* ctx) override {
 | 
			
		||||
    const Tensor& a = ctx->input(0);
 | 
			
		||||
    const Tensor& b = ctx->input(1);
 | 
			
		||||
 | 
			
		||||
    // Check that the dimensions of the two matrices are valid.
 | 
			
		||||
    OP_REQUIRES(
 | 
			
		||||
        ctx, TensorShapeUtils::IsMatrix(a.shape()),
 | 
			
		||||
        errors::InvalidArgument("In[0] is not a matrix. Instead it has shape ",
 | 
			
		||||
                                a.shape().DebugString()));
 | 
			
		||||
    OP_REQUIRES(
 | 
			
		||||
        ctx, TensorShapeUtils::IsMatrix(b.shape()),
 | 
			
		||||
        errors::InvalidArgument("In[1] is not a matrix. Instead it has shape ",
 | 
			
		||||
                                b.shape().DebugString()));
 | 
			
		||||
    Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair;
 | 
			
		||||
    dim_pair[0].first = transpose_a_ ? 0 : 1;
 | 
			
		||||
    dim_pair[0].second = transpose_b_ ? 1 : 0;
 | 
			
		||||
 | 
			
		||||
    OP_REQUIRES(
 | 
			
		||||
        ctx, a.dim_size(dim_pair[0].first) == b.dim_size(dim_pair[0].second),
 | 
			
		||||
        errors::InvalidArgument(
 | 
			
		||||
            "Matrix size-incompatible: In[0]: ", a.shape().DebugString(),
 | 
			
		||||
            ", In[1]: ", b.shape().DebugString()));
 | 
			
		||||
    int a_dim_remaining = 1 - dim_pair[0].first;
 | 
			
		||||
    int b_dim_remaining = 1 - dim_pair[0].second;
 | 
			
		||||
    TensorShape out_shape(
 | 
			
		||||
        {a.dim_size(a_dim_remaining), b.dim_size(b_dim_remaining)});
 | 
			
		||||
    Tensor* out = nullptr;
 | 
			
		||||
    OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out));
 | 
			
		||||
 | 
			
		||||
    if (out->NumElements() == 0) {
 | 
			
		||||
      // If a has shape [0, x] or b has shape [x, 0], the output shape
 | 
			
		||||
      // is a 0-element matrix, so there is nothing to do.
 | 
			
		||||
      return;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    if (a.NumElements() == 0 && b.NumElements() == 0) {
 | 
			
		||||
      // If a has shape [x, 0] and b has shape [0, y], the
 | 
			
		||||
      // output shape is [x, y] where x and y are non-zero, so we fill
 | 
			
		||||
      // the output with zeros.
 | 
			
		||||
      functor::SetZeroFunctor<Device, T> f;
 | 
			
		||||
      f(ctx->eigen_device<Device>(), out->flat<T>());
 | 
			
		||||
      return;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    if (std::is_same<T, bfloat16>::value) {
 | 
			
		||||
      bool is_cpu = std::is_same<Device, CPUDevice>::value;
 | 
			
		||||
      OP_REQUIRES(ctx, is_cpu,
 | 
			
		||||
                  errors::Internal("bfloat16 matmul is not supported by GPU"));
 | 
			
		||||
      Tensor a_float, b_float, out_float;
 | 
			
		||||
      OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, a.shape(), &a_float));
 | 
			
		||||
      OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, b.shape(), &b_float));
 | 
			
		||||
      OP_REQUIRES_OK(ctx,
 | 
			
		||||
                     ctx->allocate_temp(DT_FLOAT, out->shape(), &out_float));
 | 
			
		||||
 | 
			
		||||
      // TODO: Avoid extra copy to make bfloat16 matmul efficient on CPU.
 | 
			
		||||
      BFloat16ToFloat(a.flat<bfloat16>().data(), a_float.flat<float>().data(),
 | 
			
		||||
                      a.NumElements());
 | 
			
		||||
      BFloat16ToFloat(b.flat<bfloat16>().data(), b_float.flat<float>().data(),
 | 
			
		||||
                      b.NumElements());
 | 
			
		||||
 | 
			
		||||
      LaunchMatMul<Device, float, USE_CUBLAS>::launch(
 | 
			
		||||
          ctx, a_float, b_float, dim_pair, &algorithms_, use_autotune_,
 | 
			
		||||
          &out_float);
 | 
			
		||||
      FloatToBFloat16(out_float.flat<float>().data(),
 | 
			
		||||
                      out->flat<bfloat16>().data(), out->NumElements());
 | 
			
		||||
    } else {
 | 
			
		||||
      LaunchMatMul<Device, T, USE_CUBLAS>::launch(
 | 
			
		||||
          ctx, a, b, dim_pair, &algorithms_, use_autotune_, out);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
 private:
 | 
			
		||||
  std::vector<int64> algorithms_;
 | 
			
		||||
  bool algorithms_set_already_;
 | 
			
		||||
  bool use_autotune_;
 | 
			
		||||
  bool transpose_a_;
 | 
			
		||||
  bool transpose_b_;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
namespace functor {
 | 
			
		||||
 | 
			
		||||
// Partial specialization MatMulFunctor<Device=CPUDevice, T>.
 | 
			
		||||
template <typename T>
 | 
			
		||||
struct MatMulFunctor<CPUDevice, T> {
 | 
			
		||||
  void operator()(
 | 
			
		||||
      const CPUDevice& d, typename MatMulTypes<T>::out_type out,
 | 
			
		||||
      typename MatMulTypes<T>::in_type in0,
 | 
			
		||||
      typename MatMulTypes<T>::in_type in1,
 | 
			
		||||
      const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair) {
 | 
			
		||||
    MatMul<CPUDevice>(d, out, in0, in1, dim_pair);
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
}  // end namespace functor
 | 
			
		||||
 | 
			
		||||
#define REGISTER_CPU_EIGEN(T)                                                  \
 | 
			
		||||
  REGISTER_KERNEL_BUILDER(                                                     \
 | 
			
		||||
      Name("MatMul").Device(DEVICE_CPU).TypeConstraint<T>("T").Label("eigen"), \
 | 
			
		||||
      MatMulOp<CPUDevice, T, false /* cublas, ignored for CPU */>);
 | 
			
		||||
 | 
			
		||||
#define REGISTER_CPU(T)                                             \
 | 
			
		||||
  REGISTER_KERNEL_BUILDER(                                          \
 | 
			
		||||
      Name("MatMul").Device(DEVICE_CPU).TypeConstraint<T>("T"),     \
 | 
			
		||||
      MatMulOp<CPUDevice, T, false /* cublas, ignored for CPU */>); \
 | 
			
		||||
  REGISTER_CPU_EIGEN(T);
 | 
			
		||||
 | 
			
		||||
#define REGISTER_GPU(T)                                            \
 | 
			
		||||
  REGISTER_KERNEL_BUILDER(                                         \
 | 
			
		||||
      Name("MatMul").Device(DEVICE_GPU).TypeConstraint<T>("T"),    \
 | 
			
		||||
      MatMulOp<GPUDevice, T, true /* cublas, true by default */>); \
 | 
			
		||||
  REGISTER_KERNEL_BUILDER(Name("MatMul")                           \
 | 
			
		||||
                              .Device(DEVICE_GPU)                  \
 | 
			
		||||
                              .TypeConstraint<T>("T")              \
 | 
			
		||||
                              .Label("cublas"),                    \
 | 
			
		||||
                          MatMulOp<GPUDevice, T, true /* cublas */>)
 | 
			
		||||
 | 
			
		||||
TF_CALL_int32(REGISTER_CPU);
 | 
			
		||||
TF_CALL_int64(REGISTER_CPU);
 | 
			
		||||
TF_CALL_FLOAT_TYPES(REGISTER_CPU);
 | 
			
		||||
TF_CALL_COMPLEX_TYPES(REGISTER_CPU);
 | 
			
		||||
 | 
			
		||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 | 
			
		||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
 | 
			
		||||
TF_CALL_COMPLEX_TYPES(REGISTER_GPU);
 | 
			
		||||
#endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 | 
			
		||||
 | 
			
		||||
}  // namespace tensorflow
 | 
			
		||||
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
#include "tensorflow/core/kernels/batch_matmul_op_impl.h"
 | 
			
		||||
#include "tensorflow/core/kernels/matmul_op_impl.h"
 | 
			
		||||
 | 
			
		||||
namespace tensorflow {
 | 
			
		||||
 | 
			
		||||
@ -15,8 +15,8 @@ limitations under the License.
 | 
			
		||||
 | 
			
		||||
// See docs in ../ops/math_ops.cc.
 | 
			
		||||
 | 
			
		||||
#ifndef TENSORFLOW_CORE_KERNELS_BATCH_MATMUL_OP_IMPL_H_
 | 
			
		||||
#define TENSORFLOW_CORE_KERNELS_BATCH_MATMUL_OP_IMPL_H_
 | 
			
		||||
#ifndef TENSORFLOW_CORE_KERNELS_MATMUL_OP_IMPL_H_
 | 
			
		||||
#define TENSORFLOW_CORE_KERNELS_MATMUL_OP_IMPL_H_
 | 
			
		||||
 | 
			
		||||
#define EIGEN_USE_THREADS
 | 
			
		||||
 | 
			
		||||
@ -633,10 +633,21 @@ struct LaunchBatchMatMul<GPUDevice, Eigen::half> {
 | 
			
		||||
template <typename Device, typename Scalar>
 | 
			
		||||
class BaseBatchMatMulOp : public OpKernel {
 | 
			
		||||
 public:
 | 
			
		||||
  explicit BaseBatchMatMulOp(OpKernelConstruction* context)
 | 
			
		||||
  explicit BaseBatchMatMulOp(OpKernelConstruction* context,
 | 
			
		||||
                             bool is_legacy_matmul)
 | 
			
		||||
      : OpKernel(context) {
 | 
			
		||||
    OP_REQUIRES_OK(context, context->GetAttr("adj_x", &adj_x_));
 | 
			
		||||
    OP_REQUIRES_OK(context, context->GetAttr("adj_y", &adj_y_));
 | 
			
		||||
    if (is_legacy_matmul) {
 | 
			
		||||
      // The old MatMul kernel has "transpose_a/transpose_b" attributes.
 | 
			
		||||
      OP_REQUIRES_OK(context, context->GetAttr("transpose_a", &trans_x_));
 | 
			
		||||
      OP_REQUIRES_OK(context, context->GetAttr("transpose_b", &trans_y_));
 | 
			
		||||
      adj_x_ = false;
 | 
			
		||||
      adj_y_ = false;
 | 
			
		||||
    } else {
 | 
			
		||||
      OP_REQUIRES_OK(context, context->GetAttr("adj_x", &adj_x_));
 | 
			
		||||
      OP_REQUIRES_OK(context, context->GetAttr("adj_y", &adj_y_));
 | 
			
		||||
      trans_x_ = false;
 | 
			
		||||
      trans_y_ = false;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  ~BaseBatchMatMulOp() override {}
 | 
			
		||||
@ -672,8 +683,8 @@ class BaseBatchMatMulOp : public OpKernel {
 | 
			
		||||
        in1_reshaped.CopyFrom(in1, TensorShape({bcast.y_batch_size(), d2, d3})),
 | 
			
		||||
        errors::Internal("Failed to reshape In[1] from ",
 | 
			
		||||
                         in1.shape().DebugString()));
 | 
			
		||||
    if (adj_x_) std::swap(d0, d1);
 | 
			
		||||
    if (adj_y_) std::swap(d2, d3);
 | 
			
		||||
    if (adj_x_ || trans_x_) std::swap(d0, d1);
 | 
			
		||||
    if (adj_y_ || trans_y_) std::swap(d2, d3);
 | 
			
		||||
    OP_REQUIRES(ctx, d1 == d2,
 | 
			
		||||
                errors::InvalidArgument(
 | 
			
		||||
                    "In[0] mismatch In[1] shape: ", d1, " vs. ", d2, ": ",
 | 
			
		||||
@ -696,9 +707,36 @@ class BaseBatchMatMulOp : public OpKernel {
 | 
			
		||||
                out_reshaped.CopyFrom(*out, TensorShape({batch_size, d0, d3})),
 | 
			
		||||
                errors::Internal("Failed to reshape output from ",
 | 
			
		||||
                                 out->shape().DebugString()));
 | 
			
		||||
    LaunchBatchMatMul<Device, Scalar>::Launch(
 | 
			
		||||
        ctx, in0_reshaped, in1_reshaped, adj_x_, adj_y_, /*trans_x=*/false,
 | 
			
		||||
        /*trans_y=*/false, bcast, &out_reshaped);
 | 
			
		||||
    if (std::is_same<Scalar, bfloat16>::value) {
 | 
			
		||||
      bool is_cpu = std::is_same<Device, CPUDevice>::value;
 | 
			
		||||
      OP_REQUIRES(ctx, is_cpu,
 | 
			
		||||
                  errors::Internal("bfloat16 matmul is not supported by GPU"));
 | 
			
		||||
      Tensor in0_reshaped_float, in1_reshaped_float, out_reshaped_float;
 | 
			
		||||
      OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, in0_reshaped.shape(),
 | 
			
		||||
                                             &in0_reshaped_float));
 | 
			
		||||
      OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, in1_reshaped.shape(),
 | 
			
		||||
                                             &in1_reshaped_float));
 | 
			
		||||
      OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, out_reshaped.shape(),
 | 
			
		||||
                                             &out_reshaped_float));
 | 
			
		||||
 | 
			
		||||
      // TODO: Avoid extra copy to make bfloat16 matmul efficient on CPU.
 | 
			
		||||
      BFloat16ToFloat(in0_reshaped.flat<bfloat16>().data(),
 | 
			
		||||
                      in0_reshaped_float.flat<float>().data(),
 | 
			
		||||
                      in0_reshaped.NumElements());
 | 
			
		||||
      BFloat16ToFloat(in1_reshaped.flat<bfloat16>().data(),
 | 
			
		||||
                      in1_reshaped_float.flat<float>().data(),
 | 
			
		||||
                      in1_reshaped.NumElements());
 | 
			
		||||
 | 
			
		||||
      LaunchBatchMatMul<Device, float>::Launch(
 | 
			
		||||
          ctx, in0_reshaped_float, in1_reshaped_float, adj_x_, adj_y_, trans_x_,
 | 
			
		||||
          trans_y_, bcast, &out_reshaped_float);
 | 
			
		||||
      FloatToBFloat16(out_reshaped_float.flat<float>().data(),
 | 
			
		||||
                      out_reshaped.flat<bfloat16>().data(), out->NumElements());
 | 
			
		||||
    } else {
 | 
			
		||||
      LaunchBatchMatMul<Device, Scalar>::Launch(ctx, in0_reshaped, in1_reshaped,
 | 
			
		||||
                                                adj_x_, adj_y_, trans_x_,
 | 
			
		||||
                                                trans_y_, bcast, &out_reshaped);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
 protected:
 | 
			
		||||
@ -706,16 +744,19 @@ class BaseBatchMatMulOp : public OpKernel {
 | 
			
		||||
                                    const Tensor& in1) = 0;
 | 
			
		||||
 | 
			
		||||
 private:
 | 
			
		||||
  // TODO(171979567) Make the ops take both adj and transpose attributes.
 | 
			
		||||
  bool adj_x_;
 | 
			
		||||
  bool adj_y_;
 | 
			
		||||
  bool trans_x_;
 | 
			
		||||
  bool trans_y_;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// BatchMatMul Op implementation which disallows broadcasting.
 | 
			
		||||
template <typename Device, typename Scalar>
 | 
			
		||||
template <typename Device, typename Scalar, bool is_legacy_matmul = false>
 | 
			
		||||
class BatchMatMulOp : public BaseBatchMatMulOp<Device, Scalar> {
 | 
			
		||||
 public:
 | 
			
		||||
  explicit BatchMatMulOp(OpKernelConstruction* context)
 | 
			
		||||
      : BaseBatchMatMulOp<Device, Scalar>(context) {}
 | 
			
		||||
      : BaseBatchMatMulOp<Device, Scalar>(context, is_legacy_matmul) {}
 | 
			
		||||
 | 
			
		||||
  ~BatchMatMulOp() override {}
 | 
			
		||||
 | 
			
		||||
@ -729,15 +770,21 @@ class BatchMatMulOp : public BaseBatchMatMulOp<Device, Scalar> {
 | 
			
		||||
                                        in0.shape().DebugString(), " vs. ",
 | 
			
		||||
                                        in1.shape().DebugString()));
 | 
			
		||||
    const int ndims = in0.dims();
 | 
			
		||||
    OP_REQUIRES(
 | 
			
		||||
        ctx, ndims >= 2,
 | 
			
		||||
        errors::InvalidArgument("In[0] and In[1] ndims must be >= 2: ", ndims));
 | 
			
		||||
    for (int i = 0; i < ndims - 2; ++i) {
 | 
			
		||||
      OP_REQUIRES(ctx, in0.dim_size(i) == in1.dim_size(i),
 | 
			
		||||
    if (is_legacy_matmul) {
 | 
			
		||||
      OP_REQUIRES(ctx, ndims == 2,
 | 
			
		||||
                  errors::InvalidArgument(
 | 
			
		||||
                      "In[0].dim(", i, ") and In[1].dim(", i,
 | 
			
		||||
                      ") must be the same: ", in0.shape().DebugString(), " vs ",
 | 
			
		||||
                      in1.shape().DebugString()));
 | 
			
		||||
                      "In[0] and In[1] ndims must be == 2: ", ndims));
 | 
			
		||||
    } else {
 | 
			
		||||
      OP_REQUIRES(ctx, ndims >= 2,
 | 
			
		||||
                  errors::InvalidArgument(
 | 
			
		||||
                      "In[0] and In[1] ndims must be >= 2: ", ndims));
 | 
			
		||||
      for (int i = 0; i < ndims - 2; ++i) {
 | 
			
		||||
        OP_REQUIRES(ctx, in0.dim_size(i) == in1.dim_size(i),
 | 
			
		||||
                    errors::InvalidArgument(
 | 
			
		||||
                        "In[0].dim(", i, ") and In[1].dim(", i,
 | 
			
		||||
                        ") must be the same: ", in0.shape().DebugString(),
 | 
			
		||||
                        " vs ", in1.shape().DebugString()));
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
@ -747,7 +794,8 @@ template <typename Device, typename Scalar>
 | 
			
		||||
class BatchMatMulV2Op : public BaseBatchMatMulOp<Device, Scalar> {
 | 
			
		||||
 public:
 | 
			
		||||
  explicit BatchMatMulV2Op(OpKernelConstruction* context)
 | 
			
		||||
      : BaseBatchMatMulOp<Device, Scalar>(context) {}
 | 
			
		||||
      : BaseBatchMatMulOp<Device, Scalar>(context,
 | 
			
		||||
                                          /* is_legacy_matmul= */ false) {}
 | 
			
		||||
 | 
			
		||||
  ~BatchMatMulV2Op() override {}
 | 
			
		||||
 | 
			
		||||
@ -771,7 +819,10 @@ class BatchMatMulV2Op : public BaseBatchMatMulOp<Device, Scalar> {
 | 
			
		||||
      BatchMatMulOp<CPUDevice, TYPE>);                                    \
 | 
			
		||||
  REGISTER_KERNEL_BUILDER(                                                \
 | 
			
		||||
      Name("BatchMatMulV2").Device(DEVICE_CPU).TypeConstraint<TYPE>("T"), \
 | 
			
		||||
      BatchMatMulV2Op<CPUDevice, TYPE>)
 | 
			
		||||
      BatchMatMulV2Op<CPUDevice, TYPE>);                                  \
 | 
			
		||||
  REGISTER_KERNEL_BUILDER(                                                \
 | 
			
		||||
      Name("MatMul").Device(DEVICE_CPU).TypeConstraint<TYPE>("T"),        \
 | 
			
		||||
      BatchMatMulOp<CPUDevice, TYPE, /* is_legacy_matmul=*/true>)
 | 
			
		||||
 | 
			
		||||
#define REGISTER_BATCH_MATMUL_GPU(TYPE)                                   \
 | 
			
		||||
  REGISTER_KERNEL_BUILDER(                                                \
 | 
			
		||||
@ -779,8 +830,11 @@ class BatchMatMulV2Op : public BaseBatchMatMulOp<Device, Scalar> {
 | 
			
		||||
      BatchMatMulOp<GPUDevice, TYPE>);                                    \
 | 
			
		||||
  REGISTER_KERNEL_BUILDER(                                                \
 | 
			
		||||
      Name("BatchMatMulV2").Device(DEVICE_GPU).TypeConstraint<TYPE>("T"), \
 | 
			
		||||
      BatchMatMulV2Op<GPUDevice, TYPE>)
 | 
			
		||||
      BatchMatMulV2Op<GPUDevice, TYPE>);                                  \
 | 
			
		||||
  REGISTER_KERNEL_BUILDER(                                                \
 | 
			
		||||
      Name("MatMul").Device(DEVICE_GPU).TypeConstraint<TYPE>("T"),        \
 | 
			
		||||
      BatchMatMulOp<GPUDevice, TYPE, /* is_legacy_matmul=*/true>)
 | 
			
		||||
 | 
			
		||||
}  // namespace tensorflow
 | 
			
		||||
 | 
			
		||||
#endif  // TENSORFLOW_CORE_KERNELS_BATCH_MATMUL_OP_IMPL_H_
 | 
			
		||||
#endif  // TENSORFLOW_CORE_KERNELS_MATMUL_OP_IMPL_H_
 | 
			
		||||
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
#include "tensorflow/core/kernels/batch_matmul_op_impl.h"
 | 
			
		||||
#include "tensorflow/core/kernels/matmul_op_impl.h"
 | 
			
		||||
 | 
			
		||||
#if GOOGLE_CUDA
 | 
			
		||||
#include "third_party/gpus/cuda/include/cuda.h"
 | 
			
		||||
@ -21,17 +21,13 @@ limitations under the License.
 | 
			
		||||
 | 
			
		||||
namespace tensorflow {
 | 
			
		||||
 | 
			
		||||
TF_CALL_float(REGISTER_BATCH_MATMUL_CPU);
 | 
			
		||||
TF_CALL_double(REGISTER_BATCH_MATMUL_CPU);
 | 
			
		||||
TF_CALL_half(REGISTER_BATCH_MATMUL_CPU);
 | 
			
		||||
TF_CALL_FLOAT_TYPES(REGISTER_BATCH_MATMUL_CPU);
 | 
			
		||||
TF_CALL_int16(REGISTER_BATCH_MATMUL_CPU);
 | 
			
		||||
TF_CALL_int32(REGISTER_BATCH_MATMUL_CPU);
 | 
			
		||||
TF_CALL_int64(REGISTER_BATCH_MATMUL_CPU);
 | 
			
		||||
 | 
			
		||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 | 
			
		||||
TF_CALL_float(REGISTER_BATCH_MATMUL_GPU);
 | 
			
		||||
TF_CALL_double(REGISTER_BATCH_MATMUL_GPU);
 | 
			
		||||
TF_CALL_half(REGISTER_BATCH_MATMUL_GPU);
 | 
			
		||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_BATCH_MATMUL_GPU);
 | 
			
		||||
#endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 | 
			
		||||
 | 
			
		||||
}  // namespace tensorflow
 | 
			
		||||
@ -26,6 +26,7 @@ limitations under the License.
 | 
			
		||||
#include "tensorflow/core/public/session.h"
 | 
			
		||||
 | 
			
		||||
namespace tensorflow {
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
template <typename T>
 | 
			
		||||
class FusedMatMulOpTest : public OpsTestBase {
 | 
			
		||||
@ -459,4 +460,230 @@ BM_Matmul(2000, 1, 2000, true, false);
 | 
			
		||||
BM_Matmul(2000, 1, 2000, false, true);
 | 
			
		||||
BM_Matmul(2000, 1, 2000, true, true);
 | 
			
		||||
 | 
			
		||||
}  // end namespace tensorflow
 | 
			
		||||
// Benchmarks for batched matmul with broadcasting.
 | 
			
		||||
Node* BroadcastTo(Graph* g, Node* input, Node* shape) {
 | 
			
		||||
  Node* ret;
 | 
			
		||||
  TF_CHECK_OK(NodeBuilder(g->NewName("n"), "BroadcastTo")
 | 
			
		||||
                  .Input(input)
 | 
			
		||||
                  .Input(shape)
 | 
			
		||||
                  .Finalize(g, &ret));
 | 
			
		||||
  return ret;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Node* BatchMatmulV2(Graph* g, Node* in0, Node* in1, bool adj_x, bool adj_y) {
 | 
			
		||||
  Node* ret;
 | 
			
		||||
  TF_CHECK_OK(NodeBuilder(g->NewName("n"), "BatchMatMulV2")
 | 
			
		||||
                  .Input(in0)
 | 
			
		||||
                  .Input(in1)
 | 
			
		||||
                  .Attr("adj_x", adj_x)
 | 
			
		||||
                  .Attr("adj_y", adj_y)
 | 
			
		||||
                  .Finalize(g, &ret));
 | 
			
		||||
  return ret;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename T>
 | 
			
		||||
static Graph* BatchMatmul(int b, int m, int k, int n, bool adjoint_a,
 | 
			
		||||
                          bool adjoint_b, DataType type) {
 | 
			
		||||
  Graph* g = new Graph(OpRegistry::Global());
 | 
			
		||||
  Tensor in0(type, adjoint_a ? TensorShape({b, k, m}) : TensorShape({b, m, k}));
 | 
			
		||||
  in0.flat<T>().setRandom();
 | 
			
		||||
  Tensor in1(type, adjoint_b ? TensorShape({b, n, k}) : TensorShape({b, k, n}));
 | 
			
		||||
  in1.flat<T>().setRandom();
 | 
			
		||||
  test::graph::BatchMatmul(g, test::graph::Constant(g, in0),
 | 
			
		||||
                           test::graph::Constant(g, in1), adjoint_a, adjoint_b);
 | 
			
		||||
  return g;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename T>
 | 
			
		||||
static Graph* BatchMatmulWithBroadcast(int b0, int b1, int m, int k, int n,
 | 
			
		||||
                                       bool manual_broadcast, DataType type) {
 | 
			
		||||
  Graph* g = new Graph(OpRegistry::Global());
 | 
			
		||||
  Tensor in0(type, TensorShape({b0, m, k}));
 | 
			
		||||
  in0.flat<T>().setRandom();
 | 
			
		||||
  Tensor in1(type, TensorShape({b1, k, n}));
 | 
			
		||||
  in1.flat<T>().setRandom();
 | 
			
		||||
 | 
			
		||||
  Tensor broadcasted_in0_shape(DT_INT64, TensorShape({3}));
 | 
			
		||||
  Tensor broadcasted_in1_shape(DT_INT64, TensorShape({3}));
 | 
			
		||||
 | 
			
		||||
  Node* in0_node = nullptr;
 | 
			
		||||
  Node* in1_node = nullptr;
 | 
			
		||||
  if (manual_broadcast) {
 | 
			
		||||
    for (int i = 0; i < 3; ++i) {
 | 
			
		||||
      auto vec0 = broadcasted_in0_shape.vec<int64>();
 | 
			
		||||
      auto vec1 = broadcasted_in1_shape.vec<int64>();
 | 
			
		||||
      vec0(i) = (i == 0 ? std::max(b0, b1) : in0.shape().dim_size(i));
 | 
			
		||||
      vec1(i) = (i == 0 ? std::max(b0, b1) : in1.shape().dim_size(i));
 | 
			
		||||
    }
 | 
			
		||||
    in0_node = BroadcastTo(g, test::graph::Constant(g, in0),
 | 
			
		||||
                           test::graph::Constant(g, broadcasted_in0_shape));
 | 
			
		||||
    in1_node = BroadcastTo(g, test::graph::Constant(g, in1),
 | 
			
		||||
                           test::graph::Constant(g, broadcasted_in1_shape));
 | 
			
		||||
  } else {
 | 
			
		||||
    in0_node = test::graph::Constant(g, in0);
 | 
			
		||||
    in1_node = test::graph::Constant(g, in1);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  BatchMatmulV2(g, in0_node, in1_node, false, false);
 | 
			
		||||
  return g;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#define BM_BatchMatmulDev(B, M, K, N, TA, TB, T, TFTYPE, DEVICE)                  \
 | 
			
		||||
  static void                                                                     \
 | 
			
		||||
      BM_BatchMatmul##_##B##_##M##_##K##_##N##_##TA##_##TB##_##TFTYPE##_##DEVICE( \
 | 
			
		||||
          int iters) {                                                            \
 | 
			
		||||
    testing::UseRealTime();                                                       \
 | 
			
		||||
    testing::ItemsProcessed(static_cast<int64>(iters) * B * M * K * N * 2);       \
 | 
			
		||||
    test::Benchmark(#DEVICE, BatchMatmul<T>(B, M, K, N, TA, TB, TFTYPE))          \
 | 
			
		||||
        .Run(iters);                                                              \
 | 
			
		||||
  }                                                                               \
 | 
			
		||||
  BENCHMARK(                                                                      \
 | 
			
		||||
      BM_BatchMatmul##_##B##_##M##_##K##_##N##_##TA##_##TB##_##TFTYPE##_##DEVICE);
 | 
			
		||||
 | 
			
		||||
#define BM_BatchMatmul(B, M, K, N, TA, TB) \
 | 
			
		||||
  BM_BatchMatmulDev(B, M, K, N, TA, TB, float, DT_FLOAT, cpu);
 | 
			
		||||
// BM_BatchMatmulDev(B, M, K, N, TA, TB, std::complex<float>, DT_COMPLEX64,
 | 
			
		||||
// cpu);
 | 
			
		||||
//  BM_BatchMatmulDev(B, M, K, N, TA, TB, float, DT_FLOAT, gpu);
 | 
			
		||||
/* Uncomment to enable benchmarks for double & complex types: */
 | 
			
		||||
// BM_BatchMatmulDev(B, M, K, N, TA, TB, std::complex<float>, DT_COMPLEX64,
 | 
			
		||||
// gpu);
 | 
			
		||||
// BM_BatchMatmulDev(M, K, N, TA, TB, double, DT_DOUBLE, cpu); \
 | 
			
		||||
// BM_BatchMatmulDev(M, K, N, TA, TB, std::complex<double>, DT_COMPLEX128, cpu);
 | 
			
		||||
// \
 | 
			
		||||
// BM_BatchMatmulDev(M, K, N, TA, TB, double, DT_DOUBLE, gpu); \
 | 
			
		||||
// BM_BatchMatmulDev(M, K, N, TA, TB, std::complex<double>, DT_COMPLEX128, gpu);
 | 
			
		||||
 | 
			
		||||
// Macro arguments names: --------------------------------------------------- //
 | 
			
		||||
//   B1: batch size of LHS
 | 
			
		||||
//   B2: batch size of RHS
 | 
			
		||||
//    M: outer dimension of LHS
 | 
			
		||||
//    K: inner dimensions of LHS and RHS
 | 
			
		||||
//    N: outer dimension of RHS
 | 
			
		||||
//   MB: boolean indicating whether to use manual broadcasting
 | 
			
		||||
//    T: C++ type of scalars (e.g. float, std::complex)
 | 
			
		||||
//   TT: TensorFlow type of scalars (e.g. DT_FLOAT, DT_COMPLEX128
 | 
			
		||||
//    D: Device (e.g. cpu, gpu)
 | 
			
		||||
#define BM_BatchMatmulBCastDev(B1, B2, M, K, N, MB, T, TT, D)                  \
 | 
			
		||||
  static void                                                                  \
 | 
			
		||||
      BM_BatchMatmulBCast##_##B1##_##B2##_##M##_##K##_##N##_##MB##_##TT##_##D( \
 | 
			
		||||
          int iters) {                                                         \
 | 
			
		||||
    testing::UseRealTime();                                                    \
 | 
			
		||||
    testing::ItemsProcessed(static_cast<int64>(iters) * std::max(B1, B2) * M * \
 | 
			
		||||
                            K * N * 2);                                        \
 | 
			
		||||
    test::Benchmark(#D, BatchMatmulWithBroadcast<T>(B1, B2, M, K, N, MB, TT))  \
 | 
			
		||||
        .Run(iters);                                                           \
 | 
			
		||||
  }                                                                            \
 | 
			
		||||
  BENCHMARK(                                                                   \
 | 
			
		||||
      BM_BatchMatmulBCast##_##B1##_##B2##_##M##_##K##_##N##_##MB##_##TT##_##D);
 | 
			
		||||
 | 
			
		||||
#define BM_BatchMatmulBCast(B1, B2, M, K, N, MB) \
 | 
			
		||||
  BM_BatchMatmulBCastDev(B1, B2, M, K, N, MB, float, DT_FLOAT, cpu);
 | 
			
		||||
 | 
			
		||||
// Typical fully connected layers
 | 
			
		||||
BM_BatchMatmulBCast(1, 128, 1, 1024, 1024, true);
 | 
			
		||||
BM_BatchMatmulBCast(1, 128, 1, 1024, 1024, false);
 | 
			
		||||
BM_BatchMatmulBCast(128, 1, 1, 1024, 1024, true);
 | 
			
		||||
BM_BatchMatmulBCast(128, 1, 1, 1024, 1024, false);
 | 
			
		||||
BM_BatchMatmulBCast(1, 128, 128, 1024, 1024, true);
 | 
			
		||||
BM_BatchMatmulBCast(1, 128, 128, 1024, 1024, false);
 | 
			
		||||
BM_BatchMatmulBCast(128, 1, 128, 1024, 1024, true);
 | 
			
		||||
BM_BatchMatmulBCast(128, 1, 128, 1024, 1024, false);
 | 
			
		||||
 | 
			
		||||
// Square matmul.
 | 
			
		||||
BM_BatchMatmulBCast(1, 128, 512, 512, 512, true);
 | 
			
		||||
BM_BatchMatmulBCast(1, 128, 512, 512, 512, false);
 | 
			
		||||
BM_BatchMatmulBCast(128, 1, 512, 512, 512, true);
 | 
			
		||||
BM_BatchMatmulBCast(128, 1, 512, 512, 512, false);
 | 
			
		||||
BM_BatchMatmulBCast(1, 128, 1024, 1024, 1024, true);
 | 
			
		||||
BM_BatchMatmulBCast(1, 128, 1024, 1024, 1024, false);
 | 
			
		||||
BM_BatchMatmulBCast(128, 1, 1024, 1024, 1024, true);
 | 
			
		||||
BM_BatchMatmulBCast(128, 1, 1024, 1024, 1024, false);
 | 
			
		||||
 | 
			
		||||
// Matrix-vector multiplies.
 | 
			
		||||
BM_BatchMatmulBCast(1, 128, 10000, 200, 1, true);
 | 
			
		||||
BM_BatchMatmulBCast(1, 128, 10000, 200, 1, false);
 | 
			
		||||
BM_BatchMatmulBCast(128, 1, 10000, 200, 1, true);
 | 
			
		||||
BM_BatchMatmulBCast(128, 1, 10000, 200, 1, false);
 | 
			
		||||
 | 
			
		||||
// Vector-matrix multiplies.
 | 
			
		||||
BM_BatchMatmulBCast(1, 128, 1, 200, 10000, true);
 | 
			
		||||
BM_BatchMatmulBCast(1, 128, 1, 200, 10000, false);
 | 
			
		||||
BM_BatchMatmulBCast(128, 1, 1, 200, 10000, true);
 | 
			
		||||
BM_BatchMatmulBCast(128, 1, 1, 200, 10000, false);
 | 
			
		||||
 | 
			
		||||
// Typical fully connected layers
 | 
			
		||||
BM_BatchMatmul(1, 1, 1024, 1024, false, false);
 | 
			
		||||
BM_BatchMatmul(1, 8, 1024, 1024, false, false);
 | 
			
		||||
BM_BatchMatmul(1, 16, 1024, 1024, false, false);
 | 
			
		||||
BM_BatchMatmul(1, 128, 1024, 1024, false, false);
 | 
			
		||||
BM_BatchMatmul(2, 1, 1024, 1024, false, false);
 | 
			
		||||
BM_BatchMatmul(2, 8, 1024, 1024, false, false);
 | 
			
		||||
BM_BatchMatmul(2, 16, 1024, 1024, false, false);
 | 
			
		||||
BM_BatchMatmul(2, 128, 1024, 1024, false, false);
 | 
			
		||||
BM_BatchMatmul(8, 1, 1024, 1024, false, false);
 | 
			
		||||
BM_BatchMatmul(8, 8, 1024, 1024, false, false);
 | 
			
		||||
BM_BatchMatmul(8, 16, 1024, 1024, false, false);
 | 
			
		||||
BM_BatchMatmul(8, 128, 1024, 1024, false, false);
 | 
			
		||||
BM_BatchMatmul(32, 1, 1024, 1024, false, false);
 | 
			
		||||
BM_BatchMatmul(32, 8, 1024, 1024, false, false);
 | 
			
		||||
BM_BatchMatmul(32, 16, 1024, 1024, false, false);
 | 
			
		||||
BM_BatchMatmul(32, 128, 1024, 1024, false, false);
 | 
			
		||||
 | 
			
		||||
// Square matmul.
 | 
			
		||||
BM_BatchMatmul(1, 32, 32, 32, false, false);
 | 
			
		||||
BM_BatchMatmul(1, 128, 128, 128, false, false);
 | 
			
		||||
BM_BatchMatmul(1, 256, 256, 256, false, false);
 | 
			
		||||
BM_BatchMatmul(1, 1024, 1024, 1024, false, false);
 | 
			
		||||
BM_BatchMatmul(1, 2048, 2048, 2048, false, false);
 | 
			
		||||
BM_BatchMatmul(2, 32, 32, 32, false, false);
 | 
			
		||||
BM_BatchMatmul(2, 128, 128, 128, false, false);
 | 
			
		||||
BM_BatchMatmul(2, 256, 256, 256, false, false);
 | 
			
		||||
BM_BatchMatmul(2, 1024, 1024, 1024, false, false);
 | 
			
		||||
BM_BatchMatmul(2, 2048, 2048, 2048, false, false);
 | 
			
		||||
BM_BatchMatmul(4, 32, 32, 32, false, false);
 | 
			
		||||
BM_BatchMatmul(4, 128, 128, 128, false, false);
 | 
			
		||||
BM_BatchMatmul(4, 256, 256, 256, false, false);
 | 
			
		||||
BM_BatchMatmul(4, 1024, 1024, 1024, false, false);
 | 
			
		||||
BM_BatchMatmul(4, 2048, 2048, 2048, false, false);
 | 
			
		||||
BM_BatchMatmul(8, 32, 32, 32, false, false);
 | 
			
		||||
BM_BatchMatmul(8, 128, 128, 128, false, false);
 | 
			
		||||
BM_BatchMatmul(8, 256, 256, 256, false, false);
 | 
			
		||||
BM_BatchMatmul(8, 1024, 1024, 1024, false, false);
 | 
			
		||||
BM_BatchMatmul(8, 2048, 2048, 2048, false, false);
 | 
			
		||||
BM_BatchMatmul(32, 32, 32, 32, false, false);
 | 
			
		||||
BM_BatchMatmul(32, 128, 128, 128, false, false);
 | 
			
		||||
BM_BatchMatmul(32, 256, 256, 256, false, false);
 | 
			
		||||
BM_BatchMatmul(32, 1024, 1024, 1024, false, false);
 | 
			
		||||
BM_BatchMatmul(32, 2048, 2048, 2048, false, false);
 | 
			
		||||
 | 
			
		||||
// Matrix-vector multiplies.
 | 
			
		||||
BM_BatchMatmul(1, 10000, 200, 1, false, false);
 | 
			
		||||
BM_BatchMatmul(8, 10000, 200, 1, false, false);
 | 
			
		||||
BM_BatchMatmul(32, 10000, 200, 1, false, false);
 | 
			
		||||
BM_BatchMatmul(1, 10000, 200, 1, true, false);
 | 
			
		||||
BM_BatchMatmul(8, 10000, 200, 1, true, false);
 | 
			
		||||
BM_BatchMatmul(32, 10000, 200, 1, true, false);
 | 
			
		||||
BM_BatchMatmul(1, 10000, 200, 1, false, true);
 | 
			
		||||
BM_BatchMatmul(8, 10000, 200, 1, false, true);
 | 
			
		||||
BM_BatchMatmul(32, 10000, 200, 1, false, true);
 | 
			
		||||
BM_BatchMatmul(1, 10000, 200, 1, true, true);
 | 
			
		||||
BM_BatchMatmul(8, 10000, 200, 1, true, true);
 | 
			
		||||
BM_BatchMatmul(32, 10000, 200, 1, true, true);
 | 
			
		||||
 | 
			
		||||
// Vector-matrix multiplies.
 | 
			
		||||
BM_BatchMatmul(1, 1, 200, 10000, false, false);
 | 
			
		||||
BM_BatchMatmul(8, 1, 200, 10000, false, false);
 | 
			
		||||
BM_BatchMatmul(32, 1, 200, 10000, false, false);
 | 
			
		||||
BM_BatchMatmul(1, 1, 200, 10000, true, false);
 | 
			
		||||
BM_BatchMatmul(8, 1, 200, 10000, true, false);
 | 
			
		||||
BM_BatchMatmul(32, 1, 200, 10000, true, false);
 | 
			
		||||
BM_BatchMatmul(1, 1, 200, 10000, false, true);
 | 
			
		||||
BM_BatchMatmul(8, 1, 200, 10000, false, true);
 | 
			
		||||
BM_BatchMatmul(32, 1, 200, 10000, false, true);
 | 
			
		||||
BM_BatchMatmul(1, 1, 200, 10000, true, true);
 | 
			
		||||
BM_BatchMatmul(8, 1, 200, 10000, true, true);
 | 
			
		||||
BM_BatchMatmul(32, 1, 200, 10000, true, true);
 | 
			
		||||
 | 
			
		||||
}  // namespace
 | 
			
		||||
}  // namespace tensorflow
 | 
			
		||||
 | 
			
		||||
@ -33,8 +33,8 @@ limitations under the License.
 | 
			
		||||
#include "tensorflow/core/framework/tensor_shape.h"
 | 
			
		||||
#include "tensorflow/core/framework/type_traits.h"
 | 
			
		||||
#include "tensorflow/core/framework/types.h"
 | 
			
		||||
#include "tensorflow/core/kernels/batch_matmul_op_impl.h"
 | 
			
		||||
#include "tensorflow/core/kernels/fill_functor.h"
 | 
			
		||||
#include "tensorflow/core/kernels/matmul_op_impl.h"
 | 
			
		||||
#include "tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h"
 | 
			
		||||
#include "tensorflow/core/platform/logging.h"
 | 
			
		||||
#include "tensorflow/core/platform/types.h"
 | 
			
		||||
 | 
			
		||||
@ -1,5 +1,5 @@
 | 
			
		||||
func @Isinf_elem_type(%arg0: tensor<*xelem_type>)
 | 
			
		||||
    -> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} {
 | 
			
		||||
    -> tensor<*xi1> attributes {tf_entry, llvm.emit_c_interface} {
 | 
			
		||||
  %0 = "tf.IsInf"(%arg0) : (tensor<*xelem_type>) -> tensor<*xi1>
 | 
			
		||||
  return %0 : tensor<*xi1>
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -1,5 +1,5 @@
 | 
			
		||||
func @Isnan_elem_type(%arg0: tensor<*xelem_type>)
 | 
			
		||||
    -> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} {
 | 
			
		||||
    -> tensor<*xi1> attributes {tf_entry, llvm.emit_c_interface} {
 | 
			
		||||
  %0 = "tf.IsNan"(%arg0) : (tensor<*xelem_type>) -> tensor<*xi1>
 | 
			
		||||
  return %0 : tensor<*xi1>
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -40,11 +40,15 @@ static Graph* Multinomial(int batch_size, int num_classes, int num_samples) {
 | 
			
		||||
  return g;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#define BM_MultinomialDev(DEVICE, B, C, S)                           \
 | 
			
		||||
  static void BM_Multinomial_##DEVICE##_##B##_##C##_##S(int iters) { \
 | 
			
		||||
    test::Benchmark(#DEVICE, Multinomial(B, C, S)).Run(iters);       \
 | 
			
		||||
    testing::ItemsProcessed(static_cast<int64>(B) * C * S * iters);  \
 | 
			
		||||
  }                                                                  \
 | 
			
		||||
#define BM_MultinomialDev(DEVICE, B, C, S)                  \
 | 
			
		||||
  static void BM_Multinomial_##DEVICE##_##B##_##C##_##S(    \
 | 
			
		||||
      ::testing::benchmark::State& state) {                 \
 | 
			
		||||
    test::Benchmark(#DEVICE, Multinomial(B, C, S),          \
 | 
			
		||||
                    /*old_benchmark_api*/ false)            \
 | 
			
		||||
        .Run(state);                                        \
 | 
			
		||||
    state.SetItemsProcessed(static_cast<int64>(B) * C * S * \
 | 
			
		||||
                            state.iterations());            \
 | 
			
		||||
  }                                                         \
 | 
			
		||||
  BENCHMARK(BM_Multinomial_##DEVICE##_##B##_##C##_##S);
 | 
			
		||||
 | 
			
		||||
#define BM_MultinomialBCS(B, C, S) \
 | 
			
		||||
 | 
			
		||||
@ -103,18 +103,18 @@ enum CONV_OP {
 | 
			
		||||
 | 
			
		||||
}  // namespace
 | 
			
		||||
 | 
			
		||||
static void BM_ConvFloat(int iters, int batch, int rows, int cols, int in_depth,
 | 
			
		||||
                         int out_depth, int filter_rows, int filter_cols,
 | 
			
		||||
                         CONV_OP op, int num_threads, int stride,
 | 
			
		||||
                         Padding padding, bool use_gpu, DataType data_type,
 | 
			
		||||
static void BM_ConvFloat(::testing::benchmark::State& state, int batch,
 | 
			
		||||
                         int rows, int cols, int in_depth, int out_depth,
 | 
			
		||||
                         int filter_rows, int filter_cols, CONV_OP op,
 | 
			
		||||
                         int num_threads, int stride, Padding padding,
 | 
			
		||||
                         bool use_gpu, DataType data_type,
 | 
			
		||||
                         const string& label) {
 | 
			
		||||
  testing::StopTiming();
 | 
			
		||||
  if (!IsGoogleCudaEnabled() && use_gpu) {
 | 
			
		||||
    testing::SetLabel(
 | 
			
		||||
    state.SetLabel(
 | 
			
		||||
        strings::StrCat("Skipping GPU test (no --config=cuda): ", label));
 | 
			
		||||
    return;
 | 
			
		||||
  }
 | 
			
		||||
  testing::SetLabel(label);
 | 
			
		||||
  state.SetLabel(label);
 | 
			
		||||
 | 
			
		||||
  // Set the number of threads
 | 
			
		||||
  SessionOptions options;
 | 
			
		||||
@ -221,10 +221,10 @@ static void BM_ConvFloat(int iters, int batch, int rows, int cols, int in_depth,
 | 
			
		||||
  TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph, g));
 | 
			
		||||
 | 
			
		||||
  string device = use_gpu ? "gpu" : "cpu";
 | 
			
		||||
  testing::UseRealTime();
 | 
			
		||||
  testing::StartTiming();
 | 
			
		||||
  test::Benchmark(device, g, &options).Run(iters);
 | 
			
		||||
  testing::ItemsProcessed(num_ops * iters);
 | 
			
		||||
  test::Benchmark(device, g, &options, nullptr, nullptr, "",
 | 
			
		||||
                  /*old_benchmark_api*/ false)
 | 
			
		||||
      .Run(state);
 | 
			
		||||
  state.SetItemsProcessed(num_ops * state.iterations());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// BS: batch_size
 | 
			
		||||
@ -235,48 +235,52 @@ static void BM_ConvFloat(int iters, int batch, int rows, int cols, int in_depth,
 | 
			
		||||
// KR: kernel_rows
 | 
			
		||||
// KC: kernel_cols
 | 
			
		||||
#define BM_ConvFloatFwd(BS, R, C, ID, OD, KR, KC, STR, PAD, LABEL)             \
 | 
			
		||||
  static void BM_ConvFloatFwdCPU1_##LABEL(int iters) {                         \
 | 
			
		||||
    BM_ConvFloat(iters, BS, R, C, ID, OD, KR, KC, CONV_OP_FORWARD, 1, STR,     \
 | 
			
		||||
  static void BM_ConvFloatFwdCPU1_##LABEL(                                     \
 | 
			
		||||
      ::testing::benchmark::State& state) {                                    \
 | 
			
		||||
    BM_ConvFloat(state, BS, R, C, ID, OD, KR, KC, CONV_OP_FORWARD, 1, STR,     \
 | 
			
		||||
                 PAD, false, DT_FLOAT,                                         \
 | 
			
		||||
                 strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", OD, "_",    \
 | 
			
		||||
                                 KR, "_", KC, "_", STR, "_", PAD, "_f_cpu1")); \
 | 
			
		||||
  }                                                                            \
 | 
			
		||||
  static void BM_ConvFloatFwdCPU4_##LABEL(int iters) {                         \
 | 
			
		||||
    BM_ConvFloat(iters, BS, R, C, ID, OD, KR, KC, CONV_OP_FORWARD, 4, STR,     \
 | 
			
		||||
  static void BM_ConvFloatFwdCPU4_##LABEL(                                     \
 | 
			
		||||
      ::testing::benchmark::State& state) {                                    \
 | 
			
		||||
    BM_ConvFloat(state, BS, R, C, ID, OD, KR, KC, CONV_OP_FORWARD, 4, STR,     \
 | 
			
		||||
                 PAD, false, DT_FLOAT,                                         \
 | 
			
		||||
                 strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", OD, "_",    \
 | 
			
		||||
                                 KR, "_", KC, "_", STR, "_", PAD, "_f_cpu4")); \
 | 
			
		||||
  }                                                                            \
 | 
			
		||||
  static void BM_ConvFloatFusedCPU1_##LABEL(int iters) {                       \
 | 
			
		||||
    BM_ConvFloat(iters, BS, R, C, ID, OD, KR, KC, CONV_OP_FUSED, 1, STR, PAD,  \
 | 
			
		||||
  static void BM_ConvFloatFusedCPU1_##LABEL(                                   \
 | 
			
		||||
      ::testing::benchmark::State& state) {                                    \
 | 
			
		||||
    BM_ConvFloat(state, BS, R, C, ID, OD, KR, KC, CONV_OP_FUSED, 1, STR, PAD,  \
 | 
			
		||||
                 false, DT_FLOAT,                                              \
 | 
			
		||||
                 strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", OD, "_",    \
 | 
			
		||||
                                 KR, "_", KC, "_", STR, "_", PAD, "_f_cpu1")); \
 | 
			
		||||
  }                                                                            \
 | 
			
		||||
  static void BM_ConvFloatFusedCPU4_##LABEL(int iters) {                       \
 | 
			
		||||
    BM_ConvFloat(iters, BS, R, C, ID, OD, KR, KC, CONV_OP_FUSED, 4, STR, PAD,  \
 | 
			
		||||
  static void BM_ConvFloatFusedCPU4_##LABEL(                                   \
 | 
			
		||||
      ::testing::benchmark::State& state) {                                    \
 | 
			
		||||
    BM_ConvFloat(state, BS, R, C, ID, OD, KR, KC, CONV_OP_FUSED, 4, STR, PAD,  \
 | 
			
		||||
                 false, DT_FLOAT,                                              \
 | 
			
		||||
                 strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", OD, "_",    \
 | 
			
		||||
                                 KR, "_", KC, "_", STR, "_", PAD, "_f_cpu4")); \
 | 
			
		||||
  }                                                                            \
 | 
			
		||||
  static void BM_ConvFloatFwdGPU_##LABEL(int iters) {                          \
 | 
			
		||||
    BM_ConvFloat(iters, BS, R, C, ID, OD, KR, KC, CONV_OP_FORWARD, 1, STR,     \
 | 
			
		||||
  static void BM_ConvFloatFwdGPU_##LABEL(::testing::benchmark::State& state) { \
 | 
			
		||||
    BM_ConvFloat(state, BS, R, C, ID, OD, KR, KC, CONV_OP_FORWARD, 1, STR,     \
 | 
			
		||||
                 PAD, true, DT_FLOAT,                                          \
 | 
			
		||||
                 strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", OD, "_",    \
 | 
			
		||||
                                 KR, "_", KC, "_", STR, "_", PAD, "_f_gpu"));  \
 | 
			
		||||
  }                                                                            \
 | 
			
		||||
  static void BM_ConvHalfFwdGPU_##LABEL(int iters) {                           \
 | 
			
		||||
    BM_ConvFloat(iters, BS, R, C, ID, OD, KR, KC, CONV_OP_FORWARD, 1, STR,     \
 | 
			
		||||
  static void BM_ConvHalfFwdGPU_##LABEL(::testing::benchmark::State& state) {  \
 | 
			
		||||
    BM_ConvFloat(state, BS, R, C, ID, OD, KR, KC, CONV_OP_FORWARD, 1, STR,     \
 | 
			
		||||
                 PAD, true, DT_HALF,                                           \
 | 
			
		||||
                 strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", OD, "_",    \
 | 
			
		||||
                                 KR, "_", KC, "_", STR, "_", PAD, "_h_gpu"));  \
 | 
			
		||||
  }                                                                            \
 | 
			
		||||
  BENCHMARK(BM_ConvFloatFwdCPU1_##LABEL);                                      \
 | 
			
		||||
  BENCHMARK(BM_ConvFloatFwdCPU4_##LABEL);                                      \
 | 
			
		||||
  BENCHMARK(BM_ConvFloatFusedCPU1_##LABEL);                                    \
 | 
			
		||||
  BENCHMARK(BM_ConvFloatFusedCPU4_##LABEL);                                    \
 | 
			
		||||
  BENCHMARK(BM_ConvFloatFwdGPU_##LABEL);                                       \
 | 
			
		||||
  BENCHMARK(BM_ConvHalfFwdGPU_##LABEL)
 | 
			
		||||
  BENCHMARK(BM_ConvFloatFwdCPU1_##LABEL)->UseRealTime();                       \
 | 
			
		||||
  BENCHMARK(BM_ConvFloatFwdCPU4_##LABEL)->UseRealTime();                       \
 | 
			
		||||
  BENCHMARK(BM_ConvFloatFusedCPU1_##LABEL)->UseRealTime();                     \
 | 
			
		||||
  BENCHMARK(BM_ConvFloatFusedCPU4_##LABEL)->UseRealTime();                     \
 | 
			
		||||
  BENCHMARK(BM_ConvFloatFwdGPU_##LABEL)->UseRealTime();                        \
 | 
			
		||||
  BENCHMARK(BM_ConvHalfFwdGPU_##LABEL)->UseRealTime()
 | 
			
		||||
 | 
			
		||||
BM_ConvFloatFwd(32, 5, 5, 1248, 128, 1, 1, 1, SAME, conv0);
 | 
			
		||||
BM_ConvFloatFwd(32, 8, 8, 384, 384, 1, 3, 1, SAME, conv1);
 | 
			
		||||
@ -334,63 +338,70 @@ BM_ConvFloatFwd(32, 73, 73, 64, 192, 3, 3, 1, VALID, conv52);
 | 
			
		||||
BM_ConvFloatFwd(32, 73, 73, 64, 64, 1, 1, 1, VALID, conv53);
 | 
			
		||||
BM_ConvFloatFwd(32, 147, 147, 24, 64, 1, 1, 1, VALID, conv54);
 | 
			
		||||
 | 
			
		||||
#define BM_ConvFloatBkInAndFilter(BS, R, C, ID, OD, KR, KC, STR, PAD, LABEL)  \
 | 
			
		||||
  static void BM_ConvFloatBkInCPU1_##LABEL(int iters) {                       \
 | 
			
		||||
    BM_ConvFloat(iters, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_INPUT, 1,  \
 | 
			
		||||
                 STR, PAD, false, DT_FLOAT,                                   \
 | 
			
		||||
                 strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", OD, "_",   \
 | 
			
		||||
                                 KR, "_", KC, "_", STR, "_", PAD, "_cpu1"));  \
 | 
			
		||||
  }                                                                           \
 | 
			
		||||
  static void BM_ConvFloatBkInCPU4_##LABEL(int iters) {                       \
 | 
			
		||||
    BM_ConvFloat(iters, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_INPUT, 4,  \
 | 
			
		||||
                 STR, PAD, false, DT_FLOAT,                                   \
 | 
			
		||||
                 strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", OD, "_",   \
 | 
			
		||||
                                 KR, "_", KC, "_", STR, "_", PAD, "_cpu4"));  \
 | 
			
		||||
  }                                                                           \
 | 
			
		||||
  static void BM_ConvFloatBkInGPU_##LABEL(int iters) {                        \
 | 
			
		||||
    BM_ConvFloat(iters, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_INPUT, 1,  \
 | 
			
		||||
                 STR, PAD, true, DT_FLOAT,                                    \
 | 
			
		||||
                 strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", OD, "_",   \
 | 
			
		||||
                                 KR, "_", KC, "_", STR, "_", PAD, "_gpu"));   \
 | 
			
		||||
  }                                                                           \
 | 
			
		||||
  static void BM_ConvFloatBkFilterCPU1_##LABEL(int iters) {                   \
 | 
			
		||||
    BM_ConvFloat(iters, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_FILTER, 1, \
 | 
			
		||||
                 STR, PAD, false, DT_FLOAT,                                   \
 | 
			
		||||
                 strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", OD, "_",   \
 | 
			
		||||
                                 KR, "_", KC, "_", STR, "_", PAD, "_cpu1"));  \
 | 
			
		||||
  }                                                                           \
 | 
			
		||||
  static void BM_ConvFloatBkFilterCPU4_##LABEL(int iters) {                   \
 | 
			
		||||
    BM_ConvFloat(iters, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_FILTER, 4, \
 | 
			
		||||
                 STR, PAD, false, DT_FLOAT,                                   \
 | 
			
		||||
                 strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", OD, "_",   \
 | 
			
		||||
                                 KR, "_", KC, "_", STR, "_", PAD, "_cpu4"));  \
 | 
			
		||||
  }                                                                           \
 | 
			
		||||
  static void BM_ConvFloatBkFilterGPU_##LABEL(int iters) {                    \
 | 
			
		||||
    BM_ConvFloat(iters, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_FILTER, 1, \
 | 
			
		||||
                 STR, PAD, true, DT_FLOAT,                                    \
 | 
			
		||||
                 strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", OD, "_",   \
 | 
			
		||||
                                 KR, "_", KC, "_", STR, "_", PAD, "_gpu"));   \
 | 
			
		||||
  }                                                                           \
 | 
			
		||||
  static void BM_ConvHalfBkInGPU_##LABEL(int iters) {                         \
 | 
			
		||||
    BM_ConvFloat(iters, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_INPUT, 1,  \
 | 
			
		||||
                 STR, PAD, true, DT_HALF,                                     \
 | 
			
		||||
                 strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", OD, "_",   \
 | 
			
		||||
                                 KR, "_", KC, "_", STR, "_", PAD, "_gpu"));   \
 | 
			
		||||
  }                                                                           \
 | 
			
		||||
  static void BM_ConvHalfBkFilterGPU_##LABEL(int iters) {                     \
 | 
			
		||||
    BM_ConvFloat(iters, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_FILTER, 1, \
 | 
			
		||||
                 STR, PAD, true, DT_HALF,                                     \
 | 
			
		||||
                 strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", OD, "_",   \
 | 
			
		||||
                                 KR, "_", KC, "_", STR, "_", PAD, "_gpu"));   \
 | 
			
		||||
  }                                                                           \
 | 
			
		||||
  BENCHMARK(BM_ConvFloatBkInCPU1_##LABEL);                                    \
 | 
			
		||||
  BENCHMARK(BM_ConvFloatBkInCPU4_##LABEL);                                    \
 | 
			
		||||
  BENCHMARK(BM_ConvFloatBkInGPU_##LABEL);                                     \
 | 
			
		||||
  BENCHMARK(BM_ConvFloatBkFilterCPU1_##LABEL);                                \
 | 
			
		||||
  BENCHMARK(BM_ConvFloatBkFilterCPU4_##LABEL);                                \
 | 
			
		||||
  BENCHMARK(BM_ConvFloatBkFilterGPU_##LABEL);                                 \
 | 
			
		||||
  BENCHMARK(BM_ConvHalfBkInGPU_##LABEL);                                      \
 | 
			
		||||
  BENCHMARK(BM_ConvHalfBkFilterGPU_##LABEL)
 | 
			
		||||
#define BM_ConvFloatBkInAndFilter(BS, R, C, ID, OD, KR, KC, STR, PAD, LABEL)   \
 | 
			
		||||
  static void BM_ConvFloatBkInCPU1_##LABEL(                                    \
 | 
			
		||||
      ::testing::benchmark::State& state) {                                    \
 | 
			
		||||
    BM_ConvFloat(state, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_INPUT, 1,   \
 | 
			
		||||
                 STR, PAD, false, DT_FLOAT,                                    \
 | 
			
		||||
                 strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", OD, "_",    \
 | 
			
		||||
                                 KR, "_", KC, "_", STR, "_", PAD, "_cpu1"));   \
 | 
			
		||||
  }                                                                            \
 | 
			
		||||
  static void BM_ConvFloatBkInCPU4_##LABEL(                                    \
 | 
			
		||||
      ::testing::benchmark::State& state) {                                    \
 | 
			
		||||
    BM_ConvFloat(state, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_INPUT, 4,   \
 | 
			
		||||
                 STR, PAD, false, DT_FLOAT,                                    \
 | 
			
		||||
                 strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", OD, "_",    \
 | 
			
		||||
                                 KR, "_", KC, "_", STR, "_", PAD, "_cpu4"));   \
 | 
			
		||||
  }                                                                            \
 | 
			
		||||
  static void BM_ConvFloatBkInGPU_##LABEL(                                     \
 | 
			
		||||
      ::testing::benchmark::State& state) {                                    \
 | 
			
		||||
    BM_ConvFloat(state, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_INPUT, 1,   \
 | 
			
		||||
                 STR, PAD, true, DT_FLOAT,                                     \
 | 
			
		||||
                 strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", OD, "_",    \
 | 
			
		||||
                                 KR, "_", KC, "_", STR, "_", PAD, "_gpu"));    \
 | 
			
		||||
  }                                                                            \
 | 
			
		||||
  static void BM_ConvFloatBkFilterCPU1_##LABEL(                                \
 | 
			
		||||
      ::testing::benchmark::State& state) {                                    \
 | 
			
		||||
    BM_ConvFloat(state, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_FILTER, 1,  \
 | 
			
		||||
                 STR, PAD, false, DT_FLOAT,                                    \
 | 
			
		||||
                 strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", OD, "_",    \
 | 
			
		||||
                                 KR, "_", KC, "_", STR, "_", PAD, "_cpu1"));   \
 | 
			
		||||
  }                                                                            \
 | 
			
		||||
  static void BM_ConvFloatBkFilterCPU4_##LABEL(                                \
 | 
			
		||||
      ::testing::benchmark::State& state) {                                    \
 | 
			
		||||
    BM_ConvFloat(state, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_FILTER, 4,  \
 | 
			
		||||
                 STR, PAD, false, DT_FLOAT,                                    \
 | 
			
		||||
                 strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", OD, "_",    \
 | 
			
		||||
                                 KR, "_", KC, "_", STR, "_", PAD, "_cpu4"));   \
 | 
			
		||||
  }                                                                            \
 | 
			
		||||
  static void BM_ConvFloatBkFilterGPU_##LABEL(                                 \
 | 
			
		||||
      ::testing::benchmark::State& state) {                                    \
 | 
			
		||||
    BM_ConvFloat(state, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_FILTER, 1,  \
 | 
			
		||||
                 STR, PAD, true, DT_FLOAT,                                     \
 | 
			
		||||
                 strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", OD, "_",    \
 | 
			
		||||
                                 KR, "_", KC, "_", STR, "_", PAD, "_gpu"));    \
 | 
			
		||||
  }                                                                            \
 | 
			
		||||
  static void BM_ConvHalfBkInGPU_##LABEL(::testing::benchmark::State& state) { \
 | 
			
		||||
    BM_ConvFloat(state, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_INPUT, 1,   \
 | 
			
		||||
                 STR, PAD, true, DT_HALF,                                      \
 | 
			
		||||
                 strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", OD, "_",    \
 | 
			
		||||
                                 KR, "_", KC, "_", STR, "_", PAD, "_gpu"));    \
 | 
			
		||||
  }                                                                            \
 | 
			
		||||
  static void BM_ConvHalfBkFilterGPU_##LABEL(                                  \
 | 
			
		||||
      ::testing::benchmark::State& state) {                                    \
 | 
			
		||||
    BM_ConvFloat(state, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_FILTER, 1,  \
 | 
			
		||||
                 STR, PAD, true, DT_HALF,                                      \
 | 
			
		||||
                 strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", OD, "_",    \
 | 
			
		||||
                                 KR, "_", KC, "_", STR, "_", PAD, "_gpu"));    \
 | 
			
		||||
  }                                                                            \
 | 
			
		||||
  BENCHMARK(BM_ConvFloatBkInCPU1_##LABEL)->UseRealTime();                      \
 | 
			
		||||
  BENCHMARK(BM_ConvFloatBkInCPU4_##LABEL)->UseRealTime();                      \
 | 
			
		||||
  BENCHMARK(BM_ConvFloatBkInGPU_##LABEL)->UseRealTime();                       \
 | 
			
		||||
  BENCHMARK(BM_ConvFloatBkFilterCPU1_##LABEL)->UseRealTime();                  \
 | 
			
		||||
  BENCHMARK(BM_ConvFloatBkFilterCPU4_##LABEL)->UseRealTime();                  \
 | 
			
		||||
  BENCHMARK(BM_ConvFloatBkFilterGPU_##LABEL)->UseRealTime();                   \
 | 
			
		||||
  BENCHMARK(BM_ConvHalfBkInGPU_##LABEL)->UseRealTime();                        \
 | 
			
		||||
  BENCHMARK(BM_ConvHalfBkFilterGPU_##LABEL)->UseRealTime()
 | 
			
		||||
 | 
			
		||||
// Benchmarks from the inception model
 | 
			
		||||
 | 
			
		||||
@ -453,8 +464,8 @@ BM_ConvFloatBkInAndFilter(32, 147, 147, 24, 64, 1, 1, 1, VALID, conv54);
 | 
			
		||||
#define BM_ConvFloatBkFCPU(BS, R, C, ID, OD, KR, KC, TH, LABEL)                \
 | 
			
		||||
  static void                                                                  \
 | 
			
		||||
      BM_ConvFloatBkFCPU_##BS##_##R##_##C##_##ID##_##OD##_##KR##_##KC##_##TH(  \
 | 
			
		||||
          int iters) {                                                         \
 | 
			
		||||
    BM_ConvFloat(iters, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_FILTER, TH, \
 | 
			
		||||
          ::testing::benchmark::State& state) {                                \
 | 
			
		||||
    BM_ConvFloat(state, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_FILTER, TH, \
 | 
			
		||||
                 1, VALID, false, DT_FLOAT, LABEL);                            \
 | 
			
		||||
  }                                                                            \
 | 
			
		||||
  BENCHMARK(                                                                   \
 | 
			
		||||
@ -469,17 +480,19 @@ BM_ConvFloatBkFCPU(128, 13, 13, 384, 384, 3, 3, 4, "convnet-layer5");
 | 
			
		||||
 | 
			
		||||
#define BM_ConvFloatBkFGPU(BS, R, C, ID, OD, KR, KC, LABEL)                    \
 | 
			
		||||
  static void BM_ConvFloatBkFGPU_##BS##_##R##_##C##_##ID##_##OD##_##KR##_##KC( \
 | 
			
		||||
      int iters) {                                                             \
 | 
			
		||||
    BM_ConvFloat(iters, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_FILTER, 1,  \
 | 
			
		||||
      ::testing::benchmark::State& state) {                                    \
 | 
			
		||||
    BM_ConvFloat(state, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_FILTER, 1,  \
 | 
			
		||||
                 1, VALID, true, DT_FLOAT, LABEL);                             \
 | 
			
		||||
  }                                                                            \
 | 
			
		||||
  static void BM_ConvHalfBkFGPU_##BS##_##R##_##C##_##ID##_##OD##_##KR##_##KC(  \
 | 
			
		||||
      int iters) {                                                             \
 | 
			
		||||
    BM_ConvFloat(iters, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_FILTER, 1,  \
 | 
			
		||||
      ::testing::benchmark::State& state) {                                    \
 | 
			
		||||
    BM_ConvFloat(state, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_FILTER, 1,  \
 | 
			
		||||
                 1, VALID, true, DT_HALF, LABEL);                              \
 | 
			
		||||
  }                                                                            \
 | 
			
		||||
  BENCHMARK(BM_ConvFloatBkFGPU_##BS##_##R##_##C##_##ID##_##OD##_##KR##_##KC);  \
 | 
			
		||||
  BENCHMARK(BM_ConvHalfBkFGPU_##BS##_##R##_##C##_##ID##_##OD##_##KR##_##KC)
 | 
			
		||||
  BENCHMARK(BM_ConvFloatBkFGPU_##BS##_##R##_##C##_##ID##_##OD##_##KR##_##KC)   \
 | 
			
		||||
      ->UseRealTime();                                                         \
 | 
			
		||||
  BENCHMARK(BM_ConvHalfBkFGPU_##BS##_##R##_##C##_##ID##_##OD##_##KR##_##KC)    \
 | 
			
		||||
      ->UseRealTime()
 | 
			
		||||
 | 
			
		||||
// Benchmarks from https://github.com/soumith/convnet-benchmarks
 | 
			
		||||
BM_ConvFloatBkFGPU(128, 128, 128, 3, 96, 11, 11, "convnet-layer1");
 | 
			
		||||
@ -498,19 +511,19 @@ enum DEPTHWISE_CONV_OP {
 | 
			
		||||
 | 
			
		||||
}  // namespace
 | 
			
		||||
 | 
			
		||||
static void BM_ConvFloatDepthwise(int iters, int batch, int rows, int cols,
 | 
			
		||||
                                  int in_depth, int depth_multiplier,
 | 
			
		||||
                                  int out_depth, int filter_rows,
 | 
			
		||||
                                  int filter_cols, DEPTHWISE_CONV_OP op,
 | 
			
		||||
                                  int num_threads, int stride, Padding padding,
 | 
			
		||||
                                  bool use_gpu, const string& label) {
 | 
			
		||||
  testing::StopTiming();
 | 
			
		||||
static void BM_ConvFloatDepthwise(::testing::benchmark::State& state, int batch,
 | 
			
		||||
                                  int rows, int cols, int in_depth,
 | 
			
		||||
                                  int depth_multiplier, int out_depth,
 | 
			
		||||
                                  int filter_rows, int filter_cols,
 | 
			
		||||
                                  DEPTHWISE_CONV_OP op, int num_threads,
 | 
			
		||||
                                  int stride, Padding padding, bool use_gpu,
 | 
			
		||||
                                  const string& label) {
 | 
			
		||||
  if (!IsGoogleCudaEnabled() && use_gpu) {
 | 
			
		||||
    testing::SetLabel(
 | 
			
		||||
    state.SetLabel(
 | 
			
		||||
        strings::StrCat("Skipping GPU test (no --config=cuda): ", label));
 | 
			
		||||
    return;
 | 
			
		||||
  }
 | 
			
		||||
  testing::SetLabel(label);
 | 
			
		||||
  state.SetLabel(label);
 | 
			
		||||
 | 
			
		||||
  // Set the number of threads
 | 
			
		||||
  SessionOptions options;
 | 
			
		||||
@ -603,10 +616,10 @@ static void BM_ConvFloatDepthwise(int iters, int batch, int rows, int cols,
 | 
			
		||||
  TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph, g));
 | 
			
		||||
 | 
			
		||||
  string device = use_gpu ? "gpu" : "cpu";
 | 
			
		||||
  testing::UseRealTime();
 | 
			
		||||
  testing::StartTiming();
 | 
			
		||||
  test::Benchmark(device, g, &options).Run(iters);
 | 
			
		||||
  testing::ItemsProcessed(num_ops * iters);
 | 
			
		||||
  test::Benchmark(device, g, &options, nullptr, nullptr, "",
 | 
			
		||||
                  /*old_benchmark_api=*/false)
 | 
			
		||||
      .Run(state);
 | 
			
		||||
  state.SetItemsProcessed(num_ops * state.iterations());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// BS: batch_size
 | 
			
		||||
@ -622,30 +635,33 @@ static void BM_ConvFloatDepthwise(int iters, int batch, int rows, int cols,
 | 
			
		||||
 | 
			
		||||
#define BM_ConvFloatDepthwiseFwd(BS, R, C, ID, DM, OD, KR, KC, STR, PAD,    \
 | 
			
		||||
                                 LABEL)                                     \
 | 
			
		||||
  static void BM_ConvFloatDepthwiseFwdCPU1_##LABEL(int iters) {             \
 | 
			
		||||
  static void BM_ConvFloatDepthwiseFwdCPU1_##LABEL(                         \
 | 
			
		||||
      ::testing::benchmark::State& state) {                                 \
 | 
			
		||||
    BM_ConvFloatDepthwise(                                                  \
 | 
			
		||||
        iters, BS, R, C, ID, DM, OD, KR, KC, DEPTHWISE_CONV_OP_FWD, 1, STR, \
 | 
			
		||||
        state, BS, R, C, ID, DM, OD, KR, KC, DEPTHWISE_CONV_OP_FWD, 1, STR, \
 | 
			
		||||
        PAD, false,                                                         \
 | 
			
		||||
        strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", DM, "_", OD, "_", \
 | 
			
		||||
                        KR, "_", KC, "_", STR, "_", PAD, "_cpu1"));         \
 | 
			
		||||
  }                                                                         \
 | 
			
		||||
  static void BM_ConvFloatDepthwiseFwdCPU4_##LABEL(int iters) {             \
 | 
			
		||||
  static void BM_ConvFloatDepthwiseFwdCPU4_##LABEL(                         \
 | 
			
		||||
      ::testing::benchmark::State& state) {                                 \
 | 
			
		||||
    BM_ConvFloatDepthwise(                                                  \
 | 
			
		||||
        iters, BS, R, C, ID, DM, OD, KR, KC, DEPTHWISE_CONV_OP_FWD, 4, STR, \
 | 
			
		||||
        state, BS, R, C, ID, DM, OD, KR, KC, DEPTHWISE_CONV_OP_FWD, 4, STR, \
 | 
			
		||||
        PAD, false,                                                         \
 | 
			
		||||
        strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", DM, "_", OD, "_", \
 | 
			
		||||
                        KR, "_", KC, "_", STR, "_", PAD, "_cpu4"));         \
 | 
			
		||||
  }                                                                         \
 | 
			
		||||
  static void BM_ConvFloatDepthwiseFwdGPU_##LABEL(int iters) {              \
 | 
			
		||||
  static void BM_ConvFloatDepthwiseFwdGPU_##LABEL(                          \
 | 
			
		||||
      ::testing::benchmark::State& state) {                                 \
 | 
			
		||||
    BM_ConvFloatDepthwise(                                                  \
 | 
			
		||||
        iters, BS, R, C, ID, DM, OD, KR, KC, DEPTHWISE_CONV_OP_FWD, 1, STR, \
 | 
			
		||||
        state, BS, R, C, ID, DM, OD, KR, KC, DEPTHWISE_CONV_OP_FWD, 1, STR, \
 | 
			
		||||
        PAD, true,                                                          \
 | 
			
		||||
        strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", DM, "_", OD, "_", \
 | 
			
		||||
                        KR, "_", KC, "_", STR, "_", PAD, "_gpu"));          \
 | 
			
		||||
  }                                                                         \
 | 
			
		||||
  BENCHMARK(BM_ConvFloatDepthwiseFwdCPU1_##LABEL);                          \
 | 
			
		||||
  BENCHMARK(BM_ConvFloatDepthwiseFwdCPU4_##LABEL);                          \
 | 
			
		||||
  BENCHMARK(BM_ConvFloatDepthwiseFwdGPU_##LABEL);
 | 
			
		||||
  BENCHMARK(BM_ConvFloatDepthwiseFwdCPU1_##LABEL)->UseRealTime();           \
 | 
			
		||||
  BENCHMARK(BM_ConvFloatDepthwiseFwdCPU4_##LABEL)->UseRealTime();           \
 | 
			
		||||
  BENCHMARK(BM_ConvFloatDepthwiseFwdGPU_##LABEL)->UseRealTime();
 | 
			
		||||
 | 
			
		||||
// The configurations below are mostly from mobilenet models.
 | 
			
		||||
BM_ConvFloatDepthwiseFwd(32, 112, 112, 3, 8, 24, 3, 3, 1, SAME, conv0);
 | 
			
		||||
@ -662,53 +678,59 @@ BM_ConvFloatDepthwiseFwd(1, 100, 100, 72, 1, 72, 3, 3, 1, SAME, conv9);
 | 
			
		||||
BM_ConvFloatDepthwiseFwd(1, 100, 100, 72, 1, 72, 5, 5, 1, SAME, conv10);
 | 
			
		||||
 | 
			
		||||
#define BM_ConvFloatDepthwiseBk(BS, R, C, ID, DM, OD, KR, KC, STR, PAD, LABEL) \
 | 
			
		||||
  static void BM_ConvFloatDepthwiseBkInCPU1_##LABEL(int iters) {               \
 | 
			
		||||
  static void BM_ConvFloatDepthwiseBkInCPU1_##LABEL(                           \
 | 
			
		||||
      ::testing::benchmark::State& state) {                                    \
 | 
			
		||||
    BM_ConvFloatDepthwise(                                                     \
 | 
			
		||||
        iters, BS, R, C, ID, DM, OD, KR, KC, DEPTHWISE_CONV_OP_BACKPROP_INPUT, \
 | 
			
		||||
        state, BS, R, C, ID, DM, OD, KR, KC, DEPTHWISE_CONV_OP_BACKPROP_INPUT, \
 | 
			
		||||
        1, STR, PAD, false,                                                    \
 | 
			
		||||
        strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", DM, "_", OD, "_",    \
 | 
			
		||||
                        KR, "_", KC, "_", STR, "_", PAD, "_cpu1"));            \
 | 
			
		||||
  }                                                                            \
 | 
			
		||||
  static void BM_ConvFloatDepthwiseBkInCPU4_##LABEL(int iters) {               \
 | 
			
		||||
  static void BM_ConvFloatDepthwiseBkInCPU4_##LABEL(                           \
 | 
			
		||||
      ::testing::benchmark::State& state) {                                    \
 | 
			
		||||
    BM_ConvFloatDepthwise(                                                     \
 | 
			
		||||
        iters, BS, R, C, ID, DM, OD, KR, KC, DEPTHWISE_CONV_OP_BACKPROP_INPUT, \
 | 
			
		||||
        state, BS, R, C, ID, DM, OD, KR, KC, DEPTHWISE_CONV_OP_BACKPROP_INPUT, \
 | 
			
		||||
        4, STR, PAD, false,                                                    \
 | 
			
		||||
        strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", DM, "_", OD, "_",    \
 | 
			
		||||
                        KR, "_", KC, "_", STR, "_", PAD, "_cpu4"));            \
 | 
			
		||||
  }                                                                            \
 | 
			
		||||
  static void BM_ConvFloatDepthwiseBkInGPU_##LABEL(int iters) {                \
 | 
			
		||||
  static void BM_ConvFloatDepthwiseBkInGPU_##LABEL(                            \
 | 
			
		||||
      ::testing::benchmark::State& state) {                                    \
 | 
			
		||||
    BM_ConvFloatDepthwise(                                                     \
 | 
			
		||||
        iters, BS, R, C, ID, DM, OD, KR, KC, DEPTHWISE_CONV_OP_BACKPROP_INPUT, \
 | 
			
		||||
        state, BS, R, C, ID, DM, OD, KR, KC, DEPTHWISE_CONV_OP_BACKPROP_INPUT, \
 | 
			
		||||
        4, STR, PAD, true,                                                     \
 | 
			
		||||
        strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", DM, "_", OD, "_",    \
 | 
			
		||||
                        KR, "_", KC, "_", STR, "_", PAD, "_gpu"));             \
 | 
			
		||||
  }                                                                            \
 | 
			
		||||
  static void BM_ConvFloatDepthwiseBkFilterCPU1_##LABEL(int iters) {           \
 | 
			
		||||
  static void BM_ConvFloatDepthwiseBkFilterCPU1_##LABEL(                       \
 | 
			
		||||
      ::testing::benchmark::State& state) {                                    \
 | 
			
		||||
    BM_ConvFloatDepthwise(                                                     \
 | 
			
		||||
        iters, BS, R, C, ID, DM, OD, KR, KC,                                   \
 | 
			
		||||
        state, BS, R, C, ID, DM, OD, KR, KC,                                   \
 | 
			
		||||
        DEPTHWISE_CONV_OP_BACKPROP_FILTER, 1, STR, PAD, false,                 \
 | 
			
		||||
        strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", DM, "_", OD, "_",    \
 | 
			
		||||
                        KR, "_", KC, "_", STR, "_", PAD, "_cpu1"));            \
 | 
			
		||||
  }                                                                            \
 | 
			
		||||
  static void BM_ConvFloatDepthwiseBkFilterCPU4_##LABEL(int iters) {           \
 | 
			
		||||
  static void BM_ConvFloatDepthwiseBkFilterCPU4_##LABEL(                       \
 | 
			
		||||
      ::testing::benchmark::State& state) {                                    \
 | 
			
		||||
    BM_ConvFloatDepthwise(                                                     \
 | 
			
		||||
        iters, BS, R, C, ID, DM, OD, KR, KC,                                   \
 | 
			
		||||
        state, BS, R, C, ID, DM, OD, KR, KC,                                   \
 | 
			
		||||
        DEPTHWISE_CONV_OP_BACKPROP_FILTER, 4, STR, PAD, false,                 \
 | 
			
		||||
        strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", DM, "_", OD, "_",    \
 | 
			
		||||
                        KR, "_", KC, "_", STR, "_", PAD, "_cpu4"));            \
 | 
			
		||||
  }                                                                            \
 | 
			
		||||
  static void BM_ConvFloatDepthwiseBkFilterGPU_##LABEL(int iters) {            \
 | 
			
		||||
  static void BM_ConvFloatDepthwiseBkFilterGPU_##LABEL(                        \
 | 
			
		||||
      ::testing::benchmark::State& state) {                                    \
 | 
			
		||||
    BM_ConvFloatDepthwise(                                                     \
 | 
			
		||||
        iters, BS, R, C, ID, DM, OD, KR, KC,                                   \
 | 
			
		||||
        state, BS, R, C, ID, DM, OD, KR, KC,                                   \
 | 
			
		||||
        DEPTHWISE_CONV_OP_BACKPROP_FILTER, 4, STR, PAD, true,                  \
 | 
			
		||||
        strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", DM, "_", OD, "_",    \
 | 
			
		||||
                        KR, "_", KC, "_", STR, "_", PAD, "_gpu"));             \
 | 
			
		||||
  }                                                                            \
 | 
			
		||||
  BENCHMARK(BM_ConvFloatDepthwiseBkInCPU1_##LABEL);                            \
 | 
			
		||||
  BENCHMARK(BM_ConvFloatDepthwiseBkInCPU4_##LABEL);                            \
 | 
			
		||||
  BENCHMARK(BM_ConvFloatDepthwiseBkInGPU_##LABEL);                             \
 | 
			
		||||
  BENCHMARK(BM_ConvFloatDepthwiseBkFilterCPU1_##LABEL);                        \
 | 
			
		||||
  BENCHMARK(BM_ConvFloatDepthwiseBkFilterCPU4_##LABEL);                        \
 | 
			
		||||
  BENCHMARK(BM_ConvFloatDepthwiseBkInCPU1_##LABEL)->UseRealTime();             \
 | 
			
		||||
  BENCHMARK(BM_ConvFloatDepthwiseBkInCPU4_##LABEL)->UseRealTime();             \
 | 
			
		||||
  BENCHMARK(BM_ConvFloatDepthwiseBkInGPU_##LABEL)->UseRealTime();              \
 | 
			
		||||
  BENCHMARK(BM_ConvFloatDepthwiseBkFilterCPU1_##LABEL)->UseRealTime();         \
 | 
			
		||||
  BENCHMARK(BM_ConvFloatDepthwiseBkFilterCPU4_##LABEL)->UseRealTime();         \
 | 
			
		||||
  BENCHMARK(BM_ConvFloatDepthwiseBkFilterGPU_##LABEL)
 | 
			
		||||
 | 
			
		||||
// The configurations below are mostly from mobilenet models.
 | 
			
		||||
@ -732,10 +754,9 @@ BM_ConvFloatDepthwiseBk(32, 112, 112, 8, 3, 24, 3, 3, 1, SAME, conv12);
 | 
			
		||||
BM_ConvFloatDepthwiseBk(32, 112, 112, 12, 2, 24, 3, 3, 1, SAME, conv13);
 | 
			
		||||
BM_ConvFloatDepthwiseBk(32, 112, 112, 24, 1, 24, 3, 3, 1, SAME, conv14);
 | 
			
		||||
 | 
			
		||||
static void BM_LRNFloat(int iters, int depth, int cols, int rows,
 | 
			
		||||
                        int batch_size, int range, int num_threads,
 | 
			
		||||
static void BM_LRNFloat(::testing::benchmark::State& state, int depth, int cols,
 | 
			
		||||
                        int rows, int batch_size, int range, int num_threads,
 | 
			
		||||
                        const string& label) {
 | 
			
		||||
  tensorflow::testing::StopTiming();
 | 
			
		||||
  std::unique_ptr<Device> device(
 | 
			
		||||
      DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0"));
 | 
			
		||||
 | 
			
		||||
@ -778,26 +799,24 @@ static void BM_LRNFloat(int iters, int depth, int cols, int rows,
 | 
			
		||||
  std::unique_ptr<OpKernelContext> context(new OpKernelContext(¶ms));
 | 
			
		||||
 | 
			
		||||
  op->Compute(context.get());
 | 
			
		||||
  testing::UseRealTime();
 | 
			
		||||
  tensorflow::testing::StartTiming();
 | 
			
		||||
  for (int i = 0; i < iters; ++i) {
 | 
			
		||||
  for (auto s : state) {
 | 
			
		||||
    delete context->release_output(0).tensor;
 | 
			
		||||
    op->Compute(context.get());
 | 
			
		||||
  }
 | 
			
		||||
  tensorflow::testing::StopTiming();
 | 
			
		||||
  testing::ItemsProcessed(context->mutable_output(0)->NumElements() * iters *
 | 
			
		||||
                          (2 * range + 1) * 2);
 | 
			
		||||
  testing::SetLabel(label);
 | 
			
		||||
  state.SetItemsProcessed(context->mutable_output(0)->NumElements() *
 | 
			
		||||
                          state.iterations() * (2 * range + 1) * 2);
 | 
			
		||||
  state.SetLabel(label);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#define BM_LRNFloatFwdCPU(DEPTH, COLS, ROWS, BATCH, RANGE, THREADS, LABEL)   \
 | 
			
		||||
  static void                                                                \
 | 
			
		||||
      BM_LRNFloat_##DEPTH##_##COLS##_##ROWS##_##BATCH##_##RANGE##_##THREADS( \
 | 
			
		||||
          int iters) {                                                       \
 | 
			
		||||
    BM_LRNFloat(iters, DEPTH, COLS, ROWS, BATCH, RANGE, THREADS, LABEL);     \
 | 
			
		||||
          ::testing::benchmark::State& state) {                              \
 | 
			
		||||
    BM_LRNFloat(state, DEPTH, COLS, ROWS, BATCH, RANGE, THREADS, LABEL);     \
 | 
			
		||||
  }                                                                          \
 | 
			
		||||
  BENCHMARK(                                                                 \
 | 
			
		||||
      BM_LRNFloat_##DEPTH##_##COLS##_##ROWS##_##BATCH##_##RANGE##_##THREADS)
 | 
			
		||||
      BM_LRNFloat_##DEPTH##_##COLS##_##ROWS##_##BATCH##_##RANGE##_##THREADS) \
 | 
			
		||||
      ->UseRealTime()
 | 
			
		||||
 | 
			
		||||
// clang-format off
 | 
			
		||||
//                DEPTH, COLS, ROWS, BATCH, RANGE, THREADS, LABEL
 | 
			
		||||
@ -815,10 +834,10 @@ BM_LRNFloatFwdCPU(192,   56,   56,   32,    5,     8,       "lrn 8 threads");
 | 
			
		||||
/*
 | 
			
		||||
AvgPooling Op
 | 
			
		||||
*/
 | 
			
		||||
static void BM_AvgPool(int iters, int batch_size, int rows, int cols, int depth,
 | 
			
		||||
                       int kernel_rows, int kernel_cols, int stride,
 | 
			
		||||
                       Padding padding, int num_threads, const string& label) {
 | 
			
		||||
  tensorflow::testing::StopTiming();
 | 
			
		||||
static void BM_AvgPool(::testing::benchmark::State& state, int batch_size,
 | 
			
		||||
                       int rows, int cols, int depth, int kernel_rows,
 | 
			
		||||
                       int kernel_cols, int stride, Padding padding,
 | 
			
		||||
                       int num_threads, const string& label) {
 | 
			
		||||
  std::unique_ptr<Device> device(
 | 
			
		||||
      DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0"));
 | 
			
		||||
 | 
			
		||||
@ -860,16 +879,13 @@ static void BM_AvgPool(int iters, int batch_size, int rows, int cols, int depth,
 | 
			
		||||
      new OpKernelContext(¶ms));
 | 
			
		||||
 | 
			
		||||
  op->Compute(avgpool_context.get());
 | 
			
		||||
  testing::UseRealTime();
 | 
			
		||||
  tensorflow::testing::StartTiming();
 | 
			
		||||
  for (int i = 0; i < iters; ++i) {
 | 
			
		||||
  for (auto s : state) {
 | 
			
		||||
    delete avgpool_context->release_output(0).tensor;
 | 
			
		||||
    op->Compute(avgpool_context.get());
 | 
			
		||||
  }
 | 
			
		||||
  tensorflow::testing::StopTiming();
 | 
			
		||||
  testing::ItemsProcessed(avgpool_context->mutable_output(0)->NumElements() *
 | 
			
		||||
                          iters);
 | 
			
		||||
  testing::SetLabel(label);
 | 
			
		||||
  state.SetItemsProcessed(avgpool_context->mutable_output(0)->NumElements() *
 | 
			
		||||
                          state.iterations());
 | 
			
		||||
  state.SetLabel(label);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// BS: batch_size
 | 
			
		||||
@ -883,11 +899,12 @@ static void BM_AvgPool(int iters, int batch_size, int rows, int cols, int depth,
 | 
			
		||||
#define BM_AvgPoolFwdCPU(BS, IR, IC, ND, KR, KC, ST, PT, TH, LABEL)            \
 | 
			
		||||
  static void                                                                  \
 | 
			
		||||
      BM_AvgPool_##BS##_##IR##_##IC##_##ND##_##KR##_##KC##_##ST##_##PT##_##TH( \
 | 
			
		||||
          int iters) {                                                         \
 | 
			
		||||
    BM_AvgPool(iters, BS, IR, IC, ND, KR, KC, ST, PT, TH, LABEL);              \
 | 
			
		||||
          ::testing::benchmark::State& state) {                                \
 | 
			
		||||
    BM_AvgPool(state, BS, IR, IC, ND, KR, KC, ST, PT, TH, LABEL);              \
 | 
			
		||||
  }                                                                            \
 | 
			
		||||
  BENCHMARK(                                                                   \
 | 
			
		||||
      BM_AvgPool_##BS##_##IR##_##IC##_##ND##_##KR##_##KC##_##ST##_##PT##_##TH)
 | 
			
		||||
      BM_AvgPool_##BS##_##IR##_##IC##_##ND##_##KR##_##KC##_##ST##_##PT##_##TH) \
 | 
			
		||||
      ->UseRealTime()
 | 
			
		||||
 | 
			
		||||
// Labels are taken from the 2014-July-24 version of imagenet
 | 
			
		||||
BM_AvgPoolFwdCPU(32, 112, 112, 64, 3, 3, 2, VALID, 1, "avgpool0_VALID");
 | 
			
		||||
@ -907,11 +924,10 @@ BM_AvgPoolFwdCPU(32, 56, 56, 192, 3, 3, 2, SAME, 4, "avgpool1_SAME");
 | 
			
		||||
BM_AvgPoolFwdCPU(32, 28, 28, 352, 3, 3, 2, SAME, 4, "avgpool4_SAME");
 | 
			
		||||
BM_AvgPoolFwdCPU(32, 14, 14, 576, 3, 3, 2, SAME, 4, "avgpool10_SAME");
 | 
			
		||||
 | 
			
		||||
static void BM_AvgPoolBk(int iters, int batch_size, int rows, int cols,
 | 
			
		||||
                         int depth, int kernel_rows, int kernel_cols,
 | 
			
		||||
                         int stride, Padding padding, int num_threads,
 | 
			
		||||
                         const string& label) {
 | 
			
		||||
  tensorflow::testing::StopTiming();
 | 
			
		||||
static void BM_AvgPoolBk(::testing::benchmark::State& state, int batch_size,
 | 
			
		||||
                         int rows, int cols, int depth, int kernel_rows,
 | 
			
		||||
                         int kernel_cols, int stride, Padding padding,
 | 
			
		||||
                         int num_threads, const string& label) {
 | 
			
		||||
  std::unique_ptr<Device> device(
 | 
			
		||||
      DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0"));
 | 
			
		||||
 | 
			
		||||
@ -966,16 +982,13 @@ static void BM_AvgPoolBk(int iters, int batch_size, int rows, int cols,
 | 
			
		||||
      new OpKernelContext(¶ms));
 | 
			
		||||
 | 
			
		||||
  op->Compute(avgpool_context.get());
 | 
			
		||||
  testing::UseRealTime();
 | 
			
		||||
  tensorflow::testing::StartTiming();
 | 
			
		||||
  for (int i = 0; i < iters; ++i) {
 | 
			
		||||
  for (auto s : state) {
 | 
			
		||||
    delete avgpool_context->release_output(0).tensor;
 | 
			
		||||
    op->Compute(avgpool_context.get());
 | 
			
		||||
  }
 | 
			
		||||
  tensorflow::testing::StopTiming();
 | 
			
		||||
  testing::ItemsProcessed(avgpool_context->mutable_output(0)->NumElements() *
 | 
			
		||||
                          iters);
 | 
			
		||||
  testing::SetLabel(label);
 | 
			
		||||
  state.SetItemsProcessed(avgpool_context->mutable_output(0)->NumElements() *
 | 
			
		||||
                          state.iterations());
 | 
			
		||||
  state.SetLabel(label);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// BS: batch_size
 | 
			
		||||
@ -987,14 +1000,17 @@ static void BM_AvgPoolBk(int iters, int batch_size, int rows, int cols,
 | 
			
		||||
// ST: stride. We use the same stride for both directions.
 | 
			
		||||
// PT: padding
 | 
			
		||||
// The resulted symbol is too long. Need to use two macros to fit in 80-chars
 | 
			
		||||
// NOLINTBEGIN
 | 
			
		||||
#define BM_AvgPoolBkCPU(BS, IR, IC, ND, KR, KC, ST, PT, TH, LABEL)               \
 | 
			
		||||
  static void                                                                    \
 | 
			
		||||
      BM_AvgPoolBk_##BS##_##IR##_##IC##_##ND##_##KR##_##KC##_##ST##_##PT##_##TH( \
 | 
			
		||||
          int iters) {                                                           \
 | 
			
		||||
    BM_AvgPoolBk(iters, BS, IR, IC, ND, KR, KC, ST, PT, TH, LABEL);              \
 | 
			
		||||
          ::testing::benchmark::State& state) {                                  \
 | 
			
		||||
    BM_AvgPoolBk(state, BS, IR, IC, ND, KR, KC, ST, PT, TH, LABEL);              \
 | 
			
		||||
  }                                                                              \
 | 
			
		||||
  BENCHMARK(                                                                     \
 | 
			
		||||
      BM_AvgPoolBk_##BS##_##IR##_##IC##_##ND##_##KR##_##KC##_##ST##_##PT##_##TH)
 | 
			
		||||
      BM_AvgPoolBk_##BS##_##IR##_##IC##_##ND##_##KR##_##KC##_##ST##_##PT##_##TH) \
 | 
			
		||||
      ->UseRealTime()
 | 
			
		||||
// NOLINTEND
 | 
			
		||||
 | 
			
		||||
// Shapes taken from the 2015/05/16 inception model
 | 
			
		||||
BM_AvgPoolBkCPU(32, 35, 35, 192, 3, 3, 1, SAME, 1, "avgpool_grad0_SAME");
 | 
			
		||||
@ -1010,10 +1026,10 @@ BM_AvgPoolBkCPU(32, 8, 8, 2048, 8, 8, 1, VALID, 1, "avgpool_grad8_VALID");
 | 
			
		||||
/*
 | 
			
		||||
MaxPooling Op
 | 
			
		||||
*/
 | 
			
		||||
static void BM_MaxPool(int iters, int batch_size, int rows, int cols, int depth,
 | 
			
		||||
                       int kernel_rows, int kernel_cols, int stride,
 | 
			
		||||
                       Padding padding, int num_threads, const string& label) {
 | 
			
		||||
  tensorflow::testing::StopTiming();
 | 
			
		||||
static void BM_MaxPool(::testing::benchmark::State& state, int batch_size,
 | 
			
		||||
                       int rows, int cols, int depth, int kernel_rows,
 | 
			
		||||
                       int kernel_cols, int stride, Padding padding,
 | 
			
		||||
                       int num_threads, const string& label) {
 | 
			
		||||
  SessionOptions options;
 | 
			
		||||
  options.config.set_intra_op_parallelism_threads(num_threads);
 | 
			
		||||
 | 
			
		||||
@ -1057,16 +1073,13 @@ static void BM_MaxPool(int iters, int batch_size, int rows, int cols, int depth,
 | 
			
		||||
      new OpKernelContext(¶ms));
 | 
			
		||||
 | 
			
		||||
  op->Compute(maxpool_context.get());
 | 
			
		||||
  testing::UseRealTime();
 | 
			
		||||
  tensorflow::testing::StartTiming();
 | 
			
		||||
  for (int i = 0; i < iters; ++i) {
 | 
			
		||||
  for (auto s : state) {
 | 
			
		||||
    delete maxpool_context->release_output(0).tensor;
 | 
			
		||||
    op->Compute(maxpool_context.get());
 | 
			
		||||
  }
 | 
			
		||||
  tensorflow::testing::StopTiming();
 | 
			
		||||
  testing::ItemsProcessed(maxpool_context->mutable_output(0)->NumElements() *
 | 
			
		||||
                          iters);
 | 
			
		||||
  testing::SetLabel(label);
 | 
			
		||||
  state.SetItemsProcessed(maxpool_context->mutable_output(0)->NumElements() *
 | 
			
		||||
                          state.iterations());
 | 
			
		||||
  state.SetLabel(label);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// BS: batch_size
 | 
			
		||||
@ -1080,11 +1093,12 @@ static void BM_MaxPool(int iters, int batch_size, int rows, int cols, int depth,
 | 
			
		||||
#define BM_MaxPoolFwdCPU(BS, IR, IC, ND, KR, KC, ST, PT, TH, LABEL)            \
 | 
			
		||||
  static void                                                                  \
 | 
			
		||||
      BM_MaxPool_##BS##_##IR##_##IC##_##ND##_##KR##_##KC##_##ST##_##PT##_##TH( \
 | 
			
		||||
          int iters) {                                                         \
 | 
			
		||||
    BM_MaxPool(iters, BS, IR, IC, ND, KR, KC, ST, PT, TH, LABEL);              \
 | 
			
		||||
          ::testing::benchmark::State& state) {                                \
 | 
			
		||||
    BM_MaxPool(state, BS, IR, IC, ND, KR, KC, ST, PT, TH, LABEL);              \
 | 
			
		||||
  }                                                                            \
 | 
			
		||||
  BENCHMARK(                                                                   \
 | 
			
		||||
      BM_MaxPool_##BS##_##IR##_##IC##_##ND##_##KR##_##KC##_##ST##_##PT##_##TH)
 | 
			
		||||
      BM_MaxPool_##BS##_##IR##_##IC##_##ND##_##KR##_##KC##_##ST##_##PT##_##TH) \
 | 
			
		||||
      ->UseRealTime()
 | 
			
		||||
 | 
			
		||||
// Labels are taken from the 2014-July-24 version of imagenet
 | 
			
		||||
/* TODO XXX
 | 
			
		||||
@ -1106,10 +1120,10 @@ BM_MaxPoolFwdCPU(32, 56, 56, 192, 3, 3, 2, SAME, 4, "maxpool1_SAME");
 | 
			
		||||
BM_MaxPoolFwdCPU(32, 28, 28, 352, 3, 3, 2, SAME, 4, "maxpool4_SAME");
 | 
			
		||||
BM_MaxPoolFwdCPU(32, 14, 14, 576, 3, 3, 2, SAME, 4, "maxpool10_SAME");
 | 
			
		||||
 | 
			
		||||
static void BM_MaxPoolBk(int iters, int batch_size, int rows, int cols,
 | 
			
		||||
                         int depth, int kernel_rows, int kernel_cols,
 | 
			
		||||
                         int stride, Padding padding, int num_threads,
 | 
			
		||||
                         bool use_gpu, const string& label) {
 | 
			
		||||
static void BM_MaxPoolBk(::testing::benchmark::State& state, int batch_size,
 | 
			
		||||
                         int rows, int cols, int depth, int kernel_rows,
 | 
			
		||||
                         int kernel_cols, int stride, Padding padding,
 | 
			
		||||
                         int num_threads, bool use_gpu, const string& label) {
 | 
			
		||||
  auto root = Scope::NewRootScope().ExitOnError();
 | 
			
		||||
 | 
			
		||||
  int64 out_height, out_width, pad_rows, pad_cols;
 | 
			
		||||
@ -1138,11 +1152,11 @@ static void BM_MaxPoolBk(int iters, int batch_size, int rows, int cols,
 | 
			
		||||
  Graph* g = new Graph(OpRegistry::Global());
 | 
			
		||||
  TF_CHECK_OK(root.ToGraph(g));
 | 
			
		||||
  string device = use_gpu ? "gpu" : "cpu";
 | 
			
		||||
  testing::UseRealTime();
 | 
			
		||||
  test::Benchmark(device, g).Run(iters);
 | 
			
		||||
  test::Benchmark(device, g, /*old_benchmark_api*/ false).Run(state);
 | 
			
		||||
 | 
			
		||||
  testing::ItemsProcessed(batch_size * rows * cols * depth * iters);
 | 
			
		||||
  testing::SetLabel(label);
 | 
			
		||||
  state.SetItemsProcessed(batch_size * rows * cols * depth *
 | 
			
		||||
                          state.iterations());
 | 
			
		||||
  state.SetLabel(label);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// BS: batch_size
 | 
			
		||||
@ -1159,23 +1173,23 @@ static void BM_MaxPoolBk(int iters, int batch_size, int rows, int cols,
 | 
			
		||||
  static void                                                                  \
 | 
			
		||||
      BM_MaxPoolBk_GPU_##BS##_##IR##_##IC##_##ND##_##KR##_##KC##_##ST##_       \
 | 
			
		||||
          ##PT##_##TH(                                                         \
 | 
			
		||||
          int iters) {                                                         \
 | 
			
		||||
    BM_MaxPoolBk(iters, BS, IR, IC, ND, KR, KC, ST, PT, TH, true, LABEL);      \
 | 
			
		||||
          ::testing::benchmark::State& state) {                                \
 | 
			
		||||
    BM_MaxPoolBk(state, BS, IR, IC, ND, KR, KC, ST, PT, TH, true, LABEL);      \
 | 
			
		||||
  }                                                                            \
 | 
			
		||||
  BENCHMARK(                                                                   \
 | 
			
		||||
      BM_MaxPoolBk_GPU_##BS##_##IR##_##IC##_##ND##_##KR##_##KC##_##ST##_       \
 | 
			
		||||
          ##PT##_##TH)                                                         \
 | 
			
		||||
          ##PT##_##TH)->UseRealTime()
 | 
			
		||||
 | 
			
		||||
#define BM_MaxPoolBkCPU(BS, IR, IC, ND, KR, KC, ST, PT, TH, LABEL)             \
 | 
			
		||||
  static void                                                                  \
 | 
			
		||||
      BM_MaxPoolBk_CPU_##BS##_##IR##_##IC##_##ND##_##KR##_##KC##_##ST##_       \
 | 
			
		||||
          ##PT##_##TH(                                                         \
 | 
			
		||||
          int iters) {                                                         \
 | 
			
		||||
    BM_MaxPoolBk(iters, BS, IR, IC, ND, KR, KC, ST, PT, TH, false, LABEL);     \
 | 
			
		||||
          ::testing::benchmark::State& state) {                                \
 | 
			
		||||
    BM_MaxPoolBk(state, BS, IR, IC, ND, KR, KC, ST, PT, TH, false, LABEL);     \
 | 
			
		||||
  }                                                                            \
 | 
			
		||||
  BENCHMARK(                                                                   \
 | 
			
		||||
      BM_MaxPoolBk_CPU_##BS##_##IR##_##IC##_##ND##_##KR##_##KC##_##ST##_       \
 | 
			
		||||
          ##PT##_##TH)
 | 
			
		||||
          ##PT##_##TH)->UseRealTime()
 | 
			
		||||
// clang-format on
 | 
			
		||||
 | 
			
		||||
// Shapes taken from the 2015/05/16 inception model
 | 
			
		||||
@ -1195,9 +1209,9 @@ BM_MaxPoolBkCPU(32, 8, 8, 2048, 3, 3, 2, VALID, 1, "maxpool_grad4_VALID");
 | 
			
		||||
Relu Op
 | 
			
		||||
Run benchmark with:
 | 
			
		||||
*/
 | 
			
		||||
static void BM_ReluFloat(int iters, int batch_size, int rows, int cols,
 | 
			
		||||
                         int depth, int num_threads, const string& label) {
 | 
			
		||||
  tensorflow::testing::StopTiming();
 | 
			
		||||
static void BM_ReluFloat(::testing::benchmark::State& state, int batch_size,
 | 
			
		||||
                         int rows, int cols, int depth, int num_threads,
 | 
			
		||||
                         const string& label) {
 | 
			
		||||
  std::unique_ptr<Device> device(
 | 
			
		||||
      DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0"));
 | 
			
		||||
 | 
			
		||||
@ -1233,27 +1247,25 @@ static void BM_ReluFloat(int iters, int batch_size, int rows, int cols,
 | 
			
		||||
  std::unique_ptr<OpKernelContext> relu_context(new OpKernelContext(¶ms));
 | 
			
		||||
 | 
			
		||||
  op->Compute(relu_context.get());
 | 
			
		||||
  testing::UseRealTime();
 | 
			
		||||
  tensorflow::testing::StartTiming();
 | 
			
		||||
  for (int i = 0; i < iters; ++i) {
 | 
			
		||||
  for (auto s : state) {
 | 
			
		||||
    delete relu_context->release_output(0).tensor;
 | 
			
		||||
    op->Compute(relu_context.get());
 | 
			
		||||
  }
 | 
			
		||||
  tensorflow::testing::StopTiming();
 | 
			
		||||
  testing::ItemsProcessed(relu_context->mutable_output(0)->NumElements() *
 | 
			
		||||
                          iters);
 | 
			
		||||
  testing::SetLabel(label);
 | 
			
		||||
  state.SetItemsProcessed(relu_context->mutable_output(0)->NumElements() *
 | 
			
		||||
                          state.iterations());
 | 
			
		||||
  state.SetLabel(label);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// BS: batch_size
 | 
			
		||||
// IR: input_rows
 | 
			
		||||
// IC: input_cols
 | 
			
		||||
// ND: node_depth
 | 
			
		||||
#define BM_Relu(BS, IR, IC, ND, TH, LABEL)                               \
 | 
			
		||||
  static void BM_ReluFloat_##BS##_##IR##_##IC##_##ND##_##TH(int iters) { \
 | 
			
		||||
    BM_ReluFloat(iters, BS, IR, IC, ND, TH, LABEL);                      \
 | 
			
		||||
  }                                                                      \
 | 
			
		||||
  BENCHMARK(BM_ReluFloat_##BS##_##IR##_##IC##_##ND##_##TH)
 | 
			
		||||
#define BM_Relu(BS, IR, IC, ND, TH, LABEL)                   \
 | 
			
		||||
  static void BM_ReluFloat_##BS##_##IR##_##IC##_##ND##_##TH( \
 | 
			
		||||
      ::testing::benchmark::State& state) {                  \
 | 
			
		||||
    BM_ReluFloat(state, BS, IR, IC, ND, TH, LABEL);          \
 | 
			
		||||
  }                                                          \
 | 
			
		||||
  BENCHMARK(BM_ReluFloat_##BS##_##IR##_##IC##_##ND##_##TH)->UseRealTime()
 | 
			
		||||
 | 
			
		||||
BM_Relu(32, 112, 112, 64, 1, "relu0");
 | 
			
		||||
BM_Relu(32, 56, 56, 192, 1, "relu1");
 | 
			
		||||
@ -1268,9 +1280,9 @@ BM_Relu(32, 14, 14, 576, 4, "relu10");
 | 
			
		||||
Softplus Op
 | 
			
		||||
Run benchmark with:
 | 
			
		||||
*/
 | 
			
		||||
static void BM_SoftplusFloat(int iters, int batch_size, int rows, int cols,
 | 
			
		||||
                             int depth, int num_threads, const string& label) {
 | 
			
		||||
  tensorflow::testing::StopTiming();
 | 
			
		||||
static void BM_SoftplusFloat(::testing::benchmark::State& state, int batch_size,
 | 
			
		||||
                             int rows, int cols, int depth, int num_threads,
 | 
			
		||||
                             const string& label) {
 | 
			
		||||
  std::unique_ptr<Device> device(
 | 
			
		||||
      DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0"));
 | 
			
		||||
 | 
			
		||||
@ -1307,27 +1319,25 @@ static void BM_SoftplusFloat(int iters, int batch_size, int rows, int cols,
 | 
			
		||||
      new OpKernelContext(¶ms));
 | 
			
		||||
 | 
			
		||||
  op->Compute(softplus_context.get());
 | 
			
		||||
  testing::UseRealTime();
 | 
			
		||||
  tensorflow::testing::StartTiming();
 | 
			
		||||
  for (int i = 0; i < iters; ++i) {
 | 
			
		||||
  for (auto s : state) {
 | 
			
		||||
    delete softplus_context->release_output(0).tensor;
 | 
			
		||||
    op->Compute(softplus_context.get());
 | 
			
		||||
  }
 | 
			
		||||
  tensorflow::testing::StopTiming();
 | 
			
		||||
  testing::ItemsProcessed(softplus_context->mutable_output(0)->NumElements() *
 | 
			
		||||
                          iters);
 | 
			
		||||
  testing::SetLabel(label);
 | 
			
		||||
  state.SetItemsProcessed(softplus_context->mutable_output(0)->NumElements() *
 | 
			
		||||
                          state.iterations());
 | 
			
		||||
  state.SetLabel(label);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// BS: batch_size
 | 
			
		||||
// IR: input_rows
 | 
			
		||||
// IC: input_cols
 | 
			
		||||
// ND: node_depth
 | 
			
		||||
#define BM_Softplus(BS, IR, IC, ND, TH, LABEL)                               \
 | 
			
		||||
  static void BM_SoftplusFloat_##BS##_##IR##_##IC##_##ND##_##TH(int iters) { \
 | 
			
		||||
    BM_SoftplusFloat(iters, BS, IR, IC, ND, TH, LABEL);                      \
 | 
			
		||||
  }                                                                          \
 | 
			
		||||
  BENCHMARK(BM_SoftplusFloat_##BS##_##IR##_##IC##_##ND##_##TH)
 | 
			
		||||
#define BM_Softplus(BS, IR, IC, ND, TH, LABEL)                   \
 | 
			
		||||
  static void BM_SoftplusFloat_##BS##_##IR##_##IC##_##ND##_##TH( \
 | 
			
		||||
      ::testing::benchmark::State& state) {                      \
 | 
			
		||||
    BM_SoftplusFloat(state, BS, IR, IC, ND, TH, LABEL);          \
 | 
			
		||||
  }                                                              \
 | 
			
		||||
  BENCHMARK(BM_SoftplusFloat_##BS##_##IR##_##IC##_##ND##_##TH)->UseRealTime()
 | 
			
		||||
 | 
			
		||||
BM_Softplus(32, 112, 112, 64, 1, "softplus0");
 | 
			
		||||
BM_Softplus(32, 56, 56, 192, 1, "softplus1");
 | 
			
		||||
@ -1338,7 +1348,8 @@ BM_Softplus(32, 56, 56, 192, 4, "softplus1");
 | 
			
		||||
BM_Softplus(32, 28, 28, 352, 4, "softplus4");
 | 
			
		||||
BM_Softplus(32, 14, 14, 576, 4, "softplus10");
 | 
			
		||||
 | 
			
		||||
static void BM_ImageNetSoftmaxFwd(int iters, int batch_size, int node_depth,
 | 
			
		||||
static void BM_ImageNetSoftmaxFwd(::testing::benchmark::State& state,
 | 
			
		||||
                                  int batch_size, int node_depth,
 | 
			
		||||
                                  int num_threads, bool use_gpu,
 | 
			
		||||
                                  const string& label) {
 | 
			
		||||
  auto root = Scope::NewRootScope().ExitOnError();
 | 
			
		||||
@ -1359,19 +1370,21 @@ static void BM_ImageNetSoftmaxFwd(int iters, int batch_size, int node_depth,
 | 
			
		||||
  opts.config.mutable_graph_options()
 | 
			
		||||
      ->mutable_optimizer_options()
 | 
			
		||||
      ->set_opt_level(OptimizerOptions::L0);
 | 
			
		||||
  testing::UseRealTime();
 | 
			
		||||
  test::Benchmark(device, g, &opts).Run(iters);
 | 
			
		||||
  testing::ItemsProcessed(batch_size * node_depth * iters);
 | 
			
		||||
  testing::SetLabel(label);
 | 
			
		||||
  test::Benchmark(device, g, &opts, nullptr, nullptr, "",
 | 
			
		||||
                  /*old_benchmark_api*/ false)
 | 
			
		||||
      .Run(state);
 | 
			
		||||
  state.SetItemsProcessed(batch_size * node_depth * state.iterations());
 | 
			
		||||
  state.SetLabel(label);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#define BM_ImageNetSoftmaxFwd(BATCH_SIZE, NODE_DEPTH, TH, GPU, LABEL)     \
 | 
			
		||||
  static void                                                             \
 | 
			
		||||
      BM_ImageNetSoftmaxFwd_##BATCH_SIZE##_##NODE_DEPTH##_##TH##_##GPU(   \
 | 
			
		||||
          int iters) {                                                    \
 | 
			
		||||
    BM_ImageNetSoftmaxFwd(iters, BATCH_SIZE, NODE_DEPTH, TH, GPU, LABEL); \
 | 
			
		||||
  }                                                                       \
 | 
			
		||||
  BENCHMARK(BM_ImageNetSoftmaxFwd_##BATCH_SIZE##_##NODE_DEPTH##_##TH##_##GPU)
 | 
			
		||||
#define BM_ImageNetSoftmaxFwd(BATCH_SIZE, NODE_DEPTH, TH, GPU, LABEL)         \
 | 
			
		||||
  static void                                                                 \
 | 
			
		||||
      BM_ImageNetSoftmaxFwd_##BATCH_SIZE##_##NODE_DEPTH##_##TH##_##GPU(       \
 | 
			
		||||
          ::testing::benchmark::State& state) {                               \
 | 
			
		||||
    BM_ImageNetSoftmaxFwd(state, BATCH_SIZE, NODE_DEPTH, TH, GPU, LABEL);     \
 | 
			
		||||
  }                                                                           \
 | 
			
		||||
  BENCHMARK(BM_ImageNetSoftmaxFwd_##BATCH_SIZE##_##NODE_DEPTH##_##TH##_##GPU) \
 | 
			
		||||
      ->UseRealTime()
 | 
			
		||||
 | 
			
		||||
// Labels are taken from the 2014-July-24 version of imagenet
 | 
			
		||||
BM_ImageNetSoftmaxFwd(32, 1008, 1, false, "softmax32");
 | 
			
		||||
@ -1383,9 +1396,8 @@ BM_ImageNetSoftmaxFwd(128, 1008, 1, true, "softmax128");
 | 
			
		||||
BM_ImageNetSoftmaxFwd(8192, 1024, 1, true, "softmax32");
 | 
			
		||||
BM_ImageNetSoftmaxFwd(8192, 32768, 1, true, "softmax128");
 | 
			
		||||
 | 
			
		||||
static void BM_TopK(int iters, int rows, int cols, int k, int num_threads,
 | 
			
		||||
                    bool use_gpu, const string& label) {
 | 
			
		||||
  testing::StopTiming();
 | 
			
		||||
static void BM_TopK(::testing::benchmark::State& state, int rows, int cols,
 | 
			
		||||
                    int k, int num_threads, bool use_gpu, const string& label) {
 | 
			
		||||
  auto root = Scope::NewRootScope().ExitOnError();
 | 
			
		||||
 | 
			
		||||
  Tensor input(DT_FLOAT, TensorShape({rows, cols}));
 | 
			
		||||
@ -1407,28 +1419,30 @@ static void BM_TopK(int iters, int rows, int cols, int k, int num_threads,
 | 
			
		||||
  opts.config.mutable_graph_options()
 | 
			
		||||
      ->mutable_optimizer_options()
 | 
			
		||||
      ->set_opt_level(OptimizerOptions::L0);
 | 
			
		||||
  testing::UseRealTime();
 | 
			
		||||
  testing::StartTiming();
 | 
			
		||||
  test::Benchmark(device, g, &opts).Run(iters);
 | 
			
		||||
  testing::ItemsProcessed(rows * cols * iters);
 | 
			
		||||
  testing::SetLabel(label);
 | 
			
		||||
  test::Benchmark(device, g, &opts, nullptr, nullptr, "",
 | 
			
		||||
                  /*old_benchmark_api=*/false)
 | 
			
		||||
      .Run(state);
 | 
			
		||||
  state.SetItemsProcessed(rows * cols * state.iterations());
 | 
			
		||||
  state.SetLabel(label);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// IR: input_rows
 | 
			
		||||
// IC: input_cols
 | 
			
		||||
// IK: k
 | 
			
		||||
// TH: number of threads
 | 
			
		||||
#define BM_TopKGPU(IR, IC, IK, TH, LABEL)                        \
 | 
			
		||||
  static void BM_TopK_GPU_##IR##_##IC##_##IK##_##TH(int iters) { \
 | 
			
		||||
    BM_TopK(iters, IR, IC, IK, TH, true, LABEL);                 \
 | 
			
		||||
  }                                                              \
 | 
			
		||||
  BENCHMARK(BM_TopK_GPU_##IR##_##IC##_##IK##_##TH)
 | 
			
		||||
#define BM_TopKGPU(IR, IC, IK, TH, LABEL)            \
 | 
			
		||||
  static void BM_TopK_GPU_##IR##_##IC##_##IK##_##TH( \
 | 
			
		||||
      ::testing::benchmark::State& state) {          \
 | 
			
		||||
    BM_TopK(state, IR, IC, IK, TH, true, LABEL);     \
 | 
			
		||||
  }                                                  \
 | 
			
		||||
  BENCHMARK(BM_TopK_GPU_##IR##_##IC##_##IK##_##TH)->UseRealTime()
 | 
			
		||||
 | 
			
		||||
#define BM_TopKCPU(IR, IC, IK, TH, LABEL)                        \
 | 
			
		||||
  static void BM_TopK_CPU_##IR##_##IC##_##IK##_##TH(int iters) { \
 | 
			
		||||
    BM_TopK(iters, IR, IC, IK, TH, false, LABEL);                \
 | 
			
		||||
  }                                                              \
 | 
			
		||||
  BENCHMARK(BM_TopK_CPU_##IR##_##IC##_##IK##_##TH)
 | 
			
		||||
#define BM_TopKCPU(IR, IC, IK, TH, LABEL)            \
 | 
			
		||||
  static void BM_TopK_CPU_##IR##_##IC##_##IK##_##TH( \
 | 
			
		||||
      ::testing::benchmark::State& state) {          \
 | 
			
		||||
    BM_TopK(state, IR, IC, IK, TH, false, LABEL);    \
 | 
			
		||||
  }                                                  \
 | 
			
		||||
  BENCHMARK(BM_TopK_CPU_##IR##_##IC##_##IK##_##TH)->UseRealTime()
 | 
			
		||||
 | 
			
		||||
// clang-format on
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -56,9 +56,13 @@ static Graph* OneHot(int batch_size, int num_classes, int axis) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#define BM_OneHot(BATCH, CLASS, AXIS, DEVICE)                                \
 | 
			
		||||
  static void BM_OneHot##_##BATCH##_##CLASS##_##AXIS##_##DEVICE(int iters) { \
 | 
			
		||||
    testing::ItemsProcessed(static_cast<int64>(iters) * BATCH * CLASS);      \
 | 
			
		||||
    test::Benchmark(#DEVICE, OneHot(BATCH, CLASS, AXIS)).Run(iters);         \
 | 
			
		||||
  static void BM_OneHot##_##BATCH##_##CLASS##_##AXIS##_##DEVICE(             \
 | 
			
		||||
      ::testing::benchmark::State& state) {                                  \
 | 
			
		||||
    test::Benchmark(#DEVICE, OneHot(BATCH, CLASS, AXIS),                     \
 | 
			
		||||
                    /*old_benchmark_api*/ false)                             \
 | 
			
		||||
        .Run(state);                                                         \
 | 
			
		||||
    state.SetItemsProcessed(static_cast<int64>(state.iterations()) * BATCH * \
 | 
			
		||||
                            CLASS);                                          \
 | 
			
		||||
  }                                                                          \
 | 
			
		||||
  BENCHMARK(BM_OneHot##_##BATCH##_##CLASS##_##AXIS##_##DEVICE);
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -107,25 +107,34 @@ static Graph* PTruncatedNormalOneTail(int num_batches, int samples_per_batch) {
 | 
			
		||||
  return g;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#define BM_PTruncatedNormalDev(DEVICE, B, S)                        \
 | 
			
		||||
  static void BM_PTruncatedNormal_##DEVICE##_##B##_##S(int iters) { \
 | 
			
		||||
    test::Benchmark(#DEVICE, PTruncatedNormal(B, S)).Run(iters);    \
 | 
			
		||||
    testing::ItemsProcessed(static_cast<int64>(B) * S * iters);     \
 | 
			
		||||
  }                                                                 \
 | 
			
		||||
#define BM_PTruncatedNormalDev(DEVICE, B, S)                                 \
 | 
			
		||||
  static void BM_PTruncatedNormal_##DEVICE##_##B##_##S(                      \
 | 
			
		||||
      ::testing::benchmark::State& state) {                                  \
 | 
			
		||||
    test::Benchmark(#DEVICE, PTruncatedNormal(B, S),                         \
 | 
			
		||||
                    /*old_benchmark_api*/ false)                             \
 | 
			
		||||
        .Run(state);                                                         \
 | 
			
		||||
    state.SetItemsProcessed(static_cast<int64>(B) * S * state.iterations()); \
 | 
			
		||||
  }                                                                          \
 | 
			
		||||
  BENCHMARK(BM_PTruncatedNormal_##DEVICE##_##B##_##S);
 | 
			
		||||
 | 
			
		||||
#define BM_PTruncatedNormalDev_2SD(DEVICE, B, S)                        \
 | 
			
		||||
  static void BM_PTruncatedNormal_2SD_##DEVICE##_##B##_##S(int iters) { \
 | 
			
		||||
    test::Benchmark(#DEVICE, PTruncatedNormal2SD(B, S)).Run(iters);     \
 | 
			
		||||
    testing::ItemsProcessed(static_cast<int64>(B) * S * iters);         \
 | 
			
		||||
  }                                                                     \
 | 
			
		||||
#define BM_PTruncatedNormalDev_2SD(DEVICE, B, S)                             \
 | 
			
		||||
  static void BM_PTruncatedNormal_2SD_##DEVICE##_##B##_##S(                  \
 | 
			
		||||
      ::testing::benchmark::State& state) {                                  \
 | 
			
		||||
    test::Benchmark(#DEVICE, PTruncatedNormal2SD(B, S),                      \
 | 
			
		||||
                    /*old_benchmark_api*/ false)                             \
 | 
			
		||||
        .Run(state);                                                         \
 | 
			
		||||
    state.SetItemsProcessed(static_cast<int64>(B) * S * state.iterations()); \
 | 
			
		||||
  }                                                                          \
 | 
			
		||||
  BENCHMARK(BM_PTruncatedNormal_2SD_##DEVICE##_##B##_##S);
 | 
			
		||||
 | 
			
		||||
#define BM_PTruncatedNormalDev_OneTail(DEVICE, B, S)                        \
 | 
			
		||||
  static void BM_PTruncatedNormal_OneTail_##DEVICE##_##B##_##S(int iters) { \
 | 
			
		||||
    test::Benchmark(#DEVICE, PTruncatedNormalOneTail(B, S)).Run(iters);     \
 | 
			
		||||
    testing::ItemsProcessed(static_cast<int64>(B) * S * iters);             \
 | 
			
		||||
  }                                                                         \
 | 
			
		||||
#define BM_PTruncatedNormalDev_OneTail(DEVICE, B, S)                         \
 | 
			
		||||
  static void BM_PTruncatedNormal_OneTail_##DEVICE##_##B##_##S(              \
 | 
			
		||||
      ::testing::benchmark::State& state) {                                  \
 | 
			
		||||
    test::Benchmark(#DEVICE, PTruncatedNormalOneTail(B, S),                  \
 | 
			
		||||
                    /*old_benchmark_api*/ false)                             \
 | 
			
		||||
        .Run(state);                                                         \
 | 
			
		||||
    state.SetItemsProcessed(static_cast<int64>(B) * S * state.iterations()); \
 | 
			
		||||
  }                                                                          \
 | 
			
		||||
  BENCHMARK(BM_PTruncatedNormal_OneTail_##DEVICE##_##B##_##S);
 | 
			
		||||
 | 
			
		||||
BM_PTruncatedNormalDev(cpu, 1000, 1000);
 | 
			
		||||
 | 
			
		||||
@ -759,15 +759,16 @@ TEST_F(QuantizeAndDequantizeTest, Invalid_range_given_V3) {
 | 
			
		||||
      << s;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#define BM_SIMPLE_QUAN_DEQUAN(DEVICE)                     \
 | 
			
		||||
  static void BM_SIMPLE_QUAN_DEQUAN_##DEVICE(int iters) { \
 | 
			
		||||
    auto root = Scope::NewRootScope().ExitOnError();      \
 | 
			
		||||
    ops::QuantizeAndDequantizeV2(root, -3.5, -3.5, -3.5); \
 | 
			
		||||
    TF_CHECK_OK(root.status());                           \
 | 
			
		||||
    Graph* g = new Graph(OpRegistry::Global());           \
 | 
			
		||||
    TF_CHECK_OK(root.ToGraph(g));                         \
 | 
			
		||||
    test::Benchmark(#DEVICE, g).Run(iters);               \
 | 
			
		||||
  }                                                       \
 | 
			
		||||
#define BM_SIMPLE_QUAN_DEQUAN(DEVICE)                                    \
 | 
			
		||||
  static void BM_SIMPLE_QUAN_DEQUAN_##DEVICE(                            \
 | 
			
		||||
      ::testing::benchmark::State& state) {                              \
 | 
			
		||||
    auto root = Scope::NewRootScope().ExitOnError();                     \
 | 
			
		||||
    ops::QuantizeAndDequantizeV2(root, -3.5, -3.5, -3.5);                \
 | 
			
		||||
    TF_CHECK_OK(root.status());                                          \
 | 
			
		||||
    Graph* g = new Graph(OpRegistry::Global());                          \
 | 
			
		||||
    TF_CHECK_OK(root.ToGraph(g));                                        \
 | 
			
		||||
    test::Benchmark(#DEVICE, g, /*old_benchmark_api*/ false).Run(state); \
 | 
			
		||||
  }                                                                      \
 | 
			
		||||
  BENCHMARK(BM_SIMPLE_QUAN_DEQUAN_##DEVICE);
 | 
			
		||||
 | 
			
		||||
BM_SIMPLE_QUAN_DEQUAN(cpu);
 | 
			
		||||
 | 
			
		||||
@ -248,9 +248,8 @@ void QuantizedConcatTest::TestSecondDim8Bit(float first_min, float first_max,
 | 
			
		||||
// If <same_limits> is true, then both concatenated dimensions have the same
 | 
			
		||||
// quantized range; otherwise, they are set to different values.
 | 
			
		||||
template <typename T>
 | 
			
		||||
static void ConcatHelper(int iters, int concat_dimension, bool same_limits,
 | 
			
		||||
                         int dim2) {
 | 
			
		||||
  testing::StopTiming();
 | 
			
		||||
static void ConcatHelper(::testing::benchmark::State& state,
 | 
			
		||||
                         int concat_dimension, bool same_limits, int dim2) {
 | 
			
		||||
  Graph* g = new Graph(OpRegistry::Global());
 | 
			
		||||
 | 
			
		||||
  DataType dt = DataTypeToEnum<T>::v();
 | 
			
		||||
@ -278,61 +277,111 @@ static void ConcatHelper(int iters, int concat_dimension, bool same_limits,
 | 
			
		||||
                  .Attr("T", dt)
 | 
			
		||||
                  .Finalize(g, &node));
 | 
			
		||||
 | 
			
		||||
  testing::BytesProcessed(static_cast<int64>(iters) *
 | 
			
		||||
  test::Benchmark("cpu", g, /*old_benchmark_api*/ false).Run(state);
 | 
			
		||||
  state.SetBytesProcessed(static_cast<int64>(state.iterations()) *
 | 
			
		||||
                          ((kDim1 * dim2) + (kDim1 * dim2)) * sizeof(T));
 | 
			
		||||
  testing::StartTiming();
 | 
			
		||||
  test::Benchmark("cpu", g).Run(iters);
 | 
			
		||||
  testing::UseRealTime();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static void BM_QConcatDim0SameLimitQInt32(int iters, int dim2) {
 | 
			
		||||
  ConcatHelper<qint32>(iters, 0 /* concat_dimension */, true /* same_limits */,
 | 
			
		||||
static void BM_QConcatDim0SameLimitQInt32(::testing::benchmark::State& state) {
 | 
			
		||||
  const int dim2 = state.range(0);
 | 
			
		||||
 | 
			
		||||
  ConcatHelper<qint32>(state, 0 /* concat_dimension */, true /* same_limits */,
 | 
			
		||||
                       dim2);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static void BM_QConcatDim1SameLimitQInt32(int iters, int dim2) {
 | 
			
		||||
  ConcatHelper<qint32>(iters, 1 /* concat_dimension */, true /* same_limits */,
 | 
			
		||||
static void BM_QConcatDim1SameLimitQInt32(::testing::benchmark::State& state) {
 | 
			
		||||
  const int dim2 = state.range(0);
 | 
			
		||||
 | 
			
		||||
  ConcatHelper<qint32>(state, 1 /* concat_dimension */, true /* same_limits */,
 | 
			
		||||
                       dim2);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static void BM_QConcatDim0DifferLimitQInt32(int iters, int dim2) {
 | 
			
		||||
  ConcatHelper<qint32>(iters, 0 /* concat_dimension */, false /* same_limits */,
 | 
			
		||||
static void BM_QConcatDim0DifferLimitQInt32(
 | 
			
		||||
    ::testing::benchmark::State& state) {
 | 
			
		||||
  const int dim2 = state.range(0);
 | 
			
		||||
 | 
			
		||||
  ConcatHelper<qint32>(state, 0 /* concat_dimension */, false /* same_limits */,
 | 
			
		||||
                       dim2);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static void BM_QConcatDim1DifferLimitQInt32(int iters, int dim2) {
 | 
			
		||||
  ConcatHelper<qint32>(iters, 1 /* concat_dimension */, false /* same_limits */,
 | 
			
		||||
static void BM_QConcatDim1DifferLimitQInt32(
 | 
			
		||||
    ::testing::benchmark::State& state) {
 | 
			
		||||
  const int dim2 = state.range(0);
 | 
			
		||||
 | 
			
		||||
  ConcatHelper<qint32>(state, 1 /* concat_dimension */, false /* same_limits */,
 | 
			
		||||
                       dim2);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
BENCHMARK(BM_QConcatDim0SameLimitQInt32)->Arg(1000)->Arg(20000)->Arg(100000);
 | 
			
		||||
BENCHMARK(BM_QConcatDim1SameLimitQInt32)->Arg(1000)->Arg(20000)->Arg(100000);
 | 
			
		||||
BENCHMARK(BM_QConcatDim0DifferLimitQInt32)->Arg(1000)->Arg(20000)->Arg(100000);
 | 
			
		||||
BENCHMARK(BM_QConcatDim1DifferLimitQInt32)->Arg(1000)->Arg(20000)->Arg(100000);
 | 
			
		||||
BENCHMARK(BM_QConcatDim0SameLimitQInt32)
 | 
			
		||||
    ->UseRealTime()
 | 
			
		||||
    ->Arg(1000)
 | 
			
		||||
    ->Arg(20000)
 | 
			
		||||
    ->Arg(100000);
 | 
			
		||||
BENCHMARK(BM_QConcatDim1SameLimitQInt32)
 | 
			
		||||
    ->UseRealTime()
 | 
			
		||||
    ->Arg(1000)
 | 
			
		||||
    ->Arg(20000)
 | 
			
		||||
    ->Arg(100000);
 | 
			
		||||
BENCHMARK(BM_QConcatDim0DifferLimitQInt32)
 | 
			
		||||
    ->UseRealTime()
 | 
			
		||||
    ->Arg(1000)
 | 
			
		||||
    ->Arg(20000)
 | 
			
		||||
    ->Arg(100000);
 | 
			
		||||
BENCHMARK(BM_QConcatDim1DifferLimitQInt32)
 | 
			
		||||
    ->UseRealTime()
 | 
			
		||||
    ->Arg(1000)
 | 
			
		||||
    ->Arg(20000)
 | 
			
		||||
    ->Arg(100000);
 | 
			
		||||
 | 
			
		||||
static void BM_QConcatDim0SameLimitQUint8(int iters, int dim2) {
 | 
			
		||||
  ConcatHelper<qint32>(iters, 0 /* concat_dimension */, true /* same_limits */,
 | 
			
		||||
static void BM_QConcatDim0SameLimitQUint8(::testing::benchmark::State& state) {
 | 
			
		||||
  const int dim2 = state.range(0);
 | 
			
		||||
 | 
			
		||||
  ConcatHelper<qint32>(state, 0 /* concat_dimension */, true /* same_limits */,
 | 
			
		||||
                       dim2);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static void BM_QConcatDim1SameLimitQUint8(int iters, int dim2) {
 | 
			
		||||
  ConcatHelper<qint32>(iters, 1 /* concat_dimension */, true /* same_limits */,
 | 
			
		||||
static void BM_QConcatDim1SameLimitQUint8(::testing::benchmark::State& state) {
 | 
			
		||||
  const int dim2 = state.range(0);
 | 
			
		||||
 | 
			
		||||
  ConcatHelper<qint32>(state, 1 /* concat_dimension */, true /* same_limits */,
 | 
			
		||||
                       dim2);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static void BM_QConcatDim0DifferLimitQUint8(int iters, int dim2) {
 | 
			
		||||
  ConcatHelper<qint32>(iters, 0 /* concat_dimension */, false /* same_limits */,
 | 
			
		||||
static void BM_QConcatDim0DifferLimitQUint8(
 | 
			
		||||
    ::testing::benchmark::State& state) {
 | 
			
		||||
  const int dim2 = state.range(0);
 | 
			
		||||
 | 
			
		||||
  ConcatHelper<qint32>(state, 0 /* concat_dimension */, false /* same_limits */,
 | 
			
		||||
                       dim2);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static void BM_QConcatDim1DifferLimitQUint8(int iters, int dim2) {
 | 
			
		||||
  ConcatHelper<qint32>(iters, 1 /* concat_dimension */, false /* same_limits */,
 | 
			
		||||
static void BM_QConcatDim1DifferLimitQUint8(
 | 
			
		||||
    ::testing::benchmark::State& state) {
 | 
			
		||||
  const int dim2 = state.range(0);
 | 
			
		||||
 | 
			
		||||
  ConcatHelper<qint32>(state, 1 /* concat_dimension */, false /* same_limits */,
 | 
			
		||||
                       dim2);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
BENCHMARK(BM_QConcatDim0SameLimitQUint8)->Arg(1000)->Arg(20000)->Arg(100000);
 | 
			
		||||
BENCHMARK(BM_QConcatDim1SameLimitQUint8)->Arg(1000)->Arg(20000)->Arg(100000);
 | 
			
		||||
BENCHMARK(BM_QConcatDim0DifferLimitQUint8)->Arg(1000)->Arg(20000)->Arg(100000);
 | 
			
		||||
BENCHMARK(BM_QConcatDim1DifferLimitQUint8)->Arg(1000)->Arg(20000)->Arg(100000);
 | 
			
		||||
BENCHMARK(BM_QConcatDim0SameLimitQUint8)
 | 
			
		||||
    ->UseRealTime()
 | 
			
		||||
    ->Arg(1000)
 | 
			
		||||
    ->Arg(20000)
 | 
			
		||||
    ->Arg(100000);
 | 
			
		||||
BENCHMARK(BM_QConcatDim1SameLimitQUint8)
 | 
			
		||||
    ->UseRealTime()
 | 
			
		||||
    ->Arg(1000)
 | 
			
		||||
    ->Arg(20000)
 | 
			
		||||
    ->Arg(100000);
 | 
			
		||||
BENCHMARK(BM_QConcatDim0DifferLimitQUint8)
 | 
			
		||||
    ->UseRealTime()
 | 
			
		||||
    ->Arg(1000)
 | 
			
		||||
    ->Arg(20000)
 | 
			
		||||
    ->Arg(100000);
 | 
			
		||||
BENCHMARK(BM_QConcatDim1DifferLimitQUint8)
 | 
			
		||||
    ->UseRealTime()
 | 
			
		||||
    ->Arg(1000)
 | 
			
		||||
    ->Arg(20000)
 | 
			
		||||
    ->Arg(100000);
 | 
			
		||||
 | 
			
		||||
}  // namespace tensorflow
 | 
			
		||||
 | 
			
		||||
@ -67,32 +67,44 @@ static Graph* RandomBinomialRejComplement(int num_batches,
 | 
			
		||||
  return RandomBinomialGraph(100., 0.2, num_batches, samples_per_batch);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#define BM_RandomBinomialInv(DEVICE, B, S)                           \
 | 
			
		||||
  static void BM_RandomBinomialInv_##DEVICE##_##B##_##S(int iters) { \
 | 
			
		||||
    test::Benchmark(#DEVICE, RandomBinomialInv(B, S)).Run(iters);    \
 | 
			
		||||
    testing::ItemsProcessed(static_cast<int64>(B) * S * iters);      \
 | 
			
		||||
  }                                                                  \
 | 
			
		||||
#define BM_RandomBinomialInv(DEVICE, B, S)                                   \
 | 
			
		||||
  static void BM_RandomBinomialInv_##DEVICE##_##B##_##S(                     \
 | 
			
		||||
      ::testing::benchmark::State& state) {                                  \
 | 
			
		||||
    test::Benchmark(#DEVICE, RandomBinomialInv(B, S),                        \
 | 
			
		||||
                    /*old_benchmark_api*/ false)                             \
 | 
			
		||||
        .Run(state);                                                         \
 | 
			
		||||
    state.SetItemsProcessed(static_cast<int64>(B) * S * state.iterations()); \
 | 
			
		||||
  }                                                                          \
 | 
			
		||||
  BENCHMARK(BM_RandomBinomialInv_##DEVICE##_##B##_##S);
 | 
			
		||||
 | 
			
		||||
#define BM_RandomBinomialRej(DEVICE, B, S)                           \
 | 
			
		||||
  static void BM_RandomBinomialRej_##DEVICE##_##B##_##S(int iters) { \
 | 
			
		||||
    test::Benchmark(#DEVICE, RandomBinomialRej(B, S)).Run(iters);    \
 | 
			
		||||
    testing::ItemsProcessed(static_cast<int64>(B) * S * iters);      \
 | 
			
		||||
  }                                                                  \
 | 
			
		||||
#define BM_RandomBinomialRej(DEVICE, B, S)                                   \
 | 
			
		||||
  static void BM_RandomBinomialRej_##DEVICE##_##B##_##S(                     \
 | 
			
		||||
      ::testing::benchmark::State& state) {                                  \
 | 
			
		||||
    test::Benchmark(#DEVICE, RandomBinomialRej(B, S),                        \
 | 
			
		||||
                    /*old_benchmark_api*/ false)                             \
 | 
			
		||||
        .Run(state);                                                         \
 | 
			
		||||
    state.SetItemsProcessed(static_cast<int64>(B) * S * state.iterations()); \
 | 
			
		||||
  }                                                                          \
 | 
			
		||||
  BENCHMARK(BM_RandomBinomialRej_##DEVICE##_##B##_##S);
 | 
			
		||||
 | 
			
		||||
#define BM_RandomBinomialInvComplement(DEVICE, B, S)                           \
 | 
			
		||||
  static void BM_RandomBinomialInvComplement_##DEVICE##_##B##_##S(int iters) { \
 | 
			
		||||
    test::Benchmark(#DEVICE, RandomBinomialInvComplement(B, S)).Run(iters);    \
 | 
			
		||||
    testing::ItemsProcessed(static_cast<int64>(B) * S * iters);                \
 | 
			
		||||
  }                                                                            \
 | 
			
		||||
#define BM_RandomBinomialInvComplement(DEVICE, B, S)                         \
 | 
			
		||||
  static void BM_RandomBinomialInvComplement_##DEVICE##_##B##_##S(           \
 | 
			
		||||
      ::testing::benchmark::State& state) {                                  \
 | 
			
		||||
    test::Benchmark(#DEVICE, RandomBinomialInvComplement(B, S),              \
 | 
			
		||||
                    /*old_benchmark_api*/ false)                             \
 | 
			
		||||
        .Run(state);                                                         \
 | 
			
		||||
    state.SetItemsProcessed(static_cast<int64>(B) * S * state.iterations()); \
 | 
			
		||||
  }                                                                          \
 | 
			
		||||
  BENCHMARK(BM_RandomBinomialInvComplement_##DEVICE##_##B##_##S);
 | 
			
		||||
 | 
			
		||||
#define BM_RandomBinomialRejComplement(DEVICE, B, S)                           \
 | 
			
		||||
  static void BM_RandomBinomialRejComplement_##DEVICE##_##B##_##S(int iters) { \
 | 
			
		||||
    test::Benchmark(#DEVICE, RandomBinomialRejComplement(B, S)).Run(iters);    \
 | 
			
		||||
    testing::ItemsProcessed(static_cast<int64>(B) * S * iters);                \
 | 
			
		||||
  }                                                                            \
 | 
			
		||||
#define BM_RandomBinomialRejComplement(DEVICE, B, S)                         \
 | 
			
		||||
  static void BM_RandomBinomialRejComplement_##DEVICE##_##B##_##S(           \
 | 
			
		||||
      ::testing::benchmark::State& state) {                                  \
 | 
			
		||||
    test::Benchmark(#DEVICE, RandomBinomialRejComplement(B, S),              \
 | 
			
		||||
                    /*old_benchmark_api*/ false)                             \
 | 
			
		||||
        .Run(state);                                                         \
 | 
			
		||||
    state.SetItemsProcessed(static_cast<int64>(B) * S * state.iterations()); \
 | 
			
		||||
  }                                                                          \
 | 
			
		||||
  BENCHMARK(BM_RandomBinomialRejComplement_##DEVICE##_##B##_##S);
 | 
			
		||||
 | 
			
		||||
BM_RandomBinomialInv(cpu, 1000, 1000);
 | 
			
		||||
 | 
			
		||||
@ -58,11 +58,14 @@ Graph* TruncatedNormal(int64 n) {
 | 
			
		||||
  return g;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#define BM_RNG(DEVICE, RNG)                                   \
 | 
			
		||||
  void BM_##DEVICE##_##RNG(int iters, int arg) {              \
 | 
			
		||||
    testing::ItemsProcessed(static_cast<int64>(iters) * arg); \
 | 
			
		||||
    test::Benchmark(#DEVICE, RNG(arg)).Run(iters);            \
 | 
			
		||||
  }                                                           \
 | 
			
		||||
#define BM_RNG(DEVICE, RNG)                                                \
 | 
			
		||||
  void BM_##DEVICE##_##RNG(::testing::benchmark::State& state) {           \
 | 
			
		||||
    const int arg = state.range(0);                                        \
 | 
			
		||||
                                                                           \
 | 
			
		||||
    test::Benchmark(#DEVICE, RNG(arg), /*old_benchmark_api*/ false)        \
 | 
			
		||||
        .Run(state);                                                       \
 | 
			
		||||
    state.SetItemsProcessed(static_cast<int64>(state.iterations()) * arg); \
 | 
			
		||||
  }                                                                        \
 | 
			
		||||
  BENCHMARK(BM_##DEVICE##_##RNG)->Range(1 << 20, 8 << 20);
 | 
			
		||||
 | 
			
		||||
BM_RNG(cpu, RandomUniform);
 | 
			
		||||
@ -84,60 +87,48 @@ Tensor VecAlphas(int64 n) {
 | 
			
		||||
  return alphas;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void BM_cpu_RandomGamma(int iters, int nsamp, int nalpha) {
 | 
			
		||||
  testing::ItemsProcessed(static_cast<int64>(iters) * nsamp * nalpha);
 | 
			
		||||
void BM_cpu_RandomGamma(::testing::benchmark::State& state) {
 | 
			
		||||
  const int nsamp = state.range(0);
 | 
			
		||||
  const int nalpha = state.range(1);
 | 
			
		||||
 | 
			
		||||
  Graph* g = new Graph(OpRegistry::Global());
 | 
			
		||||
  test::graph::RandomGamma(g, test::graph::Constant(g, VecShape(nsamp)),
 | 
			
		||||
                           test::graph::Constant(g, VecAlphas(nalpha)));
 | 
			
		||||
  test::Benchmark("cpu", g).Run(iters);
 | 
			
		||||
  test::Benchmark("cpu", g, /*old_benchmark_api*/ false).Run(state);
 | 
			
		||||
  state.SetItemsProcessed(static_cast<int64>(state.iterations()) * nsamp *
 | 
			
		||||
                          nalpha);
 | 
			
		||||
}
 | 
			
		||||
BENCHMARK(BM_cpu_RandomGamma)->RangePair(1 << 14, 4 << 15, 2, 50);
 | 
			
		||||
 | 
			
		||||
void BM_PhiloxRandom(int iters) {
 | 
			
		||||
void BM_PhiloxRandom(::testing::benchmark::State& state) {
 | 
			
		||||
  // Fill 2M random numbers
 | 
			
		||||
  int count = 2 << 20;
 | 
			
		||||
 | 
			
		||||
  testing::ItemsProcessed(static_cast<int64>(iters) * count);
 | 
			
		||||
 | 
			
		||||
  random::PhiloxRandom gen(0x12345);
 | 
			
		||||
 | 
			
		||||
  int val = 1;
 | 
			
		||||
  for (int i = 0; i < iters; ++i) {
 | 
			
		||||
  for (auto s : state) {
 | 
			
		||||
    for (int j = 0; j < count; j += 4) {
 | 
			
		||||
      /// each invocation of gen() returns 128-bit samples
 | 
			
		||||
      auto samples = gen();
 | 
			
		||||
 | 
			
		||||
      // use the result trivially so the compiler does not optimize it away
 | 
			
		||||
      val ^= samples[0] ^ samples[1] ^ samples[2] ^ samples[3];
 | 
			
		||||
      tensorflow::testing::DoNotOptimize(samples);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // A anchor point to make sure the compiler does not cut corners
 | 
			
		||||
  CHECK(val) << val;
 | 
			
		||||
  state.SetItemsProcessed(static_cast<int64>(state.iterations()) * count);
 | 
			
		||||
}
 | 
			
		||||
BENCHMARK(BM_PhiloxRandom);
 | 
			
		||||
 | 
			
		||||
void BM_StdMTRandom(int iters) {
 | 
			
		||||
void BM_StdMTRandom(::testing::benchmark::State& state) {
 | 
			
		||||
  // Fill 2M random numbers
 | 
			
		||||
  int count = 2 << 20;
 | 
			
		||||
 | 
			
		||||
  testing::ItemsProcessed(static_cast<int64>(iters) * count);
 | 
			
		||||
 | 
			
		||||
  std::mt19937 gen(0x12345);
 | 
			
		||||
 | 
			
		||||
  uint_fast32_t val = 1;
 | 
			
		||||
  for (int i = 0; i < iters; ++i) {
 | 
			
		||||
  for (auto s : state) {
 | 
			
		||||
    for (int j = 0; j < count; ++j) {
 | 
			
		||||
      /// each invocation of gen() returns 32-bit sample
 | 
			
		||||
      uint_fast32_t sample = gen();
 | 
			
		||||
 | 
			
		||||
      // use the result trivially so the compiler does not optimize it away
 | 
			
		||||
      val ^= sample;
 | 
			
		||||
      tensorflow::testing::DoNotOptimize(sample);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // A anchor point to make sure the compiler does not cut corners
 | 
			
		||||
  CHECK(val) << val;
 | 
			
		||||
  state.SetItemsProcessed(static_cast<int64>(state.iterations()) * count);
 | 
			
		||||
}
 | 
			
		||||
BENCHMARK(BM_StdMTRandom);
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -84,108 +84,167 @@ static Graph* ThreeDXZReduce(const string& reduce, int num_y, int num_z) {
 | 
			
		||||
// Creates a bench which reduces a 3D tensor with total "num" floats
 | 
			
		||||
// into a scalar on a "device". Runs the bench for "iters" times.
 | 
			
		||||
template <typename T>
 | 
			
		||||
static void ReduceToScalar(int iters, const string& device,
 | 
			
		||||
                           const string& reduce, int num_x, int num_y) {
 | 
			
		||||
  testing::ItemsProcessed(static_cast<int64>(iters) * num_x * num_y);
 | 
			
		||||
  testing::BytesProcessed(static_cast<int64>(iters) * num_x * num_y *
 | 
			
		||||
                          sizeof(T));
 | 
			
		||||
  test::Benchmark(device, ToScalar<T>(reduce, num_x, num_y)).Run(iters);
 | 
			
		||||
static void ReduceToScalar(::testing::benchmark::State& state,
 | 
			
		||||
                           const string& device, const string& reduce,
 | 
			
		||||
                           int num_x, int num_y) {
 | 
			
		||||
  test::Benchmark(device, ToScalar<T>(reduce, num_x, num_y),
 | 
			
		||||
                  /*old_benchmark_api*/ false)
 | 
			
		||||
      .Run(state);
 | 
			
		||||
  state.SetItemsProcessed(static_cast<int64>(state.iterations()) * num_x *
 | 
			
		||||
                          num_y);
 | 
			
		||||
  state.SetBytesProcessed(static_cast<int64>(state.iterations()) * num_x *
 | 
			
		||||
                          num_y * sizeof(T));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static void DoRowReduce(int iters, const string& device, const string& reduce,
 | 
			
		||||
                        int num_x, int num_y) {
 | 
			
		||||
  testing::ItemsProcessed(static_cast<int64>(iters) * num_x * num_y);
 | 
			
		||||
  testing::BytesProcessed(static_cast<int64>(iters) * num_x * num_y *
 | 
			
		||||
                          sizeof(float));
 | 
			
		||||
  test::Benchmark(device, RowReduce(reduce, num_x, num_y)).Run(iters);
 | 
			
		||||
static void DoRowReduce(::testing::benchmark::State& state,
 | 
			
		||||
                        const string& device, const string& reduce, int num_x,
 | 
			
		||||
                        int num_y) {
 | 
			
		||||
  test::Benchmark(device, RowReduce(reduce, num_x, num_y),
 | 
			
		||||
                  /*old_benchmark_api*/ false)
 | 
			
		||||
      .Run(state);
 | 
			
		||||
  state.SetItemsProcessed(static_cast<int64>(state.iterations()) * num_x *
 | 
			
		||||
                          num_y);
 | 
			
		||||
  state.SetBytesProcessed(static_cast<int64>(state.iterations()) * num_x *
 | 
			
		||||
                          num_y * sizeof(float));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static void DoColReduce(int iters, const string& device, const string& reduce,
 | 
			
		||||
                        int num_x, int num_y) {
 | 
			
		||||
  testing::ItemsProcessed(static_cast<int64>(iters) * num_x * num_y);
 | 
			
		||||
  testing::BytesProcessed(static_cast<int64>(iters) * num_x * num_y *
 | 
			
		||||
                          sizeof(float));
 | 
			
		||||
  test::Benchmark(device, ColReduce(reduce, num_x, num_y)).Run(iters);
 | 
			
		||||
static void DoColReduce(::testing::benchmark::State& state,
 | 
			
		||||
                        const string& device, const string& reduce, int num_x,
 | 
			
		||||
                        int num_y) {
 | 
			
		||||
  test::Benchmark(device, ColReduce(reduce, num_x, num_y),
 | 
			
		||||
                  /*old_benchmark_api*/ false)
 | 
			
		||||
      .Run(state);
 | 
			
		||||
  state.SetItemsProcessed(static_cast<int64>(state.iterations()) * num_x *
 | 
			
		||||
                          num_y);
 | 
			
		||||
  state.SetBytesProcessed(static_cast<int64>(state.iterations()) * num_x *
 | 
			
		||||
                          num_y * sizeof(float));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static void Do3DYReduce(int iters, const string& device, const string& reduce,
 | 
			
		||||
                        int num_x, int num_y) {
 | 
			
		||||
  testing::ItemsProcessed(static_cast<int64>(iters) * num_x * num_y);
 | 
			
		||||
  testing::BytesProcessed(static_cast<int64>(iters) * num_x * num_y *
 | 
			
		||||
                          sizeof(float));
 | 
			
		||||
  test::Benchmark(device, ThreeDYReduce(reduce, num_x, num_y)).Run(iters);
 | 
			
		||||
static void Do3DYReduce(::testing::benchmark::State& state,
 | 
			
		||||
                        const string& device, const string& reduce, int num_x,
 | 
			
		||||
                        int num_y) {
 | 
			
		||||
  test::Benchmark(device, ThreeDYReduce(reduce, num_x, num_y),
 | 
			
		||||
                  /*old_benchmark_api*/ false)
 | 
			
		||||
      .Run(state);
 | 
			
		||||
  state.SetItemsProcessed(static_cast<int64>(state.iterations()) * num_x *
 | 
			
		||||
                          num_y);
 | 
			
		||||
  state.SetBytesProcessed(static_cast<int64>(state.iterations()) * num_x *
 | 
			
		||||
                          num_y * sizeof(float));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static void Do3DXZReduce(int iters, const string& device, const string& reduce,
 | 
			
		||||
                         int num_x, int num_y) {
 | 
			
		||||
  testing::ItemsProcessed(static_cast<int64>(iters) * num_x * num_y);
 | 
			
		||||
  testing::BytesProcessed(static_cast<int64>(iters) * num_x * num_y *
 | 
			
		||||
                          sizeof(float));
 | 
			
		||||
  test::Benchmark(device, ThreeDXZReduce(reduce, num_x, num_y)).Run(iters);
 | 
			
		||||
static void Do3DXZReduce(::testing::benchmark::State& state,
 | 
			
		||||
                         const string& device, const string& reduce, int num_x,
 | 
			
		||||
                         int num_y) {
 | 
			
		||||
  test::Benchmark(device, ThreeDXZReduce(reduce, num_x, num_y),
 | 
			
		||||
                  /*old_benchmark_api*/ false)
 | 
			
		||||
      .Run(state);
 | 
			
		||||
  state.SetItemsProcessed(static_cast<int64>(state.iterations()) * num_x *
 | 
			
		||||
                          num_y);
 | 
			
		||||
  state.SetBytesProcessed(static_cast<int64>(state.iterations()) * num_x *
 | 
			
		||||
                          num_y * sizeof(float));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static void BM_Sum2DToScalarGPU(int iters, int num_x, int num_y) {
 | 
			
		||||
  ReduceToScalar<float>(iters, "gpu", "Sum", num_x, num_y);
 | 
			
		||||
static void BM_Sum2DToScalarGPU(::testing::benchmark::State& state) {
 | 
			
		||||
  const int num_x = state.range(0);
 | 
			
		||||
  const int num_y = state.range(1);
 | 
			
		||||
 | 
			
		||||
  ReduceToScalar<float>(state, "gpu", "Sum", num_x, num_y);
 | 
			
		||||
}
 | 
			
		||||
BENCHMARK(BM_Sum2DToScalarGPU)->RangePair(1, 8192, 1, 8192);
 | 
			
		||||
 | 
			
		||||
static void BM_Sum2DToScalarGPUComplex(int iters, int num_x, int num_y) {
 | 
			
		||||
  ReduceToScalar<std::complex<float>>(iters, "gpu", "Sum", num_x, num_y);
 | 
			
		||||
static void BM_Sum2DToScalarGPUComplex(::testing::benchmark::State& state) {
 | 
			
		||||
  const int num_x = state.range(0);
 | 
			
		||||
  const int num_y = state.range(1);
 | 
			
		||||
 | 
			
		||||
  ReduceToScalar<std::complex<float>>(state, "gpu", "Sum", num_x, num_y);
 | 
			
		||||
}
 | 
			
		||||
BENCHMARK(BM_Sum2DToScalarGPUComplex)->RangePair(1, 8192, 1, 8192);
 | 
			
		||||
 | 
			
		||||
static void BM_Sum2DToScalarGPUHalf(int iters, int num_x, int num_y) {
 | 
			
		||||
  ReduceToScalar<Eigen::half>(iters, "gpu", "Sum", num_x, num_y);
 | 
			
		||||
static void BM_Sum2DToScalarGPUHalf(::testing::benchmark::State& state) {
 | 
			
		||||
  const int num_x = state.range(0);
 | 
			
		||||
  const int num_y = state.range(1);
 | 
			
		||||
 | 
			
		||||
  ReduceToScalar<Eigen::half>(state, "gpu", "Sum", num_x, num_y);
 | 
			
		||||
}
 | 
			
		||||
BENCHMARK(BM_Sum2DToScalarGPUHalf)->RangePair(1, 8192, 1, 8192);
 | 
			
		||||
 | 
			
		||||
static void BM_Sum2DRowReduceGPU(int iters, int num_x, int num_y) {
 | 
			
		||||
  DoRowReduce(iters, "gpu", "Sum", num_x, num_y);
 | 
			
		||||
static void BM_Sum2DRowReduceGPU(::testing::benchmark::State& state) {
 | 
			
		||||
  const int num_x = state.range(0);
 | 
			
		||||
  const int num_y = state.range(1);
 | 
			
		||||
 | 
			
		||||
  DoRowReduce(state, "gpu", "Sum", num_x, num_y);
 | 
			
		||||
}
 | 
			
		||||
BENCHMARK(BM_Sum2DRowReduceGPU)->RangePair(1, 8192, 1, 8192);
 | 
			
		||||
 | 
			
		||||
static void BM_Sum2DColumnReduceGPU(int iters, int num_x, int num_y) {
 | 
			
		||||
  DoColReduce(iters, "gpu", "Sum", num_x, num_y);
 | 
			
		||||
static void BM_Sum2DColumnReduceGPU(::testing::benchmark::State& state) {
 | 
			
		||||
  const int num_x = state.range(0);
 | 
			
		||||
  const int num_y = state.range(1);
 | 
			
		||||
 | 
			
		||||
  DoColReduce(state, "gpu", "Sum", num_x, num_y);
 | 
			
		||||
}
 | 
			
		||||
BENCHMARK(BM_Sum2DColumnReduceGPU)->RangePair(1, 8192, 1, 8192);
 | 
			
		||||
 | 
			
		||||
static void BM_Sum3DYReduceGPU(int iters, int num_x, int num_y) {
 | 
			
		||||
  Do3DYReduce(iters, "gpu", "Sum", num_x, num_y);
 | 
			
		||||
static void BM_Sum3DYReduceGPU(::testing::benchmark::State& state) {
 | 
			
		||||
  const int num_x = state.range(0);
 | 
			
		||||
  const int num_y = state.range(1);
 | 
			
		||||
 | 
			
		||||
  Do3DYReduce(state, "gpu", "Sum", num_x, num_y);
 | 
			
		||||
}
 | 
			
		||||
BENCHMARK(BM_Sum3DYReduceGPU)->RangePair(64, 4096, 64, 4096);
 | 
			
		||||
 | 
			
		||||
static void BM_Sum3DXZReduceGPU(int iters, int num_x, int num_y) {
 | 
			
		||||
  Do3DXZReduce(iters, "gpu", "Sum", num_x, num_y);
 | 
			
		||||
static void BM_Sum3DXZReduceGPU(::testing::benchmark::State& state) {
 | 
			
		||||
  const int num_x = state.range(0);
 | 
			
		||||
  const int num_y = state.range(1);
 | 
			
		||||
 | 
			
		||||
  Do3DXZReduce(state, "gpu", "Sum", num_x, num_y);
 | 
			
		||||
}
 | 
			
		||||
BENCHMARK(BM_Sum3DXZReduceGPU)->RangePair(64, 4096, 64, 4096);
 | 
			
		||||
 | 
			
		||||
static void BM_Mean2DToScalarGPU(int iters, int num_x, int num_y) {
 | 
			
		||||
  ReduceToScalar<float>(iters, "gpu", "Mean", num_x, num_y);
 | 
			
		||||
static void BM_Mean2DToScalarGPU(::testing::benchmark::State& state) {
 | 
			
		||||
  const int num_x = state.range(0);
 | 
			
		||||
  const int num_y = state.range(1);
 | 
			
		||||
 | 
			
		||||
  ReduceToScalar<float>(state, "gpu", "Mean", num_x, num_y);
 | 
			
		||||
}
 | 
			
		||||
BENCHMARK(BM_Mean2DToScalarGPU)->RangePair(2048, 8192, 2048, 8192);
 | 
			
		||||
 | 
			
		||||
static void BM_EuclideanNorm2DToScalarGPU(int iters, int num_x, int num_y) {
 | 
			
		||||
  ReduceToScalar<float>(iters, "gpu", "EuclideanNorm", num_x, num_y);
 | 
			
		||||
static void BM_EuclideanNorm2DToScalarGPU(::testing::benchmark::State& state) {
 | 
			
		||||
  const int num_x = state.range(0);
 | 
			
		||||
  const int num_y = state.range(1);
 | 
			
		||||
 | 
			
		||||
  ReduceToScalar<float>(state, "gpu", "EuclideanNorm", num_x, num_y);
 | 
			
		||||
}
 | 
			
		||||
BENCHMARK(BM_EuclideanNorm2DToScalarGPU)->RangePair(2048, 8192, 2048, 8192);
 | 
			
		||||
 | 
			
		||||
static void BM_Max2DToScalarGPU(int iters, int num_x, int num_y) {
 | 
			
		||||
  ReduceToScalar<float>(iters, "gpu", "Max", num_x, num_y);
 | 
			
		||||
static void BM_Max2DToScalarGPU(::testing::benchmark::State& state) {
 | 
			
		||||
  const int num_x = state.range(0);
 | 
			
		||||
  const int num_y = state.range(1);
 | 
			
		||||
 | 
			
		||||
  ReduceToScalar<float>(state, "gpu", "Max", num_x, num_y);
 | 
			
		||||
}
 | 
			
		||||
BENCHMARK(BM_Max2DToScalarGPU)->RangePair(2048, 8192, 2048, 8192);
 | 
			
		||||
 | 
			
		||||
static void BM_Min2DToScalarGPU(int iters, int num_x, int num_y) {
 | 
			
		||||
  ReduceToScalar<float>(iters, "gpu", "Min", num_x, num_y);
 | 
			
		||||
static void BM_Min2DToScalarGPU(::testing::benchmark::State& state) {
 | 
			
		||||
  const int num_x = state.range(0);
 | 
			
		||||
  const int num_y = state.range(1);
 | 
			
		||||
 | 
			
		||||
  ReduceToScalar<float>(state, "gpu", "Min", num_x, num_y);
 | 
			
		||||
}
 | 
			
		||||
BENCHMARK(BM_Min2DToScalarGPU)->RangePair(2048, 8192, 2048, 8192);
 | 
			
		||||
 | 
			
		||||
static void BM_Min2DToScalarGPUHalf(int iters, int num_x, int num_y) {
 | 
			
		||||
  ReduceToScalar<Eigen::half>(iters, "gpu", "Min", num_x, num_y);
 | 
			
		||||
static void BM_Min2DToScalarGPUHalf(::testing::benchmark::State& state) {
 | 
			
		||||
  const int num_x = state.range(0);
 | 
			
		||||
  const int num_y = state.range(1);
 | 
			
		||||
 | 
			
		||||
  ReduceToScalar<Eigen::half>(state, "gpu", "Min", num_x, num_y);
 | 
			
		||||
}
 | 
			
		||||
BENCHMARK(BM_Min2DToScalarGPUHalf)->RangePair(2048, 8192, 2048, 8192);
 | 
			
		||||
 | 
			
		||||
static void BM_Bool2DToScalarGPU(int iters, int num_x, int num_y) {
 | 
			
		||||
  ReduceToScalar<bool>(iters, "gpu", "All", num_x, num_y);
 | 
			
		||||
static void BM_Bool2DToScalarGPU(::testing::benchmark::State& state) {
 | 
			
		||||
  const int num_x = state.range(0);
 | 
			
		||||
  const int num_y = state.range(1);
 | 
			
		||||
 | 
			
		||||
  ReduceToScalar<bool>(state, "gpu", "All", num_x, num_y);
 | 
			
		||||
}
 | 
			
		||||
BENCHMARK(BM_Bool2DToScalarGPU)->RangePair(2048, 8192, 2048, 8192);
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -84,17 +84,17 @@ Graph* SetupRegexReplaceGraph(const Tensor& input, const string& input_pattern,
 | 
			
		||||
  return g;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void BM_RegexReplace(int iters, int batch_size) {
 | 
			
		||||
  testing::StopTiming();
 | 
			
		||||
  testing::ItemsProcessed(static_cast<int64>(iters));
 | 
			
		||||
  testing::UseRealTime();
 | 
			
		||||
static void BM_RegexReplace(::testing::benchmark::State& state) {
 | 
			
		||||
  const int batch_size = state.range(0);
 | 
			
		||||
 | 
			
		||||
  Tensor input = GetTestTensor(batch_size);
 | 
			
		||||
  Graph* g = SetupRegexReplaceGraph(input, kRegExPattern, kRewrite);
 | 
			
		||||
  testing::StartTiming();
 | 
			
		||||
  test::Benchmark("cpu", g).Run(iters);
 | 
			
		||||
  test::Benchmark("cpu", g, /*old_benchmark_api*/ false).Run(state);
 | 
			
		||||
  state.SetItemsProcessed(static_cast<int64>(state.iterations()));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
BENCHMARK(BM_RegexReplace)
 | 
			
		||||
    ->UseRealTime()
 | 
			
		||||
    ->Arg(1)
 | 
			
		||||
    ->Arg(8)
 | 
			
		||||
    ->Arg(16)
 | 
			
		||||
@ -115,17 +115,17 @@ Graph* SetupStaticGraph(const Tensor& input, const string& input_pattern,
 | 
			
		||||
                  .Finalize(g, nullptr /* node */));
 | 
			
		||||
  return g;
 | 
			
		||||
}
 | 
			
		||||
void BM_StaticRegexReplace(int iters, int batch_size) {
 | 
			
		||||
  testing::StopTiming();
 | 
			
		||||
  testing::ItemsProcessed(static_cast<int64>(iters));
 | 
			
		||||
  testing::UseRealTime();
 | 
			
		||||
static void BM_StaticRegexReplace(::testing::benchmark::State& state) {
 | 
			
		||||
  const int batch_size = state.range(0);
 | 
			
		||||
 | 
			
		||||
  Tensor input = GetTestTensor(batch_size);
 | 
			
		||||
  Graph* g = SetupStaticGraph(input, kRegExPattern, kRewrite);
 | 
			
		||||
  testing::StartTiming();
 | 
			
		||||
  test::Benchmark("cpu", g).Run(iters);
 | 
			
		||||
  test::Benchmark("cpu", g, /*old_benchmark_api*/ false).Run(state);
 | 
			
		||||
  state.SetItemsProcessed(static_cast<int64>(state.iterations()));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
BENCHMARK(BM_StaticRegexReplace)
 | 
			
		||||
    ->UseRealTime()
 | 
			
		||||
    ->Arg(1)
 | 
			
		||||
    ->Arg(8)
 | 
			
		||||
    ->Arg(16)
 | 
			
		||||
 | 
			
		||||
@ -67,56 +67,29 @@ TEST_F(RequantizationRangeTest, HandCrafted) {
 | 
			
		||||
  test::ExpectTensorEqual<float>(expected_max, *GetOutput(1));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static void BM_RequantizationRange(int iters, int size) {
 | 
			
		||||
  testing::StopTiming();
 | 
			
		||||
  testing::UseRealTime();
 | 
			
		||||
  testing::ItemsProcessed(static_cast<int64>(iters) * size);
 | 
			
		||||
  testing::ItemsProcessed(static_cast<int64>(iters) * size * 4);
 | 
			
		||||
static void BM_RequantizationRange(::testing::benchmark::State& state) {
 | 
			
		||||
  const int size = state.range(0);
 | 
			
		||||
 | 
			
		||||
  Tensor quantized_tensor(DT_QINT32, TensorShape({1, size}));
 | 
			
		||||
  test::FillFn<qint32>(&quantized_tensor, [](int n) { return qint32(n); });
 | 
			
		||||
 | 
			
		||||
  qint32 actual_min;
 | 
			
		||||
  qint32 actual_max;
 | 
			
		||||
  testing::StartTiming();
 | 
			
		||||
  for (int iter = 0; iter < iters; ++iter) {
 | 
			
		||||
  for (auto s : state) {
 | 
			
		||||
    CalculateUsedRange(quantized_tensor, &actual_min, &actual_max);
 | 
			
		||||
  }
 | 
			
		||||
  state.SetItemsProcessed(static_cast<int64>(state.iterations()) * size);
 | 
			
		||||
  state.SetBytesProcessed(static_cast<int64>(state.iterations()) * size * 4);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static void BM_RequantizationRange100(int iters) {
 | 
			
		||||
  BM_RequantizationRange(100, iters);
 | 
			
		||||
}
 | 
			
		||||
BENCHMARK(BM_RequantizationRange100);
 | 
			
		||||
 | 
			
		||||
static void BM_RequantizationRange1000(int iters) {
 | 
			
		||||
  BM_RequantizationRange(1000, iters);
 | 
			
		||||
}
 | 
			
		||||
BENCHMARK(BM_RequantizationRange1000);
 | 
			
		||||
 | 
			
		||||
static void BM_RequantizationRange10000(int iters) {
 | 
			
		||||
  BM_RequantizationRange(10000, iters);
 | 
			
		||||
}
 | 
			
		||||
BENCHMARK(BM_RequantizationRange10000);
 | 
			
		||||
 | 
			
		||||
static void BM_RequantizationRange100000(int iters) {
 | 
			
		||||
  BM_RequantizationRange(100000, iters);
 | 
			
		||||
}
 | 
			
		||||
BENCHMARK(BM_RequantizationRange100000);
 | 
			
		||||
 | 
			
		||||
static void BM_RequantizationRange1000000(int iters) {
 | 
			
		||||
  BM_RequantizationRange(1000000, iters);
 | 
			
		||||
}
 | 
			
		||||
BENCHMARK(BM_RequantizationRange1000000);
 | 
			
		||||
 | 
			
		||||
static void BM_RequantizationRange10000000(int iters) {
 | 
			
		||||
  BM_RequantizationRange(10000000, iters);
 | 
			
		||||
}
 | 
			
		||||
BENCHMARK(BM_RequantizationRange10000000);
 | 
			
		||||
 | 
			
		||||
static void BM_RequantizationRange100000000(int iters) {
 | 
			
		||||
  BM_RequantizationRange(100000000, iters);
 | 
			
		||||
}
 | 
			
		||||
BENCHMARK(BM_RequantizationRange100000000);
 | 
			
		||||
BENCHMARK(BM_RequantizationRange)
 | 
			
		||||
    ->UseRealTime()
 | 
			
		||||
    ->Arg(100)
 | 
			
		||||
    ->Arg(1000)
 | 
			
		||||
    ->Arg(10000)
 | 
			
		||||
    ->Arg(100000)
 | 
			
		||||
    ->Arg(1000000)
 | 
			
		||||
    ->Arg(10000000)
 | 
			
		||||
    ->Arg(100000000);
 | 
			
		||||
 | 
			
		||||
}  // end namespace tensorflow
 | 
			
		||||
 | 
			
		||||
@ -197,148 +197,187 @@ static Graph* Reverse(const TensorShape& shape, int reverse_axis) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename T>
 | 
			
		||||
static void RunReverseRowsBenchmark(int iters, int outer_dim, int middle_dim,
 | 
			
		||||
static void RunReverseRowsBenchmark(::testing::benchmark::State& state,
 | 
			
		||||
                                    int outer_dim, int middle_dim,
 | 
			
		||||
                                    int intra_threads, int channels) {
 | 
			
		||||
  SessionOptions opts = GetOptions(intra_threads);
 | 
			
		||||
  TensorShape shape{outer_dim, middle_dim, channels};
 | 
			
		||||
  const int64 num_items = static_cast<int64>(iters) * shape.num_elements();
 | 
			
		||||
  testing::ItemsProcessed(num_items);
 | 
			
		||||
  testing::BytesProcessed(num_items * sizeof(T));
 | 
			
		||||
  testing::UseRealTime();
 | 
			
		||||
  test::Benchmark("cpu", Reverse<T>(shape, 1), &opts).Run(iters);
 | 
			
		||||
  test::Benchmark("cpu", Reverse<T>(shape, 1), &opts, nullptr, nullptr, "",
 | 
			
		||||
                  /*old_benchmark_api*/ false)
 | 
			
		||||
      .Run(state);
 | 
			
		||||
  const int64 num_items =
 | 
			
		||||
      static_cast<int64>(state.iterations()) * shape.num_elements();
 | 
			
		||||
  state.SetItemsProcessed(num_items);
 | 
			
		||||
  state.SetBytesProcessed(num_items * sizeof(T));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static void BM_ReverseRowsOf1Channel_1T_float(int iters, int outer_dim,
 | 
			
		||||
                                              int middle_dim) {
 | 
			
		||||
  RunReverseRowsBenchmark<float>(iters, outer_dim, middle_dim,
 | 
			
		||||
void BM_ReverseRowsOf1Channel_1T_float(::testing::benchmark::State& state) {
 | 
			
		||||
  const int outer_dim = state.range(0);
 | 
			
		||||
  const int middle_dim = state.range(1);
 | 
			
		||||
 | 
			
		||||
  RunReverseRowsBenchmark<float>(state, outer_dim, middle_dim,
 | 
			
		||||
                                 1 /* intra_threads */, 1 /* channels */);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
BENCHMARK(BM_ReverseRowsOf1Channel_1T_float)
 | 
			
		||||
    ->UseRealTime()
 | 
			
		||||
    ->ArgPair(288, 288)
 | 
			
		||||
    ->ArgPair(1024, 1024)
 | 
			
		||||
    ->ArgPair(10 * 1024, 1024);
 | 
			
		||||
 | 
			
		||||
static void BM_ReverseRowsOf1Channel_1T_uint8(int iters, int outer_dim,
 | 
			
		||||
                                              int middle_dim) {
 | 
			
		||||
  RunReverseRowsBenchmark<uint8>(iters, outer_dim, middle_dim,
 | 
			
		||||
void BM_ReverseRowsOf1Channel_1T_uint8(::testing::benchmark::State& state) {
 | 
			
		||||
  const int outer_dim = state.range(0);
 | 
			
		||||
  const int middle_dim = state.range(1);
 | 
			
		||||
 | 
			
		||||
  RunReverseRowsBenchmark<uint8>(state, outer_dim, middle_dim,
 | 
			
		||||
                                 1 /* intra_threads */, 1 /* channels */);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
BENCHMARK(BM_ReverseRowsOf1Channel_1T_uint8)
 | 
			
		||||
    ->UseRealTime()
 | 
			
		||||
    ->ArgPair(288, 288)
 | 
			
		||||
    ->ArgPair(1024, 1024)
 | 
			
		||||
    ->ArgPair(10 * 1024, 1024);
 | 
			
		||||
 | 
			
		||||
static void BM_ReverseRowsOf1Channel_4T_float(int iters, int outer_dim,
 | 
			
		||||
                                              int middle_dim) {
 | 
			
		||||
  RunReverseRowsBenchmark<float>(iters, outer_dim, middle_dim,
 | 
			
		||||
void BM_ReverseRowsOf1Channel_4T_float(::testing::benchmark::State& state) {
 | 
			
		||||
  const int outer_dim = state.range(0);
 | 
			
		||||
  const int middle_dim = state.range(1);
 | 
			
		||||
 | 
			
		||||
  RunReverseRowsBenchmark<float>(state, outer_dim, middle_dim,
 | 
			
		||||
                                 4 /* intra_threads */, 1 /* channels */);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
BENCHMARK(BM_ReverseRowsOf1Channel_4T_float)
 | 
			
		||||
    ->UseRealTime()
 | 
			
		||||
    ->ArgPair(288, 288)
 | 
			
		||||
    ->ArgPair(1024, 1024)
 | 
			
		||||
    ->ArgPair(10 * 1024, 1024);
 | 
			
		||||
 | 
			
		||||
static void BM_ReverseRowsOf1Channel_4T_uint8(int iters, int outer_dim,
 | 
			
		||||
                                              int middle_dim) {
 | 
			
		||||
  RunReverseRowsBenchmark<uint8>(iters, outer_dim, middle_dim,
 | 
			
		||||
void BM_ReverseRowsOf1Channel_4T_uint8(::testing::benchmark::State& state) {
 | 
			
		||||
  const int outer_dim = state.range(0);
 | 
			
		||||
  const int middle_dim = state.range(1);
 | 
			
		||||
 | 
			
		||||
  RunReverseRowsBenchmark<uint8>(state, outer_dim, middle_dim,
 | 
			
		||||
                                 4 /* intra_threads */, 1 /* channels */);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
BENCHMARK(BM_ReverseRowsOf1Channel_4T_uint8)
 | 
			
		||||
    ->UseRealTime()
 | 
			
		||||
    ->ArgPair(288, 288)
 | 
			
		||||
    ->ArgPair(1024, 1024)
 | 
			
		||||
    ->ArgPair(10 * 1024, 1024);
 | 
			
		||||
 | 
			
		||||
static void BM_ReverseRowsOf3Channels_1T_float(int iters, int outer_dim,
 | 
			
		||||
                                               int middle_dim) {
 | 
			
		||||
  RunReverseRowsBenchmark<float>(iters, outer_dim, middle_dim,
 | 
			
		||||
void BM_ReverseRowsOf3Channels_1T_float(::testing::benchmark::State& state) {
 | 
			
		||||
  const int outer_dim = state.range(0);
 | 
			
		||||
  const int middle_dim = state.range(1);
 | 
			
		||||
 | 
			
		||||
  RunReverseRowsBenchmark<float>(state, outer_dim, middle_dim,
 | 
			
		||||
                                 1 /* intra_threads */, 3 /* channels */);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
BENCHMARK(BM_ReverseRowsOf3Channels_1T_float)
 | 
			
		||||
    ->UseRealTime()
 | 
			
		||||
    ->ArgPair(288, 288)
 | 
			
		||||
    ->ArgPair(30, 30)
 | 
			
		||||
    ->ArgPair(1024, 1024)
 | 
			
		||||
    ->ArgPair(10 * 1024, 1024);
 | 
			
		||||
 | 
			
		||||
static void BM_ReverseRowsOf3Channels_1T_uint8(int iters, int outer_dim,
 | 
			
		||||
                                               int middle_dim) {
 | 
			
		||||
  RunReverseRowsBenchmark<uint8>(iters, outer_dim, middle_dim,
 | 
			
		||||
void BM_ReverseRowsOf3Channels_1T_uint8(::testing::benchmark::State& state) {
 | 
			
		||||
  const int outer_dim = state.range(0);
 | 
			
		||||
  const int middle_dim = state.range(1);
 | 
			
		||||
 | 
			
		||||
  RunReverseRowsBenchmark<uint8>(state, outer_dim, middle_dim,
 | 
			
		||||
                                 1 /* intra_threads */, 3 /* channels */);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
BENCHMARK(BM_ReverseRowsOf3Channels_1T_uint8)
 | 
			
		||||
    ->UseRealTime()
 | 
			
		||||
    ->ArgPair(288, 288)
 | 
			
		||||
    ->ArgPair(30, 30)
 | 
			
		||||
    ->ArgPair(1024, 1024)
 | 
			
		||||
    ->ArgPair(10 * 1024, 1024);
 | 
			
		||||
 | 
			
		||||
static void BM_ReverseRowsOf3Channels_4T_float(int iters, int outer_dim,
 | 
			
		||||
                                               int middle_dim) {
 | 
			
		||||
  RunReverseRowsBenchmark<float>(iters, outer_dim, middle_dim,
 | 
			
		||||
void BM_ReverseRowsOf3Channels_4T_float(::testing::benchmark::State& state) {
 | 
			
		||||
  const int outer_dim = state.range(0);
 | 
			
		||||
  const int middle_dim = state.range(1);
 | 
			
		||||
 | 
			
		||||
  RunReverseRowsBenchmark<float>(state, outer_dim, middle_dim,
 | 
			
		||||
                                 4 /* intra_threads */, 3 /* channels */);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
BENCHMARK(BM_ReverseRowsOf3Channels_4T_float)
 | 
			
		||||
    ->UseRealTime()
 | 
			
		||||
    ->ArgPair(288, 288)
 | 
			
		||||
    ->ArgPair(30, 30)
 | 
			
		||||
    ->ArgPair(1024, 1024)
 | 
			
		||||
    ->ArgPair(10 * 1024, 1024);
 | 
			
		||||
 | 
			
		||||
static void BM_ReverseRowsOf3Channels_4T_uint8(int iters, int outer_dim,
 | 
			
		||||
                                               int middle_dim) {
 | 
			
		||||
  RunReverseRowsBenchmark<uint8>(iters, outer_dim, middle_dim,
 | 
			
		||||
void BM_ReverseRowsOf3Channels_4T_uint8(::testing::benchmark::State& state) {
 | 
			
		||||
  const int outer_dim = state.range(0);
 | 
			
		||||
  const int middle_dim = state.range(1);
 | 
			
		||||
 | 
			
		||||
  RunReverseRowsBenchmark<uint8>(state, outer_dim, middle_dim,
 | 
			
		||||
                                 4 /* intra_threads */, 3 /* channels */);
 | 
			
		||||
}
 | 
			
		||||
BENCHMARK(BM_ReverseRowsOf3Channels_4T_uint8)
 | 
			
		||||
    ->UseRealTime()
 | 
			
		||||
    ->ArgPair(288, 288)
 | 
			
		||||
    ->ArgPair(30, 30)
 | 
			
		||||
    ->ArgPair(1024, 1024)
 | 
			
		||||
    ->ArgPair(10 * 1024, 1024);
 | 
			
		||||
 | 
			
		||||
static void BM_ReverseRowsOf4Channels_1T_float(int iters, int outer_dim,
 | 
			
		||||
                                               int middle_dim) {
 | 
			
		||||
  RunReverseRowsBenchmark<float>(iters, outer_dim, middle_dim,
 | 
			
		||||
void BM_ReverseRowsOf4Channels_1T_float(::testing::benchmark::State& state) {
 | 
			
		||||
  const int outer_dim = state.range(0);
 | 
			
		||||
  const int middle_dim = state.range(1);
 | 
			
		||||
 | 
			
		||||
  RunReverseRowsBenchmark<float>(state, outer_dim, middle_dim,
 | 
			
		||||
                                 1 /* intra_threads */, 4 /* channels */);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
BENCHMARK(BM_ReverseRowsOf4Channels_1T_float)
 | 
			
		||||
    ->UseRealTime()
 | 
			
		||||
    ->ArgPair(288, 288)
 | 
			
		||||
    ->ArgPair(1024, 1024)
 | 
			
		||||
    ->ArgPair(10 * 1024, 1024);
 | 
			
		||||
 | 
			
		||||
static void BM_ReverseRowsOf4Channels_1T_uint8(int iters, int outer_dim,
 | 
			
		||||
                                               int middle_dim) {
 | 
			
		||||
  RunReverseRowsBenchmark<uint8>(iters, outer_dim, middle_dim,
 | 
			
		||||
void BM_ReverseRowsOf4Channels_1T_uint8(::testing::benchmark::State& state) {
 | 
			
		||||
  const int outer_dim = state.range(0);
 | 
			
		||||
  const int middle_dim = state.range(1);
 | 
			
		||||
 | 
			
		||||
  RunReverseRowsBenchmark<uint8>(state, outer_dim, middle_dim,
 | 
			
		||||
                                 1 /* intra_threads */, 4 /* channels */);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
BENCHMARK(BM_ReverseRowsOf4Channels_1T_uint8)
 | 
			
		||||
    ->UseRealTime()
 | 
			
		||||
    ->ArgPair(288, 288)
 | 
			
		||||
    ->ArgPair(1024, 1024)
 | 
			
		||||
    ->ArgPair(10 * 1024, 1024);
 | 
			
		||||
 | 
			
		||||
static void BM_ReverseRowsOf4Channels_4T_float(int iters, int outer_dim,
 | 
			
		||||
                                               int middle_dim) {
 | 
			
		||||
  RunReverseRowsBenchmark<float>(iters, outer_dim, middle_dim,
 | 
			
		||||
void BM_ReverseRowsOf4Channels_4T_float(::testing::benchmark::State& state) {
 | 
			
		||||
  const int outer_dim = state.range(0);
 | 
			
		||||
  const int middle_dim = state.range(1);
 | 
			
		||||
 | 
			
		||||
  RunReverseRowsBenchmark<float>(state, outer_dim, middle_dim,
 | 
			
		||||
                                 4 /* intra_threads */, 4 /* channels */);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
BENCHMARK(BM_ReverseRowsOf4Channels_4T_float)
 | 
			
		||||
    ->UseRealTime()
 | 
			
		||||
    ->ArgPair(288, 288)
 | 
			
		||||
    ->ArgPair(1024, 1024)
 | 
			
		||||
    ->ArgPair(10 * 1024, 1024);
 | 
			
		||||
 | 
			
		||||
static void BM_ReverseRowsOf4Channels_4T_uint8(int iters, int outer_dim,
 | 
			
		||||
                                               int middle_dim) {
 | 
			
		||||
  RunReverseRowsBenchmark<uint8>(iters, outer_dim, middle_dim,
 | 
			
		||||
void BM_ReverseRowsOf4Channels_4T_uint8(::testing::benchmark::State& state) {
 | 
			
		||||
  const int outer_dim = state.range(0);
 | 
			
		||||
  const int middle_dim = state.range(1);
 | 
			
		||||
 | 
			
		||||
  RunReverseRowsBenchmark<uint8>(state, outer_dim, middle_dim,
 | 
			
		||||
                                 4 /* intra_threads */, 4 /* channels */);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
BENCHMARK(BM_ReverseRowsOf4Channels_4T_uint8)
 | 
			
		||||
    ->UseRealTime()
 | 
			
		||||
    ->ArgPair(288, 288)
 | 
			
		||||
    ->ArgPair(1024, 1024)
 | 
			
		||||
    ->ArgPair(10 * 1024, 1024);
 | 
			
		||||
 | 
			
		||||
@ -450,34 +450,44 @@ static Graph* RollGraph(const TensorShape& shape, int isd) {
 | 
			
		||||
  return g;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#define BM_ROLL_OUTER(DEVICE)                                                 \
 | 
			
		||||
  static void BM_##DEVICE##_roll_outer(int iters, int rows, int columns) {    \
 | 
			
		||||
    TensorShape shape{rows, columns};                                         \
 | 
			
		||||
    const int64 num_items = static_cast<int64>(iters) * shape.num_elements(); \
 | 
			
		||||
    testing::ItemsProcessed(num_items);                                       \
 | 
			
		||||
    testing::BytesProcessed(num_items * sizeof(float));                       \
 | 
			
		||||
    testing::UseRealTime();                                                   \
 | 
			
		||||
    test::Benchmark(#DEVICE, RollGraph(shape, 0)).Run(iters);                 \
 | 
			
		||||
  }                                                                           \
 | 
			
		||||
  BENCHMARK(BM_##DEVICE##_roll_outer)                                         \
 | 
			
		||||
      ->ArgPair(256, 256)                                                     \
 | 
			
		||||
      ->ArgPair(512, 512)                                                     \
 | 
			
		||||
      ->ArgPair(1024, 1024)                                                   \
 | 
			
		||||
#define BM_ROLL_OUTER(DEVICE)                                                  \
 | 
			
		||||
  static void BM_##DEVICE##_roll_outer(::testing::benchmark::State& state) {   \
 | 
			
		||||
    const int rows = state.range(0);                                           \
 | 
			
		||||
    const int columns = state.range(1);                                        \
 | 
			
		||||
                                                                               \
 | 
			
		||||
    TensorShape shape{rows, columns};                                          \
 | 
			
		||||
    test::Benchmark(#DEVICE, RollGraph(shape, 0), /*old_benchmark_api*/ false) \
 | 
			
		||||
        .Run(state);                                                           \
 | 
			
		||||
    const int64 num_items =                                                    \
 | 
			
		||||
        static_cast<int64>(state.iterations()) * shape.num_elements();         \
 | 
			
		||||
    state.SetItemsProcessed(num_items);                                        \
 | 
			
		||||
    state.SetBytesProcessed(num_items * sizeof(float));                        \
 | 
			
		||||
  }                                                                            \
 | 
			
		||||
  BENCHMARK(BM_##DEVICE##_roll_outer)                                          \
 | 
			
		||||
      ->UseRealTime()                                                          \
 | 
			
		||||
      ->ArgPair(256, 256)                                                      \
 | 
			
		||||
      ->ArgPair(512, 512)                                                      \
 | 
			
		||||
      ->ArgPair(1024, 1024)                                                    \
 | 
			
		||||
      ->ArgPair(2048, 2048)
 | 
			
		||||
 | 
			
		||||
#define BM_ROLL_ALL(DEVICE)                                                   \
 | 
			
		||||
  static void BM_##DEVICE##_roll_all(int iters, int rows, int columns) {      \
 | 
			
		||||
    TensorShape shape{rows, columns};                                         \
 | 
			
		||||
    const int64 num_items = static_cast<int64>(iters) * shape.num_elements(); \
 | 
			
		||||
    testing::ItemsProcessed(num_items);                                       \
 | 
			
		||||
    testing::BytesProcessed(num_items * sizeof(float));                       \
 | 
			
		||||
    testing::UseRealTime();                                                   \
 | 
			
		||||
    test::Benchmark(#DEVICE, RollGraph(shape, 1)).Run(iters);                 \
 | 
			
		||||
  }                                                                           \
 | 
			
		||||
  BENCHMARK(BM_##DEVICE##_roll_all)                                           \
 | 
			
		||||
      ->ArgPair(256, 256)                                                     \
 | 
			
		||||
      ->ArgPair(512, 512)                                                     \
 | 
			
		||||
      ->ArgPair(1024, 1024)                                                   \
 | 
			
		||||
#define BM_ROLL_ALL(DEVICE)                                                    \
 | 
			
		||||
  static void BM_##DEVICE##_roll_all(::testing::benchmark::State& state) {     \
 | 
			
		||||
    const int rows = state.range(0);                                           \
 | 
			
		||||
    const int columns = state.range(1);                                        \
 | 
			
		||||
                                                                               \
 | 
			
		||||
    TensorShape shape{rows, columns};                                          \
 | 
			
		||||
    test::Benchmark(#DEVICE, RollGraph(shape, 1), /*old_benchmark_api*/ false) \
 | 
			
		||||
        .Run(state);                                                           \
 | 
			
		||||
    const int64 num_items =                                                    \
 | 
			
		||||
        static_cast<int64>(state.iterations()) * shape.num_elements();         \
 | 
			
		||||
    state.SetItemsProcessed(num_items);                                        \
 | 
			
		||||
    state.SetBytesProcessed(num_items * sizeof(float));                        \
 | 
			
		||||
  }                                                                            \
 | 
			
		||||
  BENCHMARK(BM_##DEVICE##_roll_all)                                            \
 | 
			
		||||
      ->UseRealTime()                                                          \
 | 
			
		||||
      ->ArgPair(256, 256)                                                      \
 | 
			
		||||
      ->ArgPair(512, 512)                                                      \
 | 
			
		||||
      ->ArgPair(1024, 1024)                                                    \
 | 
			
		||||
      ->ArgPair(2048, 2048)
 | 
			
		||||
 | 
			
		||||
BM_ROLL_OUTER(cpu);
 | 
			
		||||
 | 
			
		||||
@ -663,8 +663,8 @@ TEST_F(SaveOpSlices2Test, TwoSlices) {
 | 
			
		||||
 | 
			
		||||
// Benchmark-related code below.
 | 
			
		||||
 | 
			
		||||
static void BM_LargeTensorWrite(int iters, int num_elements) {
 | 
			
		||||
  testing::StopTiming();
 | 
			
		||||
void BM_LargeTensorWrite(::testing::benchmark::State& state) {
 | 
			
		||||
  const int num_elements = state.range(0);
 | 
			
		||||
 | 
			
		||||
  // 4 * num_elements bytes total , since sizeof(float) == 4.
 | 
			
		||||
  Tensor tensor(DT_FLOAT, TensorShape({num_elements}));
 | 
			
		||||
@ -689,8 +689,9 @@ static void BM_LargeTensorWrite(int iters, int num_elements) {
 | 
			
		||||
  VLOG(1) << "Save op's output path: " << temp_filename;
 | 
			
		||||
  VLOG(1) << "# nodes in Graph: " << g->num_nodes();
 | 
			
		||||
 | 
			
		||||
  testing::StartTiming();
 | 
			
		||||
  test::Benchmark("cpu", g, &session_options).Run(iters);
 | 
			
		||||
  test::Benchmark("cpu", g, &session_options, nullptr, nullptr, "",
 | 
			
		||||
                  /*old_benchmark_api*/ false)
 | 
			
		||||
      .Run(state);
 | 
			
		||||
}
 | 
			
		||||
BENCHMARK(BM_LargeTensorWrite)->Arg((1 << 30) / 4 /* 1GB float tensor */);
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -67,79 +67,120 @@ static Graph* ThreeDYCumsum(int num_y, int num_z, bool reverse = false) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename T>
 | 
			
		||||
static void LargeOneDimensional(int iters, const string& device, int num_x,
 | 
			
		||||
static void LargeOneDimensional(::testing::benchmark::State& state,
 | 
			
		||||
                                const string& device, int num_x,
 | 
			
		||||
                                bool reverse = false) {
 | 
			
		||||
  testing::ItemsProcessed(static_cast<int64>(iters) * num_x);
 | 
			
		||||
  testing::BytesProcessed(static_cast<int64>(iters) * num_x * sizeof(T));
 | 
			
		||||
  test::Benchmark(device, LargeOneDCumsum<T>(num_x, reverse)).Run(iters);
 | 
			
		||||
  test::Benchmark(device, LargeOneDCumsum<T>(num_x, reverse),
 | 
			
		||||
                  /*old_benchmark_api*/ false)
 | 
			
		||||
      .Run(state);
 | 
			
		||||
  state.SetItemsProcessed(static_cast<int64>(state.iterations()) * num_x);
 | 
			
		||||
  state.SetBytesProcessed(static_cast<int64>(state.iterations()) * num_x *
 | 
			
		||||
                          sizeof(T));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static void DoRowCumsum(int iters, const string& device, int num_x, int num_y,
 | 
			
		||||
static void DoRowCumsum(::testing::benchmark::State& state,
 | 
			
		||||
                        const string& device, int num_x, int num_y,
 | 
			
		||||
                        bool reverse = false) {
 | 
			
		||||
  testing::ItemsProcessed(static_cast<int64>(iters) * num_x * num_y);
 | 
			
		||||
  testing::BytesProcessed(static_cast<int64>(iters) * num_x * num_y *
 | 
			
		||||
                          sizeof(float));
 | 
			
		||||
  test::Benchmark(device, RowCumsum(num_x, num_y, reverse)).Run(iters);
 | 
			
		||||
  test::Benchmark(device, RowCumsum(num_x, num_y, reverse),
 | 
			
		||||
                  /*old_benchmark_api*/ false)
 | 
			
		||||
      .Run(state);
 | 
			
		||||
  state.SetItemsProcessed(static_cast<int64>(state.iterations()) * num_x *
 | 
			
		||||
                          num_y);
 | 
			
		||||
  state.SetBytesProcessed(static_cast<int64>(state.iterations()) * num_x *
 | 
			
		||||
                          num_y * sizeof(float));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static void DoColCumsum(int iters, const string& device, int num_x, int num_y,
 | 
			
		||||
static void DoColCumsum(::testing::benchmark::State& state,
 | 
			
		||||
                        const string& device, int num_x, int num_y,
 | 
			
		||||
                        bool reverse = false) {
 | 
			
		||||
  testing::ItemsProcessed(static_cast<int64>(iters) * num_x * num_y);
 | 
			
		||||
  testing::BytesProcessed(static_cast<int64>(iters) * num_x * num_y *
 | 
			
		||||
                          sizeof(float));
 | 
			
		||||
  test::Benchmark(device, ColCumsum(num_x, num_y, reverse)).Run(iters);
 | 
			
		||||
  test::Benchmark(device, ColCumsum(num_x, num_y, reverse),
 | 
			
		||||
                  /*old_benchmark_api*/ false)
 | 
			
		||||
      .Run(state);
 | 
			
		||||
  state.SetItemsProcessed(static_cast<int64>(state.iterations()) * num_x *
 | 
			
		||||
                          num_y);
 | 
			
		||||
  state.SetBytesProcessed(static_cast<int64>(state.iterations()) * num_x *
 | 
			
		||||
                          num_y * sizeof(float));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static void Do3DYCumsum(int iters, const string& device, int num_x, int num_y,
 | 
			
		||||
static void Do3DYCumsum(::testing::benchmark::State& state,
 | 
			
		||||
                        const string& device, int num_x, int num_y,
 | 
			
		||||
                        bool reverse = false) {
 | 
			
		||||
  testing::ItemsProcessed(static_cast<int64>(iters) * num_x * num_y);
 | 
			
		||||
  testing::BytesProcessed(static_cast<int64>(iters) * num_x * num_y *
 | 
			
		||||
                          sizeof(float));
 | 
			
		||||
  test::Benchmark(device, ThreeDYCumsum(num_x, num_y, reverse)).Run(iters);
 | 
			
		||||
  test::Benchmark(device, ThreeDYCumsum(num_x, num_y, reverse),
 | 
			
		||||
                  /*old_benchmark_api*/ false)
 | 
			
		||||
      .Run(state);
 | 
			
		||||
  state.SetItemsProcessed(static_cast<int64>(state.iterations()) * num_x *
 | 
			
		||||
                          num_y);
 | 
			
		||||
  state.SetBytesProcessed(static_cast<int64>(state.iterations()) * num_x *
 | 
			
		||||
                          num_y * sizeof(float));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static void BM_OneDCumsumGPU(int iters, int num_x) {
 | 
			
		||||
  LargeOneDimensional<float>(iters, "gpu", num_x);
 | 
			
		||||
static void BM_OneDCumsumGPU(::testing::benchmark::State& state) {
 | 
			
		||||
  const int num_x = state.range(0);
 | 
			
		||||
 | 
			
		||||
  LargeOneDimensional<float>(state, "gpu", num_x);
 | 
			
		||||
}
 | 
			
		||||
BENCHMARK(BM_OneDCumsumGPU)->Range(1, 1 << 21);
 | 
			
		||||
 | 
			
		||||
static void BM_OneDCumsumGPUHalf(int iters, int num_x) {
 | 
			
		||||
  LargeOneDimensional<Eigen::half>(iters, "gpu", num_x);
 | 
			
		||||
static void BM_OneDCumsumGPUHalf(::testing::benchmark::State& state) {
 | 
			
		||||
  const int num_x = state.range(0);
 | 
			
		||||
 | 
			
		||||
  LargeOneDimensional<Eigen::half>(state, "gpu", num_x);
 | 
			
		||||
}
 | 
			
		||||
BENCHMARK(BM_OneDCumsumGPUHalf)->Range(1, 1 << 21);
 | 
			
		||||
 | 
			
		||||
static void BM_Sum2DRowCumsumGPU(int iters, int num_x, int num_y) {
 | 
			
		||||
  DoRowCumsum(iters, "gpu", num_x, num_y);
 | 
			
		||||
static void BM_Sum2DRowCumsumGPU(::testing::benchmark::State& state) {
 | 
			
		||||
  const int num_x = state.range(0);
 | 
			
		||||
  const int num_y = state.range(1);
 | 
			
		||||
 | 
			
		||||
  DoRowCumsum(state, "gpu", num_x, num_y);
 | 
			
		||||
}
 | 
			
		||||
BENCHMARK(BM_Sum2DRowCumsumGPU)->RangePair(1, 8192, 1, 8192);
 | 
			
		||||
 | 
			
		||||
static void BM_Sum2DColumnCumsumGPU(int iters, int num_x, int num_y) {
 | 
			
		||||
  DoColCumsum(iters, "gpu", num_x, num_y);
 | 
			
		||||
static void BM_Sum2DColumnCumsumGPU(::testing::benchmark::State& state) {
 | 
			
		||||
  const int num_x = state.range(0);
 | 
			
		||||
  const int num_y = state.range(1);
 | 
			
		||||
 | 
			
		||||
  DoColCumsum(state, "gpu", num_x, num_y);
 | 
			
		||||
}
 | 
			
		||||
BENCHMARK(BM_Sum2DColumnCumsumGPU)->RangePair(1, 8192, 1, 8192);
 | 
			
		||||
 | 
			
		||||
static void BM_Sum3DYCumsumGPU(int iters, int num_x, int num_y) {
 | 
			
		||||
  Do3DYCumsum(iters, "gpu", num_x, num_y);
 | 
			
		||||
static void BM_Sum3DYCumsumGPU(::testing::benchmark::State& state) {
 | 
			
		||||
  const int num_x = state.range(0);
 | 
			
		||||
  const int num_y = state.range(1);
 | 
			
		||||
 | 
			
		||||
  Do3DYCumsum(state, "gpu", num_x, num_y);
 | 
			
		||||
}
 | 
			
		||||
BENCHMARK(BM_Sum3DYCumsumGPU)->RangePair(64, 4096, 64, 4096);
 | 
			
		||||
 | 
			
		||||
static void BM_OneDCumsumGPU_reverse(int iters, int num_x) {
 | 
			
		||||
  LargeOneDimensional<float>(iters, "gpu", num_x, true);
 | 
			
		||||
static void BM_OneDCumsumGPU_reverse(::testing::benchmark::State& state) {
 | 
			
		||||
  const int num_x = state.range(0);
 | 
			
		||||
 | 
			
		||||
  LargeOneDimensional<float>(state, "gpu", num_x, true);
 | 
			
		||||
}
 | 
			
		||||
BENCHMARK(BM_OneDCumsumGPU_reverse)->Range(1, 1 << 21);
 | 
			
		||||
 | 
			
		||||
static void BM_Sum2DRowCumsumGPU_reverse(int iters, int num_x, int num_y) {
 | 
			
		||||
  DoRowCumsum(iters, "gpu", num_x, num_y, true);
 | 
			
		||||
static void BM_Sum2DRowCumsumGPU_reverse(::testing::benchmark::State& state) {
 | 
			
		||||
  const int num_x = state.range(0);
 | 
			
		||||
  const int num_y = state.range(1);
 | 
			
		||||
 | 
			
		||||
  DoRowCumsum(state, "gpu", num_x, num_y, true);
 | 
			
		||||
}
 | 
			
		||||
BENCHMARK(BM_Sum2DRowCumsumGPU_reverse)->RangePair(1, 8192, 1, 8192);
 | 
			
		||||
 | 
			
		||||
static void BM_Sum2DColumnCumsumGPU_reverse(int iters, int num_x, int num_y) {
 | 
			
		||||
  DoColCumsum(iters, "gpu", num_x, num_y, true);
 | 
			
		||||
static void BM_Sum2DColumnCumsumGPU_reverse(
 | 
			
		||||
    ::testing::benchmark::State& state) {
 | 
			
		||||
  const int num_x = state.range(0);
 | 
			
		||||
  const int num_y = state.range(1);
 | 
			
		||||
 | 
			
		||||
  DoColCumsum(state, "gpu", num_x, num_y, true);
 | 
			
		||||
}
 | 
			
		||||
BENCHMARK(BM_Sum2DColumnCumsumGPU_reverse)->RangePair(1, 8192, 1, 8192);
 | 
			
		||||
 | 
			
		||||
static void BM_Sum3DYCumsumGPU_reverse(int iters, int num_x, int num_y) {
 | 
			
		||||
  Do3DYCumsum(iters, "gpu", num_x, num_y, true);
 | 
			
		||||
static void BM_Sum3DYCumsumGPU_reverse(::testing::benchmark::State& state) {
 | 
			
		||||
  const int num_x = state.range(0);
 | 
			
		||||
  const int num_y = state.range(1);
 | 
			
		||||
 | 
			
		||||
  Do3DYCumsum(state, "gpu", num_x, num_y, true);
 | 
			
		||||
}
 | 
			
		||||
BENCHMARK(BM_Sum3DYCumsumGPU_reverse)->RangePair(32, 2048, 32, 2048);
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -254,8 +254,8 @@ class ScatterNdUpdateBM : public ScatterNdUpdateOpTest {
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
template <typename Index>
 | 
			
		||||
static void BM_ScatterNdHelper(int iters, int embedding_size, const char* op) {
 | 
			
		||||
  testing::StopTiming();
 | 
			
		||||
void BM_ScatterNdHelper(::testing::benchmark::State& state, int embedding_size,
 | 
			
		||||
                        const char* op) {
 | 
			
		||||
  const int kRows = 10000000 / embedding_size;
 | 
			
		||||
  std::vector<float> values;
 | 
			
		||||
  values.reserve(kRows);
 | 
			
		||||
@ -280,27 +280,33 @@ static void BM_ScatterNdHelper(int iters, int embedding_size, const char* op) {
 | 
			
		||||
  bm.AddInputFromArray<Index>(TensorShape({kNumUpdates}), indices);
 | 
			
		||||
  bm.AddInputFromArray<float>(TensorShape({kNumUpdates, embedding_size}),
 | 
			
		||||
                              updates);
 | 
			
		||||
  testing::ItemsProcessed((static_cast<int64>(kNumUpdates) * embedding_size) *
 | 
			
		||||
                          iters);
 | 
			
		||||
  testing::StartTiming();
 | 
			
		||||
  while (iters-- > 0) {
 | 
			
		||||
  for (auto i : state) {
 | 
			
		||||
    Status s = bm.RunOpKernel();
 | 
			
		||||
  }
 | 
			
		||||
  testing::StopTiming();
 | 
			
		||||
  state.SetItemsProcessed((static_cast<int64>(kNumUpdates) * embedding_size) *
 | 
			
		||||
                          state.iterations());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static void BM_ScatterNdUpdateInt32(int iters, int embedding_size) {
 | 
			
		||||
  BM_ScatterNdHelper<int32>(iters, embedding_size, "ScatterNdUpdate");
 | 
			
		||||
void BM_ScatterNdUpdateInt32(::testing::benchmark::State& state) {
 | 
			
		||||
  const int embedding_size = state.range(0);
 | 
			
		||||
 | 
			
		||||
  BM_ScatterNdHelper<int32>(state, embedding_size, "ScatterNdUpdate");
 | 
			
		||||
}
 | 
			
		||||
static void BM_ScatterNdUpdateInt64(int iters, int embedding_size) {
 | 
			
		||||
  BM_ScatterNdHelper<int64>(iters, embedding_size, "ScatterNdUpdate");
 | 
			
		||||
void BM_ScatterNdUpdateInt64(::testing::benchmark::State& state) {
 | 
			
		||||
  const int embedding_size = state.range(0);
 | 
			
		||||
 | 
			
		||||
  BM_ScatterNdHelper<int64>(state, embedding_size, "ScatterNdUpdate");
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static void BM_ScatterNdAddInt32(int iters, int embedding_size) {
 | 
			
		||||
  BM_ScatterNdHelper<int32>(iters, embedding_size, "ScatterNdAdd");
 | 
			
		||||
void BM_ScatterNdAddInt32(::testing::benchmark::State& state) {
 | 
			
		||||
  const int embedding_size = state.range(0);
 | 
			
		||||
 | 
			
		||||
  BM_ScatterNdHelper<int32>(state, embedding_size, "ScatterNdAdd");
 | 
			
		||||
}
 | 
			
		||||
static void BM_ScatterNdAddInt64(int iters, int embedding_size) {
 | 
			
		||||
  BM_ScatterNdHelper<int64>(iters, embedding_size, "ScatterNdAdd");
 | 
			
		||||
void BM_ScatterNdAddInt64(::testing::benchmark::State& state) {
 | 
			
		||||
  const int embedding_size = state.range(0);
 | 
			
		||||
 | 
			
		||||
  BM_ScatterNdHelper<int64>(state, embedding_size, "ScatterNdAdd");
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
BENCHMARK(BM_ScatterNdUpdateInt32)
 | 
			
		||||
 | 
			
		||||
Some files were not shown because too many files have changed in this diff Show More
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user