diff --git a/RELEASE.md b/RELEASE.md index 1011610350d..af3acb177ce 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -42,6 +42,10 @@ * Removed deprecated `Interpreter::UseNNAPI(bool)` C++ API. * Use `NnApiDelegate()` and related delegate configuration methods directly. +* TF Core: + * Corrected higher-order gradients of control flow constructs (`tf.cond`, + `tf.while_loop`, and compositions like `tf.foldl`) computed with + `tf.GradientTape` inside a `tf.function`. ## Thanks to our Contributors diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index fd208c6770d..0f5f494e5e2 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -769,7 +769,7 @@ void Execute_MatMul_CPU_Runtime_Error(bool async) { TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status); EXPECT_NE(TF_OK, TF_GetCode(status)); EXPECT_EQ(nullptr, t); - const char* msg = "Matrix size-incompatible: In[0]: [2,2], In[1]: [3,2]"; + const char* msg = "In[0] mismatch In[1] shape: 2 vs. 3: [2,2] [3,2]"; EXPECT_TRUE(strstr(TF_Message(status), msg) != nullptr) << TF_Message(status); // Since error is not cleared, the following copy with correct device will diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc index a0e60b1eafe..66922f901a1 100644 --- a/tensorflow/compiler/jit/xla_launch_util.cc +++ b/tensorflow/compiler/jit/xla_launch_util.cc @@ -583,7 +583,11 @@ XlaComputationLaunchContext::BuildXlaCompilerArguments( XlaCompiler::Argument& arg = out[input_num]; if (absl::c_binary_search(must_be_constant_idxs, input_num)) { // Handles compile-time constants. - TF_RET_CHECK(input->dtype() != DT_RESOURCE); + + // TODO(b/157241314): Support constants located in resource variables. + TF_RET_CHECK(input->dtype() != DT_RESOURCE) + << "tf2xla bridge does not support must-be-constants located in " + "resource variables; try moving them to a tensor"; arg.kind = XlaCompiler::Argument::kConstant; arg.type = input->dtype(); arg.shape = input->shape(); diff --git a/tensorflow/compiler/mlir/hlo/BUILD b/tensorflow/compiler/mlir/hlo/BUILD index a719f303d3d..e1b81133724 100644 --- a/tensorflow/compiler/mlir/hlo/BUILD +++ b/tensorflow/compiler/mlir/hlo/BUILD @@ -517,6 +517,15 @@ cc_library( ], ) +cc_library( + name = "map_chlo_to_hlo_op", + hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h"], + deps = [ + ":hlo", + "@llvm-project//mlir:IR", + ], +) + cc_library( name = "map_hlo_to_lhlo_op", hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h"], @@ -606,9 +615,11 @@ cc_library( ], deps = [ ":hlo", + ":map_chlo_to_hlo_op", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:Shape", "@llvm-project//mlir:StandardOps", "@llvm-project//mlir:Transforms", @@ -893,6 +904,7 @@ cc_library( deps = [ ":chlo_legalize_to_hlo_inc_gen", ":hlo", + ":map_chlo_to_hlo_op", "@llvm-project//mlir:IR", "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:Shape", diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h new file mode 100644 index 00000000000..316e65076ae --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h @@ -0,0 +1,97 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_CHLO_TO_MHLO_OP_H_ +#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_CHLO_TO_MHLO_OP_H_ + +#include <type_traits> + +#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir/IR/PatternMatch.h" + +namespace mlir { +namespace chlo { + +struct HloComplexAdaptor { + static mhlo::ComplexOp CreateOp(BroadcastComplexOp from_op, Type result_type, + Value broadcasted_lhs, Value broadcasted_rhs, + OpBuilder &builder) { + return builder.create<mhlo::ComplexOp>(from_op.getLoc(), result_type, + broadcasted_lhs, broadcasted_rhs); + } +}; +template <typename FromOpTy, typename ToOpTy> +struct HloBinaryElementwiseAdaptor { + static ToOpTy CreateOp(FromOpTy from_op, Type result_type, + Value broadcasted_lhs, Value broadcasted_rhs, + OpBuilder &builder) { + return builder.create<ToOpTy>(from_op.getLoc(), result_type, + broadcasted_lhs, broadcasted_rhs); + } +}; +struct HloCompareAdaptor { + static mhlo::CompareOp CreateOp(BroadcastCompareOp from_op, Type result_type, + Value broadcasted_lhs, Value broadcasted_rhs, + OpBuilder &builder) { + return builder.create<mhlo::CompareOp>( + from_op.getLoc(), result_type, broadcasted_lhs, broadcasted_rhs, + from_op.comparison_direction(), from_op.compare_typeAttr()); + } +}; + +// Populate a pattern for each Broadcasting CHlo op. This requires the pattern +// to take a ChloOpTy, MhloOpTy, and an Adaptor as templated values. +template <template <typename, typename, typename> class Pattern, + typename... ConstructorArgs> +void PopulateForBroadcastingBinaryOp(MLIRContext *context, + OwningRewritePatternList *patterns, + ConstructorArgs &&...args) { +#define POPULATE_BCAST(ChloOp, HloOp) \ + patterns->insert< \ + Pattern<ChloOp, HloOp, HloBinaryElementwiseAdaptor<ChloOp, HloOp>>>( \ + context, args...); + + POPULATE_BCAST(BroadcastAddOp, mhlo::AddOp); + POPULATE_BCAST(BroadcastAndOp, mhlo::AndOp); + POPULATE_BCAST(BroadcastAtan2Op, mhlo::Atan2Op); + POPULATE_BCAST(BroadcastDivOp, mhlo::DivOp); + POPULATE_BCAST(BroadcastMaxOp, mhlo::MaxOp); + POPULATE_BCAST(BroadcastMinOp, mhlo::MinOp); + POPULATE_BCAST(BroadcastMulOp, mhlo::MulOp); + POPULATE_BCAST(BroadcastOrOp, mhlo::OrOp); + POPULATE_BCAST(BroadcastPowOp, mhlo::PowOp); + POPULATE_BCAST(BroadcastRemOp, mhlo::RemOp); + POPULATE_BCAST(BroadcastShiftLeftOp, mhlo::ShiftLeftOp); + POPULATE_BCAST(BroadcastShiftRightArithmeticOp, mhlo::ShiftRightArithmeticOp); + POPULATE_BCAST(BroadcastShiftRightLogicalOp, mhlo::ShiftRightLogicalOp); + POPULATE_BCAST(BroadcastSubOp, mhlo::SubOp); + POPULATE_BCAST(BroadcastXorOp, mhlo::XorOp); + + // Broadcasting ops requiring special construction. + patterns + ->insert<Pattern<BroadcastComplexOp, mhlo::ComplexOp, HloComplexAdaptor>>( + context, args...); + patterns + ->insert<Pattern<BroadcastCompareOp, mhlo::CompareOp, HloCompareAdaptor>>( + context, args...); + +#undef POPULATE_BCAST +} + +} // namespace chlo +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_CHLO_TO_HLO_OP_H_ diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc index f261d6adb23..b232ca4e6cb 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc @@ -17,6 +17,7 @@ limitations under the License. #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h" #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" #include "mlir-hlo/utils/broadcast_utils.h" #include "mlir/Dialect/SCF/SCF.h" @@ -69,13 +70,18 @@ struct ConvertConstantLikeOp : public OpConversionPattern<ConstantLikeOp> { // Converts binary ops that statically are determined to not broadcast directly // to the corresponding mhlo non-broadcasting op. template <typename ChloOpTy, typename HloOpTy, typename Adaptor> -struct ConvertTrivialNonBroadcastBinaryOp : public OpRewritePattern<ChloOpTy> { - using OpRewritePattern<ChloOpTy>::OpRewritePattern; - LogicalResult matchAndRewrite(ChloOpTy op, - PatternRewriter &rewriter) const override { +struct ConvertTrivialNonBroadcastBinaryOp + : public OpConversionPattern<ChloOpTy> { + using OpConversionPattern<ChloOpTy>::OpConversionPattern; + LogicalResult matchAndRewrite( + ChloOpTy op, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override { // Only rewrite for statically determinable non-broadcasting cases. - auto lhs_type = op.lhs().getType().template dyn_cast<RankedTensorType>(); - auto rhs_type = op.rhs().getType().template dyn_cast<RankedTensorType>(); + typename ChloOpTy::Adaptor transformed(operands); + auto lhs_type = + transformed.lhs().getType().template dyn_cast<RankedTensorType>(); + auto rhs_type = + transformed.rhs().getType().template dyn_cast<RankedTensorType>(); if (!lhs_type || !rhs_type) return failure(); // Requires rank broadcast. @@ -93,8 +99,9 @@ struct ConvertTrivialNonBroadcastBinaryOp : public OpRewritePattern<ChloOpTy> { } } - rewriter.replaceOp(op, {Adaptor::CreateOp(op, op.getResult().getType(), - op.lhs(), op.rhs(), rewriter)}); + rewriter.replaceOp( + op, {Adaptor::CreateOp(op, op.getResult().getType(), operands[0], + operands[1], rewriter)}); return success(); } }; @@ -113,13 +120,15 @@ struct ConvertTrivialNonBroadcastBinaryOp : public OpRewritePattern<ChloOpTy> { // `shape.broadcast` op, which only supports prefix-padding. template <typename ChloOpTy, typename HloOpTy, typename Adaptor> struct ConvertRankedDynamicBroadcastBinaryOp - : public OpRewritePattern<ChloOpTy> { - using OpRewritePattern<ChloOpTy>::OpRewritePattern; - LogicalResult matchAndRewrite(ChloOpTy op, - PatternRewriter &rewriter) const override { + : public OpConversionPattern<ChloOpTy> { + using OpConversionPattern<ChloOpTy>::OpConversionPattern; + LogicalResult matchAndRewrite( + ChloOpTy op, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override { // Only support ranked operands. - Value lhs = op.lhs(); - Value rhs = op.rhs(); + typename ChloOpTy::Adaptor transformed(operands); + Value lhs = transformed.lhs(); + Value rhs = transformed.rhs(); auto lhs_type = lhs.getType().dyn_cast<RankedTensorType>(); auto rhs_type = rhs.getType().dyn_cast<RankedTensorType>(); auto result_type = @@ -193,324 +202,6 @@ struct ConvertRankedDynamicBroadcastBinaryOp } }; -// Converts a broadcasting binary operation with a scalar operand and an -// unranked operand to a ranked broadcasting operation by dynamically reshaping -// the unranked operand to a 1D tensor. This will always be safe because -// broadcasting from a scalar to another shape always works. -template <typename ChloOpTy, typename HloOpTy> -struct ConvertUnrankedScalarDynamicBroadcastBinaryOp - : public OpRewritePattern<ChloOpTy> { - using OpRewritePattern<ChloOpTy>::OpRewritePattern; - LogicalResult matchAndRewrite(ChloOpTy op, - PatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Value lhs = op.lhs(); - Value rhs = op.rhs(); - - auto lhs_ranked_type = lhs.getType().dyn_cast<RankedTensorType>(); - auto lhs_unranked_type = lhs.getType().dyn_cast<UnrankedTensorType>(); - - auto rhs_ranked_type = rhs.getType().dyn_cast<RankedTensorType>(); - auto rhs_unranked_type = rhs.getType().dyn_cast<UnrankedTensorType>(); - - bool lhs_is_scalar = lhs_ranked_type && - lhs_ranked_type.getShape().empty() && - rhs_unranked_type; - bool rhs_is_scalar = rhs_ranked_type && - rhs_ranked_type.getShape().empty() && - lhs_unranked_type; - - // Only support the case where exactly one operand is scalar and the other - // is unranked. Other patterns in this file will create more efficient - // lowerings for cases where both ranks are known or will handle the more - // generic case of both inputs being unranked. - if (!(lhs_is_scalar ^ rhs_is_scalar)) return failure(); - - auto result_type = op.getResult().getType().template dyn_cast<TensorType>(); - - // Reshape the non-scalar value into a dynamically sized, rank-1 tensor - Value shape = - rewriter.create<shape::ShapeOfOp>(loc, lhs_is_scalar ? rhs : lhs); - Value num_elements = rewriter.create<shape::NumElementsOp>(loc, shape); - Value size_tensor = - rewriter.create<TensorFromElementsOp>(loc, num_elements); - Value reshaped = rewriter.create<mhlo::DynamicReshapeOp>( - loc, RankedTensorType::get({-1}, result_type.getElementType()), - lhs_is_scalar ? rhs : lhs, size_tensor); - - // Create a new ranked Chlo op that will be further lowered by other - // patterns into Mhlo. - SmallVector<Value, 2> operands{lhs_is_scalar ? lhs : reshaped, - rhs_is_scalar ? rhs : reshaped}; - Value computed = rewriter.create<ChloOpTy>( - loc, SmallVector<Type, 1>{reshaped.getType()}, operands, op.getAttrs()); - - // Reshape the result back into an unranked tensor. - rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(op, result_type, - computed, shape); - - return success(); - } -}; - -// Handles lowering of the following pattern to patterns that will be further -// matched by other patterns until they result in LHLO: -// %result = "chlo.op"(%lhs, %rhs) : (<*xTy>, <*xTy>) -> <*xTy> -// -// The sequence of specializations this handles is: -// - Either operand being scalar -// - Operands having equal shapes -// - The resulting value being any of ranks [2,6] -template <typename ChloOpTy, typename HloOpTy, typename Adaptor> -struct ConvertUnrankedDynamicBroadcastBinaryOp - : public OpRewritePattern<ChloOpTy> { - using OpRewritePattern<ChloOpTy>::OpRewritePattern; - - LogicalResult matchAndRewrite(ChloOpTy op, - PatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Value lhs = op.lhs(); - Value rhs = op.rhs(); - auto lhs_type = lhs.getType().dyn_cast<UnrankedTensorType>(); - auto rhs_type = rhs.getType().dyn_cast<UnrankedTensorType>(); - auto result_type = op.getResult().getType().template dyn_cast<TensorType>(); - - // Only support unranked operands. If either operand is ranked, another - // pattern will handle the lowering. - if (!lhs_type || !rhs_type) return failure(); - - // If lhs is scalar - auto if_op = rewriter.create<scf::IfOp>( - loc, result_type, IsScalarTensor(rewriter, op, lhs), true); - OpBuilder if_lhs_scalar_builder = if_op.getThenBodyBuilder(); - Value reshaped_lhs = if_lhs_scalar_builder.create<TensorCastOp>( - loc, RankedTensorType::get({}, lhs_type.getElementType()), lhs); - Value if_lhs_scalar_result = if_lhs_scalar_builder.create<ChloOpTy>( - loc, ArrayRef<Type>{result_type}, ArrayRef<Value>{reshaped_lhs, rhs}, - op.getAttrs()); - if_lhs_scalar_builder.create<scf::YieldOp>(loc, if_lhs_scalar_result); - - // If lhs is NOT scalar - // - // See if rhs is scalar - OpBuilder else_lhs_scalar_builder = if_op.getElseBodyBuilder(); - auto if_rhs_scalar_op = else_lhs_scalar_builder.create<scf::IfOp>( - loc, result_type, IsScalarTensor(else_lhs_scalar_builder, op, rhs), - true); - else_lhs_scalar_builder.create<scf::YieldOp>(loc, - if_rhs_scalar_op.getResult(0)); - OpBuilder if_rhs_scalar_builder = if_rhs_scalar_op.getThenBodyBuilder(); - Value reshaped_rhs = if_rhs_scalar_builder.create<TensorCastOp>( - loc, RankedTensorType::get({}, lhs_type.getElementType()), rhs); - Value if_rhs_scalar_result = if_rhs_scalar_builder.create<ChloOpTy>( - loc, ArrayRef<Type>{result_type}, ArrayRef<Value>{lhs, reshaped_rhs}, - op.getAttrs()); - if_rhs_scalar_builder.create<scf::YieldOp>(loc, if_rhs_scalar_result); - - // If NEITHER shape is scalar - // - // See if shapes are equal. - OpBuilder else_no_scalars_builder = if_rhs_scalar_op.getElseBodyBuilder(); - Value shape_of_lhs = - else_no_scalars_builder.create<shape::ShapeOfOp>(loc, lhs); - Value shape_of_rhs = - else_no_scalars_builder.create<shape::ShapeOfOp>(loc, rhs); - Value equal_shapes = else_no_scalars_builder.create<shape::ShapeEqOp>( - loc, shape_of_lhs, shape_of_rhs); - - auto if_eq_shapes_op = else_no_scalars_builder.create<scf::IfOp>( - loc, result_type, equal_shapes, true); - else_no_scalars_builder.create<scf::YieldOp>(loc, - if_eq_shapes_op.getResult(0)); - - OpBuilder if_eq_shapes_builder = if_eq_shapes_op.getThenBodyBuilder(); - Value non_broadcast_op = - Adaptor::CreateOp(op, result_type, lhs, rhs, if_eq_shapes_builder); - if_eq_shapes_builder.create<scf::YieldOp>(loc, non_broadcast_op); - - // If shapes are not scalar, nor equal - // - // See if values are of a rank that we support. - OpBuilder if_neq_shapes_builder = if_eq_shapes_op.getElseBodyBuilder(); - if_neq_shapes_builder.create<scf::YieldOp>( - loc, HandleBroadcastAndOp(if_neq_shapes_builder, op, lhs, rhs)); - - rewriter.replaceOp(op, {if_op.getResult(0)}); - return success(); - } - - private: - // Returns the dyanamic result of checking the given value is a scalar - // tensor. - Value IsScalarTensor(OpBuilder &rewriter, ChloOpTy op, Value tensor) const { - auto loc = op.getLoc(); - - Value shape_of_tensor = rewriter.create<shape::ShapeOfOp>(loc, tensor); - Value rank_tensor = rewriter.create<shape::RankOp>( - loc, rewriter.getIndexType(), shape_of_tensor); - return rewriter.create<CmpIOp>(loc, rewriter.getI1Type(), CmpIPredicate::eq, - rank_tensor, - rewriter.create<ConstantIndexOp>(loc, 0)); - } - - // Create the if statement and code for a broadcasting op with a result of a - // given rank. - scf::IfOp createRankSpecializedBroadcastAndOp(OpBuilder &builder, ChloOpTy op, - Value lhs, Value rhs, - Value actual_rank, - int targeted_rank) const { - auto loc = op.getLoc(); - - // Create the if block to place the current specialized logic in. - Value greater_rank_is_n = builder.create<CmpIOp>( - loc, CmpIPredicate::eq, actual_rank, - builder.create<ConstantIndexOp>(loc, targeted_rank)); - auto if_op = - builder.create<scf::IfOp>(loc, lhs.getType(), greater_rank_is_n, true); - OpBuilder if_builder = if_op.getThenBodyBuilder(); - - // Handle shape broadcasting and inferrence. - Value lhs_shape = if_builder.create<shape::ShapeOfOp>(loc, lhs); - Value rhs_shape = if_builder.create<shape::ShapeOfOp>(loc, rhs); - SmallVector<int64_t, 6> ranked_shape(targeted_rank, 1); - auto unknown_rank_extent_tensor_type = RankedTensorType::get( - {RankedTensorType::kDynamicSize}, builder.getIndexType()); - auto known_rank_extent_tensor_type = - RankedTensorType::get({targeted_rank}, builder.getIndexType()); - auto reshaped_type = RankedTensorType::get( - llvm::SmallVector<int64_t, 6>(targeted_rank, - RankedTensorType::kDynamicSize), - lhs.getType().template dyn_cast<TensorType>().getElementType()); - Value ranked_shape_val = if_builder.create<shape::ConstShapeOp>( - loc, known_rank_extent_tensor_type, - mlir::DenseIntElementsAttr::get(known_rank_extent_tensor_type, - ranked_shape)); - Value extended_lhs = if_builder.create<shape::BroadcastOp>( - loc, unknown_rank_extent_tensor_type, lhs_shape, ranked_shape_val, - nullptr); - Value extended_lhs_casted = if_builder.create<TensorCastOp>( - loc, known_rank_extent_tensor_type, extended_lhs); - Value extended_rhs = if_builder.create<shape::BroadcastOp>( - loc, unknown_rank_extent_tensor_type, rhs_shape, ranked_shape_val, - nullptr); - Value extended_rhs_casted = if_builder.create<TensorCastOp>( - loc, known_rank_extent_tensor_type, extended_rhs); - - // 1. Reshape operands to the given rank (with the same number of elements) - // 2. Compute the ranked-broadcasted ChloOp (which will assert that the ops - // can be broadcasted and do the actual broadcasting) - // 3. Type erase the output back to unranked - Value reshaped_lhs = if_builder.create<mhlo::DynamicReshapeOp>( - loc, reshaped_type, lhs, extended_lhs_casted); - Value reshaped_rhs = if_builder.create<mhlo::DynamicReshapeOp>( - loc, reshaped_type, rhs, extended_rhs_casted); - Value result = if_builder.create<ChloOpTy>( - loc, ArrayRef<Type>{reshaped_type}, - ArrayRef<Value>{reshaped_lhs, reshaped_rhs}, op.getAttrs()); - Value reshaped_result = if_builder.create<TensorCastOp>( - loc, UnrankedTensorType::get(reshaped_type.getElementType()), result); - if_builder.create<scf::YieldOp>(loc, reshaped_result); - - // Return the if_op, so the result can be used and the else block can be - // used for the next rank specialized step. - return if_op; - } - - // Iterates over the desired ranks to be specialized and generates the code - // snippet for each case. - Value HandleBroadcastAndOp(OpBuilder &rewriter, ChloOpTy op, Value lhs, - Value rhs) const { - constexpr int max_rank_specialization = 7; - auto loc = op.getLoc(); - - // Find the larger rank of the 2 operands. - auto extent_tensor_type = RankedTensorType::get({ShapedType::kDynamicSize}, - rewriter.getIndexType()); - Value lhs_shape = - rewriter.create<shape::ShapeOfOp>(loc, extent_tensor_type, lhs); - Value rhs_shape = - rewriter.create<shape::ShapeOfOp>(loc, extent_tensor_type, rhs); - Value lhs_rank = - rewriter.create<RankOp>(loc, rewriter.getIndexType(), lhs_shape); - Value rhs_rank = - rewriter.create<RankOp>(loc, rewriter.getIndexType(), rhs_shape); - Value greater_rank_lhs = - rewriter.create<CmpIOp>(loc, CmpIPredicate::sgt, lhs_rank, rhs_rank); - Value greater_rank = - rewriter.create<SelectOp>(loc, greater_rank_lhs, lhs_rank, rhs_rank); - - // Generate a list of nested if/else statements to handle rank - // specializations from 2-6. - scf::IfOp if_op = createRankSpecializedBroadcastAndOp(rewriter, op, lhs, - rhs, greater_rank, 2); - - // Put each subsequent rank specialization inside the else statement of the - // previous one. - OpBuilder else_builder = if_op.getElseBodyBuilder(); - for (int i = 3; i < max_rank_specialization; i++) { - auto inner_if = createRankSpecializedBroadcastAndOp(else_builder, op, lhs, - rhs, greater_rank, i); - - else_builder.create<scf::YieldOp>(loc, inner_if.getResult(0)); - else_builder = inner_if.getElseBodyBuilder(); - } - - // Fire an assertion if none of the rank specializations applied (one of the - // ranks was greater than 6). - else_builder.create<AssertOp>( - loc, else_builder.create<ConstantIntOp>(loc, 0, 1), - "Input for dynamic binary op lowering was of a rank greater than 6"); - else_builder.create<scf::YieldOp>(loc, lhs); - - // Return the result of the outermost if statement. - return if_op.getResult(0); - } -}; - -template <typename ChloOpTy, typename HloOpTy, typename Adaptor> -void PopulateForBinaryOp(MLIRContext *context, - OwningRewritePatternList *patterns) { - patterns - ->insert<ConvertTrivialNonBroadcastBinaryOp<ChloOpTy, HloOpTy, Adaptor>>( - context, 10); - patterns->insert< - ConvertRankedDynamicBroadcastBinaryOp<ChloOpTy, HloOpTy, Adaptor>>( - context, 5); - patterns->insert< - ConvertUnrankedScalarDynamicBroadcastBinaryOp<ChloOpTy, HloOpTy>, - ConvertUnrankedDynamicBroadcastBinaryOp<ChloOpTy, HloOpTy, Adaptor>>( - context); -} - -template <typename FromOpTy, typename ToOpTy> -struct HloBinaryElementwiseAdaptor { - static ToOpTy CreateOp(FromOpTy from_op, Type result_type, - Value broadcasted_lhs, Value broadcasted_rhs, - OpBuilder &builder) { - return builder.create<ToOpTy>(from_op.getLoc(), result_type, - broadcasted_lhs, broadcasted_rhs); - } -}; - -struct HloComplexAdaptor { - static mhlo::ComplexOp CreateOp(BroadcastComplexOp from_op, Type result_type, - Value broadcasted_lhs, Value broadcasted_rhs, - OpBuilder &builder) { - return builder.create<mhlo::ComplexOp>(from_op.getLoc(), result_type, - broadcasted_lhs, broadcasted_rhs); - } -}; - -struct HloCompareAdaptor { - static mhlo::CompareOp CreateOp(BroadcastCompareOp from_op, Type result_type, - Value broadcasted_lhs, Value broadcasted_rhs, - OpBuilder &builder) { - return builder.create<mhlo::CompareOp>( - from_op.getLoc(), result_type, broadcasted_lhs, broadcasted_rhs, - from_op.comparison_direction(), from_op.compare_typeAttr()); - } -}; - #include "generated_chlo_legalize_to_hlo.inc" } // namespace @@ -521,32 +212,10 @@ void PopulateLegalizeChloToHloPatterns(MLIRContext *context, // Instantiate conversion templates for conforming binary elementwise ops // that do not have different dtypes between operands and results and do // not have special attributes that need to be preserved. -#define POPULATE_BCAST(ChloOp, HloOp) \ - PopulateForBinaryOp<ChloOp, HloOp, \ - HloBinaryElementwiseAdaptor<ChloOp, HloOp>>(context, \ - patterns); - - POPULATE_BCAST(BroadcastAddOp, mhlo::AddOp); - POPULATE_BCAST(BroadcastAndOp, mhlo::AndOp); - POPULATE_BCAST(BroadcastAtan2Op, mhlo::Atan2Op); - POPULATE_BCAST(BroadcastDivOp, mhlo::DivOp); - POPULATE_BCAST(BroadcastMaxOp, mhlo::MaxOp); - POPULATE_BCAST(BroadcastMinOp, mhlo::MinOp); - POPULATE_BCAST(BroadcastMulOp, mhlo::MulOp); - POPULATE_BCAST(BroadcastOrOp, mhlo::OrOp); - POPULATE_BCAST(BroadcastPowOp, mhlo::PowOp); - POPULATE_BCAST(BroadcastRemOp, mhlo::RemOp); - POPULATE_BCAST(BroadcastShiftLeftOp, mhlo::ShiftLeftOp); - POPULATE_BCAST(BroadcastShiftRightArithmeticOp, mhlo::ShiftRightArithmeticOp); - POPULATE_BCAST(BroadcastShiftRightLogicalOp, mhlo::ShiftRightLogicalOp); - POPULATE_BCAST(BroadcastSubOp, mhlo::SubOp); - POPULATE_BCAST(BroadcastXorOp, mhlo::XorOp); - - // Broadcasting ops requiring special construction. - PopulateForBinaryOp<BroadcastComplexOp, mhlo::ComplexOp, HloComplexAdaptor>( - context, patterns); - PopulateForBinaryOp<BroadcastCompareOp, mhlo::CompareOp, HloCompareAdaptor>( - context, patterns); + PopulateForBroadcastingBinaryOp<ConvertTrivialNonBroadcastBinaryOp>( + context, patterns, 10); + PopulateForBroadcastingBinaryOp<ConvertRankedDynamicBroadcastBinaryOp>( + context, patterns, 5); // Other patterns. patterns->insert<ConvertConstantLikeOp>(context); diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc index 0dc9f1be929..358cb5c4349 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc @@ -16,7 +16,9 @@ limitations under the License. #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h" #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/Function.h" @@ -126,6 +128,291 @@ struct ElementwiseOpConversion : public OpRewritePattern<OpTy> { } }; +// Converts a broadcasting binary operation with a scalar operand and an +// unranked operand to a ranked broadcasting operation by dynamically reshaping +// the unranked operand to a 1D tensor. This will always be safe because +// broadcasting from a scalar to another shape always works. +template <typename ChloOpTy, typename HloOpTy, typename Adaptor> +struct ConvertUnrankedScalarDynamicBroadcastBinaryOp + : public OpConversionPattern<ChloOpTy> { + using OpConversionPattern<ChloOpTy>::OpConversionPattern; + LogicalResult matchAndRewrite( + ChloOpTy op, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + typename ChloOpTy::Adaptor transformed(operands); + Value lhs = transformed.lhs(); + Value rhs = transformed.rhs(); + + auto lhs_ranked_type = lhs.getType().dyn_cast<RankedTensorType>(); + auto lhs_unranked_type = lhs.getType().dyn_cast<UnrankedTensorType>(); + + auto rhs_ranked_type = rhs.getType().dyn_cast<RankedTensorType>(); + auto rhs_unranked_type = rhs.getType().dyn_cast<UnrankedTensorType>(); + + bool lhs_is_scalar = lhs_ranked_type && + lhs_ranked_type.getShape().empty() && + rhs_unranked_type; + bool rhs_is_scalar = rhs_ranked_type && + rhs_ranked_type.getShape().empty() && + lhs_unranked_type; + + // Only support the case where exactly one operand is scalar and the other + // is unranked. Other patterns in chlo-to-hlo legalization will create more + // efficient lowerings for cases where both ranks are known or will handle + // the more generic case of both inputs being unranked. + if (!(lhs_is_scalar ^ rhs_is_scalar)) return failure(); + + auto result_type = op.getResult().getType().template dyn_cast<TensorType>(); + + // Reshape the non-scalar value into a dynamically sized, rank-1 tensor + Value shape = + rewriter.create<shape::ShapeOfOp>(loc, lhs_is_scalar ? rhs : lhs); + Value num_elements = rewriter.create<shape::NumElementsOp>(loc, shape); + Value size_tensor = + rewriter.create<TensorFromElementsOp>(loc, num_elements); + Value reshaped = rewriter.create<mhlo::DynamicReshapeOp>( + loc, RankedTensorType::get({-1}, result_type.getElementType()), + lhs_is_scalar ? rhs : lhs, size_tensor); + + // Create a new ranked Chlo op that will be further lowered by other + // patterns into Mhlo. + SmallVector<Value, 2> new_operands{lhs_is_scalar ? lhs : reshaped, + rhs_is_scalar ? rhs : reshaped}; + Value computed = + rewriter.create<ChloOpTy>(loc, SmallVector<Type, 1>{reshaped.getType()}, + new_operands, op.getAttrs()); + + // Reshape the result back into an unranked tensor. + rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(op, result_type, + computed, shape); + + return success(); + } +}; + +// Handles lowering of the following pattern to patterns that will be further +// matched by other patterns until they result in LHLO: +// %result = "chlo.op"(%lhs, %rhs) : (<*xTy>, <*xTy>) -> <*xTy> +// +// The sequence of specializations this handles is: +// - Either operand being scalar +// - Operands having equal shapes +// - The resulting value being any of ranks [2,6] +template <typename ChloOpTy, typename HloOpTy, typename Adaptor> +struct ConvertUnrankedDynamicBroadcastBinaryOp + : public OpConversionPattern<ChloOpTy> { + using OpConversionPattern<ChloOpTy>::OpConversionPattern; + + LogicalResult matchAndRewrite( + ChloOpTy op, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + typename ChloOpTy::Adaptor transformed(operands); + Value lhs = transformed.lhs(); + Value rhs = transformed.rhs(); + auto lhs_type = lhs.getType().dyn_cast<UnrankedTensorType>(); + auto rhs_type = rhs.getType().dyn_cast<UnrankedTensorType>(); + auto result_type = op.getResult().getType().template dyn_cast<TensorType>(); + + // Only support unranked operands. If either operand is ranked, another + // pattern will handle the lowering. + if (!lhs_type || !rhs_type) return failure(); + + // If lhs is scalar + auto if_op = rewriter.create<scf::IfOp>( + loc, result_type, IsScalarTensor(rewriter, op, lhs), true); + OpBuilder if_lhs_scalar_builder = + if_op.getThenBodyBuilder(rewriter.getListener()); + Value reshaped_lhs = if_lhs_scalar_builder.create<TensorCastOp>( + loc, RankedTensorType::get({}, lhs_type.getElementType()), lhs); + Value if_lhs_scalar_result = if_lhs_scalar_builder.create<ChloOpTy>( + loc, ArrayRef<Type>{result_type}, ArrayRef<Value>{reshaped_lhs, rhs}, + op.getAttrs()); + if_lhs_scalar_builder.create<scf::YieldOp>(loc, if_lhs_scalar_result); + + // If lhs is NOT scalar + // + // See if rhs is scalar + OpBuilder else_lhs_scalar_builder = + if_op.getElseBodyBuilder(rewriter.getListener()); + auto if_rhs_scalar_op = else_lhs_scalar_builder.create<scf::IfOp>( + loc, result_type, IsScalarTensor(else_lhs_scalar_builder, op, rhs), + true); + else_lhs_scalar_builder.create<scf::YieldOp>(loc, + if_rhs_scalar_op.getResult(0)); + OpBuilder if_rhs_scalar_builder = + if_rhs_scalar_op.getThenBodyBuilder(rewriter.getListener()); + Value reshaped_rhs = if_rhs_scalar_builder.create<TensorCastOp>( + loc, RankedTensorType::get({}, lhs_type.getElementType()), rhs); + Value if_rhs_scalar_result = if_rhs_scalar_builder.create<ChloOpTy>( + loc, ArrayRef<Type>{result_type}, ArrayRef<Value>{lhs, reshaped_rhs}, + op.getAttrs()); + if_rhs_scalar_builder.create<scf::YieldOp>(loc, if_rhs_scalar_result); + + // If NEITHER shape is scalar + // + // See if shapes are equal. + OpBuilder else_no_scalars_builder = + if_rhs_scalar_op.getElseBodyBuilder(rewriter.getListener()); + Value shape_of_lhs = + else_no_scalars_builder.create<shape::ShapeOfOp>(loc, lhs); + Value shape_of_rhs = + else_no_scalars_builder.create<shape::ShapeOfOp>(loc, rhs); + Value equal_shapes = else_no_scalars_builder.create<shape::ShapeEqOp>( + loc, shape_of_lhs, shape_of_rhs); + + auto if_eq_shapes_op = else_no_scalars_builder.create<scf::IfOp>( + loc, result_type, equal_shapes, true); + else_no_scalars_builder.create<scf::YieldOp>(loc, + if_eq_shapes_op.getResult(0)); + + OpBuilder if_eq_shapes_builder = + if_eq_shapes_op.getThenBodyBuilder(rewriter.getListener()); + Value non_broadcast_op = + Adaptor::CreateOp(op, result_type, lhs, rhs, if_eq_shapes_builder); + if_eq_shapes_builder.create<scf::YieldOp>(loc, non_broadcast_op); + + // If shapes are not scalar, nor equal + // + // See if values are of a rank that we support. + OpBuilder if_neq_shapes_builder = + if_eq_shapes_op.getElseBodyBuilder(rewriter.getListener()); + if_neq_shapes_builder.create<scf::YieldOp>( + loc, HandleBroadcastAndOp(if_neq_shapes_builder, op, lhs, rhs)); + + rewriter.replaceOp(op, {if_op.getResult(0)}); + return success(); + } + + private: + // Returns the dyanamic result of checking the given value is a scalar + // tensor. + Value IsScalarTensor(OpBuilder &rewriter, ChloOpTy op, Value tensor) const { + auto loc = op.getLoc(); + + Value shape_of_tensor = rewriter.create<shape::ShapeOfOp>(loc, tensor); + Value rank_tensor = rewriter.create<shape::RankOp>( + loc, rewriter.getIndexType(), shape_of_tensor); + return rewriter.create<CmpIOp>(loc, rewriter.getI1Type(), CmpIPredicate::eq, + rank_tensor, + rewriter.create<ConstantIndexOp>(loc, 0)); + } + + // Create the if statement and code for a broadcasting op with a result of a + // given rank. + scf::IfOp createRankSpecializedBroadcastAndOp(OpBuilder &builder, ChloOpTy op, + Value lhs, Value rhs, + Value actual_rank, + int targeted_rank) const { + auto loc = op.getLoc(); + + // Create the if block to place the current specialized logic in. + Value greater_rank_is_n = builder.create<CmpIOp>( + loc, CmpIPredicate::eq, actual_rank, + builder.create<ConstantIndexOp>(loc, targeted_rank)); + auto if_op = + builder.create<scf::IfOp>(loc, lhs.getType(), greater_rank_is_n, true); + OpBuilder if_builder = if_op.getThenBodyBuilder(builder.getListener()); + + // Handle shape broadcasting and inferrence. + Value lhs_shape = if_builder.create<shape::ShapeOfOp>(loc, lhs); + Value rhs_shape = if_builder.create<shape::ShapeOfOp>(loc, rhs); + SmallVector<int64_t, 6> ranked_shape(targeted_rank, 1); + auto unknown_rank_extent_tensor_type = RankedTensorType::get( + {RankedTensorType::kDynamicSize}, builder.getIndexType()); + auto known_rank_extent_tensor_type = + RankedTensorType::get({targeted_rank}, builder.getIndexType()); + auto reshaped_type = RankedTensorType::get( + llvm::SmallVector<int64_t, 6>(targeted_rank, + RankedTensorType::kDynamicSize), + lhs.getType().template dyn_cast<TensorType>().getElementType()); + Value ranked_shape_val = if_builder.create<shape::ConstShapeOp>( + loc, known_rank_extent_tensor_type, + mlir::DenseIntElementsAttr::get(known_rank_extent_tensor_type, + ranked_shape)); + Value extended_lhs = if_builder.create<shape::BroadcastOp>( + loc, unknown_rank_extent_tensor_type, lhs_shape, ranked_shape_val, + nullptr); + Value extended_lhs_casted = if_builder.create<TensorCastOp>( + loc, known_rank_extent_tensor_type, extended_lhs); + Value extended_rhs = if_builder.create<shape::BroadcastOp>( + loc, unknown_rank_extent_tensor_type, rhs_shape, ranked_shape_val, + nullptr); + Value extended_rhs_casted = if_builder.create<TensorCastOp>( + loc, known_rank_extent_tensor_type, extended_rhs); + + // 1. Reshape operands to the given rank (with the same number of elements) + // 2. Compute the ranked-broadcasted ChloOp (which will assert that the ops + // can be broadcasted and do the actual broadcasting) + // 3. Type erase the output back to unranked + Value reshaped_lhs = if_builder.create<mhlo::DynamicReshapeOp>( + loc, reshaped_type, lhs, extended_lhs_casted); + Value reshaped_rhs = if_builder.create<mhlo::DynamicReshapeOp>( + loc, reshaped_type, rhs, extended_rhs_casted); + Value result = if_builder.create<ChloOpTy>( + loc, ArrayRef<Type>{reshaped_type}, + ArrayRef<Value>{reshaped_lhs, reshaped_rhs}, op.getAttrs()); + Value reshaped_result = if_builder.create<TensorCastOp>( + loc, UnrankedTensorType::get(reshaped_type.getElementType()), result); + if_builder.create<scf::YieldOp>(loc, reshaped_result); + + // Return the if_op, so the result can be used and the else block can be + // used for the next rank specialized step. + return if_op; + } + + // Iterates over the desired ranks to be specialized and generates the code + // snippet for each case. + Value HandleBroadcastAndOp(OpBuilder &rewriter, ChloOpTy op, Value lhs, + Value rhs) const { + constexpr int max_rank_specialization = 7; + auto loc = op.getLoc(); + + // Find the larger rank of the 2 operands. + auto extent_tensor_type = RankedTensorType::get({ShapedType::kDynamicSize}, + rewriter.getIndexType()); + Value lhs_shape = + rewriter.create<shape::ShapeOfOp>(loc, extent_tensor_type, lhs); + Value rhs_shape = + rewriter.create<shape::ShapeOfOp>(loc, extent_tensor_type, rhs); + Value lhs_rank = + rewriter.create<RankOp>(loc, rewriter.getIndexType(), lhs_shape); + Value rhs_rank = + rewriter.create<RankOp>(loc, rewriter.getIndexType(), rhs_shape); + Value greater_rank_lhs = + rewriter.create<CmpIOp>(loc, CmpIPredicate::sgt, lhs_rank, rhs_rank); + Value greater_rank = + rewriter.create<SelectOp>(loc, greater_rank_lhs, lhs_rank, rhs_rank); + + // Generate a list of nested if/else statements to handle rank + // specializations from 2-6. + scf::IfOp if_op = createRankSpecializedBroadcastAndOp(rewriter, op, lhs, + rhs, greater_rank, 2); + + // Put each subsequent rank specialization inside the else statement of the + // previous one. + OpBuilder else_builder = if_op.getElseBodyBuilder(rewriter.getListener()); + for (int i = 3; i < max_rank_specialization; i++) { + auto inner_if = createRankSpecializedBroadcastAndOp(else_builder, op, lhs, + rhs, greater_rank, i); + + else_builder.create<scf::YieldOp>(loc, inner_if.getResult(0)); + else_builder = inner_if.getElseBodyBuilder(rewriter.getListener()); + } + + // Fire an assertion if none of the rank specializations applied (one of the + // ranks was greater than 6). + else_builder.create<AssertOp>( + loc, else_builder.create<ConstantIntOp>(loc, 0, 1), + "Input for dynamic binary op lowering was of a rank greater than 6"); + else_builder.create<scf::YieldOp>(loc, lhs); + + // Return the result of the outermost if statement. + return if_op.getResult(0); + } +}; + struct TransformUnrankedHloPass : public PassWrapper<TransformUnrankedHloPass, FunctionPass> { void getDependentDialects(DialectRegistry ®istry) const override { @@ -137,7 +424,7 @@ struct TransformUnrankedHloPass MLIRContext &ctx = getContext(); ConversionTarget target(ctx); target.addLegalDialect<mhlo::MhloDialect, StandardOpsDialect, - shape::ShapeDialect>(); + shape::ShapeDialect, scf::SCFDialect>(); target.addLegalOp<FuncOp>(); #define ADD_LEGAL_MHLO(op) AddLegalOpOnRankedTensor<mhlo::op>(&target) #define ADD_LEGAL_CHLO(op) AddLegalOpOnRankedTensor<chlo::op>(&target) @@ -148,6 +435,12 @@ struct TransformUnrankedHloPass #undef ADD_LEGAL_CHLO AddLegalOpOnRankedTensor<mhlo::CompareOp>(&target); AddLegalOpOnRankedTensor<mhlo::SelectOp>(&target); + target.addDynamicallyLegalDialect<chlo::HloClientDialect>( + [](Operation *op) { + return !llvm::any_of(op->getOperandTypes(), [](Type type) { + return type.isa<UnrankedTensorType>(); + }); + }); // Populate rewrite patterns. OwningRewritePatternList patterns; @@ -180,6 +473,10 @@ void PopulateTransformUnrankedHloPatterns(MLIRContext *context, #undef MAP_BINARY #undef MAP_CHLO_UNARY #undef COMMA + chlo::PopulateForBroadcastingBinaryOp< + ConvertUnrankedDynamicBroadcastBinaryOp>(context, patterns); + chlo::PopulateForBroadcastingBinaryOp< + ConvertUnrankedScalarDynamicBroadcastBinaryOp>(context, patterns); } std::unique_ptr<FunctionPass> createTransformUnrankedHloPass() { diff --git a/tensorflow/compiler/mlir/hlo/tests/chlo_legalize_to_hlo_broadcasts.mlir b/tensorflow/compiler/mlir/hlo/tests/chlo_legalize_to_hlo_broadcasts.mlir index 60ec26f48a1..a83a29ff96a 100644 --- a/tensorflow/compiler/mlir/hlo/tests/chlo_legalize_to_hlo_broadcasts.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/chlo_legalize_to_hlo_broadcasts.mlir @@ -237,209 +237,3 @@ func @xorWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4x %0 = chlo.broadcast_xor %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1> return %0 : tensor<4xi1> } - -// ----- -func @addScalarUnranked(%arg0: tensor<f32>, %arg1: tensor<*xf32>) -> tensor<*xf32> { - %0 = chlo.broadcast_add %arg0, %arg1 : (tensor<f32>, tensor<*xf32>) - -> tensor<*xf32> - return %0 : tensor<*xf32> -} - -// CHECK-LABEL: func @addScalarUnranked( -// CHECK-SAME: %[[ARG_0:.*]]: tensor<f32>, -// CHECK-SAME: %[[ARG_1:.*]]: tensor<*xf32> -// CHECK-SAME: ) -> tensor<*xf32> { -// First handle the dynamic reshaping of the unranked operand -// to a 1D tensor. -// CHECK: %[[SHAPE_1:.*]] = shape.shape_of %[[ARG_1]] : tensor<*xf32> -// CHECK: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE_1]] : tensor<?xindex> -> index -// CHECK: %[[SIZE_TENSOR:.*]] = tensor_from_elements %[[NUM_ELEMENTS]] : tensor<1xindex> -// CHECK: %[[RESHAPED:.*]] = "mhlo.dynamic_reshape"(%[[ARG_1]], %[[SIZE_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32> -// The assuming region is part of the second stage of lowering -// with ranked broadcasting logic. -// CHECK: %[[SHAPE_0:.*]] = shape.shape_of %[[ARG_0]] : tensor<f32> -// CHECK: %[[SHAPE_RESHAPED:.*]] = shape.shape_of %[[RESHAPED]] : tensor<?xf32> -// CHECK: %[[WITNESS:.*]] = shape.cstr_broadcastable %[[SHAPE_0]], %[[SHAPE_RESHAPED]] -// CHECK: %[[ASSUMING_RESULT:.*]] = shape.assuming %[[WITNESS]] -> (tensor<?xf32>) { -// CHECK: %[[SCALAR_SHAPE:.*]] = shape.const_shape [] -// CHECK: %[[BROADCASTED_SHAPE:.*]] = shape.broadcast %[[SCALAR_SHAPE]], %[[SHAPE_RESHAPED]] -// CHECK: %[[SHAPE_TENSOR:.*]] = tensor_cast %[[BROADCASTED_SHAPE]] : tensor<?xindex> to tensor<1xindex> -// CHECK: %[[BROADCASTED_LHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG_0]], %[[SHAPE_TENSOR]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<1xindex>) -> tensor<?xf32> -// CHECK: %[[BROADCASTED_RHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[RESHAPED]], %[[SHAPE_TENSOR]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xf32>, tensor<1xindex>) -> tensor<?xf32> -// CHECK: %[[BROADCASTED_RESULT:.*]] = mhlo.add %[[BROADCASTED_LHS]], %[[BROADCASTED_RHS]] : tensor<?xf32> -// CHECK: shape.assuming_yield %[[BROADCASTED_RESULT]] : tensor<?xf32> -// CHECK: } -// As part of the unranked logic, the result is reshaped back -// to an unranked tensor. -// CHECK: %[[RESHAPED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[ASSUMING_RESULT:.*]], %[[SHAPE_1]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32> -// CHECK: return %[[RESHAPED_RESULT]] : tensor<*xf32> -// CHECK: } - -// ----- -func @addUnrankedScalar(%arg0: tensor<*xf32>, %arg1: tensor<f32>) -> tensor<*xf32> { - %0 = chlo.broadcast_add %arg0, %arg1 : (tensor<*xf32>, tensor<f32>) - -> tensor<*xf32> - return %0 : tensor<*xf32> -} -// CHECK-LABEL: func @addUnrankedScalar( -// CHECK-SAME: %[[ARG_0:.*]]: tensor<*xf32>, -// CHECK-SAME: %[[ARG_1:.*]]: tensor<f32>) -> tensor<*xf32> { -// First handle the dynamic reshaping of the unranked operand -// to a 1D tensor. -// CHECK: %[[SHAPE_0:.*]] = shape.shape_of %[[ARG_0]] : tensor<*xf32> -// CHECK: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE_0]] : tensor<?xindex> -> index -// CHECK: %[[SIZE_TENSOR:.*]] = tensor_from_elements %[[NUM_ELEMENTS]] : tensor<1xindex> -// CHECK: %[[RESHAPED:.*]] = "mhlo.dynamic_reshape"(%[[ARG_0]], %[[SIZE_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32> -// The assuming region is part of the second stage of lowering -// with ranked broadcasting logic. -// CHECK: %[[SHAPE_RESHAPED:.*]] = shape.shape_of %[[RESHAPED]] : tensor<?xf32> -// CHECK: %[[SHAPE_1:.*]] = shape.shape_of %[[ARG_1]] : tensor<f32> -// CHECK: %[[WITNESS:.*]] = shape.cstr_broadcastable %[[SHAPE_RESHAPED]], %[[SHAPE_1]] -// CHECK: %[[ASSUMING_RESULT:.*]] = shape.assuming %[[WITNESS]] -> (tensor<?xf32>) { -// CHECK: %[[ASTENSOR:.*]] = tensor_cast %[[SHAPE_RESHAPED]] -// CHECK: %[[BROADCASTED_LHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[RESHAPED]], %[[ASTENSOR]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xf32>, tensor<1xindex>) -> tensor<?xf32> -// CHECK: %[[BROADCASTED_RHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG_1]], %[[ASTENSOR]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<1xindex>) -> tensor<?xf32> -// CHECK: %[[BROADCASTED_RESULT:.*]] = mhlo.add %[[BROADCASTED_LHS]], %[[BROADCASTED_RHS]] : tensor<?xf32> -// CHECK: shape.assuming_yield %[[BROADCASTED_RESULT]] : tensor<?xf32> -// CHECK: } -// As part of the unranked logic, the result is reshaped back -// to an unranked tensor. -// CHECK: %[[RESHAPED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[ASSUMING_RESULT:.*]], %[[SHAPE_0]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32> -// CHECK: return %[[RESHAPED_RESULT]] : tensor<*xf32> -// CHECK: } - -// ----- -func @addUnrankedUnranked( - %arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> { - %0 = chlo.broadcast_add %arg0, %arg1 : (tensor<*xf32>, tensor<*xf32>) - -> tensor<*xf32> - return %0 : tensor<*xf32> -} - -// CHECK-LABEL: func @addUnrankedUnranked( -// CHECK-SAME: %[[LHS:.*]]: tensor<*xf32>, -// CHECK-SAME: %[[RHS:.*]]: tensor<*xf32>) -> tensor<*xf32> { -// CHECK: %[[LHS_SHAPE:.*]] = shape.shape_of %[[LHS]] : tensor<*xf32> -> tensor<?xindex> -// CHECK: %[[RANK_LHS:.*]] = shape.rank %[[LHS_SHAPE]] : tensor<?xindex> -> index -// CHECK: %[[C0:.*]] = constant 0 : index -// CHECK: %[[LHS_IS_SCALAR:.*]] = cmpi "eq", %[[RANK_LHS]], %[[C0]] : index -// Handle scalar LHS case -// CHECK: %[[VAL_8:.*]] = scf.if %[[LHS_IS_SCALAR]] -> (tensor<*xf32>) { -// CHECK: %[[SCALAR_LHS:.*]] = tensor_cast %[[LHS]] : tensor<*xf32> to tensor<f32> -// CHECK: %[[VAL_10:.*]] = chlo.broadcast_add %[[SCALAR_LHS]], %[[RHS]] : (tensor<f32>, tensor<*xf32>) -> tensor<*xf32> -// CHECK: scf.yield %[[VAL_10]] : tensor<*xf32> -// CHECK: } else { -// CHECK: %[[RHS_SHAPE:.*]] = shape.shape_of %[[RHS]] : tensor<*xf32> -> tensor<?xindex> -// CHECK: %[[RANK_RHS:.*]] = shape.rank %[[RHS_SHAPE]] : tensor<?xindex> -> index -// CHECK: %[[RHS_IS_SCALAR:.*]] = cmpi "eq", %[[RANK_RHS]], %[[C0]] : index - // Handle scalar RHS case -// CHECK: %[[VAL_14:.*]] = scf.if %[[RHS_IS_SCALAR]] -> (tensor<*xf32>) { -// CHECK: %[[SCALAR_RHS:.*]] = tensor_cast %[[RHS]] : tensor<*xf32> to tensor<f32> -// CHECK: %[[VAL_16:.*]] = chlo.broadcast_add %[[LHS]], %[[SCALAR_RHS]] : (tensor<*xf32>, tensor<f32>) -> tensor<*xf32> -// CHECK: scf.yield %[[VAL_16]] : tensor<*xf32> -// CHECK: } else { -// CHECK: %[[SHAPES_EQ:.*]] = shape.shape_eq %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex> - // Handle scalar RHS case -// CHECK: %[[VAL_18:.*]] = scf.if %[[SHAPES_EQ]] -> (tensor<*xf32>) { -// CHECK: %[[VAL_19:.*]] = mhlo.add %[[LHS]], %[[RHS]] : tensor<*xf32> -// CHECK: scf.yield %[[VAL_19]] : tensor<*xf32> -// CHECK: } else { -// CHECK: %[[LHS_RANK:.*]] = rank %[[LHS_SHAPE]] : tensor<?xindex> -// CHECK: %[[RHS_RANK:.*]] = rank %[[RHS_SHAPE]] : tensor<?xindex> -// CHECK: %[[LHS_RANK_GREATER:.*]] = cmpi "sgt", %[[LHS_RANK]], %[[RHS_RANK]] : index -// CHECK: %[[GREATEST_RANK:.*]] = select %[[LHS_RANK_GREATER]], %[[LHS_RANK]], %[[RHS_RANK]] : index -// CHECK: %[[C2:.*]] = constant 2 : index -// CHECK: %[[GREATEST_RANK_IS_2:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C2]] : index -// Handle rank 2 specialization -// CHECK: %[[VAL_26:.*]] = scf.if %[[GREATEST_RANK_IS_2]] -> (tensor<*xf32>) { -// CHECK: %[[CONST_SHAPE_2:.*]] = shape.const_shape [1, 1] -// CHECK: %[[BROADCASTED_LHS_2:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_2]] : tensor<?xindex>, tensor<2xindex> -> tensor<?xindex> -// CHECK: %[[CASTED_LHS_2:.*]] = tensor_cast %[[BROADCASTED_LHS_2]] : tensor<?xindex> to tensor<2xindex> -// CHECK: %[[BROADCASTED_RHS_2:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_2]] : tensor<?xindex>, tensor<2xindex> -> tensor<?xindex> -// CHECK: %[[CASTED_RHS_2:.*]] = tensor_cast %[[BROADCASTED_RHS_2]] : tensor<?xindex> to tensor<2xindex> -// CHECK: %[[RESHAPED_LHS_2:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor<?x?xf32> -// CHECK: %[[RESHAPED_RHS_2:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor<?x?xf32> -// CHECK: %[[RESULT_RANK_2:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_2]], %[[RESHAPED_RHS_2]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32> -// CHECK: %[[RESULT_2:.*]] = tensor_cast %[[RESULT_RANK_2]] : tensor<?x?xf32> to tensor<*xf32> -// CHECK: scf.yield %[[RESULT_2]] : tensor<*xf32> -// CHECK: } else { -// CHECK: %[[C3:.*]] = constant 3 : index -// CHECK: %[[GREATEST_RANK_IS_3:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C3]] : index -// Handle rank 3 specialization -// CHECK: %[[VAL_34:.*]] = scf.if %[[GREATEST_RANK_IS_3]] -> (tensor<*xf32>) { -// CHECK: %[[CONST_SHAPE_3:.*]] = shape.const_shape [1, 1, 1] -// CHECK: %[[BROADCASTED_LHS_3:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_3]] : tensor<?xindex>, tensor<3xindex> -> tensor<?xindex> -// CHECK: %[[CASTED_LHS_3:.*]] = tensor_cast %[[BROADCASTED_LHS_3]] : tensor<?xindex> to tensor<3xindex> -// CHECK: %[[BROADCASTED_RHS_3:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_3]] : tensor<?xindex>, tensor<3xindex> -> tensor<?xindex> -// CHECK: %[[CASTED_RHS_3:.*]] = tensor_cast %[[BROADCASTED_RHS_3]] : tensor<?xindex> to tensor<3xindex> -// CHECK: %[[RESHAPED_LHS_3:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor<?x?x?xf32> -// CHECK: %[[RESHAPED_RHS_3:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor<?x?x?xf32> -// CHECK: %[[RESULT_RANK_3:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_3]], %[[RESHAPED_RHS_3]] : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32> -// CHECK: %[[RESULT_3:.*]] = tensor_cast %[[RESULT_RANK_3]] : tensor<?x?x?xf32> to tensor<*xf32> -// CHECK: scf.yield %[[RESULT_3]] : tensor<*xf32> -// CHECK: } else { -// CHECK: %[[C4:.*]] = constant 4 : index -// CHECK: %[[GREATEST_RANK_IS_4:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C4]] : index -// Handle rank 4 specialization -// CHECK: %[[VAL_42:.*]] = scf.if %[[GREATEST_RANK_IS_4]] -> (tensor<*xf32>) { -// CHECK: %[[CONST_SHAPE_4:.*]] = shape.const_shape [1, 1, 1, 1] -// CHECK: %[[BROADCASTED_LHS_4:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_4]] : tensor<?xindex>, tensor<4xindex> -> tensor<?xindex> -// CHECK: %[[CASTED_LHS_4:.*]] = tensor_cast %[[BROADCASTED_LHS_4]] : tensor<?xindex> to tensor<4xindex> -// CHECK: %[[BROADCASTED_RHS_4:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_4]] : tensor<?xindex>, tensor<4xindex> -> tensor<?xindex> -// CHECK: %[[CASTED_RHS_4:.*]] = tensor_cast %[[BROADCASTED_RHS_4]] : tensor<?xindex> to tensor<4xindex> -// CHECK: %[[RESHAPED_LHS_4:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32> -// CHECK: %[[RESHAPED_RHS_4:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32> -// CHECK: %[[RESULT_RANK_4:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_4]], %[[RESHAPED_RHS_4]] : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> -// CHECK: %[[RESULT_4:.*]] = tensor_cast %[[RESULT_RANK_4]] : tensor<?x?x?x?xf32> to tensor<*xf32> -// CHECK: scf.yield %[[RESULT_4]] : tensor<*xf32> -// CHECK: } else { -// CHECK: %[[C5:.*]] = constant 5 : index -// CHECK: %[[GREATEST_RANK_IS_5:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C5]] : index -// Handle rank 5 specialization -// CHECK: %[[VAL_50:.*]] = scf.if %[[GREATEST_RANK_IS_5]] -> (tensor<*xf32>) { -// CHECK: %[[CONST_SHAPE_5:.*]] = shape.const_shape [1, 1, 1, 1, 1] -// CHECK: %[[BROADCASTED_LHS_5:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_5]] : tensor<?xindex>, tensor<5xindex> -> tensor<?xindex> -// CHECK: %[[CASTED_LHS_5:.*]] = tensor_cast %[[BROADCASTED_LHS_5]] : tensor<?xindex> to tensor<5xindex> -// CHECK: %[[BROADCASTED_RHS_5:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_5]] : tensor<?xindex>, tensor<5xindex> -> tensor<?xindex> -// CHECK: %[[CASTED_RHS_5:.*]] = tensor_cast %[[BROADCASTED_RHS_5]] : tensor<?xindex> to tensor<5xindex> -// CHECK: %[[RESHAPED_LHS_5:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor<?x?x?x?x?xf32> -// CHECK: %[[RESHAPED_RHS_5:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor<?x?x?x?x?xf32> -// CHECK: %[[RESULT_RANK_5:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_5]], %[[RESHAPED_RHS_5]] : (tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> -// CHECK: %[[RESULT_5:.*]] = tensor_cast %[[RESULT_RANK_5]] : tensor<?x?x?x?x?xf32> to tensor<*xf32> -// CHECK: scf.yield %[[RESULT_5]] : tensor<*xf32> -// CHECK: } else { -// CHECK: %[[C6:.*]] = constant 6 : index -// CHECK: %[[GREATEST_RANK_IS_6:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C6]] : index -// Handle rank 6 specialization -// CHECK: %[[VAL_58:.*]] = scf.if %[[GREATEST_RANK_IS_6]] -> (tensor<*xf32>) { -// CHECK: %[[CONST_SHAPE_6:.*]] = shape.const_shape [1, 1, 1, 1, 1, 1] -// CHECK: %[[BROADCASTED_LHS_6:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_6]] : tensor<?xindex>, tensor<6xindex> -> tensor<?xindex> -// CHECK: %[[CASTED_LHS_6:.*]] = tensor_cast %[[BROADCASTED_LHS_6]] : tensor<?xindex> to tensor<6xindex> -// CHECK: %[[BROADCASTED_RHS_6:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_6]] : tensor<?xindex>, tensor<6xindex> -> tensor<?xindex> -// CHECK: %[[CASTED_RHS_6:.*]] = tensor_cast %[[BROADCASTED_RHS_6]] : tensor<?xindex> to tensor<6xindex> -// CHECK: %[[RESHAPED_LHS_6:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_6]]) : (tensor<*xf32>, tensor<6xindex>) -> tensor<?x?x?x?x?x?xf32> -// CHECK: %[[RESHAPED_RHS_6:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_6]]) : (tensor<*xf32>, tensor<6xindex>) -> tensor<?x?x?x?x?x?xf32> -// CHECK: %[[RESULT_RANK_6:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_6]], %[[RESHAPED_RHS_6]] : (tensor<?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?xf32>) -> tensor<?x?x?x?x?x?xf32> -// CHECK: %[[RESULT_6:.*]] = tensor_cast %[[RESULT_RANK_6]] : tensor<?x?x?x?x?x?xf32> to tensor<*xf32> -// CHECK: scf.yield %[[RESULT_6]] : tensor<*xf32> -// CHECK: } else { -// CHECK: %false = constant false -// CHECK: assert %false -// CHECK: scf.yield %[[LHS]] : tensor<*xf32> -// CHECK: } -// CHECK: scf.yield %[[VAL_64:.*]] : tensor<*xf32> -// CHECK: } -// CHECK: scf.yield %[[VAL_65:.*]] : tensor<*xf32> -// CHECK: } -// CHECK: scf.yield %[[VAL_66:.*]] : tensor<*xf32> -// CHECK: } -// CHECK: scf.yield %[[VAL_67:.*]] : tensor<*xf32> -// CHECK: } -// CHECK: scf.yield %[[VAL_68:.*]] : tensor<*xf32> -// CHECK: } -// CHECK: scf.yield %[[VAL_69:.*]] : tensor<*xf32> -// CHECK: } -// CHECK: scf.yield %[[VAL_70:.*]] : tensor<*xf32> -// CHECK: } -// CHECK: return %[[VAL_71:.*]] : tensor<*xf32> -// CHECK: } diff --git a/tensorflow/compiler/mlir/hlo/tests/hlo-transform-unranked.mlir b/tensorflow/compiler/mlir/hlo/tests/hlo-transform-unranked.mlir index ae61fc8477e..c0baecdd954 100644 --- a/tensorflow/compiler/mlir/hlo/tests/hlo-transform-unranked.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/hlo-transform-unranked.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-hlo-opt --transform-unranked-hlo --split-input-file %s | FileCheck %s +// RUN: mlir-hlo-opt --transform-unranked-hlo --cse --split-input-file %s | FileCheck %s // Check the validity of expected IR. // CHECK-LABEL: @sqr_transform_result @@ -96,3 +96,203 @@ func @tan(%a : tensor<*xf32>) -> tensor<*xf32> { %result = chlo.tan %a : tensor<*xf32> return %result : tensor<*xf32> } + +// ----- + +func @addScalarUnranked(%arg0: tensor<f32>, %arg1: tensor<*xf32>) -> tensor<*xf32> { + %0 = chlo.broadcast_add %arg0, %arg1 : (tensor<f32>, tensor<*xf32>) + -> tensor<*xf32> + return %0 : tensor<*xf32> +} + +// CHECK-LABEL: func @addScalarUnranked( +// CHECK-SAME: %[[ARG_0:.*]]: tensor<f32>, +// CHECK-SAME: %[[ARG_1:.*]]: tensor<*xf32> +// CHECK-SAME: ) -> tensor<*xf32> { +// First handle the dynamic reshaping of the unranked operand +// to a 1D tensor. +// CHECK-NEXT: %[[SHAPE_1:.*]] = shape.shape_of %[[ARG_1]] : tensor<*xf32> +// CHECK-NEXT: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE_1]] : tensor<?xindex> -> index +// CHECK-NEXT: %[[SIZE_TENSOR:.*]] = tensor_from_elements %[[NUM_ELEMENTS]] : tensor<1xindex> +// CHECK-NEXT: %[[RESHAPED:.*]] = "mhlo.dynamic_reshape"(%[[ARG_1]], %[[SIZE_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32> +// CHECK-NEXT: %[[BROADCASTED_RESULT:.*]] = chlo.broadcast_add %[[ARG_0]], %[[RESHAPED]] : (tensor<f32>, tensor<?xf32>) -> tensor<?xf32> +// As part of the unranked logic, the result is reshaped back +// to an unranked tensor. +// CHECK-NEXT: %[[RESHAPED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[BROADCASTED_RESULT:.*]], %[[SHAPE_1]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32> +// CHECK-NEXT: return %[[RESHAPED_RESULT]] : tensor<*xf32> +// CHECK-NEXT: } + +// ----- +func @addUnrankedScalar(%arg0: tensor<*xf32>, %arg1: tensor<f32>) -> tensor<*xf32> { + %0 = chlo.broadcast_add %arg0, %arg1 : (tensor<*xf32>, tensor<f32>) + -> tensor<*xf32> + return %0 : tensor<*xf32> +} +// CHECK-LABEL: func @addUnrankedScalar( +// CHECK-SAME: %[[ARG_0:.*]]: tensor<*xf32>, +// CHECK-SAME: %[[ARG_1:.*]]: tensor<f32>) -> tensor<*xf32> { +// First handle the dynamic reshaping of the unranked operand +// to a 1D tensor. +// CHECK-NEXT: %[[SHAPE_0:.*]] = shape.shape_of %[[ARG_0]] : tensor<*xf32> +// CHECK-NEXT: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE_0]] : tensor<?xindex> -> index +// CHECK-NEXT: %[[SIZE_TENSOR:.*]] = tensor_from_elements %[[NUM_ELEMENTS]] : tensor<1xindex> +// CHECK-NEXT: %[[RESHAPED:.*]] = "mhlo.dynamic_reshape"(%[[ARG_0]], %[[SIZE_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32> +// The assuming region is part of the second stage of lowering +// with ranked broadcasting logic. +// CHECK-NEXT: %[[BROADCASTED_RESULT:.*]] = chlo.broadcast_add %[[RESHAPED]], %[[ARG_1]] : (tensor<?xf32>, tensor<f32>) -> tensor<?xf32> +// As part of the unranked logic, the result is reshaped back +// to an unranked tensor. +// CHECK-NEXT: %[[RESHAPED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[BROADCASTED_RESULT:.*]], %[[SHAPE_0]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32> +// CHECK-NEXT: return %[[RESHAPED_RESULT]] : tensor<*xf32> +// CHECK-NEXT: } + +// ----- +func @addUnrankedUnranked( + %arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> { + %0 = chlo.broadcast_add %arg0, %arg1 : (tensor<*xf32>, tensor<*xf32>) + -> tensor<*xf32> + return %0 : tensor<*xf32> +} + +// CHECK-LABEL: func @addUnrankedUnranked( +// CHECK-SAME: %[[LHS:.*]]: tensor<*xf32>, +// CHECK-SAME: %[[RHS:.*]]: tensor<*xf32>) -> tensor<*xf32> { +// CHECK-NEXT: %[[LHS_SHAPE:.*]] = shape.shape_of %[[LHS]] : tensor<*xf32> -> tensor<?xindex> +// CHECK-NEXT: %[[RANK_LHS:.*]] = shape.rank %[[LHS_SHAPE]] : tensor<?xindex> -> index +// CHECK-NEXT: %[[C0:.*]] = constant 0 : index +// CHECK-NEXT: %[[LHS_IS_SCALAR:.*]] = cmpi "eq", %[[RANK_LHS]], %[[C0]] : index +// Handle scalar LHS case +// CHECK-NEXT: %[[VAL_8:.*]] = scf.if %[[LHS_IS_SCALAR]] -> (tensor<*xf32>) { +// CHECK-NEXT: %[[SCALAR_LHS:.*]] = tensor_cast %[[LHS]] : tensor<*xf32> to tensor<f32> +// CHECK-NEXT: %[[RHS_SHAPE_1:.*]] = shape.shape_of %[[RHS]] : tensor<*xf32> -> tensor<?xindex> +// CHECK-NEXT: %[[NUM_RHS:.*]] = shape.num_elements %[[RHS_SHAPE_1]] : tensor<?xindex> -> index +// CHECK-NEXT: %[[NUM_TENS_RHS:.*]] = tensor_from_elements %[[NUM_RHS]] : tensor<1xindex> +// CHECK-NEXT: %[[RESHAPED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[NUM_TENS_RHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32> +// CHECK-NEXT: %[[LHS_SCALAR_RESULT:.*]] = chlo.broadcast_add %[[SCALAR_LHS]], %[[RESHAPED_RHS]] : (tensor<f32>, tensor<?xf32>) -> tensor<?xf32> +// CHECK-NEXT: %[[RESHAPED_LHS_SCALAR_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[LHS_SCALAR_RESULT]], %[[RHS_SHAPE_1]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32> +// CHECK-NEXT: scf.yield %[[RESHAPED_LHS_SCALAR_RESULT]] : tensor<*xf32> +// CHECK-NEXT: } else { +// CHECK-NEXT: %[[RHS_SHAPE:.*]] = shape.shape_of %[[RHS]] : tensor<*xf32> -> tensor<?xindex> +// CHECK-NEXT: %[[RANK_RHS:.*]] = shape.rank %[[RHS_SHAPE]] : tensor<?xindex> -> index +// CHECK-NEXT: %[[RHS_IS_SCALAR:.*]] = cmpi "eq", %[[RANK_RHS]], %[[C0]] : index +// Handle scalar RHS case +// CHECK-NEXT: %[[VAL_14:.*]] = scf.if %[[RHS_IS_SCALAR]] -> (tensor<*xf32>) { +// CHECK-NEXT: %[[SCALAR_RHS:.*]] = tensor_cast %[[RHS]] : tensor<*xf32> to tensor<f32> +// CHECK-NEXT: %[[NUM_LHS:.*]] = shape.num_elements %[[LHS_SHAPE]] : tensor<?xindex> -> index +// CHECK-NEXT: %[[NUM_TENS_LHS:.*]] = tensor_from_elements %[[NUM_LHS]] : tensor<1xindex> +// CHECK-NEXT: %[[RESHAPED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[NUM_TENS_LHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32> +// CHECK-NEXT: %[[RHS_SCALAR_RESULT:.*]] = chlo.broadcast_add %[[RESHAPED_LHS]], %[[SCALAR_RHS]] : (tensor<?xf32>, tensor<f32>) -> tensor<?xf32> +// CHECK-NEXT: %[[RESHAPED_RHS_SCALAR_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[RHS_SCALAR_RESULT:.*]], %[[LHS_SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32> +// CHECK-NEXT: scf.yield %[[RESHAPED_RHS_SCALAR_RESULT]] : tensor<*xf32> +// CHECK-NEXT: } else { +// CHECK-NEXT: %[[SHAPES_EQ:.*]] = shape.shape_eq %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex> +// Handle equal shapes case +// CHECK-NEXT: %[[VAL_18:.*]] = scf.if %[[SHAPES_EQ]] -> (tensor<*xf32>) { +// CHECK-NEXT: %[[ANY_SHAPE:.*]] = shape.any %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex> -> tensor<?xindex> +// CHECK-NEXT: %[[ANY_NUM:.*]] = shape.num_elements %[[ANY_SHAPE]] : tensor<?xindex> -> index +// CHECK-NEXT: %[[ANY_TENSOR:.*]] = tensor_from_elements %[[ANY_NUM]] : tensor<1xindex> +// CHECK-NEXT: %[[FLATTENED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[ANY_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32> +// CHECK-NEXT: %[[FLATTENED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[ANY_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32> +// CHECK-NEXT: %[[FLATTENED_RESULT:.*]] = mhlo.add %[[FLATTENED_LHS]], %[[FLATTENED_RHS]] : tensor<?xf32> +// CHECK-NEXT: %[[RESHAPED_SAME_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[FLATTENED_RESULT]], %[[ANY_SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32> +// CHECK-NEXT: scf.yield %[[RESHAPED_SAME_RESULT]] : tensor<*xf32> +// CHECK-NEXT: } else { +// CHECK-NEXT: %[[LHS_RANK:.*]] = rank %[[LHS_SHAPE]] : tensor<?xindex> +// CHECK-NEXT: %[[RHS_RANK:.*]] = rank %[[RHS_SHAPE]] : tensor<?xindex> +// CHECK-NEXT: %[[LHS_RANK_GREATER:.*]] = cmpi "sgt", %[[LHS_RANK]], %[[RHS_RANK]] : index +// CHECK-NEXT: %[[GREATEST_RANK:.*]] = select %[[LHS_RANK_GREATER]], %[[LHS_RANK]], %[[RHS_RANK]] : index +// CHECK-NEXT: %[[C2:.*]] = constant 2 : index +// CHECK-NEXT: %[[GREATEST_RANK_IS_2:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C2]] : index +// Handle rank 2 specialization +// CHECK-NEXT: %[[VAL_26:.*]] = scf.if %[[GREATEST_RANK_IS_2]] -> (tensor<*xf32>) { +// CHECK-NEXT: %[[CONST_SHAPE_2:.*]] = shape.const_shape [1, 1] +// CHECK-NEXT: %[[BROADCASTED_LHS_2:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_2]] : tensor<?xindex>, tensor<2xindex> -> tensor<?xindex> +// CHECK-NEXT: %[[CASTED_LHS_2:.*]] = tensor_cast %[[BROADCASTED_LHS_2]] : tensor<?xindex> to tensor<2xindex> +// CHECK-NEXT: %[[BROADCASTED_RHS_2:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_2]] : tensor<?xindex>, tensor<2xindex> -> tensor<?xindex> +// CHECK-NEXT: %[[CASTED_RHS_2:.*]] = tensor_cast %[[BROADCASTED_RHS_2]] : tensor<?xindex> to tensor<2xindex> +// CHECK-NEXT: %[[RESHAPED_LHS_2:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor<?x?xf32> +// CHECK-NEXT: %[[RESHAPED_RHS_2:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor<?x?xf32> +// CHECK-NEXT: %[[RESULT_RANK_2:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_2]], %[[RESHAPED_RHS_2]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32> +// CHECK-NEXT: %[[RESULT_2:.*]] = tensor_cast %[[RESULT_RANK_2]] : tensor<?x?xf32> to tensor<*xf32> +// CHECK-NEXT: scf.yield %[[RESULT_2]] : tensor<*xf32> +// CHECK-NEXT: } else { +// CHECK-NEXT: %[[C3:.*]] = constant 3 : index +// CHECK-NEXT: %[[GREATEST_RANK_IS_3:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C3]] : index +// Handle rank 3 specialization +// CHECK-NEXT: %[[VAL_34:.*]] = scf.if %[[GREATEST_RANK_IS_3]] -> (tensor<*xf32>) { +// CHECK-NEXT: %[[CONST_SHAPE_3:.*]] = shape.const_shape [1, 1, 1] +// CHECK-NEXT: %[[BROADCASTED_LHS_3:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_3]] : tensor<?xindex>, tensor<3xindex> -> tensor<?xindex> +// CHECK-NEXT: %[[CASTED_LHS_3:.*]] = tensor_cast %[[BROADCASTED_LHS_3]] : tensor<?xindex> to tensor<3xindex> +// CHECK-NEXT: %[[BROADCASTED_RHS_3:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_3]] : tensor<?xindex>, tensor<3xindex> -> tensor<?xindex> +// CHECK-NEXT: %[[CASTED_RHS_3:.*]] = tensor_cast %[[BROADCASTED_RHS_3]] : tensor<?xindex> to tensor<3xindex> +// CHECK-NEXT: %[[RESHAPED_LHS_3:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor<?x?x?xf32> +// CHECK-NEXT: %[[RESHAPED_RHS_3:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor<?x?x?xf32> +// CHECK-NEXT: %[[RESULT_RANK_3:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_3]], %[[RESHAPED_RHS_3]] : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32> +// CHECK-NEXT: %[[RESULT_3:.*]] = tensor_cast %[[RESULT_RANK_3]] : tensor<?x?x?xf32> to tensor<*xf32> +// CHECK-NEXT: scf.yield %[[RESULT_3]] : tensor<*xf32> +// CHECK-NEXT: } else { +// CHECK-NEXT: %[[C4:.*]] = constant 4 : index +// CHECK-NEXT: %[[GREATEST_RANK_IS_4:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C4]] : index +// Handle rank 4 specialization +// CHECK-NEXT: %[[VAL_42:.*]] = scf.if %[[GREATEST_RANK_IS_4]] -> (tensor<*xf32>) { +// CHECK-NEXT: %[[CONST_SHAPE_4:.*]] = shape.const_shape [1, 1, 1, 1] +// CHECK-NEXT: %[[BROADCASTED_LHS_4:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_4]] : tensor<?xindex>, tensor<4xindex> -> tensor<?xindex> +// CHECK-NEXT: %[[CASTED_LHS_4:.*]] = tensor_cast %[[BROADCASTED_LHS_4]] : tensor<?xindex> to tensor<4xindex> +// CHECK-NEXT: %[[BROADCASTED_RHS_4:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_4]] : tensor<?xindex>, tensor<4xindex> -> tensor<?xindex> +// CHECK-NEXT: %[[CASTED_RHS_4:.*]] = tensor_cast %[[BROADCASTED_RHS_4]] : tensor<?xindex> to tensor<4xindex> +// CHECK-NEXT: %[[RESHAPED_LHS_4:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32> +// CHECK-NEXT: %[[RESHAPED_RHS_4:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32> +// CHECK-NEXT: %[[RESULT_RANK_4:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_4]], %[[RESHAPED_RHS_4]] : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> +// CHECK-NEXT: %[[RESULT_4:.*]] = tensor_cast %[[RESULT_RANK_4]] : tensor<?x?x?x?xf32> to tensor<*xf32> +// CHECK-NEXT: scf.yield %[[RESULT_4]] : tensor<*xf32> +// CHECK-NEXT: } else { +// CHECK-NEXT: %[[C5:.*]] = constant 5 : index +// CHECK-NEXT: %[[GREATEST_RANK_IS_5:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C5]] : index +// Handle rank 5 specialization +// CHECK-NEXT: %[[VAL_50:.*]] = scf.if %[[GREATEST_RANK_IS_5]] -> (tensor<*xf32>) { +// CHECK-NEXT: %[[CONST_SHAPE_5:.*]] = shape.const_shape [1, 1, 1, 1, 1] +// CHECK-NEXT: %[[BROADCASTED_LHS_5:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_5]] : tensor<?xindex>, tensor<5xindex> -> tensor<?xindex> +// CHECK-NEXT: %[[CASTED_LHS_5:.*]] = tensor_cast %[[BROADCASTED_LHS_5]] : tensor<?xindex> to tensor<5xindex> +// CHECK-NEXT: %[[BROADCASTED_RHS_5:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_5]] : tensor<?xindex>, tensor<5xindex> -> tensor<?xindex> +// CHECK-NEXT: %[[CASTED_RHS_5:.*]] = tensor_cast %[[BROADCASTED_RHS_5]] : tensor<?xindex> to tensor<5xindex> +// CHECK-NEXT: %[[RESHAPED_LHS_5:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor<?x?x?x?x?xf32> +// CHECK-NEXT: %[[RESHAPED_RHS_5:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor<?x?x?x?x?xf32> +// CHECK-NEXT: %[[RESULT_RANK_5:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_5]], %[[RESHAPED_RHS_5]] : (tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> +// CHECK-NEXT: %[[RESULT_5:.*]] = tensor_cast %[[RESULT_RANK_5]] : tensor<?x?x?x?x?xf32> to tensor<*xf32> +// CHECK-NEXT: scf.yield %[[RESULT_5]] : tensor<*xf32> +// CHECK-NEXT: } else { +// CHECK-NEXT: %[[C6:.*]] = constant 6 : index +// CHECK-NEXT: %[[GREATEST_RANK_IS_6:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C6]] : index +// Handle rank 6 specialization +// CHECK-NEXT: %[[VAL_58:.*]] = scf.if %[[GREATEST_RANK_IS_6]] -> (tensor<*xf32>) { +// CHECK-NEXT: %[[CONST_SHAPE_6:.*]] = shape.const_shape [1, 1, 1, 1, 1, 1] +// CHECK-NEXT: %[[BROADCASTED_LHS_6:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_6]] : tensor<?xindex>, tensor<6xindex> -> tensor<?xindex> +// CHECK-NEXT: %[[CASTED_LHS_6:.*]] = tensor_cast %[[BROADCASTED_LHS_6]] : tensor<?xindex> to tensor<6xindex> +// CHECK-NEXT: %[[BROADCASTED_RHS_6:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_6]] : tensor<?xindex>, tensor<6xindex> -> tensor<?xindex> +// CHECK-NEXT: %[[CASTED_RHS_6:.*]] = tensor_cast %[[BROADCASTED_RHS_6]] : tensor<?xindex> to tensor<6xindex> +// CHECK-NEXT: %[[RESHAPED_LHS_6:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_6]]) : (tensor<*xf32>, tensor<6xindex>) -> tensor<?x?x?x?x?x?xf32> +// CHECK-NEXT: %[[RESHAPED_RHS_6:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_6]]) : (tensor<*xf32>, tensor<6xindex>) -> tensor<?x?x?x?x?x?xf32> +// CHECK-NEXT: %[[RESULT_RANK_6:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_6]], %[[RESHAPED_RHS_6]] : (tensor<?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?xf32>) -> tensor<?x?x?x?x?x?xf32> +// CHECK-NEXT: %[[RESULT_6:.*]] = tensor_cast %[[RESULT_RANK_6]] : tensor<?x?x?x?x?x?xf32> to tensor<*xf32> +// CHECK-NEXT: scf.yield %[[RESULT_6]] : tensor<*xf32> +// CHECK-NEXT: } else { +// CHECK-NEXT: %false = constant false +// CHECK-NEXT: assert %false +// CHECK-NEXT: scf.yield %[[LHS]] : tensor<*xf32> +// CHECK-NEXT: } +// CHECK-NEXT: scf.yield %[[VAL_64:.*]] : tensor<*xf32> +// CHECK-NEXT: } +// CHECK-NEXT: scf.yield %[[VAL_65:.*]] : tensor<*xf32> +// CHECK-NEXT: } +// CHECK-NEXT: scf.yield %[[VAL_66:.*]] : tensor<*xf32> +// CHECK-NEXT: } +// CHECK-NEXT: scf.yield %[[VAL_67:.*]] : tensor<*xf32> +// CHECK-NEXT: } +// CHECK-NEXT: scf.yield %[[VAL_68:.*]] : tensor<*xf32> +// CHECK-NEXT: } +// CHECK-NEXT: scf.yield %[[VAL_69:.*]] : tensor<*xf32> +// CHECK-NEXT: } +// CHECK-NEXT: scf.yield %[[VAL_70:.*]] : tensor<*xf32> +// CHECK-NEXT: } +// CHECK-NEXT: return %[[VAL_71:.*]] : tensor<*xf32> +// CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc index 861048a9a93..d593c0ec836 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc @@ -306,6 +306,14 @@ inline bool IsF32ShapedType(Type t) { return false; } +// Returns true if it is a shaped type of bf16 elements. +inline bool IsBF16ShapedType(Type t) { + if (auto shaped_type = t.dyn_cast_or_null<ShapedType>()) { + return shaped_type.getElementType().isBF16(); + } + return false; +} + // Performs const folding `calculate` with broadcast behavior on the two // attributes `operand1` and `operand2` and returns the result if possible. // The two operands are expected to both be scalar values. @@ -498,7 +506,7 @@ Attribute ConstFoldBinaryOp( /// "tfl.logical_not". Attribute ConstFoldUnaryOp(Type result_type, Attribute operand, llvm::function_ref<APFloat(APFloat)> calculate) { - assert(IsF32ShapedType(result_type)); + assert(IsF32ShapedType(result_type) || IsBF16ShapedType(result_type)); auto result_shape_type = result_type.cast<ShapedType>(); if (auto dense_elements = operand.dyn_cast_or_null<DenseElementsAttr>()) { @@ -1911,13 +1919,20 @@ OpFoldResult SqrtOp::fold(ArrayRef<Attribute> operands) { OpFoldResult RsqrtOp::fold(ArrayRef<Attribute> operands) { Type result_type = getType(); - // Only constant fold for tensor of f32 is implemented. - if (!IsF32ShapedType(result_type)) return nullptr; + // Only constant fold for tensor of f32/bf16 is implemented. + if (!IsF32ShapedType(result_type) && !IsBF16ShapedType(result_type)) + return nullptr; auto compute = [](APFloat value) -> APFloat { + bool loseInfo; + const llvm::fltSemantics &original_float_semantics = value.getSemantics(); + value.convert(APFloat::IEEEsingle(), APFloat::rmNearestTiesToEven, + &loseInfo); float f = value.convertToFloat(); - float result = 1.f / std::sqrt(f); - return APFloat(result); + APFloat result(1.f / std::sqrt(f)); + result.convert(original_float_semantics, APFloat::rmNearestTiesToEven, + &loseInfo); + return result; }; return ConstFoldUnaryOp(result_type, operands[0], compute); } diff --git a/tensorflow/compiler/mlir/lite/tests/const-fold.mlir b/tensorflow/compiler/mlir/lite/tests/const-fold.mlir index ff7c47fb621..69009ae594b 100644 --- a/tensorflow/compiler/mlir/lite/tests/const-fold.mlir +++ b/tensorflow/compiler/mlir/lite/tests/const-fold.mlir @@ -577,3 +577,13 @@ func @div_dense_different_rank() -> tensor<1x2x2xf32> { // CHECK: %[[CST:.*]] = constant dense<[{{\[}}{{\[}}5.000000e-01, 0.333333343], [1.000000e+00, 0.666666686]]]> : tensor<1x2x2xf32> // CHECK: return %[[CST]] } + +// CHECK-LABEL: @rsqrt_bf16 +func @rsqrt_bf16() -> tensor<bf16> { + %cst = constant dense<4.0> : tensor<bf16> + %0 = "tfl.rsqrt"(%cst) : (tensor<bf16>) -> tensor<bf16> + return %0 : tensor<bf16> + +// CHECK: %[[CST:.*]] = constant dense<5.000000e-01> : tensor<bf16> +// CHECK: return %[[CST]] +} diff --git a/tensorflow/compiler/mlir/lite/tests/optimize.mlir b/tensorflow/compiler/mlir/lite/tests/optimize.mlir index e09e63219e5..336a4653748 100644 --- a/tensorflow/compiler/mlir/lite/tests/optimize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/optimize.mlir @@ -1358,3 +1358,27 @@ func @fuseScalarAddIntoConv2d(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16x3 // CHECK: %cst = constant dense<[2.500000e+00, 3.500000e+00, 4.500000e+00, 5.500000e+00, 6.500000e+00, 7.500000e+00, 8.500000e+00, 9.500000e+00, 1.050000e+01, 1.150000e+01, 1.250000e+01, 1.350000e+01, 1.450000e+01, 1.550000e+01, 1.650000e+01, 1.750000e+01]> : tensor<16xf32> // CHECK: %0 = "tfl.conv_2d"(%arg0, %arg1, %cst) } + +// CHECK-LABEL: fuseScalarAddIntoConv2dBf16 +func @fuseScalarAddIntoConv2dBf16(%arg0: tensor<256x32x32x3xbf16>, %arg1: tensor<16x3x3x3xbf16>) -> tensor<256x30x30x16xbf16> { + %cst = constant dense<1.5> : tensor<bf16> + %cst_0 = constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : tensor<16xbf16> + %0 = "tfl.conv_2d"(%arg0, %arg1, %cst_0) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xbf16>, tensor<16x3x3x3xbf16>, tensor<16xbf16>) -> tensor<256x30x30x16xbf16> + %1 = "tfl.add"(%0, %cst) {fused_activation_function = "NONE"} : (tensor<256x30x30x16xbf16>, tensor<bf16>) -> tensor<256x30x30x16xbf16> + return %1 : tensor<256x30x30x16xbf16> + + // CHECK: %cst = constant dense<[2.500000e+00, 3.500000e+00, 4.500000e+00, 5.500000e+00, 6.500000e+00, 7.500000e+00, 8.500000e+00, 9.500000e+00, 1.050000e+01, 1.150000e+01, 1.250000e+01, 1.350000e+01, 1.450000e+01, 1.550000e+01, 1.650000e+01, 1.750000e+01]> : tensor<16xbf16> + // CHECK: %0 = "tfl.conv_2d"(%arg0, %arg1, %cst) +} + +// CHECK-LABEL: fuseScalarAddIntoConv2dHalf +func @fuseScalarAddIntoConv2dHalf(%arg0: tensor<256x32x32x3xf16>, %arg1: tensor<16x3x3x3xf16>) -> tensor<256x30x30x16xf16> { + %cst = constant dense<1.5> : tensor<f16> + %cst_0 = constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : tensor<16xf16> + %0 = "tfl.conv_2d"(%arg0, %arg1, %cst_0) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf16>, tensor<16x3x3x3xf16>, tensor<16xf16>) -> tensor<256x30x30x16xf16> + %1 = "tfl.add"(%0, %cst) {fused_activation_function = "NONE"} : (tensor<256x30x30x16xf16>, tensor<f16>) -> tensor<256x30x30x16xf16> + return %1 : tensor<256x30x30x16xf16> + + // CHECK: %cst = constant dense<[2.500000e+00, 3.500000e+00, 4.500000e+00, 5.500000e+00, 6.500000e+00, 7.500000e+00, 8.500000e+00, 9.500000e+00, 1.050000e+01, 1.150000e+01, 1.250000e+01, 1.350000e+01, 1.450000e+01, 1.550000e+01, 1.650000e+01, 1.750000e+01]> : tensor<16xf16> + // CHECK: %0 = "tfl.conv_2d"(%arg0, %arg1, %cst) +} diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc index 5e4a986a885..f667749f69d 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc +++ b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc @@ -211,20 +211,21 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config, pass_manager->addPass(mlir::createSymbolDCEPass()); pass_manager->addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass()); pass_manager->addNestedPass<mlir::FuncOp>(mlir::createCSEPass()); - // This pass should be always at the end of the floating point model - // conversion. Some TFL ops like unidirectional - // sequence lstm will have stateful operands and some optimization passes - // will merge those operands if they have identical values & types. However, - // it's not desired by TFL. This pass serves as a "fix" pass to split the - // merged inputs until we have 1st class variable support or reuse - // tf.variable to model this. - pass_manager->addPass(mlir::TFL::CreateSplitMergedOperandsPass()); // Run quantization after all the floating point model conversion is // completed. if (pass_config.quant_specs.RunPropagationAndRewriteQuantizationPasses()) { AddQuantizationPasses(pass_config.quant_specs, pass_manager); } + + // This pass should be always at the end of the model + // conversion (even after quantization). Some TFL ops like unidirectional + // sequence lstm will have stateful operands and some optimization passes + // will merge those operands if they have identical values & types. However, + // it's not desired by TFL. This pass serves as a "fix" pass to split the + // merged inputs until we have 1st class variable support or reuse + // tf.variable to model this. + pass_manager->addPass(mlir::TFL::CreateSplitMergedOperandsPass()); } } diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td index fa266c5e44e..ff750590a19 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td @@ -24,6 +24,11 @@ include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" // Checks if the param passed is a F32 ElementsAttr. def F32ElementsAttr : ElementsAttrBase< CPred<"$_self.isa<ElementsAttr>() && $_self.cast<ElementsAttr>().getType().getElementType().isF32()">, + "32 bit float constant tensor">; + +// Checks if the param passed is a float ElementsAttr. +def FloatElementsAttr : ElementsAttrBase< + CPred<"$_self.isa<ElementsAttr>() && $_self.cast<ElementsAttr>().getType().getElementType().isa<FloatType>()">, "float constant tensor">; // Checks if the param passed is of NoneType. @@ -93,9 +98,9 @@ class CanFuseConvOrDepthwiseConv<string is_depthwise> : Constraint< multiclass FuseBinaryOpToPrecedingAffine<dag binaryOp> { def FuseBinaryOpWithConv#binaryOp : Pat< (binaryOp (TFL_Conv2DOp:$output $input, $filter, - (ConstantOp F32ElementsAttr:$bias), $h_factor, $w_factor, + (ConstantOp FloatElementsAttr:$bias), $h_factor, $w_factor, TFL_AF_None, $padding, $stride_h, $stride_w), - (ConstantOp F32ElementsAttr:$value), $act_fn), + (ConstantOp FloatElementsAttr:$value), $act_fn), (TFL_Conv2DOp $input, $filter, (binaryOp (ConstantOp $bias), (ConstantOp $value), TFL_AF_None), @@ -104,10 +109,10 @@ multiclass FuseBinaryOpToPrecedingAffine<dag binaryOp> { (HasOneUse $output)]>; def FuseBinaryOpWithDepthwiseConv#binaryOp : Pat< (binaryOp (TFL_DepthwiseConv2DOp:$output $input, $filter, - (ConstantOp F32ElementsAttr:$bias), + (ConstantOp FloatElementsAttr:$bias), $h_factor, $w_factor, TFL_AF_None, $padding, $stride_h, $stride_w, $multiplier), - (ConstantOp F32ElementsAttr:$value), $act_fn), + (ConstantOp FloatElementsAttr:$value), $act_fn), (TFL_DepthwiseConv2DOp $input, $filter, (binaryOp (ConstantOp $bias), (ConstantOp $value), TFL_AF_None), $h_factor, $w_factor, $act_fn, $padding, $stride_h, $stride_w, @@ -116,9 +121,9 @@ multiclass FuseBinaryOpToPrecedingAffine<dag binaryOp> { (HasOneUse $output)]>; def FuseBinaryOpWithTransposeConv#binaryOp : Pat< (binaryOp (TFL_TransposeConvOp:$output $output_shape, $weights, $inputs, - (ConstantOp F32ElementsAttr:$bias), $padding, + (ConstantOp FloatElementsAttr:$bias), $padding, $stride_h, $stride_w), - (ConstantOp F32ElementsAttr:$value), TFL_AF_None), + (ConstantOp FloatElementsAttr:$value), TFL_AF_None), (TFL_TransposeConvOp $output_shape, $weights, $inputs, (binaryOp (ConstantOp $bias), (ConstantOp $value), TFL_AF_None), @@ -130,7 +135,7 @@ multiclass FuseBinaryOpToPrecedingAffine<dag binaryOp> { (binaryOp (TFL_TransposeConvOp:$output $output_shape, $weights, $inputs, (ConstantOp $bias), $padding, $stride_h, $stride_w), - (ConstantOp F32ElementsAttr:$value), TFL_AF_None), + (ConstantOp FloatElementsAttr:$value), TFL_AF_None), (TFL_TransposeConvOp $output_shape, $weights, $inputs, (ConstantOp $value), $padding, $stride_h, $stride_w), @@ -155,11 +160,11 @@ def ExpandTo4DForDepthwiseConv: NativeCodeCall< multiclass FuseMulOrDivWithConv2dOrDepthwiseConv2d<dag BinaryOp> { def FuseMulOrDivWithDepthwiseConv#BinaryOp : Pat< (BinaryOp (TFL_DepthwiseConv2DOp:$output $input, - (ConstantOp F32ElementsAttr:$filter), - (ConstantOp F32ElementsAttr:$bias), + (ConstantOp FloatElementsAttr:$filter), + (ConstantOp FloatElementsAttr:$bias), $h_factor, $w_factor, TFL_AF_None, $padding, $stride_h, $stride_w, $multiplier), - (ConstantOp F32ElementsAttr:$value), $act_fn), + (ConstantOp FloatElementsAttr:$value), $act_fn), (TFL_DepthwiseConv2DOp $input, (BinaryOp (ConstantOp $filter), @@ -175,11 +180,11 @@ multiclass FuseMulOrDivWithConv2dOrDepthwiseConv2d<dag BinaryOp> { (HasOneUse $output)]>; def FuseMulOrDivWithConv#BinaryOp : Pat< (BinaryOp (TFL_Conv2DOp:$conv_output $input, - (ConstantOp F32ElementsAttr:$filter), - (ConstantOp F32ElementsAttr:$bias), + (ConstantOp FloatElementsAttr:$filter), + (ConstantOp FloatElementsAttr:$bias), $h_factor, $w_factor, TFL_AF_None, $padding, $stride_h, $stride_w), - (ConstantOp F32ElementsAttr:$value), $act_fn), + (ConstantOp FloatElementsAttr:$value), $act_fn), (TFL_Conv2DOp $input, (BinaryOp (ConstantOp $filter), (ConstantOp (ExpandTo4DForConv $value)), @@ -192,8 +197,8 @@ multiclass FuseMulOrDivWithConv2dOrDepthwiseConv2d<dag BinaryOp> { (HasOneUse $conv_output)]>; def FuseMulOrDivWithTransposeConv#BinaryOp : Pat< (BinaryOp (TFL_TransposeConvOp:$output $output_shape, - (ConstantOp F32ElementsAttr:$weights), $input, - (ConstantOp F32ElementsAttr:$bias), + (ConstantOp FloatElementsAttr:$weights), $input, + (ConstantOp FloatElementsAttr:$bias), $padding, $stride_h, $stride_w), (ConstantOp $value), TFL_AF_None), (TFL_TransposeConvOp $output_shape, @@ -209,7 +214,7 @@ multiclass FuseMulOrDivWithConv2dOrDepthwiseConv2d<dag BinaryOp> { (HasOneUse $output)]>; def FuseMulOrDivWithTransposeConvWithNoneBias#BinaryOp : Pat< (BinaryOp (TFL_TransposeConvOp:$output $output_shape, - (ConstantOp F32ElementsAttr:$weights), $input, + (ConstantOp FloatElementsAttr:$weights), $input, (ConstantOp $bias), $padding, $stride_h, $stride_w), (ConstantOp $value), TFL_AF_None), diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index 1530ddd2f1d..af7dba56daa 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -1651,10 +1651,12 @@ Mutually reduces multiple tensors of identical type and shape. TF_Int32Tensor:$group_size, TF_Int32Tensor:$group_key, TF_Int32Tensor:$instance_key, + Variadic<TF_ResourceTensor>:$ordering_token, TF_AnyStrAttrOf<["Min", "Max", "Mul", "Add"]>:$merge_op, TF_AnyStrAttrOf<["Id", "Div"]>:$final_op, - DefaultValuedAttr<StrAttr, "auto">:$communication_hint + DefaultValuedAttr<StrAttr, "auto">:$communication_hint, + DefaultValuedAttr<F32Attr, "0.0f">:$timeout_seconds ); let results = (outs @@ -1662,6 +1664,7 @@ Mutually reduces multiple tensors of identical type and shape. ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + TF_DerivedOperandSizeAttr Nordering_token = TF_DerivedOperandSizeAttr<4>; } def TF_ComplexOp : TF_Op<"Complex", [NoSideEffect, ResultsBroadcastableShape]> { diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc index 635346af08c..e135a1a9854 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc @@ -77,66 +77,6 @@ namespace TF { namespace { -// Returns true of the given function has a single uses (within the scope -// of the module containing it and all parent modules). -bool HasSingleUse(FuncOp func) { - // Public function can have any number of external uses. - if (func.isPublic()) return false; - - // Return false if unexpected IR structure seen. - ModuleOp module = func.getParentOfType<ModuleOp>(); - if (!module) return false; - - // Inspect function uses in the containing module and all parent - // modules. - bool use_seen = false; - for (; module; module = func.isPrivate() - ? nullptr - : module.getParentOfType<ModuleOp>()) { - auto func_uses_optional = - SymbolTable::getSymbolUses(func, &module.getBodyRegion()); - // Found an unknown use. - if (!func_uses_optional) return false; - - // If no uses in this scope, continue looking in parent module - SymbolTable::UseRange func_uses = func_uses_optional.getValue(); - if (func_uses.empty()) continue; - - // Check if multiple uses at this scope or another use already seen. - if (!llvm::hasSingleElement(func_uses) || use_seen) return false; - - // This is the first use seen. - use_seen = true; - } - - // No multiple uses seen. - return true; -} - -// Returns true if the caller ops can be inlined. -bool HasInlinableUsers(FuncOp func) { - // Return false if unexpected IR structure seen. - ModuleOp module = func.getParentOfType<ModuleOp>(); - if (!module) return false; - - // Inspect function uses in the containing module and all parent - // modules. - for (; module; module = func.isPrivate() - ? nullptr - : module.getParentOfType<ModuleOp>()) { - auto func_uses_optional = - SymbolTable::getSymbolUses(func, &module.getBodyRegion()); - // Found an unknown use. - if (!func_uses_optional) return false; - - for (auto &use : func_uses_optional.getValue()) - if (isa<TPUPartitionedCallOp>(use.getUser())) return false; - } - - // All caller ops that can be inlined. - return true; -} - struct TFConstantFoldInterface : public DialectFoldInterface { TFConstantFoldInterface(Dialect *dialect) : DialectFoldInterface(dialect) {} LogicalResult fold(Operation *op, ArrayRef<Attribute> operands, @@ -160,10 +100,12 @@ struct TFInlinerInterface : public DialectInlinerInterface { // Analysis Hooks //===--------------------------------------------------------------------===// - // Allow all call operations to be inlined. + // Returns if it's legal to inline 'callable' into the 'call', where 'call' is + // a TF operation. bool isLegalToInline(Operation *call, Operation *callable, bool wouldBeCloned) const final { - return true; + // Check that the TF call operation is one that is legal to inline. + return !isa<TPUPartitionedCallOp>(call); } // Returns if its legal to inline 'src' region into the 'dest' region @@ -186,10 +128,7 @@ struct TFInlinerInterface : public DialectInlinerInterface { // post inlining, the function will be dead and eliminated from the IR. // So there won't be any code duplication. // plus the function caller op can be replaced by inlined ops. - FuncOp func = op->getParentOfType<FuncOp>(); - if (!func) return true; - if (!HasInlinableUsers(func)) return false; - return TensorFlowDialect::CanDuplicate(op) || HasSingleUse(func); + return !wouldBeCloned || TensorFlowDialect::CanDuplicate(op); } //===--------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tensorflow/tests/replicate_to_island.mlir b/tensorflow/compiler/mlir/tensorflow/tests/replicate_to_island.mlir index 487234ce958..eab6e8be986 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/replicate_to_island.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/replicate_to_island.mlir @@ -361,8 +361,7 @@ func @send_recv(%arg0: tensor<2x!tf.string>) { // ----- // Tests functional control flow functions with replica variant ops reachable -// from a replicate region is cloned and remapped. Only the branches reachable -// with replica variant ops are cloned. +// from a replicate region is cloned and remapped. // CHECK-LABEL: func @control_flow_with_replicate_variant_ops func @control_flow_with_replicate_variant_ops(%arg0: tensor<i1>, %arg1: tensor<f32>, %arg2: tensor<f32>, %arg3: tensor<2x!tf.string>) { @@ -380,30 +379,32 @@ func @control_flow_with_replicate_variant_ops(%arg0: tensor<i1>, %arg1: tensor<f } // CHECK: "tf.If" -// CHECK-SAME: else_branch = @cond_false +// CHECK-SAME: else_branch = [[COND_FALSE_REPLICA_0:@[a-z0-9_]+]] // CHECK-SAME: then_branch = [[COND_TRUE_REPLICA_0:@[a-z0-9_]+]] // CHECK: "tf.If" -// CHECK-SAME: else_branch = @cond_false +// CHECK-SAME: else_branch = [[COND_FALSE_REPLICA_1:@[a-z0-9_]+]] // CHECK-SAME: then_branch = [[COND_TRUE_REPLICA_1:@[a-z0-9_]+]] func @cond_false(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<2x!tf.string>) -> tensor<f32> { return %arg0 : tensor<f32> } -// CHECK-NOT: func @cond_false.+( - func @cond_true(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<2x!tf.string>) -> tensor<f32> { "tf._XlaSendFromHost"(%arg1, %arg2) {_xla_has_host_transfer = true, device_ordinal = 0 : i64, key = "host_compute_channel_recv_0"} : (tensor<f32>, tensor<2x!tf.string>) -> () %0 = "tf._XlaRecvAtHost"(%arg2) {_xla_has_host_transfer = true, device_ordinal = 0 : i64, key = "host_compute_channel_send_0"} : (tensor<2x!tf.string>) -> tensor<f32> return %0 : tensor<f32> } +// CHECK: func [[COND_FALSE_REPLICA_0]] + // CHECK: func [[COND_TRUE_REPLICA_0]] // CHECK: "tf._XlaSendFromHost" // CHECK-SAME: device_ordinal = 1 // CHECK: "tf._XlaRecvAtHost" // CHECK-SAME: device_ordinal = 1 +// CHECK: func [[COND_FALSE_REPLICA_1]] + // CHECK: func [[COND_TRUE_REPLICA_1]] // CHECK: "tf._XlaSendFromHost" // CHECK-SAME: device_ordinal = 2 @@ -413,7 +414,7 @@ func @cond_true(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<2x!tf.stri // ----- // Tests function with no replica variant ops reachable from a replicate region -// is not cloned. +// is cloned. // CHECK-LABEL: func @no_replicate_variant_ops func @no_replicate_variant_ops(%arg0: tensor<f32>, %arg1: tensor<2x!tf.string>) { @@ -431,11 +432,17 @@ func @no_replicate_variant_ops(%arg0: tensor<f32>, %arg1: tensor<2x!tf.string>) } // CHECK: "tf.StatefulPartitionedCall" -// CHECK-SAME: f = @send_recv +// CHECK-SAME: f = [[CALLEE_REPLICA_0:@[a-z0-9_]+]] +// CHECK: "tf.StatefulPartitionedCall" +// CHECK-SAME: f = [[CALLEE_REPLICA_1:@[a-z0-9_]+]] func @send_recv(%arg0: tensor<2x!tf.string>) { "tf.NoOp"() : () -> () return } -// CHECK-NOT: @send_recv.+( +// CHECK: func [[CALLEE_REPLICA_0]] +// CHECK: "tf.NoOp" + +// CHECK: func [[CALLEE_REPLICA_1]] +// CHECK: "tf.NoOp" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc index 67832d74af6..51465a4b3b2 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc @@ -53,7 +53,6 @@ void AddGraphExportLoweringPasses(OpPassManager &pm) { add_pass(TFDevice::CreateParallelExecuteToIslandsPass()); pm.addPass(TFDevice::CreateLaunchToDeviceAttributePass()); pm.addPass(CreateBreakUpIslandsPass()); - pm.addNestedPass<FuncOp>(CreateTPUDevicePropagationPass()); pm.addPass(createSymbolDCEPass()); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc b/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc index 5b70729ee80..88ba1cee1b8 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc @@ -120,46 +120,6 @@ llvm::SmallPtrSet<FuncOp, 4> GetReachableFunctionsFromRegion(ModuleOp module, return visited_functions; } -// Collects all functions and transitive functions reachable from region that -// contain replicate variant ops. -llvm::SmallDenseMap<llvm::StringRef, FuncOp> GetReachableFunctionsToClone( - ModuleOp module, Region& region, - const llvm::Optional<DictionaryAttr>& devices) { - llvm::SmallPtrSet<FuncOp, 4> reachable_functions = - GetReachableFunctionsFromRegion(module, region); - - llvm::SmallDenseMap<llvm::StringRef, FuncOp> functions_to_clone; - llvm::SmallVector<FuncOp, 4> functions_to_visit; - for (FuncOp func : reachable_functions) { - if (!func.getCallableRegion()) continue; - if (HasReplicaVariantOps(*func.getCallableRegion(), devices)) { - functions_to_clone.insert({func.getName(), func}); - functions_to_visit.push_back(func); - } - } - - while (!functions_to_visit.empty()) { - llvm::SmallVector<FuncOp, 4> new_functions_to_visit; - - for (FuncOp func_to_visit : functions_to_visit) { - auto func_uses = func_to_visit.getSymbolUses(module); - if (!func_uses) continue; - for (auto use : *func_uses) { - auto parent_func = use.getUser()->getParentOfType<FuncOp>(); - if (!parent_func || !reachable_functions.contains(parent_func) || - !functions_to_clone.insert({parent_func.getName(), parent_func}) - .second) - continue; - new_functions_to_visit.push_back(parent_func); - } - } - - functions_to_visit.swap(new_functions_to_visit); - } - - return functions_to_clone; -} - struct FuncOldNameAndClone { StringRef old_name; FuncOp clone; @@ -276,20 +236,19 @@ LogicalResult ExpandReplicateIntoReplicas( terminator.erase(); auto funcs_to_clone = - GetReachableFunctionsToClone(module, replicate_op.body(), devices); + GetReachableFunctionsFromRegion(module, replicate_op.body()); SymbolTable symbol_table(module); builder.setInsertionPoint(island_op); BlockAndValueMapping mapping; for (int i : llvm::seq<int>(0, num_replicas)) { - // Clone reachable functions with replica variant ops. + // Clone reachable functions from region. llvm::SmallVector<FuncOldNameAndClone, 4> cloned_functions; cloned_functions.reserve(funcs_to_clone.size()); - for (auto& func_to_clone : funcs_to_clone) { - auto cloned_function = func_to_clone.getSecond().clone(); + for (FuncOp func_to_clone : funcs_to_clone) { + auto cloned_function = func_to_clone.clone(); symbol_table.insert(cloned_function, module.end()); - cloned_functions.push_back( - {func_to_clone.getSecond().getName(), cloned_function}); + cloned_functions.push_back({func_to_clone.getName(), cloned_function}); } // Create new island for replica. diff --git a/tensorflow/compiler/mlir/tf_mlir_opt_main.cc b/tensorflow/compiler/mlir/tf_mlir_opt_main.cc index e5408cef828..ee0241945f6 100644 --- a/tensorflow/compiler/mlir/tf_mlir_opt_main.cc +++ b/tensorflow/compiler/mlir/tf_mlir_opt_main.cc @@ -31,7 +31,6 @@ int main(int argc, char **argv) { mlir::registerAllPasses(); mlir::mhlo::registerAllMhloPasses(); mlir::lmhlo::registerAllLmhloPasses(); - mlir::mhlo::registerAllMhloPasses(); mlir::DialectRegistry registry; mlir::registerAllDialects(registry); diff --git a/tensorflow/compiler/mlir/tfr/README.md b/tensorflow/compiler/mlir/tfr/README.md new file mode 100644 index 00000000000..aa45ba02e94 --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/README.md @@ -0,0 +1,158 @@ +# Composable Tensorflow + +## Composable Tensorflow + +Composable TensorFlow (TF) is the framework for defining portable TF ops with +composition in the authoring language. + +The set of standard TF ops is currently open. New ops are defined for special +purposes but it is hard to make them work end-to-end: The op +needs to be handled separately by a several backends (tf2xla bridge, tflite +converter, CPU kernels, etc.). Writing shape functions and gradients for these +ops is extremely difficult. `tf.function` makes some parts of the implementation +simpler, but it introduces runtime overhead and it cannot easily be used to +apply dedicated optimizations to op kernels. + +The composable TF framework allows the user to define portable TF ops as +ompositions of other TF ops. It translates a Python function used to define the +composition directly into a portable IR at build time, and uses it to expand the +composite op in the TF program during compilation / execution. By using this +expansion mechanism, new op are readily available on different platforms without +extra work. Moreover, since the expansion is optional, the backend can easily +treat it as a monolithic op when needed, for instance to apply optimizations or +associate it with a custom kernel. + +### Benefits + +Using the Composable TF API to define a new op and its composition can bring the +following benefits: + +* *Automatic backend support*: As long as it is composed of ops supported by the +backend, the new op is automatcally supported (as a `tf.function` alternative); +* *Reduced tracing overhead*: Unlike `tf.function`, the composition function is +compiled at build time, hence TF only needs to trace a single op to build the +`graph`; +* *Easy fused op/kernel optimization*: Even if it has complex +semantics, the new op is presented as a single node in the graph, thus +optimization passes and kernels can easily be specialized to this op for better +performance. +* *Automatic shape/type inference support*: No shape functions are required for +the new op; +* *Automatic gradient support (WIP)*: The user doesn't need to author +gradient a function of the op for training. + +### Use Cases + +* (Portablity) User wants to add a new op and run this op on different +platforms (CPU, TPU, TFLite, etc.) to be portable. + * *Solution*: The user should define the new op as a composition. The ops used + inside the composition should have support for these platforms. These ops can + also be composite ops. + +* (Performance) User defines a custom kernel for a regular structure +(i.e. LSTM), but it is hard to add the logic to fuse the individual ops to +target this kernel in the inference graph. + * *Solution*: The user should define a new TF op, which corresponds to the + fused kernel, with composition, and use this op to build the model for both + training and inference. For the platforms where a fused kernel is not + available, the execution will use the composition instead. + +## Gradient +(TODO) + +## Authoring Op Composition in Python + +The composable TF provides a single API to define a new op with its composition +at the same time. For example, the following code defines a new +`FusedFullyConnected` op, which have `MatMul`, `Add` and some +`activation function` (specified by an op attribute) fused. + + +```python +import tensorflow as tf + +@Composite( + 'FusedFullyConnected', + inputs=['input_: T', 'filter_: T', 'bias: T'], + attrs=['act: {"", "RELU", "RELU6", "TANH"} = ""'], + derived_attrs=['T: {float, int8}'], + outputs=['o: T']) +def _composite_fully_connected(input_, filter_, bias, act): + res = tf.raw_ops.MatMul( + a=input_, b=filter_, transpose_a=False, transpose_b=True) + res = tf.raw_ops.Add(x=res, y=bias) + if act == 'RELU': + return tf.raw_ops.Relu(features=res) + elif act == 'RELU6': + return tf.raw_ops.Relu6(features=res) + elif act == 'TANH': + return tf.raw_ops.Tanh(x=res) + else: + return res + +``` + +Besides defining new ops, composition can be specified for an existing op +for portability. The following code defines the semantics of `AddNOp`: + +```python +@Composite('AddNOp') +def _my_op_c(ins): + N = len(ins) + if N == 1: + return ins[0] + sum = ins[0] + for i in range(1, N): + sum += ins[i] + return sum +``` + +Utilities have been built to compile the Python composition functions down to +the backend IR. The project also provides a set of graph optimization passes to +expand the composite ops in the graph by using the input backend IR. These +passes have been added to the TF [common runtime] +(https://github.com/tensorflow/tensorflow/tree/master/tensorflow/core/common_runtime) +for graph execution and [eager_runtime] +(https://github.com/tensorflow/tensorflow/tree/master/tensorflow/core/common_runtime/eager) +for eager execution. + +## Compiling Op Composition + +### Ahead-Of-Time (AOT) mode + +Like the op kernels, the op composition can be pre-compiled to the backend IR +so the decomposition can be invoked at runtime. A Python [define_op_template.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/tfr/define_op_template.py) +file is provided as an example to build composite ops in the users project +directory. All the targets required to build the new ops are created by the +following target: + + +```BUILD +load("//tensorflow/compiler/mlir/tfr:build_defs.bzl", "gen_op_libraries") + +gen_op_libraries( + name = "test_ops", + src = "define_op_template.py", + deps = [ + "//third_party/py/tensorflow", + ], +) +``` + +More composite op definitions and usages are here included in the +[examples](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/compiler/mlir/tfr/examples) +directory. + +### Just-In-Time (JIT) mode +(TODO) + +## Known Limitations + +* `while` statement +* condition of `if` statement couldn't be a tensor + +## Team + +* Feng Liu +* Dan Moldovan + diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tf_cuda_runtime_wrappers.cc b/tensorflow/compiler/mlir/tools/kernel_gen/tf_cuda_runtime_wrappers.cc index 06d613e0599..d2a64a26597 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tf_cuda_runtime_wrappers.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tf_cuda_runtime_wrappers.cc @@ -44,6 +44,10 @@ extern "C" CUmodule mgpuModuleLoad(void *data) { return module; } +extern "C" void mgpuModuleUnload(CUmodule module) { + CUDA_REPORT_IF_ERROR(cuModuleUnload(module)); +} + extern "C" CUfunction mgpuModuleGetFunction(CUmodule module, const char *name) { CUfunction function = nullptr; CUDA_REPORT_IF_ERROR(cuModuleGetFunction(&function, module, name)); @@ -64,16 +68,15 @@ extern "C" void mgpuLaunchKernel(CUfunction function, intptr_t gridX, } extern "C" CUstream mgpuStreamCreate() { - static CUstream stream = []() { - // TODO(b/170649852): This is neither thread-safe nor handles - // creation/descruction of one stream per context. - CUstream stream = nullptr; - CUDA_REPORT_IF_ERROR(cuStreamCreate(&stream, CU_STREAM_NON_BLOCKING)); - return stream; - }(); + CUstream stream = nullptr; + CUDA_REPORT_IF_ERROR(cuStreamCreate(&stream, CU_STREAM_NON_BLOCKING)); return stream; } +extern "C" void mgpuStreamDestroy(CUstream stream) { + CUDA_REPORT_IF_ERROR(cuStreamDestroy(stream)); +} + extern "C" void mgpuStreamSynchronize(CUstream stream) { CUDA_REPORT_IF_ERROR(cuStreamSynchronize(stream)); } diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize_pass.cc index 33276fd37af..064a4581b58 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize_pass.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize_pass.cc @@ -19,6 +19,7 @@ limitations under the License. #include <memory> #include "mlir/Dialect/SCF/SCF.h" // from @llvm-project +#include "mlir/Dialect/SCF/Transforms.h" // from @llvm-project #include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project #include "mlir/Dialect/Shape/Transforms/Passes.h" // from @llvm-project #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project @@ -107,6 +108,8 @@ struct BufferizePass : public BufferizePassBase<BufferizePass> { populateStandardBufferizePattern(&context, &converter, &patterns); populateShapeStructuralTypeConversionsAndLegality(&context, converter, patterns, target); + scf::populateSCFStructuralTypeConversionsAndLegality(&context, converter, + patterns, target); patterns.insert<UnrankedTensorStoreTestOnlyPattern>(&context); auto module = getOperation(); diff --git a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc index 81facea3857..62e8dc3bd0b 100644 --- a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc +++ b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc @@ -218,7 +218,7 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstructionImpl( auto attr = CreateDenseElementsAttrFromLiteral(literal, *builder_); if (!attr.ok()) return attr.status(); mlir::Operation* new_operation = - func_builder->create<mlir::ConstantOp>(loc, attr.ValueOrDie()); + func_builder->create<mlir::mhlo::ConstOp>(loc, attr.ValueOrDie()); for (auto attr : attributes) { new_operation->setAttr(attr.first, attr.second); } diff --git a/tensorflow/compiler/mlir/xla/tests/BUILD b/tensorflow/compiler/mlir/xla/tests/BUILD index 047d305a8a0..11507b11ea0 100644 --- a/tensorflow/compiler/mlir/xla/tests/BUILD +++ b/tensorflow/compiler/mlir/xla/tests/BUILD @@ -12,7 +12,11 @@ glob_lit_tests( data = [":test_utilities"], driver = "@llvm-project//mlir:run_lit.sh", tags_override = { - "hlo_to_lhlo_with_xla/gpu_ops.mlir": tf_cuda_tests_tags() + ["noasan"], # TODO(b/171751580) + "hlo_to_lhlo_with_xla/gpu_ops.mlir": tf_cuda_tests_tags() + [ + "noasan", + "nomsan", + "noubsan", + ], # b/171751580 }, test_file_exts = [ "mlir", diff --git a/tensorflow/compiler/mlir/xla/tests/translate/case_conditional.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/case_conditional.hlotxt index 1fa7367763e..ff205d7d510 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/case_conditional.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/translate/case_conditional.hlotxt @@ -26,10 +26,10 @@ ENTRY %indexed_conditional () -> f32[] { } // CHECK-LABEL: func @main() -> tensor<f32> -// CHECK: %[[INDEX:.*]] = constant dense<1> : tensor<i32> -// CHECK: %[[OPERAND_1:.*]] = constant dense<5.600000e+01> : tensor<f32> -// CHECK: %[[OPERAND_2:.*]] = constant dense<1.200000e+01> : tensor<f32> -// CHECK: %[[OPERAND_3:.*]] = constant dense<1.300000e+01> : tensor<f32> +// CHECK: %[[INDEX:.*]] = mhlo.constant dense<1> : tensor<i32> +// CHECK: %[[OPERAND_1:.*]] = mhlo.constant dense<5.600000e+01> : tensor<f32> +// CHECK: %[[OPERAND_2:.*]] = mhlo.constant dense<1.200000e+01> : tensor<f32> +// CHECK: %[[OPERAND_3:.*]] = mhlo.constant dense<1.300000e+01> : tensor<f32> // CHECK: %[[RESULT:.*]] = "mhlo.case"(%[[INDEX]], %[[OPERAND_1]], %[[OPERAND_2]], %[[OPERAND_3]]) ( { // CHECK: ^bb0(%[[ARG_1:.*]]: tensor<f32>): // CHECK: %[[RES_1:.*]] = "mhlo.negate"(%[[ARG_1]]) : (tensor<f32>) -> tensor<f32> diff --git a/tensorflow/compiler/mlir/xla/tests/translate/fully_connected_reference_model.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/fully_connected_reference_model.hlotxt index 4cc70be0965..d5e100f9a4c 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/fully_connected_reference_model.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/translate/fully_connected_reference_model.hlotxt @@ -4,100 +4,102 @@ HloModule tfcompile.48 -// CHECK-LABEL: func @main(%arg0: tensor<1x300xf32>, %arg1: tensor<1x300x3x1xf32>) -> tuple<tensor<300x1x5xf32>> { +// CHECK-LABEL: func @main( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x300xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x300x3x1xf32>) -> tuple<tensor<300x1x5xf32>> { ENTRY %tfcompile.48 { %arg0.1 = f32[1,300] parameter(0) %arg1.2 = f32[1,300,3,1] parameter(1) - // CHECK-NEXT: %0 = "mhlo.reshape"(%arg0) : (tensor<1x300xf32>) -> tensor<1x300xf32> + // CHECK-NEXT: %[[VAL_2:.*]] = "mhlo.reshape"(%[[VAL_0]]) : (tensor<1x300xf32>) -> tensor<1x300xf32> %reshape.3 = f32[1,300] reshape(%arg0.1) - // CHECK-NEXT: %1 = "mhlo.transpose"(%0) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<1x300xf32>) -> tensor<300x1xf32> + // CHECK-NEXT: %[[VAL_3:.*]] = "mhlo.transpose"(%[[VAL_2]]) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<1x300xf32>) -> tensor<300x1xf32> %transpose.27 = f32[300,1] transpose(%reshape.3), dimensions={1,0} - // CHECK-NEXT: %2 = "mhlo.reshape"(%1) : (tensor<300x1xf32>) -> tensor<300x1x1xf32> + // CHECK-NEXT: %[[VAL_4:.*]] = "mhlo.reshape"(%[[VAL_3]]) : (tensor<300x1xf32>) -> tensor<300x1x1xf32> %reshape.28 = f32[300,1,1] reshape(%transpose.27) - // CHECK-NEXT: %3 = "mhlo.reshape"(%2) : (tensor<300x1x1xf32>) -> tensor<300x1xf32> + // CHECK-NEXT: %[[VAL_5:.*]] = "mhlo.reshape"(%[[VAL_4]]) : (tensor<300x1x1xf32>) -> tensor<300x1xf32> %reshape.29 = f32[300,1] reshape(%reshape.28) - // CHECK-NEXT: %4 = "mhlo.broadcast_in_dim"(%3) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<300x1xf32>) -> tensor<300x1x5xf32> + // CHECK-NEXT: %[[VAL_6:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_5]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<300x1xf32>) -> tensor<300x1x5xf32> %broadcast.30 = f32[300,1,5] broadcast(%reshape.29), dimensions={0,1} - // CHECK-NEXT: %cst = constant dense<1.000000e+00> : tensor<f32> + // CHECK-NEXT: %[[VAL_7:.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32> %constant.8 = f32[] constant(1) - // CHECK-NEXT: %5 = "mhlo.broadcast_in_dim"(%cst) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<300x1x5xf32> + // CHECK-NEXT: %[[VAL_8:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_7]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<300x1x5xf32> %broadcast.9 = f32[300,1,5] broadcast(%constant.8), dimensions={} - // CHECK-NEXT: %6 = mhlo.multiply %4, %5 : tensor<300x1x5xf32> + // CHECK-NEXT: %[[VAL_9:.*]] = mhlo.multiply %[[VAL_6]], %[[VAL_8]] : tensor<300x1x5xf32> %multiply.31 = f32[300,1,5] multiply(%broadcast.30, %broadcast.9) - // CHECK-NEXT: %cst_0 = constant dense<0.000000e+00> : tensor<f32> + // CHECK-NEXT: %[[VAL_10:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> %constant.32 = f32[] constant(0) - // CHECK-NEXT: %7 = "mhlo.broadcast_in_dim"(%cst_0) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<300x1x5xf32> + // CHECK-NEXT: %[[VAL_11:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_10]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<300x1x5xf32> %broadcast.33 = f32[300,1,5] broadcast(%constant.32), dimensions={} - // CHECK-NEXT: %8 = "mhlo.compare"(%6, %7) {comparison_direction = "GT"} : (tensor<300x1x5xf32>, tensor<300x1x5xf32>) -> tensor<300x1x5xi1> + // CHECK-NEXT: %[[VAL_12:.*]] = "mhlo.compare"(%[[VAL_9]], %[[VAL_11]]) {comparison_direction = "GT"} : (tensor<300x1x5xf32>, tensor<300x1x5xf32>) -> tensor<300x1x5xi1> %compare.34 = pred[300,1,5] compare(%multiply.31, %broadcast.33), direction=GT - // CHECK-NEXT: %cst_1 = constant dense<0.000000e+00> : tensor<f32> + // CHECK-NEXT: %[[VAL_13:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> %constant.10 = f32[] constant(0) - // CHECK-NEXT: %9 = "mhlo.broadcast_in_dim"(%cst_1) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<300x1x5xf32> + // CHECK-NEXT: %[[VAL_14:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_13]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<300x1x5xf32> %broadcast.11 = f32[300,1,5] broadcast(%constant.10), dimensions={} - // CHECK-NEXT: %cst_2 = constant dense<0.000000e+00> : tensor<f32> + // CHECK-NEXT: %[[VAL_15:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> %constant.40 = f32[] constant(0) - // CHECK-NEXT: %10 = "mhlo.broadcast_in_dim"(%cst_2) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<300x5xf32> + // CHECK-NEXT: %[[VAL_16:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_15]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<300x5xf32> %broadcast.41 = f32[300,5] broadcast(%constant.40), dimensions={} - // CHECK-NEXT: %11 = "mhlo.copy"(%arg1) : (tensor<1x300x3x1xf32>) -> tensor<1x300x3x1xf32> + // CHECK-NEXT: %[[VAL_17:.*]] = "mhlo.copy"(%[[VAL_1]]) : (tensor<1x300x3x1xf32>) -> tensor<1x300x3x1xf32> %copy.1 = f32[1,300,3,1] copy(%arg1.2) - // CHECK-NEXT: %12 = "mhlo.reshape"(%11) : (tensor<1x300x3x1xf32>) -> tensor<1x300x3x1xf32> + // CHECK-NEXT: %[[VAL_18:.*]] = "mhlo.reshape"(%[[VAL_17]]) : (tensor<1x300x3x1xf32>) -> tensor<1x300x3x1xf32> %reshape.4 = f32[1,300,3,1] reshape(%copy.1) - // CHECK-NEXT: %13 = "mhlo.reshape"(%12) : (tensor<1x300x3x1xf32>) -> tensor<1x300x3xf32> + // CHECK-NEXT: %[[VAL_19:.*]] = "mhlo.reshape"(%[[VAL_18]]) : (tensor<1x300x3x1xf32>) -> tensor<1x300x3xf32> %reshape.24 = f32[1,300,3] reshape(%reshape.4) - // CHECK-NEXT: %14 = "mhlo.transpose"(%13) {permutation = dense<[1, 0, 2]> : tensor<3xi64>} : (tensor<1x300x3xf32>) -> tensor<300x1x3xf32> + // CHECK-NEXT: %[[VAL_20:.*]] = "mhlo.transpose"(%[[VAL_19]]) {permutation = dense<[1, 0, 2]> : tensor<3xi64>} : (tensor<1x300x3xf32>) -> tensor<300x1x3xf32> %transpose.25 = f32[300,1,3] transpose(%reshape.24), dimensions={1,0,2} - // CHECK-NEXT: %15 = "mhlo.reshape"(%14) : (tensor<300x1x3xf32>) -> tensor<300x3xf32> + // CHECK-NEXT: %[[VAL_21:.*]] = "mhlo.reshape"(%[[VAL_20]]) : (tensor<300x1x3xf32>) -> tensor<300x3xf32> %reshape.26 = f32[300,3] reshape(%transpose.25) - // CHECK-NEXT: %cst_3 = constant dense<{{\[\[}}-1.060230e-01, 1.215050e-01, 8.002390e-01, -7.688850e-01, 0.0966112986], [6.890140e-01, -4.070560e-01, -0.797852993, 3.789250e-03, -2.088810e-01], [-6.085290e-01, 2.766170e-02, 2.685570e-01, 5.774010e-01, -4.284370e-01]]> : tensor<3x5xf32> + // CHECK-NEXT: %[[VAL_22:.*]] = mhlo.constant dense<{{\[\[}}-1.060230e-01, 1.215050e-01, 8.002390e-01, -7.688850e-01, 0.0966112986], [6.890140e-01, -4.070560e-01, -0.797852993, 3.789250e-03, -2.088810e-01], [-6.085290e-01, 2.766170e-02, 2.685570e-01, 5.774010e-01, -4.284370e-01]]> : tensor<3x5xf32> %constant.35 = f32[3,5] constant({ { -0.106023, 0.121505, 0.800239, -0.768885, 0.0966113 }, { 0.689014, -0.407056, -0.797853, 0.00378925, -0.208881 }, { -0.608529, 0.0276617, 0.268557, 0.577401, -0.428437 } }) // TODO(b/129709049) consider making this default precision config implied. - // CHECK-NEXT: %16 = "mhlo.dot"(%15, %cst_3) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<300x3xf32>, tensor<3x5xf32>) -> tensor<300x5xf32> + // CHECK-NEXT: %[[VAL_23:.*]] = "mhlo.dot"(%[[VAL_21]], %[[VAL_22]]) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<300x3xf32>, tensor<3x5xf32>) -> tensor<300x5xf32> %dot.36 = f32[300,5] dot(%reshape.26, %constant.35), lhs_contracting_dims={1}, rhs_contracting_dims={0} - // CHECK-NEXT: %cst_4 = constant dense<0.000000e+00> : tensor<5xf32> + // CHECK-NEXT: %[[VAL_24:.*]] = mhlo.constant dense<0.000000e+00> : tensor<5xf32> %constant.37 = f32[5]{0} constant({0, 0, 0, 0, 0}) - // CHECK-NEXT: %17 = "mhlo.broadcast_in_dim"(%cst_4) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<5xf32>) -> tensor<300x5xf32> + // CHECK-NEXT: %[[VAL_25:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_24]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<5xf32>) -> tensor<300x5xf32> %broadcast.38 = f32[300,5] broadcast(%constant.37), dimensions={1} - // CHECK-NEXT: %18 = mhlo.add %16, %17 : tensor<300x5xf32> + // CHECK-NEXT: %[[VAL_26:.*]] = mhlo.add %[[VAL_23]], %[[VAL_25]] : tensor<300x5xf32> %add.39 = f32[300,5] add(%dot.36, %broadcast.38) - // CHECK-NEXT: %19 = mhlo.maximum %10, %18 : tensor<300x5xf32> + // CHECK-NEXT: %[[VAL_27:.*]] = mhlo.maximum %[[VAL_16]], %[[VAL_26]] : tensor<300x5xf32> %maximum.42 = f32[300,5] maximum(%broadcast.41, %add.39) - // CHECK-NEXT: %20 = "mhlo.reshape"(%19) : (tensor<300x5xf32>) -> tensor<300x1x5xf32> + // CHECK-NEXT: %[[VAL_28:.*]] = "mhlo.reshape"(%[[VAL_27]]) : (tensor<300x5xf32>) -> tensor<300x1x5xf32> %reshape.44 = f32[300,1,5] reshape(%maximum.42) - // CHECK-NEXT: %21 = "mhlo.select"(%8, %9, %20) : (tensor<300x1x5xi1>, tensor<300x1x5xf32>, tensor<300x1x5xf32>) -> tensor<300x1x5xf32> + // CHECK-NEXT: %[[VAL_29:.*]] = "mhlo.select"(%[[VAL_12]], %[[VAL_14]], %[[VAL_28]]) : (tensor<300x1x5xi1>, tensor<300x1x5xf32>, tensor<300x1x5xf32>) -> tensor<300x1x5xf32> %select.45 = f32[300,1,5] select(%compare.34, %broadcast.11, %reshape.44) - // CHECK-NEXT: %22 = "mhlo.reshape"(%21) : (tensor<300x1x5xf32>) -> tensor<300x1x5xf32> + // CHECK-NEXT: %[[VAL_30:.*]] = "mhlo.reshape"(%[[VAL_29]]) : (tensor<300x1x5xf32>) -> tensor<300x1x5xf32> %reshape.46 = f32[300,1,5] reshape(%select.45) - // CHECK-NEXT: %23 = "mhlo.tuple"(%22) : (tensor<300x1x5xf32>) -> tuple<tensor<300x1x5xf32>> - // CHECK-NEXT: return %23 : tuple<tensor<300x1x5xf32>> + // CHECK-NEXT: %[[VAL_31:.*]] = "mhlo.tuple"(%[[VAL_30]]) : (tensor<300x1x5xf32>) -> tuple<tensor<300x1x5xf32>> + // CHECK-NEXT: return %[[VAL_31]] : tuple<tensor<300x1x5xf32>> ROOT %tuple.47 = (f32[300,1,5]) tuple(%reshape.46) } diff --git a/tensorflow/compiler/mlir/xla/tests/translate/if_conditional.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/if_conditional.hlotxt index 28e98c1376a..327e5107e4a 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/if_conditional.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/translate/if_conditional.hlotxt @@ -20,7 +20,7 @@ HloModule tfcompile.20 ENTRY %tfcompile.20 { %arg0.1 = f32[] parameter(0), parameter_replication={false}, metadata={op_name="XLA_Args"} - // CHECK: [[C0:%.+]] = constant + // CHECK: [[C0:%.+]] = mhlo.constant %constant.3 = f32[] constant(10), metadata={op_type="Less" op_name="Less"} // CHECK: [[R1:%.+]] = "mhlo.compare"([[A0]], [[C0]]) diff --git a/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt index b1a54af2c6e..2a3d42f30ff 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt @@ -176,48 +176,49 @@ add { %test_constant { // Scalar/0D tensor constant - // CHECK-NEXT: %cst = constant dense<1> : tensor<i64> + // CHECK-NEXT: %[[VAL_0:.*]] = mhlo.constant dense<1> : tensor<i64> %constant.0 = s64[] constant(1) // Note that double brackets "[[" have to be escaped as they denote variables // in FileCheck. The only way to do so is to drop into regex with "{{" - // CHECK-NEXT: constant dense<{{\[\[\[\[}}1.000000e+00]], {{\[\[}}2.000000e+00]]], {{\[\[\[}}3.000000e+00]], {{\[\[}}4.000000e+00]]]]> : tensor<2x2x1x1xf32> + // CHECK-NEXT: %[[VAL_1:.*]] = mhlo.constant dense<{{\[\[}}{{\[\[}}1.000000e+00]], {{\[\[}}2.000000e+00]]], {{\[\[}}[3.000000e+00]], {{\[\[}}4.000000e+00]]]]> : tensor<2x2x1x1xf32> %constant.1 = f32[2,2,1,1]{3,2,1,0} constant({{{{1.0}},{{2.0}}},{{{3.0}},{{4.0}}}}), metadata={op_type="Conv2D" op_name="embedded_inference/conv_model/conv_0/Conv2D"} - // CHECK: dense<[1, 2, 4, 8]> : tensor<4xui64> + // CHECK: %[[VAL_2:.*]] = mhlo.constant dense<[1, 2, 4, 8]> : tensor<4xui64> %constant.2 = u64[4] constant({ 1, 2, 4, 8 }) - // CHECK: dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xbf16> + // CHECK: %[[VAL_3:.*]] = mhlo.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xbf16> %constant.3 = bf16[4] constant({1, 2, 3, 4}) - // CHECK: dense<(1.000000e+00,0.000000e+00)> : tensor<complex<f32>> + // CHECK: %[[VAL_4:.*]] = mhlo.constant dense<(1.000000e+00,0.000000e+00)> : tensor<complex<f32>> %constant.4 = c64[] constant((1, 0)) - // CHECK: dense<(1.000000e+00,0.000000e+00)> : tensor<complex<f64>> + // CHECK: %[[VAL_5:.*]] = mhlo.constant dense<(1.000000e+00,0.000000e+00)> : tensor<complex<f64>> %constant.5 = c128[] constant((1, 0)) - // CHECK: dense<[1.000000e+00, -4.000000e+00, -6.550400e+04, 1.562500e-02]> : tensor<4xf16> + // CHECK: %[[VAL_6:.*]] = mhlo.constant dense<[1.000000e+00, -4.000000e+00, -6.550400e+04, 1.562500e-02]> : tensor<4xf16> ROOT %constant.6 = f16[4] constant({1, -4, -65504, 0.015625}) } // TODO(b/129422361) Potentially update when copy, reshape, and conv have actual // implementations with attributes, etc. -// CHECK-LABEL: func @test_conv(%arg0: tensor<256x32x32x6xf32>) -> tuple<tensor<256x30x30x16xf32>> +// CHECK-LABEL: func @test_conv( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<256x32x32x6xf32>) -> tuple<tensor<256x30x30x16xf32>> attributes {sym_visibility = "private"} { %test_conv { %arg0.1 = f32[256,32,32,6]{3,2,1,0} parameter(0), metadata={op_name="HLO_Args"} - // CHECK-NEXT: %0 = "mhlo.copy"(%arg0) {minor_to_major = dense<[2, 1, 3, 0]> : tensor<4xindex>} : (tensor<256x32x32x6xf32>) -> tensor<256x32x32x6xf32> + // CHECK-NEXT: %[[VAL_1:.*]] = "mhlo.copy"(%[[VAL_0]]) {minor_to_major = dense<[2, 1, 3, 0]> : tensor<4xindex>} : (tensor<256x32x32x6xf32>) -> tensor<256x32x32x6xf32> %copy.1 = f32[256,32,32,6]{2,1,3,0} copy(%arg0.1), metadata={op_name="HLO_Args"} - // CHECK-NEXT: %1 = "mhlo.reshape"(%0) {minor_to_major = dense<[2, 1, 3, 0]> : tensor<4xindex>} : (tensor<256x32x32x6xf32>) -> tensor<256x32x32x6xf32> + // CHECK-NEXT: %[[VAL_2:.*]] = "mhlo.reshape"(%[[VAL_1]]) {minor_to_major = dense<[2, 1, 3, 0]> : tensor<4xindex>} : (tensor<256x32x32x6xf32>) -> tensor<256x32x32x6xf32> %reshape.2 = f32[256,32,32,6]{2,1,3,0} reshape(%copy.1) // Note that double brackets "[[" have to be escaped as they denote variables // in FileCheck. The only way to do so is to drop into regex with "{{" - // CHECK-NEXT: %cst = constant dense<{{\[\[\[\[}}5.000000e-01]], {{\[\[}}-6.000000e-01]]], {{\[\[\[}}3.000000e-01]], {{\[\[}}-1.000000e-01]]]]> : tensor<2x2x1x1xf32> + // CHECK-NEXT: %[[VAL_3:.*]] = mhlo.constant dense<{{\[\[}}{{\[\[}}5.000000e-01]], {{\[\[}}-6.000000e-01]]], {{\[\[}}[3.000000e-01]], {{\[\[}}-1.000000e-01]]]]> : tensor<2x2x1x1xf32> %constant.3 = f32[2,2,1,1]{3,2,1,0} constant({{{{0.5}}, {{-0.6}}}, {{{0.3}}, {{-0.1}}}}), metadata={op_type="Conv2D" op_name="embedded_inference/conv_model/conv_0/Conv2D"} - // CHECK-NEXT: %2 = "mhlo.convolution"(%1, %cst) { + // CHECK-NEXT: %[[VAL_4:.*]] = "mhlo.convolution"(%[[VAL_2]], %[[VAL_3]]) { // CHECK-SAME: batch_group_count = 1 : i64 // CHECK-SAME: dimension_numbers = { // CHECK-SAME: input_batch_dimension = 0 : i64 @@ -241,10 +242,10 @@ add { %convolution.4 = f32[16,30,30,256]{2,1,3,0} convolution(%reshape.2, %constant.3), window={size=3x3 stride=4x5 pad=44_45x60_60 rhs_dilate=2x3}, dim_labels=b01f_01io->f01b, metadata={op_type="Conv2D" op_name="embedded_inference/conv_model/conv_0/Conv2D"} - // CHECK-NEXT: %3 = "mhlo.reshape"(%2) : (tensor<16x30x30x256xf32>) -> tensor<256x30x30x16xf32> + // CHECK-NEXT: %[[VAL_5:.*]] = "mhlo.reshape"(%[[VAL_4]]) : (tensor<16x30x30x256xf32>) -> tensor<256x30x30x16xf32> %reshape.5 = f32[256,30,30,16]{3,2,1,0} reshape(%convolution.4), metadata={op_name="HLO_Retvals"} - // CHECK-NEXT: "mhlo.tuple"(%3) : (tensor<256x30x30x16xf32>) -> tuple<tensor<256x30x30x16xf32>> + // CHECK-NEXT: %[[VAL_6:.*]] = "mhlo.tuple"(%[[VAL_5]]) : (tensor<256x30x30x16xf32>) -> tuple<tensor<256x30x30x16xf32>> ROOT %tuple.6 = (f32[256,30,30,16]{3,2,1,0}) tuple(%reshape.5), metadata={op_name="HLO_Retvals"} } diff --git a/tensorflow/compiler/mlir/xla/tests/translate/types.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/types.hlotxt index f7e1ba9ff15..892fca73b6d 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/types.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/translate/types.hlotxt @@ -4,25 +4,25 @@ HloModule tfcompile.1 // CHECK-LABEL: func @main() -> tensor<i1> { ENTRY %tfcompile.1 { - // CHECK-NEXT: %cst = constant dense<1.000000e+00> : tensor<f32> + // CHECK-NEXT: %[[VAL_0:.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32> %constant.0 = f32[] constant(1) - // CHECK-NEXT: %cst_0 = constant dense<1.000000e+00> : tensor<f64> + // CHECK-NEXT: %[[VAL_1:.*]] = mhlo.constant dense<1.000000e+00> : tensor<f64> %constant.1 = f64[] constant(1) - // CHECK-NEXT: %cst_1 = constant dense<1> : tensor<i8> + // CHECK-NEXT: %[[VAL_2:.*]] = mhlo.constant dense<1> : tensor<i8> %constant.2 = s8[] constant(1) - // CHECK-NEXT: %cst_2 = constant dense<1> : tensor<i16> + // CHECK-NEXT: %[[VAL_3:.*]] = mhlo.constant dense<1> : tensor<i16> %constant.3 = s16[] constant(1) - // CHECK-NEXT: %cst_3 = constant dense<1> : tensor<i32> + // CHECK-NEXT: %[[VAL_4:.*]] = mhlo.constant dense<1> : tensor<i32> %constant.4 = s32[] constant(1) - // CHECK-NEXT: %cst_4 = constant dense<1> : tensor<i64> + // CHECK-NEXT: %[[VAL_5:.*]] = mhlo.constant dense<1> : tensor<i64> %constant.5 = s64[] constant(1) - // CHECK-NEXT: %cst_5 = constant dense<true> : tensor<i1> - // CHECK-NEXT: return %cst_5 : tensor<i1> + // CHECK-NEXT: %[[VAL_6:.*]] = mhlo.constant dense<true> : tensor<i1> + // CHECK-NEXT: return %[[VAL_6]] : tensor<i1> ROOT %constant.6 = pred[] constant(1) } diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 3adb7dacd93..df1b7f86956 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -351,6 +351,7 @@ cc_library( ":xla_op_registry", ":xla_resource", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", "@com_google_absl//absl/types:span", "@com_google_absl//absl/types:variant", diff --git a/tensorflow/compiler/tf2xla/kernels/data_format_ops.cc b/tensorflow/compiler/tf2xla/kernels/data_format_ops.cc index a62d15f7904..3c9c0997fcd 100644 --- a/tensorflow/compiler/tf2xla/kernels/data_format_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/data_format_ops.cc @@ -91,15 +91,15 @@ class DataFormatVecPermuteOp : public XlaOpKernel { : XlaOpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("src_format", &src_format_)); OP_REQUIRES( - ctx, src_format_.size() == 4, - errors::InvalidArgument("Data format should have 4 characters")); + ctx, src_format_.size() == 4 || src_format_.size() == 5, + errors::InvalidArgument("Data format should have 4 or 5 characters")); TensorFormat data_format; OP_REQUIRES(ctx, FormatFromString(src_format_, &data_format), errors::InvalidArgument("Invalid data format")); OP_REQUIRES_OK(ctx, ctx->GetAttr("dst_format", &dst_format_)); OP_REQUIRES( - ctx, dst_format_.size() == 4, - errors::InvalidArgument("Data format should have 4 characters")); + ctx, dst_format_.size() == 4 || dst_format_.size() == 5, + errors::InvalidArgument("Data format should have 4 or 5 characters")); OP_REQUIRES(ctx, FormatFromString(dst_format_, &data_format), errors::InvalidArgument("Invalid data format")); } @@ -113,9 +113,10 @@ class DataFormatVecPermuteOp : public XlaOpKernel { input_tensor_shape.DebugString())); const int dim0 = input_tensor_shape.dim_size(0); OP_REQUIRES( - ctx, dim0 == 2 || dim0 == 4, + ctx, dim0 == 2 || dim0 == 4 || dim0 == 5, errors::InvalidArgument( - "First dimension of input must be of size 4, but got shape ", + "First dimension of input must be of size 2, 4 or 5, but got " + "shape ", input_tensor_shape.DebugString())); if (input_rank == 2) { OP_REQUIRES( diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 5c8cfdde9e4..3d6a66c6ebc 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -18,6 +18,7 @@ limitations under the License. #include <numeric> #include <vector> +#include "absl/container/flat_hash_map.h" #include "absl/memory/memory.h" #include "absl/types/variant.h" #include "tensorflow/compiler/jit/defs.h" @@ -675,6 +676,38 @@ std::unique_ptr<Graph> XlaCompiler::GetGraph(const FunctionBody* fbody) { return graph; } +// Collects all control rets from `orig_control_ret_nodes` that are still valid, +// keeping the same order. +std::vector<std::string> GetValidControlRets( + absl::Span<Node* const> orig_control_ret_nodes, const Graph& graph) { + // Build map from control ret node to index. + absl::flat_hash_map<const Node*, int> control_ret_nodes_map; + for (int i = 0; i < orig_control_ret_nodes.size(); ++i) { + const Node* n = orig_control_ret_nodes[i]; + control_ret_nodes_map[n] = i; + } + // Check which control rets are still valid. + std::vector<bool> is_valid_control_ret(orig_control_ret_nodes.size(), false); + int num_valid_control_rets = 0; + for (const Node* n : graph.nodes()) { + auto iter = control_ret_nodes_map.find(n); + if (iter != control_ret_nodes_map.end()) { + ++num_valid_control_rets; + is_valid_control_ret[iter->second] = true; + } + } + // Return valid control rets in same order as they appear in + // `orig_control_ret_nodes`. + std::vector<std::string> valid_control_rets; + valid_control_rets.reserve(num_valid_control_rets); + for (int i = 0; i < orig_control_ret_nodes.size(); ++i) { + if (is_valid_control_ret[i]) { + valid_control_rets.push_back(orig_control_ret_nodes[i]->name()); + } + } + return valid_control_rets; +} + Status XlaCompiler::CompileFunction( const XlaCompiler::CompileOptions& options, const NameAttrList& fn_name_attrs, @@ -765,15 +798,15 @@ Status XlaCompiler::CompileFunction( ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED) { VLOG(1) << "Using MLIR bridge"; GraphDebugInfo debug_info; - std::vector<std::string> control_rets; - for (const auto* control_ret_node : fbody->control_ret_nodes) { - control_rets.push_back(control_ret_node->name()); - } + + std::vector<std::string> valid_control_rets = + GetValidControlRets(fbody->control_ret_nodes, *graph); + TF_RETURN_IF_ERROR(CompileGraphToXlaHlo( std::move(*graph), mlir::SpanToArrayRef<XlaCompiler::Argument>(args), - control_rets, options_.device_type.type_string(), options.use_tuple_arg, - *options_.flib_def, debug_info, options_.shape_representation_fn, - result)); + valid_control_rets, options_.device_type.type_string(), + options.use_tuple_arg, *options_.flib_def, debug_info, + options_.shape_representation_fn, result)); } else { TF_RETURN_IF_ERROR( CompileGraph(options, function_id, std::move(graph), args, result)); diff --git a/tensorflow/compiler/xla/pjrt/pjrt_client.cc b/tensorflow/compiler/xla/pjrt/pjrt_client.cc index b2e02a0450f..16de0718d64 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_client.cc +++ b/tensorflow/compiler/xla/pjrt/pjrt_client.cc @@ -1994,7 +1994,7 @@ PjRtExecutable::ExecuteOnLocalDevices( if (!statusor.ok()) { return AppendStatus( statusor.status(), - absl::StrFormat("while running replica %d and partition %d of a" + absl::StrFormat("while running replica %d and partition %d of a " "replicated computation (other " "replicas may have failed as well).", replica, partition)); diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 27c308319f7..046701c564f 100755 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -5220,13 +5220,31 @@ StatusOr<bool> AlgebraicSimplifierVisitor::SwapConvOperands( for (int64 spatial_dim = 0; spatial_dim < dnums.input_spatial_dimensions_size(); ++spatial_dim) { const int64 kernel_size = window_dims[spatial_dim].size(); - const int64 dilated_kernel_size = - 1 + (kernel_size - 1) * window_dims[spatial_dim].window_dilation(); - + const bool can_be_group_or_contraction = + !window_dims[spatial_dim].window_reversal() && + window_dims[spatial_dim].padding_low() == 0 && + window_dims[spatial_dim].padding_high() == 0 && + window_dims[spatial_dim].window_dilation() == 1; + const bool is_group_dim = + can_be_group_or_contraction && + window_dims[spatial_dim].base_dilation() == kernel_size && + window_dims[spatial_dim].stride() == kernel_size - 1; const int64 input_size = input->shape().dimensions(dnums.input_spatial_dimensions(spatial_dim)); + const bool is_pure_contraction_dim = + kernel_size == input_size && can_be_group_or_contraction && + window_dims[spatial_dim].base_dilation() == 1 && + window_dims[spatial_dim].stride() == 1; + if (is_group_dim || is_pure_contraction_dim) { + *(swapped_window.add_dimensions()) = window_dims[spatial_dim]; + continue; + } + + const int64 dilated_kernel_size = + 1 + (kernel_size - 1) * window_dims[spatial_dim].window_dilation(); const int64 dilated_input_size = 1 + (input_size - 1) * window_dims[spatial_dim].base_dilation(); + // Don't decide to swap if the input size is one, since many convolution // implementations can easily hand that special case efficiently. kernel_product *= kernel_size; diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index c4f3ea4087b..91c6c29ee80 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -6654,6 +6654,32 @@ TEST_F(AlgebraicSimplifierTest, AbsEliminationMultiply) { GmockMatch(m::Multiply(m::Parameter(0), m::Parameter(0)))); } +TEST_F(AlgebraicSimplifierTest, BroadcastCompareSimplification) { + std::string module_string = R"( + HloModule m + test { + a = s32[] parameter(0) + b = s32[] parameter(1) + x = s32[10]{0} parameter(2) + broadcast_a = s32[10]{0} broadcast(a), dimensions={} + broadcast_b = s32[10]{0} broadcast(b), dimensions={} + add = s32[10]{0} add(broadcast_a, x) + ROOT cmp = pred[10]{0} compare(add, broadcast_b), direction=EQ + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_string)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Compare(m::Parameter(2), + m::Broadcast(m::Subtract( + m::Parameter(1), m::Parameter(0)))))); + + // Numerically unstable transformation shouldn't be applied to floating types. + std::string module_string_f32 = + absl::StrReplaceAll(module_string, {{"s32", "f32"}}); + ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); +} + TEST_F(AlgebraicSimplifierTest, AbsEliminationPower2) { const char* kModuleStr = R"( HloModule m diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 54822323137..4425a3681c1 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -63,7 +63,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h" -#include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" @@ -2200,20 +2199,21 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) { if (llvm_ir::CanEmitFusedDynamicUpdateSliceInPlace(fusion, assignment_)) { VLOG(3) << "HandleFusion FusedDynamicUpdateSliceInPlace"; CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_); + FusedIrEmitter fused_emitter(&elemental_emitter); + BindFusionArguments(fusion, &fused_emitter); + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(fusion)); // Delegate to common implementation of fused in-place dynamic-update-slice. return llvm_ir::EmitFusedDynamicUpdateSliceInPlace( - fusion, GetGeneratorForOperandIrArrays(fusion), GetIrArrayFor(fusion), - &elemental_emitter, &b_); + fusion, GetIrArrayFor(fusion), &fused_emitter, &b_); } else if (fusion->IsLoopFusion()) { VLOG(3) << "HandleFusion kLoop"; CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_); - auto operands = GetIrArraysForOperandsOf(fusion); - FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(fusion), - &elemental_emitter); - TF_RETURN_IF_ERROR(fusion->fused_expression_root()->Accept(&fused_emitter)); - - return EmitTargetElementLoop(fusion, fused_emitter.GetRootGenerator()); + FusedIrEmitter fused_emitter(&elemental_emitter); + BindFusionArguments(fusion, &fused_emitter); + TF_ASSIGN_OR_RETURN(auto generator, fused_emitter.GetGenerator( + fusion->fused_expression_root())); + return EmitTargetElementLoop(fusion, generator); } else if (fusion->IsOutputFusion()) { VLOG(3) << "HandleFusion kOutput"; int64 dot_op_index = root->operand(0)->opcode() == HloOpcode::kDot ? 0 : 1; @@ -3451,5 +3451,17 @@ llvm::Value* IrEmitter::GetBufferForGlobalCallReturnValue( return EmitBufferPointer(root_buffer, root_inst->shape()); } +void IrEmitter::BindFusionArguments(const HloInstruction* fusion, + FusedIrEmitter* fused_emitter) { + for (int i = 0; i < fusion->operand_count(); i++) { + const HloInstruction* operand = fusion->operand(i); + fused_emitter->BindGenerator( + fusion->fused_parameter(i), + [this, operand](llvm_ir::IrArray::Index index) { + return GetIrArrayFor(operand).EmitReadArrayElement(index, &b_); + }); + } +} + } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index f136e3470e5..891d53c889d 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -43,6 +43,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h" +#include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" @@ -234,10 +235,9 @@ class IrEmitter : public DfsHloVisitorWithDefault, std::vector<llvm_ir::IrArray> GetIrArraysForOperandsOf( const HloInstruction* hlo); - GeneratorForOperandIrArrays GetGeneratorForOperandIrArrays( - HloInstruction* unnested_hlo) { - return [=]() { return GetIrArraysForOperandsOf(unnested_hlo); }; - } + // Bind all argument IrArrays of `fusion` to `fused_emitter`. + void BindFusionArguments(const HloInstruction* fusion, + FusedIrEmitter* fused_emitter); // Augments IrArray with aliasing information. void AddAliasingInformationToIrArray(const HloInstruction& hlo, diff --git a/tensorflow/compiler/xla/service/fusion_node_indexing_evaluation.cc b/tensorflow/compiler/xla/service/fusion_node_indexing_evaluation.cc index 17d3fb2b3d6..25b9658ba98 100644 --- a/tensorflow/compiler/xla/service/fusion_node_indexing_evaluation.cc +++ b/tensorflow/compiler/xla/service/fusion_node_indexing_evaluation.cc @@ -38,14 +38,60 @@ FusionNodeIndexingEvaluation::FusionNodeIndexingEvaluation( // a tradeoff between compilation time and runtime here. const int64 FusionNodeIndexingEvaluation::kAllowedCodeDuplication = 15; +namespace { + +// Returns which ops invalidate the cache of emitted instructions by creating a +// new BasicBlock and setting the insertion point to the newly created +// BasicBlock. We can only reuse cached values if they were emitted in the same +// BasicBlock as the current BasicBlock. +bool OpInvalidatesCache(const HloInstruction* hlo) { + switch (hlo->opcode()) { + // This list of ops was created by inspecting the code. There is no + // guarantee that it is complete. + case HloOpcode::kConcatenate: + case HloOpcode::kDot: + case HloOpcode::kDynamicUpdateSlice: + case HloOpcode::kPad: + case HloOpcode::kReduce: + case HloOpcode::kReduceWindow: + return true; + default: + return false; + } +} + +// Counts the number of "real" users of 'hlo'. When 'hlo' has a fusion node as +// user, we consider the users of the fusion parameter corresponding to 'hlo' as +// the real users. +int64 UserCount(const HloInstruction* hlo) { + int64 cnt = 0; + for (HloInstruction* user : hlo->users()) { + if (user->opcode() == HloOpcode::kFusion) { + // Count the number of users of the parameter corresponding to the fusion + // operand. + int64 operand_index = user->operand_index(hlo); + cnt += user->fused_parameter(operand_index)->user_count(); + } else { + ++cnt; + } + } + return cnt; +} +} // namespace + bool FusionNodeIndexingEvaluation::CodeDuplicationTooHigh( const HloInstruction* producer) const { - return EvaluateEmittedInstructions(producer) > kAllowedCodeDuplication; + int64 emitted_instructions = EvaluateEmittedInstructions(producer); + return emitted_instructions > kAllowedCodeDuplication || + (OpInvalidatesCache(producer) && + (emitted_instructions > 1 || UserCount(producer) > 1)); } bool FusionNodeIndexingEvaluation::MaxCodeDuplicationTooHigh() const { for (const auto& entry : index_usage_count_) { - if (entry.second > kAllowedCodeDuplication) { + if (entry.second > kAllowedCodeDuplication || + (OpInvalidatesCache(entry.first) && + (entry.second > 1 || UserCount(entry.first) > 1))) { return true; } } diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc index 2215881271c..5682fcedf1d 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc @@ -773,11 +773,11 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) { CHECK_EQ(HloInstruction::FusionKind::kLoop, fusion->fusion_kind()); GpuElementalIrEmitter elemental_emitter(hlo_module_config_, module_, &b_, GetNestedComputer()); - FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(fusion), - &elemental_emitter); - TF_RETURN_IF_ERROR(fusion->fused_expression_root()->Accept(&fused_emitter)); - - return EmitTargetElementLoop(*fusion, fused_emitter.GetRootGenerator()); + FusedIrEmitter fused_emitter(&elemental_emitter); + BindFusionArguments(fusion, &fused_emitter); + TF_ASSIGN_OR_RETURN(auto generator, fused_emitter.GetGenerator( + fusion->fused_expression_root())); + return EmitTargetElementLoop(*fusion, generator); } Status IrEmitter::HandleCall(HloInstruction* call) { @@ -876,5 +876,17 @@ std::vector<llvm_ir::IrArray> IrEmitter::ConstructIrArrayForOutputs( return output_arrays; } +void IrEmitter::BindFusionArguments(const HloInstruction* fusion, + FusedIrEmitter* fused_emitter) { + for (int i = 0; i < fusion->operand_count(); i++) { + const HloInstruction* operand = fusion->operand(i); + fused_emitter->BindGenerator( + fusion->fused_parameter(i), + [this, operand, fusion](llvm_ir::IrArray::Index index) { + return GetIrArray(*operand, *fusion).EmitReadArrayElement(index, &b_); + }); + } +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.h b/tensorflow/compiler/xla/service/gpu/ir_emitter.h index 1a387528220..894f1401e0d 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.h @@ -36,6 +36,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" @@ -182,18 +183,9 @@ class IrEmitter : public DfsHloVisitorWithDefault, const HloModuleConfig& hlo_module_config_; protected: - GeneratorForOperandIrArrays GetGeneratorForOperandIrArrays( - const HloInstruction* fusion) { - return [=]() { - std::vector<llvm_ir::IrArray> ir_arrays; - ir_arrays.reserve(fusion->operand_count()); - absl::c_transform(fusion->operands(), std::back_inserter(ir_arrays), - [&](const HloInstruction* operand) { - return GetIrArray(*operand, *fusion); - }); - return ir_arrays; - }; - } + // Bind all argument IrArrays of `fusion` to `fused_emitter`. + void BindFusionArguments(const HloInstruction* fusion, + FusedIrEmitter* fused_emitter); private: // A helper method for EmitAtomicOperationForNestedComputation. Certain diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index ed4d75effba..3896d8e870c 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -960,19 +960,24 @@ Status IrEmitterUnnested::EmitLoopFusionFromMlir(MlirEmitterInput input, GpuElementalIrEmitter elemental_emitter(hlo_module_config_, module_, &b_, GetNestedComputer()); + FusedIrEmitter fused_emitter(&elemental_emitter); - FusedIrEmitter fused_emitter( - [&] { - auto operand_ir_arrays = - absl::MakeSpan(ir_arrays).subspan(0, fusion_operands.size()); - return std::vector<llvm_ir::IrArray>(operand_ir_arrays.begin(), - operand_ir_arrays.end()); - }, - &elemental_emitter); - TF_RETURN_IF_ERROR( - fused_computation->root_instruction()->Accept(&fused_emitter)); + for (int i = 0; i < fusion_operands.size(); i++) { + auto operand_ir_arrays = + absl::MakeSpan(ir_arrays).subspan(0, fusion_operands.size()); + + auto* builder = &b_; + auto ir_array = operand_ir_arrays[i]; + fused_emitter.BindGenerator( + fused_computation->parameter_instruction(i), + [builder, ir_array](llvm_ir::IrArray::Index index) { + return ir_array.EmitReadArrayElement(index, builder); + }); + } + TF_ASSIGN_OR_RETURN( + auto element_generator, + fused_emitter.GetGenerator(fused_computation->root_instruction())); - auto element_generator = fused_emitter.GetRootGenerator(); Shape element_shape = TypeToShape(fusion_outputs[0].getType()); LaunchDimensions launch_dimensions = CalculateLaunchDimensions( element_shape, ir_emitter_context_->gpu_device_info(), unroll_factor); @@ -1022,14 +1027,14 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { GpuElementalIrEmitter operand_elemental_emitter( hlo_module_config_, ir_emitter_context_->llvm_module(), &b_, GetNestedComputer()); - FusedIrEmitter operand_fused_emitter( - GetGeneratorForOperandIrArrays(fusion), - &operand_elemental_emitter); - TF_RETURN_IF_ERROR( - root->mutable_operand(0)->Accept(&operand_fused_emitter)); + FusedIrEmitter operand_fused_emitter(&operand_elemental_emitter); + BindFusionArguments(fusion, &operand_fused_emitter); + TF_ASSIGN_OR_RETURN( + auto generator, + operand_fused_emitter.GetGenerator(root->operand(0))); TF_RETURN_IF_ERROR(EmitTargetElementLoopInThunk( - *fusion, operand_fused_emitter.GetGenerator(root->operand(0)), + *fusion, generator, static_cast<KernelThunk*>(thunks.back().get()), ComputeMaxUnrollFactor(fusion))); } @@ -1044,10 +1049,8 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { GpuElementalIrEmitter scatter_elemental_emitter( hlo_module_config_, ir_emitter_context_->llvm_module(), &b_, GetNestedComputer()); - FusedIrEmitter scatter_fused_emitter( - GetGeneratorForOperandIrArrays(fusion), - &scatter_elemental_emitter); - TF_RETURN_IF_ERROR(root->Accept(&scatter_fused_emitter)); + FusedIrEmitter scatter_fused_emitter(&scatter_elemental_emitter); + BindFusionArguments(fusion, &scatter_fused_emitter); CHECK_EQ(root->parent()->FusionInstruction(), fusion); TF_ASSIGN_OR_RETURN( @@ -1063,10 +1066,12 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { desc.unique_indices = root->unique_indices(); desc.update_computation = root->called_computations()[0]; desc.output = GetIrArray(*fusion, *fusion); - desc.scatter_indices_gen = - scatter_fused_emitter.GetGenerator(root->operand(1)); - desc.updates_gen = - scatter_fused_emitter.GetGenerator(root->operand(2)); + TF_ASSIGN_OR_RETURN( + desc.scatter_indices_gen, + scatter_fused_emitter.GetGenerator(root->operand(1))); + TF_ASSIGN_OR_RETURN( + desc.updates_gen, + scatter_fused_emitter.GetGenerator(root->operand(2))); desc.get_index_type = [&](int64 launch_size) { return GetIndexTypeForKernel(root, launch_size, &b_); }; @@ -1133,9 +1138,11 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { ir_emitter_context_->llvm_module()); AddThunkToThunkSequence(std::move(fusion_thunk)); + FusedIrEmitter fused_emitter(&elemental_emitter); + BindFusionArguments(fusion, &fused_emitter); + return llvm_ir::EmitParallelFusedDynamicUpdateSliceInPlace( - fusion, GetGeneratorForOperandIrArrays(fusion), output_array, - &elemental_emitter, launch_dimensions, &b_); + fusion, output_array, &fused_emitter, launch_dimensions, &b_); } CHECK_EQ(fusion->fusion_kind(), HloInstruction::FusionKind::kLoop) @@ -2596,14 +2603,14 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk( ir_emitter_context_->llvm_module(), &b_, GetNestedComputer()); - FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(hlo), - &elemental_emitter); - TF_RETURN_IF_ERROR(init_value_operand->Accept(&fused_emitter)); - TF_RETURN_IF_ERROR( - ParallelLoopEmitter(fused_emitter.GetGenerator(init_value_operand), - GetIrArray(*hlo, *hlo, index), launch_dimensions, - &b_) - .EmitLoop(IrName(hlo))); + FusedIrEmitter fused_emitter(&elemental_emitter); + BindFusionArguments(hlo, &fused_emitter); + TF_ASSIGN_OR_RETURN(auto generator, + fused_emitter.GetGenerator(init_value_operand)); + TF_RETURN_IF_ERROR(ParallelLoopEmitter(generator, + GetIrArray(*hlo, *hlo, index), + launch_dimensions, &b_) + .EmitLoop(IrName(hlo))); } else { // In the unfused case the element is already there, just read from it. TF_RETURN_IF_ERROR(ParallelLoopEmitter( @@ -3097,15 +3104,35 @@ void IrEmitterUnnested::EmitTileElementForFusion( std::vector<IrArray> output_arrays = ConstructIrArrayForOutputs(*hlo); GpuElementalIrEmitter elem_emitter(hlo_module_config_, module_, &b_, GetNestedComputer()); - FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(hlo), - &elem_emitter, x_loc, y_loc, - param_shmem_buffers); - - TF_CHECK_OK(hlo->fused_expression_root()->Accept(&fused_emitter)); + FusedIrEmitter fused_emitter(&elem_emitter); + for (int i = 0; i < hlo->operand_count(); i++) { + llvm_ir::ElementGenerator gen; + if (llvm::Value* param_tile_buffer = param_shmem_buffers[i]) { + gen = [this, param_tile_buffer, x_loc, + y_loc](llvm_ir::IrArray::Index index) { + // TODO(jlebar): Add AA metadata to this load. Tile buffers are + // global variables, so LLVM's points-to analysis doesn't help us + // much. And we want the AA info to be present before address + // spaces are inferred (which is pretty late in the pipeline), so + // even if we had address-space-based AA in LLVM, it wouldn't help + // us much here. + return b_.CreateLoad( + b_.CreateGEP(param_tile_buffer, + {index.GetConstantWithIndexType(0), x_loc, y_loc}), + "tiled_buffer"); + }; + } else { + const HloInstruction* operand = hlo->operand(i); + gen = [this, operand, hlo](llvm_ir::IrArray::Index index) { + return GetIrArray(*operand, *hlo).EmitReadArrayElement(index, &b_); + }; + } + fused_emitter.BindGenerator(hlo->fused_parameter(i), std::move(gen)); + } IrArray::Index untiled_index = GetUnnormalizedIndex( index, output_arrays[0].GetShape(), &b_, mapping_scheme); - const llvm_ir::ElementGenerator& output_generator = - fused_emitter.GetRootGenerator(); + llvm_ir::ElementGenerator output_generator = + *fused_emitter.GetGenerator(hlo->fused_expression_root()); llvm::Value* output_value = output_generator(untiled_index).ValueOrDie(); if (hlo->IsMultiOutputFusion()) { DCHECK(output_value->getType()->isStructTy()); @@ -3161,14 +3188,12 @@ void IrEmitterUnnested::EmitPrologueForReduction( llvm::Value* init_ir_value; const HloInstruction* init_value = reduce_inst->operand(1); if (unnested_hlo->opcode() == HloOpcode::kFusion) { - FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(unnested_hlo), - &elemental_emitter); + FusedIrEmitter fused_emitter(&elemental_emitter); + BindFusionArguments(unnested_hlo, &fused_emitter); - TF_CHECK_OK(init_value->Accept(&fused_emitter)); - init_ir_value = - fused_emitter - .GetGenerator(init_value)(IrArray::Index(b_.getInt32Ty())) - .ValueOrDie(); + init_ir_value = (*fused_emitter.GetGenerator(init_value))( + IrArray::Index(b_.getInt32Ty())) + .ValueOrDie(); } else { init_ir_value = GetIrArray(*init_value, *unnested_hlo) @@ -3507,21 +3532,21 @@ void IrEmitterUnnested::EmitTileElementForReduction( extra_output_gens; GpuElementalIrEmitter elem_emitter(hlo_module_config_, module_, &b_, GetNestedComputer()); - FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(unnested_hlo), - &elem_emitter); + FusedIrEmitter fused_emitter(&elem_emitter); + // Construct the ElementGenerator for each reduction and extra output in the // the group of output instructions. if (unnested_hlo->opcode() == HloOpcode::kFusion) { - TF_CHECK_OK(unnested_hlo->fused_expression_root()->Accept(&fused_emitter)); + BindFusionArguments(unnested_hlo, &fused_emitter); for (int i = 0, e = output_instructions.size(); i != e; ++i) { const HloInstruction* inst = output_instructions[i]; ShapeIndex idx = CreateShapeIndexForOutputInstruction(*unnested_hlo, *inst); if (IsReductionFromOrToContiguousDimensions(*inst)) { - input_gens.push_back(fused_emitter.GetGenerator(inst->operand(0))); + input_gens.push_back(*fused_emitter.GetGenerator(inst->operand(0))); } else { - extra_output_gens.emplace_back(fused_emitter.GetGenerator(inst), + extra_output_gens.emplace_back(*fused_emitter.GetGenerator(inst), std::move(idx)); } } @@ -4506,11 +4531,10 @@ void IrEmitterUnnested::EmitElementForInputFusibleSlices( std::vector<llvm::Value*> input_ir_values; GpuElementalIrEmitter elem_emitter(hlo_module_config_, module_, &b_, GetNestedComputer()); - FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(unnested_hlo), - &elem_emitter); - TF_CHECK_OK(unnested_hlo->fused_expression_root()->Accept(&fused_emitter)); + FusedIrEmitter fused_emitter(&elem_emitter); + BindFusionArguments(unnested_hlo, &fused_emitter); for (const HloInstruction* slice : slice_instructions) { - auto input_generator = fused_emitter.GetGenerator(slice->operand(0)); + auto input_generator = *fused_emitter.GetGenerator(slice->operand(0)); input_ir_values.push_back(input_generator(index).ValueOrDie()); } diff --git a/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc index 406fe84019e..ff132b40605 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc @@ -17,7 +17,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/launch_dimensions.h" #include "tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h" -#include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h" @@ -190,9 +189,8 @@ Status EmitDynamicUpdateSliceInPlace(absl::Span<const IrArray> operand_arrays, // // Emits a sequential loop if launch_dimensions is null. static Status EmitFusedDynamicUpdateSliceInPlaceImpl( - HloInstruction* fusion, - GeneratorForOperandIrArrays operand_arrays_generator, - const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter, + HloInstruction* fusion, const IrArray& fusion_output_array, + FusedIrEmitter* fused_emitter, const gpu::LaunchDimensions* launch_dimensions, llvm::IRBuilder<>* b) { CHECK_EQ(fusion->opcode(), HloOpcode::kFusion); VLOG(2) << "EmitFusedDynamicUpdateSliceInPlace for " @@ -221,14 +219,14 @@ static Status EmitFusedDynamicUpdateSliceInPlaceImpl( LayoutUtil::CopyLayoutBetweenShapes(fusion->shape(), &update_shape)); // Create element generators for update and start_indices. - FusedIrEmitter fused_emitter(std::move(operand_arrays_generator), - elemental_emitter); - TF_RETURN_IF_ERROR(dynamic_update_slice->Accept(&fused_emitter)); - ElementGenerator update_array_generator = fused_emitter.GetGenerator(update); + TF_ASSIGN_OR_RETURN(ElementGenerator update_array_generator, + fused_emitter->GetGenerator(update)); - IndexGenerator start_indices_generator = [&](int64 index) { - ElementGenerator element_generator = - fused_emitter.GetGenerator(dynamic_update_slice->operand(2 + index)); + IndexGenerator start_indices_generator = + [&](int64 index) -> StatusOr<llvm::Value*> { + TF_ASSIGN_OR_RETURN( + ElementGenerator element_generator, + fused_emitter->GetGenerator(dynamic_update_slice->operand(2 + index))); return element_generator(IrArray::Index(b->getInt64Ty())); }; bool is_signed = ShapeUtil::ElementIsSigned(start_indices->shape()); @@ -237,25 +235,21 @@ static Status EmitFusedDynamicUpdateSliceInPlaceImpl( fusion_output_array, launch_dimensions, IrName(fusion), b); } -Status EmitFusedDynamicUpdateSliceInPlace( - HloInstruction* fusion, - GeneratorForOperandIrArrays operand_arrays_generator, - const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter, - llvm::IRBuilder<>* b) { +Status EmitFusedDynamicUpdateSliceInPlace(HloInstruction* fusion, + const IrArray& fusion_output_array, + FusedIrEmitter* fused_emitter, + llvm::IRBuilder<>* b) { return EmitFusedDynamicUpdateSliceInPlaceImpl( - fusion, std::move(operand_arrays_generator), fusion_output_array, - elemental_emitter, + fusion, fusion_output_array, fused_emitter, /*launch_dimensions=*/nullptr, b); } Status EmitParallelFusedDynamicUpdateSliceInPlace( - HloInstruction* fusion, - GeneratorForOperandIrArrays operand_arrays_generator, - const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter, + HloInstruction* fusion, const IrArray& fusion_output_array, + FusedIrEmitter* fused_emitter, const gpu::LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* b) { return EmitFusedDynamicUpdateSliceInPlaceImpl( - fusion, std::move(operand_arrays_generator), fusion_output_array, - elemental_emitter, &launch_dimensions, b); + fusion, fusion_output_array, fused_emitter, &launch_dimensions, b); } } // namespace llvm_ir diff --git a/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h index b40501b738c..b5e343b367f 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h +++ b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h" #include "tensorflow/compiler/xla/service/gpu/launch_dimensions.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" // Utilities related to emitting LLVM IR for various HLO ops. @@ -71,18 +72,16 @@ Status EmitDynamicUpdateSliceInPlace(absl::Span<const IrArray> operand_arrays, // array-to-be-updated and output share the same buffer slice, emits // (sequential) code for a fusion node that does the dynamic-update-slice in // place. -Status EmitFusedDynamicUpdateSliceInPlace( - HloInstruction* fusion, - GeneratorForOperandIrArrays operand_arrays_generator, - const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter, - llvm::IRBuilder<>* b); +Status EmitFusedDynamicUpdateSliceInPlace(HloInstruction* fusion, + const IrArray& fusion_output_array, + FusedIrEmitter* fused_emitter, + llvm::IRBuilder<>* b); // Same as EmitFusedDynamicUpdateSliceInPlace, except emits a parallel loop with // the given launch dimensions. Status EmitParallelFusedDynamicUpdateSliceInPlace( - HloInstruction* fusion, - GeneratorForOperandIrArrays operand_arrays_generator, - const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter, + HloInstruction* fusion, const IrArray& fusion_output_array, + FusedIrEmitter* fused_emitter, const gpu::LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* b); } // namespace llvm_ir diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc index 245ab8182af..0a26a2bb7ce 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc @@ -114,25 +114,9 @@ Status FusedIrEmitter::HandleGetTupleElement( } Status FusedIrEmitter::HandleParameter(const HloInstruction* parameter) { - indexed_generators_[parameter] = - [=](const IrArray::Index& index) -> llvm::Value* { - int64 param_num = parameter->parameter_number(); - if (param_shmem_buffers_.size() > param_num) { - if (llvm::Value* param_tile_buffer = param_shmem_buffers_[param_num]) { - // TODO(jlebar): Add AA metadata to this load. Tile buffers are global - // variables, so LLVM's points-to analysis doesn't help us much. And we - // want the AA info to be present before address spaces are inferred - // (which is pretty late in the pipeline), so even if we had - // address-space-based AA in LLVM, it wouldn't help us much here. - return b_->CreateLoad( - b_->CreateGEP(param_tile_buffer, {index.GetConstantWithIndexType(0), - thread_id_x_, thread_id_y_}), - "tiled_buffer"); - } - } - return GetIrArrayForFusedParameter(param_num).EmitReadArrayElement(index, - b_); - }; + if (indexed_generators_.find(parameter) == indexed_generators_.end()) { + return InvalidArgument("Unbound parameter: %s", parameter->ToString()); + } return Status::OK(); } @@ -157,22 +141,6 @@ Status FusedIrEmitter::HandleTuple(const HloInstruction* tuple) { return Status::OK(); } -Status FusedIrEmitter::FinishVisit(const HloInstruction* root) { - fused_root_ = root; - return Status::OK(); -} - -FusedIrEmitter::IndexedGenerator FusedIrEmitter::GetRootGenerator() const { - CHECK_NE(nullptr, fused_root_) - << "GetRootGenerator should be called after Accept."; - return indexed_generators_.at(fused_root_); -} - -FusedIrEmitter::IndexedGenerator FusedIrEmitter::GetGenerator( - const HloInstruction* instruction) const { - return indexed_generators_.at(instruction); -} - bool FusedIrEmitter::IsFusedIrEmitterInefficient( const HloInstruction* consumer, const HloInstruction* producer) { if (consumer->opcode() != HloOpcode::kFusion) { @@ -189,4 +157,39 @@ bool FusedIrEmitter::IsFusedIrEmitterInefficient( return eval_producer.MaxCodeDuplicationTooHigh(); } +StatusOr<FusedIrEmitter::IndexedGenerator> FusedIrEmitter::GetGenerator( + const HloInstruction* instruction) { + std::vector<const HloInstruction*> stack; + stack.push_back(instruction); + while (!stack.empty()) { + const HloInstruction* instr = stack.back(); + stack.pop_back(); + if (indexed_generators_.count(instr)) { + continue; + } + for (const HloInstruction* operand : instr->operands()) { + stack.push_back(operand); + } + switch (instr->opcode()) { + case HloOpcode::kConstant: + TF_RETURN_IF_ERROR(HandleConstant(instr)); + break; + case HloOpcode::kGetTupleElement: + TF_RETURN_IF_ERROR(HandleGetTupleElement(instr)); + break; + case HloOpcode::kParameter: + TF_RETURN_IF_ERROR(HandleParameter(instr)); + break; + case HloOpcode::kTuple: + TF_RETURN_IF_ERROR(HandleTuple(instr)); + break; + default: + TF_RETURN_IF_ERROR(DefaultAction(instr)); + break; + } + CHECK(indexed_generators_.count(instr)); + } + return indexed_generators_.at(instruction); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h index d13b0262180..9059d150065 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h +++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h @@ -24,7 +24,6 @@ limitations under the License. #include "absl/types/span.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Value.h" -#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" @@ -51,47 +50,22 @@ namespace xla { // created produces an LLVM struct with N elements, one for each element of the // arrays in the tuple. It follows that the arrays in the tuple must have the // same length. -class FusedIrEmitter : public ConstDfsHloVisitorWithDefault { +class FusedIrEmitter { public: using IndexedGenerator = llvm_ir::ElementGenerator; - using NonIndexedGenerator = std::function<StatusOr<llvm::Value*>()>; - using GeneratorForOperandIrArrays = - std::function<std::vector<llvm_ir::IrArray>()>; - FusedIrEmitter(GeneratorForOperandIrArrays operand_arrays_generator, - ElementalIrEmitter* elemental_emitter, - llvm::Value* thread_id_x = nullptr, - llvm::Value* thread_id_y = nullptr, - absl::Span<llvm::Value* const> param_shmem_buffers = {}) - : operand_arrays_(), - operand_arrays_generator_(std::move(operand_arrays_generator)), - thread_id_x_(thread_id_x), - thread_id_y_(thread_id_y), - param_shmem_buffers_(param_shmem_buffers.begin(), - param_shmem_buffers.end()), - elemental_emitter_(elemental_emitter), + explicit FusedIrEmitter(ElementalIrEmitter* elemental_emitter) + : elemental_emitter_(elemental_emitter), b_(elemental_emitter->b()), module_(elemental_emitter->module()) {} - Status DefaultAction(const HloInstruction* hlo) override; - - Status HandleConstant(const HloInstruction* constant) override; - - Status HandleGetTupleElement( - const HloInstruction* get_tuple_element) override; - - Status HandleParameter(const HloInstruction* parameter) override; - - // Emits the ir value for each element in the tuple. - Status HandleTuple(const HloInstruction* tuple) override; - - Status FinishVisit(const HloInstruction* root) override; - - // Returns the generator function for the root of the fused computation. - IndexedGenerator GetRootGenerator() const; + void BindGenerator(const HloInstruction* hlo, + llvm_ir::ElementGenerator generator) { + indexed_generators_[hlo] = std::move(generator); + } // Returns the generator function for the given instruction. - IndexedGenerator GetGenerator(const HloInstruction* instruction) const; + StatusOr<IndexedGenerator> GetGenerator(const HloInstruction* instruction); // Evaluates whether fusing 'producer' into 'consumer' might cause exponential // behavior in FusedIrEmitter. We currently can have exponential time/memory @@ -101,40 +75,20 @@ class FusedIrEmitter : public ConstDfsHloVisitorWithDefault { static bool IsFusedIrEmitterInefficient(const HloInstruction* consumer, const HloInstruction* producer); - protected: - // Returns the IrArrays for the fusion instruction operands. - llvm_ir::IrArray& GetIrArrayForFusedParameter(int64 parameter_number) { - if (!operand_arrays_.has_value()) { - operand_arrays_ = operand_arrays_generator_(); - } - return operand_arrays_.value()[parameter_number]; - } - - llvm::Value* GetBasePointerForFusedParameter(int64 parameter_number) { - return GetIrArrayForFusedParameter(parameter_number).GetBasePointer(); - } - private: - // IrArrays for the fusion instruction operands, whose base addresses are the - // base address of the corresponding parameters in the fused computation. - absl::optional<std::vector<llvm_ir::IrArray>> operand_arrays_; - GeneratorForOperandIrArrays operand_arrays_generator_; + Status DefaultAction(const HloInstruction* hlo); - // The x coordinate within a tile. - llvm::Value* thread_id_x_; + Status HandleConstant(const HloInstruction* constant); - // The y coordinate within a tile. - llvm::Value* thread_id_y_; + Status HandleGetTupleElement(const HloInstruction* get_tuple_element); - // Param_buffers_[i] stores the tile buffer for the ith parameter or nullptr - // if the parameter is not tiled. - std::vector<llvm::Value*> param_shmem_buffers_; + Status HandleParameter(const HloInstruction* parameter); + + // Emits the ir value for each element in the tuple. + Status HandleTuple(const HloInstruction* tuple); ElementalIrEmitter* elemental_emitter_; - // This member will be set by FinishVisit and used in GetRootGenerator. - const HloInstruction* fused_root_ = nullptr; - // Borrowed llvm::IRBuilder<>* b_; llvm::Module* module_; @@ -145,12 +99,6 @@ class FusedIrEmitter : public ConstDfsHloVisitorWithDefault { std::unordered_map<const HloInstruction*, IndexedGenerator> indexed_generators_; - // Map from tuple-result-producing GetTupleELement instructions to functions - // that generate the base pointers for the output elements. This is used to - // support the translation of nested GetTupleElement instructions. - std::unordered_map<const HloInstruction*, NonIndexedGenerator> - non_indexed_generators_; - // Cache of generated values, lest we regenerate an element of a node with // multiple outgoing edges absl::flat_hash_map< diff --git a/tensorflow/compiler/xla/tests/concat_test.cc b/tensorflow/compiler/xla/tests/concat_test.cc index 9df83e30ad4..fe27a8c6963 100644 --- a/tensorflow/compiler/xla/tests/concat_test.cc +++ b/tensorflow/compiler/xla/tests/concat_test.cc @@ -521,8 +521,7 @@ XLA_TEST_F(ConcatTest, ConcatDeeplyNested) { ComputeAndCompareR1<float>(&builder, expected, {a_data.get()}); } -// TODO(b/169314478): Enable the test when the slow compilation is fixed. -XLA_TEST_F(ConcatTestHlo, DISABLED_ConcatWithBitcast) { +XLA_TEST_F(ConcatTestHlo, ConcatWithBitcast) { auto module = ParseAndReturnVerifiedModule(R"( HloModule jit_broken.874 @@ -762,7 +761,7 @@ ENTRY jit_broken.874 { auto input_array = absl::make_unique<Array2D<float>>(4, 2); input_array->FillUnique(1.0f); auto input = LiteralUtil::CreateR2FromArray2D<float>(*input_array); - EXPECT_TRUE(RunAndCompare(std::move(module), {&input}, absl::nullopt)); + EXPECT_TRUE(RunAndCompare(std::move(module), {&input}, error_spec_)); } // Describes a binary rank-2 concatenation test. diff --git a/tensorflow/core/common_runtime/base_collective_executor.cc b/tensorflow/core/common_runtime/base_collective_executor.cc index fe48b3f6079..a2cfce1111c 100644 --- a/tensorflow/core/common_runtime/base_collective_executor.cc +++ b/tensorflow/core/common_runtime/base_collective_executor.cc @@ -354,8 +354,11 @@ void BaseCollectiveExecutor::CompleteParamsAsync( // timeout callback executes, done_safe will become a no-op and the timeout // callback is responsible for invoking done() at the end. const auto is_callback_called = std::make_shared<std::atomic<bool>>(false); - auto done_safe = [this, is_callback_called, cancel_mgr, + auto trace_id = + profiler::TraceMe::ActivityStart("CollectiveExecutor::CompleteParams"); + auto done_safe = [this, is_callback_called, cancel_mgr, trace_id, done](const Status& s) { + profiler::TraceMe::ActivityEnd(trace_id); bool called = is_callback_called->exchange(true); if (!called) { if (!s.ok() && !IsCancelled(cancel_mgr)) { diff --git a/tensorflow/core/common_runtime/direct_session_test.cc b/tensorflow/core/common_runtime/direct_session_test.cc index 613449f572e..aceacf40132 100644 --- a/tensorflow/core/common_runtime/direct_session_test.cc +++ b/tensorflow/core/common_runtime/direct_session_test.cc @@ -2587,11 +2587,9 @@ TEST(DirectSessionTest, // A simple benchmark for the overhead of `DirectSession::Run()` calls // with varying numbers of feeds/fetches. -void FeedFetchBenchmarkHelper(int iters, int num_feeds, bool use_make_callable, - int inter_op_threads, +void FeedFetchBenchmarkHelper(::testing::benchmark::State& state, int num_feeds, + bool use_make_callable, int inter_op_threads, bool use_single_threaded_executor) { - testing::StopTiming(); - Tensor value(DT_FLOAT, TensorShape()); value.flat<float>()(0) = 37.0; @@ -2643,13 +2641,11 @@ void FeedFetchBenchmarkHelper(int iters, int num_feeds, bool use_make_callable, } TF_CHECK_OK(session->MakeCallable(callable_options, &handle)); - testing::StartTiming(); - for (int i = 0; i < iters; ++i) { + for (auto s : state) { std::vector<Tensor> output_values; TF_CHECK_OK( session->RunCallable(handle, input_tensors, &output_values, nullptr)); } - testing::StopTiming(); } else { { // NOTE(mrry): Ignore the first run, which will incur the graph @@ -2661,32 +2657,40 @@ void FeedFetchBenchmarkHelper(int iters, int num_feeds, bool use_make_callable, std::vector<Tensor> output_values; TF_CHECK_OK(session->Run(inputs, outputs, {}, &output_values)); } - testing::StartTiming(); - for (int i = 0; i < iters; ++i) { + + for (auto s : state) { std::vector<Tensor> output_values; TF_CHECK_OK(session->Run(inputs, outputs, {}, &output_values)); } - testing::StopTiming(); } } -void BM_FeedFetch(int iters, int num_feeds) { - FeedFetchBenchmarkHelper(iters, num_feeds, /* use_make_callable */ false, +void BM_FeedFetch(::testing::benchmark::State& state) { + const int num_feeds = state.range(0); + + FeedFetchBenchmarkHelper(state, num_feeds, /* use_make_callable */ false, /* inter_op_threads */ 0, /* use_single_threaded_executor */ false); } -void BM_FeedFetchCallable(int iters, int num_feeds) { - FeedFetchBenchmarkHelper(iters, num_feeds, /* use_make_callable */ true, +void BM_FeedFetchCallable(::testing::benchmark::State& state) { + const int num_feeds = state.range(0); + + FeedFetchBenchmarkHelper(state, num_feeds, /* use_make_callable */ true, /* inter_op_threads */ 0, /* use_single_threaded_executor */ false); } -void BM_FeedFetchCallableSingleThread(int iters, int num_feeds) { - FeedFetchBenchmarkHelper(iters, num_feeds, /* use_make_callable */ true, +void BM_FeedFetchCallableSingleThread(::testing::benchmark::State& state) { + const int num_feeds = state.range(0); + + FeedFetchBenchmarkHelper(state, num_feeds, /* use_make_callable */ true, /* inter_op_threads */ -1, /* use_single_threaded_executor */ false); } -void BM_FeedFetchCallableSingleThreadExecutor(int iters, int num_feeds) { - FeedFetchBenchmarkHelper(iters, num_feeds, /* use_make_callable */ true, +void BM_FeedFetchCallableSingleThreadExecutor( + ::testing::benchmark::State& state) { + const int num_feeds = state.range(0); + + FeedFetchBenchmarkHelper(state, num_feeds, /* use_make_callable */ true, /* inter_op_threads */ -1, /* use_single_threaded_executor */ true); } diff --git a/tensorflow/core/common_runtime/eager/eager_operation.cc b/tensorflow/core/common_runtime/eager/eager_operation.cc index 6d1ecf64fcc..41ab54a91e9 100644 --- a/tensorflow/core/common_runtime/eager/eager_operation.cc +++ b/tensorflow/core/common_runtime/eager/eager_operation.cc @@ -378,6 +378,12 @@ Status EagerOperation::InferInputListAttrs(int num_inputs) { } else if (!input_def.type_attr().empty() && !input_def.number_attr().empty()) { InferSingleTypeInputListAttrs(input_def, inputs_[start]->dtype, num_inputs); + } else if (!input_def.number_attr().empty()) { + if (inference_attrs_.find(input_def.number_attr()) == + inference_attrs_.end()) { + MutableAttrs()->Set(input_def.number_attr(), num_inputs); + inference_attrs_.insert(input_def.number_attr()); + } } else { return errors::InvalidArgument("Invalid input list definition"); } diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device_test.cc b/tensorflow/core/common_runtime/eager/kernel_and_device_test.cc index 33e85b25fb4..169dbb5fe4b 100644 --- a/tensorflow/core/common_runtime/eager/kernel_and_device_test.cc +++ b/tensorflow/core/common_runtime/eager/kernel_and_device_test.cc @@ -69,8 +69,8 @@ class TestEnv { Device* cpu_device_; }; -void BM_CreateGraph(int iters) { - for (int i = 0; i < iters; ++i) { +void BM_CreateGraph(::testing::benchmark::State& state) { + for (auto s : state) { Scope root = Scope::NewRootScope(); auto C = ops::Const(root, {{1.0, 2.0}, {3.0, 4.0}}); auto M = ops::MatMul(root, C, C); @@ -79,8 +79,7 @@ void BM_CreateGraph(int iters) { } BENCHMARK(BM_CreateGraph); -void BM_RunGraph(int iters) { - tensorflow::testing::StopTiming(); +void BM_RunGraph(::testing::benchmark::State& state) { Scope root = Scope::NewRootScope(); auto C = ops::Const(root, {{1.0, 2.0}, {3.0, 4.0}}); auto M = ops::MatMul(root, C, C); @@ -89,28 +88,24 @@ void BM_RunGraph(int iters) { opts.config.set_intra_op_parallelism_threads(1); ClientSession sess(root, opts); std::vector<Tensor> outputs; - tensorflow::testing::StartTiming(); - for (int i = 0; i < iters; ++i) { + for (auto s : state) { outputs.clear(); TF_CHECK_OK(sess.Run({M}, &outputs)); } } BENCHMARK(BM_RunGraph); -void BM_CreateAndDestroySession(int iters) { - tensorflow::testing::StopTiming(); +void BM_CreateAndDestroySession(::testing::benchmark::State& state) { Scope root = Scope::NewRootScope(); auto C = ops::Const(root, {{1.0, 2.0}, {3.0, 4.0}}); auto M = ops::MatMul(root, C, C); - tensorflow::testing::StartTiming(); - for (int i = 0; i < iters; ++i) { + for (auto s : state) { ClientSession sess(root); } } BENCHMARK(BM_CreateAndDestroySession); -void BM_KernelAndDeviceInit(int iters) { - tensorflow::testing::StopTiming(); +void BM_KernelAndDeviceInit(::testing::benchmark::State& state) { NodeDef ndef(AttrBuilder("MatMul") .Set("T", DT_FLOAT) .Set("transpose_a", false) @@ -120,15 +115,13 @@ void BM_KernelAndDeviceInit(int iters) { TestEnv env; KernelAndDeviceOp k(nullptr, false, env.function_library_runtime(), nullptr, nullptr, env.cpu_device()); - tensorflow::testing::StartTiming(); - for (int i = 0; i < iters; ++i) { + for (auto s : state) { TF_CHECK_OK(k.Init({}, ndef, nullptr)); } } BENCHMARK(BM_KernelAndDeviceInit); -void BM_KernelAndDeviceRun(int iters) { - tensorflow::testing::StopTiming(); +void BM_KernelAndDeviceRun(::testing::benchmark::State& state) { Tensor t(Input({{1.0f, 2.0f}, {3.0f, 4.0f}}).tensor()); gtl::InlinedVector<TensorValue, 4> inputs; inputs.push_back(TensorValue(&t)); @@ -145,8 +138,7 @@ void BM_KernelAndDeviceRun(int iters) { nullptr, env.cpu_device()); TF_CHECK_OK(k.Init({}, ndef, nullptr)); const EagerKernelArgs args(std::move(inputs)); - tensorflow::testing::StartTiming(); - for (int i = 0; i < iters; ++i) { + for (auto s : state) { TF_CHECK_OK(k.Run(nullptr, args, &outputs, nullptr, absl::nullopt)); } } diff --git a/tensorflow/core/common_runtime/executor_test.cc b/tensorflow/core/common_runtime/executor_test.cc index dd65b5dce1d..2d483451d8f 100644 --- a/tensorflow/core/common_runtime/executor_test.cc +++ b/tensorflow/core/common_runtime/executor_test.cc @@ -433,11 +433,10 @@ TEST_F(ExecutorTest, NoInputTensors) { // Create a graph that is 'depth' deep. At each level, fan-in and fan-out a // maximum of 'width' nodes. All nodes are no-ops and all dependencies are // control dependencies. -static void BM_executor(int iters, int width, int depth) { - testing::StopTiming(); -#ifdef PLATFORM_GOOGLE - BenchmarkUseRealTime(); -#endif // PLATFORM_GOOGLE +static void BM_executor(::testing::benchmark::State& state) { + const int width = state.range(0); + const int depth = state.range(1); + Graph* g = new Graph(OpRegistry::Global()); random::PhiloxRandom philox(1729, 17); random::SimplePhilox rand(&philox); @@ -466,30 +465,29 @@ static void BM_executor(int iters, int width, int depth) { ++cur; } } -#ifdef PLATFORM_GOOGLE - SetBenchmarkLabel(strings::StrCat("Nodes = ", cur)); - SetBenchmarkItemsProcessed(cur * static_cast<int64>(iters)); -#endif // PLATFORM_GOOGLE + FixupSourceAndSinkEdges(g); - testing::StartTiming(); - test::Benchmark("cpu", g).Run(iters); + test::Benchmark("cpu", g, /*old_benchmark_api=*/false).Run(state); + + state.SetLabel(strings::StrCat("Nodes = ", cur)); + state.SetItemsProcessed(cur * static_cast<int64>(state.iterations())); } // Tall skinny graphs -BENCHMARK(BM_executor)->ArgPair(16, 1024); -BENCHMARK(BM_executor)->ArgPair(32, 8192); +BENCHMARK(BM_executor)->UseRealTime()->ArgPair(16, 1024); +BENCHMARK(BM_executor)->UseRealTime()->ArgPair(32, 8192); // Short fat graphs -BENCHMARK(BM_executor)->ArgPair(1024, 16); -BENCHMARK(BM_executor)->ArgPair(8192, 32); +BENCHMARK(BM_executor)->UseRealTime()->ArgPair(1024, 16); +BENCHMARK(BM_executor)->UseRealTime()->ArgPair(8192, 32); // Tall fat graph -BENCHMARK(BM_executor)->ArgPair(1024, 1024); +BENCHMARK(BM_executor)->UseRealTime()->ArgPair(1024, 1024); + +static void BM_const_identity(::testing::benchmark::State& state) { + const int width = state.range(0); + const int outputs_per_const = state.range(1); -static void BM_const_identity(int iters, int width, int outputs_per_const) { -#ifdef PLATFORM_GOOGL - BenchmarkUseRealTime(); -#endif // PLATFORM_GOOGLE Graph* g = new Graph(OpRegistry::Global()); for (int i = 0; i < width; ++i) { Tensor i_t(i); @@ -499,23 +497,21 @@ static void BM_const_identity(int iters, int width, int outputs_per_const) { } } FixupSourceAndSinkEdges(g); -#ifdef PLATFORM_GOOGLE - SetBenchmarkLabel( - strings::StrCat("Nodes = ", (1 + outputs_per_const) * width)); - SetBenchmarkItemsProcessed((1 + outputs_per_const) * width * - static_cast<int64>(iters)); -#endif // PLATFORM_GOOGLE - test::Benchmark("cpu", g).Run(iters); + test::Benchmark("cpu", g, /*old_benchmark_api=*/false).Run(state); + state.SetLabel(strings::StrCat("Nodes = ", (1 + outputs_per_const) * width)); + state.SetItemsProcessed((1 + outputs_per_const) * width * + static_cast<int64>(state.iterations())); } // Graph with actual op execution. -BENCHMARK(BM_const_identity)->ArgPair(1, 1); -BENCHMARK(BM_const_identity)->ArgPair(1, 100); -BENCHMARK(BM_const_identity)->ArgPair(100, 1); -BENCHMARK(BM_const_identity)->ArgPair(100, 100); +BENCHMARK(BM_const_identity) + ->UseRealTime() + ->ArgPair(1, 1) + ->ArgPair(1, 100) + ->ArgPair(100, 1) + ->ArgPair(100, 100); -static void BM_FeedInputFetchOutput(int iters) { - testing::StopTiming(); +static void BM_FeedInputFetchOutput(::testing::benchmark::State& state) { Graph* g = new Graph(OpRegistry::Global()); // z = x + y: x and y are provided as benchmark inputs. z is the // output of the benchmark. Conceptually, the caller is ALICE, the @@ -531,13 +527,10 @@ static void BM_FeedInputFetchOutput(int iters) { Tensor val(DT_FLOAT, TensorShape({})); val.scalar<float>()() = 3.14; -#ifdef PLATFORM_GOOGLE - SetBenchmarkItemsProcessed(static_cast<int64>(iters)); -#endif // PLATFORM_GOOGLE FixupSourceAndSinkEdges(g); - testing::StartTiming(); - test::Benchmark("cpu", g).RunWithRendezvousArgs({{x_key, val}, {y_key, val}}, - {z_key}, iters); + test::Benchmark("cpu", g, /*old_benchmark_api=*/false) + .RunWithRendezvousArgs({{x_key, val}, {y_key, val}}, {z_key}, state); + state.SetItemsProcessed(static_cast<int64>(state.iterations())); } BENCHMARK(BM_FeedInputFetchOutput); @@ -549,9 +542,8 @@ BENCHMARK(BM_FeedInputFetchOutput); // // ...using the functional `WhileOp` (if `lower` is false) or the // `Switch`/`Merge`-style of control flow (if `lower` is true). -static void BM_WhileLoopHelper(int iters, int loop_iters, int loop_vars, - bool lower) { - testing::StopTiming(); +static void BM_WhileLoopHelper(::testing::benchmark::State& state, + int loop_iters, int loop_vars, bool lower) { std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); // Add test functions for cond and body. @@ -661,12 +653,15 @@ static void BM_WhileLoopHelper(int iters, int loop_iters, int loop_vars, } FixupSourceAndSinkEdges(graph.get()); - testing::StartTiming(); - test::Benchmark("cpu", graph.release()).Run(iters); + test::Benchmark("cpu", graph.release(), /*old_benchmark_api=*/false) + .Run(state); } -static void BM_LoweredWhileLoop(int iters, int loop_iters, int loop_vars) { - BM_WhileLoopHelper(iters, loop_iters, loop_vars, /* lower= */ true); +static void BM_LoweredWhileLoop(::testing::benchmark::State& state) { + const int loop_iters = state.range(0); + const int loop_vars = state.range(1); + + BM_WhileLoopHelper(state, loop_iters, loop_vars, /* lower= */ true); } BENCHMARK(BM_LoweredWhileLoop) ->ArgPair(0, 1) @@ -680,8 +675,11 @@ BENCHMARK(BM_LoweredWhileLoop) ->ArgPair(100, 100) ->ArgPair(1000, 100); -static void BM_FunctionalWhileLoop(int iters, int loop_iters, int loop_vars) { - BM_WhileLoopHelper(iters, loop_iters, loop_vars, /* lower= */ false); +static void BM_FunctionalWhileLoop(::testing::benchmark::State& state) { + const int loop_iters = state.range(0); + const int loop_vars = state.range(1); + + BM_WhileLoopHelper(state, loop_iters, loop_vars, /* lower= */ false); } BENCHMARK(BM_FunctionalWhileLoop) ->ArgPair(0, 1) diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc b/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc index 07bb8e3eeea..d50b60cf899 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc @@ -931,7 +931,6 @@ TEST(SessionTest, InvalidOpInputName) { attr { key: 'T' value { type: DT_FLOAT } } attr { key: 'transpose_a' value { b: false } } attr { key: 'transpose_b' value { b: false } } - attr { key: '_kernel' value { s: 'eigen' } } } )", "Illegal op input name"); @@ -950,7 +949,6 @@ TEST(SessionTest, InvalidOpInputName) { attr { key: 'T' value { type: DT_FLOAT } } attr { key: 'transpose_a' value { b: false } } attr { key: 'transpose_b' value { b: false } } - attr { key: '_kernel' value { s: 'eigen' } } } )", "Illegal op input name"); @@ -969,7 +967,6 @@ TEST(SessionTest, InvalidOpInputName) { attr { key: 'T' value { type: DT_FLOAT } } attr { key: 'transpose_a' value { b: false } } attr { key: 'transpose_b' value { b: false } } - attr { key: '_kernel' value { s: 'eigen' } } } )", "Illegal op input name"); @@ -988,7 +985,6 @@ TEST(SessionTest, InvalidOpInputName) { attr { key: 'T' value { type: DT_FLOAT } } attr { key: 'transpose_a' value { b: false } } attr { key: 'transpose_b' value { b: false } } - attr { key: '_kernel' value { s: 'eigen' } } } )", "Illegal op input name"); @@ -1026,7 +1022,6 @@ TEST(SessionTest, ExtendValidation) { attr { key: 'T' value { type: DT_FLOAT } } attr { key: 'transpose_a' value { b: false } } attr { key: 'transpose_b' value { b: false } } - attr { key: '_kernel' value { s: 'eigen' } } } )", &extension); @@ -1043,7 +1038,6 @@ TEST(SessionTest, ExtendValidation) { attr { key: 'T' value { type: DT_FLOAT } } attr { key: 'transpose_a' value { b: false } } attr { key: 'transpose_b' value { b: false } } - attr { key: '_kernel' value { s: 'eigen' } } } )", &extension); @@ -1057,7 +1051,6 @@ TEST(SessionTest, ExtendValidation) { attr { key: 'T' value { type: DT_FLOAT } } attr { key: 'transpose_a' value { b: false } } attr { key: 'transpose_b' value { b: false } } - attr { key: '_kernel' value { s: 'eigen' } } } )", &extension); diff --git a/tensorflow/core/framework/allocator_test.cc b/tensorflow/core/framework/allocator_test.cc index 0ac3da1a19c..f1001e7ab24 100644 --- a/tensorflow/core/framework/allocator_test.cc +++ b/tensorflow/core/framework/allocator_test.cc @@ -221,14 +221,16 @@ TEST(CustomAllocatorAttributes, TestSetterAndGetter) { EXPECT_FALSE(HasDeviceAllocatorAttribute(AllocatorAttributes())); } -static void BM_Allocation(int iters, int arg) { +static void BM_Allocation(::testing::benchmark::State& state) { + const int arg = state.range(0); + Allocator* a = cpu_allocator(); // Exercise a few different allocation sizes std::vector<int> sizes = {256, 4096, 16384, 524288, 512, 1048576}; int size_index = 0; if (arg) EnableCPUAllocatorStats(); - while (--iters > 0) { + for (auto s : state) { int bytes = sizes[size_index++ % sizes.size()]; void* p = a->AllocateRaw(1, bytes); a->DeallocateRaw(p); diff --git a/tensorflow/core/framework/bfloat16_test.cc b/tensorflow/core/framework/bfloat16_test.cc index 0de298cfce8..bed21ddcc99 100644 --- a/tensorflow/core/framework/bfloat16_test.cc +++ b/tensorflow/core/framework/bfloat16_test.cc @@ -39,60 +39,60 @@ TEST(Bfloat16Test, Conversion) { } } -static void BM_FloatToBFloat16(int iters) { - testing::StopTiming(); +void BM_FloatToBFloat16(::testing::benchmark::State& state) { static const int N = 32 << 20; - const int64 tot = static_cast<int64>(iters) * N; - testing::ItemsProcessed(tot); - testing::BytesProcessed(tot * (sizeof(float) + sizeof(bfloat16))); float* inp = new float[N]; bfloat16* out = new bfloat16[N]; - testing::StartTiming(); - while (iters--) { + for (auto s : state) { FloatToBFloat16(inp, out, N); } + + const int64 tot = static_cast<int64>(state.iterations()) * N; + state.SetItemsProcessed(tot); + state.SetBytesProcessed(tot * (sizeof(float) + sizeof(bfloat16))); + delete[] inp; delete[] out; } BENCHMARK(BM_FloatToBFloat16); -static void BM_RoundFloatToBFloat16(int iters) { - testing::StopTiming(); +void BM_RoundFloatToBFloat16(::testing::benchmark::State& state) { static const int N = 32 << 20; - const int64 tot = static_cast<int64>(iters) * N; - testing::ItemsProcessed(tot); - testing::BytesProcessed(tot * (sizeof(float) + sizeof(bfloat16))); float* inp = new float[N]; bfloat16* out = new bfloat16[N]; - testing::StartTiming(); - while (iters--) { + for (auto s : state) { RoundFloatToBFloat16(inp, out, N); tensorflow::testing::DoNotOptimize(inp); tensorflow::testing::DoNotOptimize(out); } + + const int64 tot = static_cast<int64>(state.iterations()) * N; + state.SetItemsProcessed(tot); + state.SetBytesProcessed(tot * (sizeof(float) + sizeof(bfloat16))); + delete[] inp; delete[] out; } BENCHMARK(BM_RoundFloatToBFloat16); -static void BM_BFloat16ToFloat(int iters) { - testing::StopTiming(); +void BM_BFloat16ToFloat(::testing::benchmark::State& state) { static const int N = 32 << 20; - const int64 tot = static_cast<int64>(iters) * N; - testing::ItemsProcessed(tot); - testing::BytesProcessed(tot * (sizeof(float) + sizeof(bfloat16))); bfloat16* inp = new bfloat16[N]; float* out = new float[N]; - testing::StartTiming(); - while (iters--) { + for (auto s : state) { BFloat16ToFloat(inp, out, N); } + + const int64 tot = static_cast<int64>(state.iterations()) * N; + state.SetItemsProcessed(tot); + state.SetBytesProcessed(tot * (sizeof(float) + sizeof(bfloat16))); + delete[] inp; delete[] out; } diff --git a/tensorflow/core/framework/function_test.cc b/tensorflow/core/framework/function_test.cc index 38ab8be291d..cafe343ef3a 100644 --- a/tensorflow/core/framework/function_test.cc +++ b/tensorflow/core/framework/function_test.cc @@ -406,7 +406,7 @@ XTimesTwo[T:{float, double, int32, int64}](x:T) -> (y:T) { TEST(TFunc, WXPlusB) { auto expect = R"P( WXPlusB[T:{float, double}](w:T, x:T, b:T) -> (y:T) { - mm = MatMul[T=$T, _kernel="eigen", transpose_a=false, transpose_b=false](w, x) + mm = MatMul[T=$T, transpose_a=false, transpose_b=false](w, x) y = Add[T=$T](mm:product:0, b) return y = y:z:0 } diff --git a/tensorflow/core/framework/function_testlib.cc b/tensorflow/core/framework/function_testlib.cc index 5919ed7831b..dcda948c083 100644 --- a/tensorflow/core/framework/function_testlib.cc +++ b/tensorflow/core/framework/function_testlib.cc @@ -346,10 +346,7 @@ FunctionDef WXPlusB() { {{{"mm"}, "MatMul", {"w", "x"}, - {{"T", "$T"}, - {"transpose_a", false}, - {"transpose_b", false}, - {"_kernel", "eigen"}}}, + {{"T", "$T"}, {"transpose_a", false}, {"transpose_b", false}}}, {{"y"}, "Add", {"mm", "b"}, {{"T", "$T"}}}}); } diff --git a/tensorflow/core/framework/op_kernel_test.cc b/tensorflow/core/framework/op_kernel_test.cc index 9b5648927d1..f11f85df8de 100644 --- a/tensorflow/core/framework/op_kernel_test.cc +++ b/tensorflow/core/framework/op_kernel_test.cc @@ -1002,9 +1002,9 @@ TEST_F(LabelTest, Duplicate) { error::INVALID_ARGUMENT); } -void BM_InputRangeHelper(int iters, const NodeDef& node_def, - const char* input_name, int expected_start, - int expected_stop) { +void BM_InputRangeHelper(::testing::benchmark::State& state, + const NodeDef& node_def, const char* input_name, + int expected_start, int expected_stop) { Status status; auto device = absl::make_unique<DummyDevice>(Env::Default()); @@ -1013,24 +1013,20 @@ void BM_InputRangeHelper(int iters, const NodeDef& node_def, TF_GRAPH_DEF_VERSION, &status)); TF_CHECK_OK(status); - testing::StartTiming(); - for (int i = 0; i < iters; ++i) { + for (auto s : state) { int start; int stop; TF_CHECK_OK(op->InputRange(input_name, &start, &stop)); EXPECT_EQ(expected_start, start); EXPECT_EQ(expected_stop, stop); } - testing::StopTiming(); } REGISTER_KERNEL_BUILDER(Name("ConcatV2").Device(DEVICE_CPU), DummyKernel); REGISTER_KERNEL_BUILDER(Name("Select").Device(DEVICE_CPU), DummyKernel); REGISTER_KERNEL_BUILDER(Name("MatMul").Device(DEVICE_CPU), DummyKernel); -void BM_ConcatInputRange(int iters) { - testing::StopTiming(); - +void BM_ConcatInputRange(::testing::benchmark::State& state) { // Create a ConcatV2 NodeDef with 4 inputs (plus the axis). NodeDef node_def; node_def.set_name("concat-op"); @@ -1048,12 +1044,10 @@ void BM_ConcatInputRange(int iters) { node_def.add_input(strings::StrCat("a:", i)); } - BM_InputRangeHelper(iters, node_def, "values", 0, 4); + BM_InputRangeHelper(state, node_def, "values", 0, 4); } -void BM_SelectInputRange(int iters) { - testing::StopTiming(); - +void BM_SelectInputRange(::testing::benchmark::State& state) { // Create a Select NodeDef with 3 inputs. NodeDef node_def; node_def.set_name("select-op"); @@ -1065,11 +1059,11 @@ void BM_SelectInputRange(int iters) { node_def.add_input(strings::StrCat("a:", i)); } - BM_InputRangeHelper(iters, node_def, "condition", 0, 1); + BM_InputRangeHelper(state, node_def, "condition", 0, 1); } -void BM_TraceString(const int iters, const int verbose) { - testing::StopTiming(); +void BM_TraceString(::testing::benchmark::State& state) { + const int verbose = state.range(0); // Create a MatMul NodeDef with 2 inputs. NodeDef node_def; @@ -1103,11 +1097,9 @@ void BM_TraceString(const int iters, const int verbose) { params.inputs = &inputs; auto ctx = absl::make_unique<OpKernelContext>(¶ms); - testing::StartTiming(); - for (int i = 0; i < iters; ++i) { + for (auto s : state) { auto trace = op->TraceString(*ctx, verbose); } - testing::StopTiming(); } BENCHMARK(BM_ConcatInputRange); diff --git a/tensorflow/core/framework/rendezvous_test.cc b/tensorflow/core/framework/rendezvous_test.cc index d02d090f32b..5b8b9ff79e8 100644 --- a/tensorflow/core/framework/rendezvous_test.cc +++ b/tensorflow/core/framework/rendezvous_test.cc @@ -434,83 +434,89 @@ TEST_F(LocalRendezvousTest, TransferDummyDeviceContext) { args1.device_context->Unref(); } -void BM_SendRecv(int iters) { +void BM_SendRecv(::testing::benchmark::State& state) { Rendezvous* rendez = NewLocalRendezvous(); Tensor orig = V("val"); Tensor val(DT_STRING, TensorShape({})); bool is_dead = false; Rendezvous::Args args; - if (iters > 0) { - while (iters--) { - TF_CHECK_OK(rendez->Send(KeyFoo(), args, orig, is_dead)); - TF_CHECK_OK(rendez->Recv(KeyFoo(), args, &val, &is_dead)); - } - CHECK_EQ(V(val), V(orig)); + + for (auto s : state) { + TF_CHECK_OK(rendez->Send(KeyFoo(), args, orig, is_dead)); + TF_CHECK_OK(rendez->Recv(KeyFoo(), args, &val, &is_dead)); } + CHECK_EQ(V(val), V(orig)); + rendez->Unref(); } BENCHMARK(BM_SendRecv); -void BM_RecvSend(int iters) { +void BM_RecvSend(::testing::benchmark::State& state) { Rendezvous* rendez = NewLocalRendezvous(); Tensor orig = V("val"); Tensor val(DT_STRING, TensorShape({})); bool is_dead = false; Rendezvous::Args args; - if (iters > 0) { - while (iters--) { - bool received = false; - rendez->RecvAsync( - KeyFoo(), args, - [&val, &received](const Status& s, const Rendezvous::Args& send_args, - const Rendezvous::Args& recv_args, - const Tensor& tensor, bool is_dead) { - val = tensor; - received = true; - }); - TF_CHECK_OK(rendez->Send(KeyFoo(), args, orig, is_dead)); - CHECK(received); - } - CHECK_EQ(V(val), V(orig)); + + for (auto s : state) { + bool received = false; + rendez->RecvAsync( + KeyFoo(), args, + [&val, &received](const Status& /*s*/, + const Rendezvous::Args& /*send_args*/, + const Rendezvous::Args& /*recv_args*/, + const Tensor& tensor, bool /*is_dead*/) { + val = tensor; + received = true; + }); + TF_CHECK_OK(rendez->Send(KeyFoo(), args, orig, is_dead)); + CHECK(received); } + CHECK_EQ(V(val), V(orig)); + rendez->Unref(); } BENCHMARK(BM_RecvSend); -void BM_PingPong(int iters) { - CHECK_GT(iters, 0); +void BM_PingPong(::testing::benchmark::State& state) { + const int messages_count = state.range(0); auto* cm = new CancellationManager(); thread::ThreadPool* pool = new thread::ThreadPool(Env::Default(), "test", 1); - // The main thread sends "foo" for iters times and receives "bar" - // for iters times. The other thread sends "bar" for iters times - // and receives "foo" for iters times. - Rendezvous* rendez = NewLocalRendezvous(); - pool->Schedule([rendez, iters]() { - Tensor bar = V("bar"); - Tensor foo(DT_STRING, TensorShape({})); + // Benchmark loop + // In each iteration: + // The main thread sends "foo" for messages_count times and receives "bar" + // for messages_count times. The other thread sends "bar" for + // messages_count times and receives "foo" for messages_count times. + for (auto s : state) { + Rendezvous* rendez = NewLocalRendezvous(); + pool->Schedule([rendez, messages_count]() { + Tensor bar = V("bar"); + Tensor foo(DT_STRING, TensorShape({})); + bool is_dead = false; + Rendezvous::Args args; + for (int i = 0; i < messages_count; ++i) { + TF_CHECK_OK(rendez->Recv(KeyFoo(), args, &foo, &is_dead)); + TF_CHECK_OK(rendez->Send(KeyBar(), args, bar, is_dead)); + } + CHECK_EQ("foo", V(foo)); + }); + Tensor foo = V("foo"); + Tensor bar(DT_STRING, TensorShape({})); bool is_dead = false; Rendezvous::Args args; - for (int i = 0; i < iters; ++i) { - TF_CHECK_OK(rendez->Recv(KeyFoo(), args, &foo, &is_dead)); - TF_CHECK_OK(rendez->Send(KeyBar(), args, bar, is_dead)); + args.cancellation_manager = cm; + for (int i = 0; i < messages_count; ++i) { + TF_CHECK_OK(rendez->Send(KeyFoo(), args, foo, is_dead)); + TF_CHECK_OK(rendez->Recv(KeyBar(), args, &bar, &is_dead)); } - CHECK_EQ("foo", V(foo)); - }); - Tensor foo = V("foo"); - Tensor bar(DT_STRING, TensorShape({})); - bool is_dead = false; - Rendezvous::Args args; - args.cancellation_manager = cm; - for (int i = 0; i < iters; ++i) { - TF_CHECK_OK(rendez->Send(KeyFoo(), args, foo, is_dead)); - TF_CHECK_OK(rendez->Recv(KeyBar(), args, &bar, &is_dead)); + CHECK_EQ("bar", V(bar)); } - CHECK_EQ("bar", V(bar)); + state.SetItemsProcessed(messages_count * state.iterations()); delete pool; delete cm; } -BENCHMARK(BM_PingPong); +BENCHMARK(BM_PingPong)->Arg(100)->Arg(200)->Arg(300); } // namespace } // namespace tensorflow diff --git a/tensorflow/core/framework/tensor_shape_test.cc b/tensorflow/core/framework/tensor_shape_test.cc index ea93009ef40..076a5a53d76 100644 --- a/tensorflow/core/framework/tensor_shape_test.cc +++ b/tensorflow/core/framework/tensor_shape_test.cc @@ -684,19 +684,24 @@ static std::vector<int64> MakeSizes(int arg) { return sizes; } -static void BM_TensorShape_Init(int iters, int arg) { +void BM_TensorShape_Init(::testing::benchmark::State& state) { + const int arg = state.range(0); + auto sizes = MakeSizes(arg); - while (--iters > 0) { + for (auto s : state) { TensorShape shape(sizes); tensorflow::testing::DoNotOptimize(shape.num_elements()); } } BENCHMARK(BM_TensorShape_Init)->Arg(0)->Arg(1)->Arg(2)->Arg(3)->Arg(4); -static void BM_TensorShape_Assign(int iters, int arg) { - TensorShape s(MakeSizes(arg)); - while (--iters > 0) { - TensorShape s2 = s; +void BM_TensorShape_Assign(::testing::benchmark::State& state) { + const int arg = state.range(0); + + TensorShape shape(MakeSizes(arg)); + for (auto s : state) { + const TensorShape s2 = shape; + tensorflow::testing::DoNotOptimize(s2); } } BENCHMARK(BM_TensorShape_Assign)->Arg(0)->Arg(1)->Arg(2)->Arg(3)->Arg(4); diff --git a/tensorflow/core/framework/tensor_test.cc b/tensorflow/core/framework/tensor_test.cc index fdfeef9e84a..daaba4bc3a5 100644 --- a/tensorflow/core/framework/tensor_test.cc +++ b/tensorflow/core/framework/tensor_test.cc @@ -1468,19 +1468,19 @@ TEST(SummarizeValue, STRING_PRINT_V2) { x.SummarizeValue(16, true)); } -void BM_CreateAndDestroy(int iters) { +void BM_CreateAndDestroy(::testing::benchmark::State& state) { TensorShape shape({10, 20}); - while (--iters) { + for (auto s : state) { Tensor t(DT_FLOAT, shape); } } BENCHMARK(BM_CreateAndDestroy); -void BM_Assign(int iters) { +void BM_Assign(::testing::benchmark::State& state) { Tensor a(DT_FLOAT, TensorShape({10, 20})); Tensor b(DT_FLOAT, TensorShape({10, 20})); bool a_to_b = true; - while (--iters) { + for (auto s : state) { if (a_to_b) { b = a; } else { @@ -1498,20 +1498,20 @@ TEST(Tensor, EmptyTensorData) { } // Benchmark create and destroy a tensor, with an allocated buffer. -void BM_CreateAndDestroyWithBuf(int iters) { +void BM_CreateAndDestroyWithBuf(::testing::benchmark::State& state) { TensorShape shape({10, 20}); Allocator* allocator = cpu_allocator(); - while (--iters) { + for (auto s : state) { Tensor a(allocator, DT_FLOAT, shape); } } BENCHMARK(BM_CreateAndDestroyWithBuf); // Benchmark create+copy a tensor, with an allocated buffer. -void BM_CreateAndCopyCtrWithBuf(int iters) { +void BM_CreateAndCopyCtrWithBuf(::testing::benchmark::State& state) { TensorShape shape({10, 20}); Allocator* allocator = cpu_allocator(); - while (--iters) { + for (auto s : state) { Tensor a(allocator, DT_FLOAT, shape); Tensor b(a); } @@ -1519,10 +1519,10 @@ void BM_CreateAndCopyCtrWithBuf(int iters) { BENCHMARK(BM_CreateAndCopyCtrWithBuf); // Benchmark create+move a tensor, with an allocated buffer. -void BM_CreateAndMoveCtrWithBuf(int iters) { +void BM_CreateAndMoveCtrWithBuf(::testing::benchmark::State& state) { TensorShape shape({10, 20}); Allocator* allocator = cpu_allocator(); - while (--iters) { + for (auto s : state) { Tensor a(allocator, DT_FLOAT, shape); Tensor b(std::move(a)); } @@ -1531,10 +1531,11 @@ BENCHMARK(BM_CreateAndMoveCtrWithBuf); // Benchmark creating and destroy a host-scalar tensor, using the allocator // interface. -void BM_CreateAndDestroyHostScalarNonOptimized(int iters) { +void BM_CreateAndDestroyHostScalarNonOptimized( + ::testing::benchmark::State& state) { TensorShape shape({}); Allocator* allocator = cpu_allocator(); - while (--iters) { + for (auto s : state) { Tensor a(allocator, DT_FLOAT, shape); a.scalar<float>()() = 37.0; } @@ -1543,32 +1544,33 @@ BENCHMARK(BM_CreateAndDestroyHostScalarNonOptimized); // Benchmark creating and destroy a host-scalar tensor, using the specialized // constructor. -void BM_CreateAndDestroyHostScalarOptimized(int iters) { - while (--iters) { +void BM_CreateAndDestroyHostScalarOptimized( + ::testing::benchmark::State& state) { + for (auto s : state) { Tensor a(37.0); } } BENCHMARK(BM_CreateAndDestroyHostScalarOptimized); -static void BM_FromProto(int iters, int size) { - testing::StopTiming(); +void BM_FromProto(::testing::benchmark::State& state) { + const int size = state.range(0); + TensorShape shape({size}); Allocator* allocator = cpu_allocator(); Tensor a(allocator, DT_FLOAT, shape); std::fill_n(a.flat<float>().data(), size, 42.0); TensorProto p; a.AsProtoField(&p); - testing::StartTiming(); - while (--iters) { + for (auto s : state) { Tensor b; ASSERT_TRUE(b.FromProto(p)); } - testing::StopTiming(); } BENCHMARK(BM_FromProto)->Range(1, 1 << 20); -static void BM_FromProtoCompressed(int iters, int size) { - testing::StopTiming(); +void BM_FromProtoCompressed(::testing::benchmark::State& state) { + const int size = state.range(0); + TensorShape shape({size}); Allocator* allocator = cpu_allocator(); Tensor a(allocator, DT_FLOAT, shape); @@ -1576,17 +1578,16 @@ static void BM_FromProtoCompressed(int iters, int size) { TensorProto p; a.AsProtoField(&p); tensor::CompressTensorProtoInPlace(&p); - testing::StartTiming(); - while (--iters) { + for (auto s : state) { Tensor b; ASSERT_TRUE(b.FromProto(p)); } - testing::StopTiming(); } BENCHMARK(BM_FromProtoCompressed)->Range(1, 1 << 20); -static void BM_FromProtoCompressedZero(int iters, int size) { - testing::StopTiming(); +void BM_FromProtoCompressedZero(::testing::benchmark::State& state) { + const int size = state.range(0); + TensorShape shape({size}); Allocator* allocator = cpu_allocator(); Tensor a(allocator, DT_FLOAT, shape); @@ -1595,12 +1596,10 @@ static void BM_FromProtoCompressedZero(int iters, int size) { TensorProto p; a.AsProtoField(&p); tensor::CompressTensorProtoInPlace(&p); - testing::StartTiming(); - while (--iters) { + for (auto s : state) { Tensor b; ASSERT_TRUE(b.FromProto(p)); } - testing::StopTiming(); } BENCHMARK(BM_FromProtoCompressedZero)->Range(1, 1 << 20); diff --git a/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc b/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc index ab73616eb99..848d833d340 100644 --- a/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc +++ b/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc @@ -767,16 +767,49 @@ Status AvgPoolGradTransposer::TransposeNode(TransposeContext* context, return context->graph_view->GetMutationBuilder()->Apply(); } -Status BiasAddGradTransposer::TransposeNode(TransposeContext* context, - utils::MutableNodeView* node) { - DCHECK(IsBiasAddGrad(*node->node())); - if (!ShouldProcess(*context, *node) || !IsFaninPortRankN(*node, 0, 4)) { +Status BiasAddTransposer::TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) { + DCHECK(IsBiasAdd(*node->node())); + const int rank = GetFanoutPortRank(*node, 0); + if (rank != 4 && rank != 5) { + return Status::OK(); + } + if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, rank)) { return Status::OK(); } VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName() << "' with op '" << node->GetOp() << "' from data format '" << context->src_format << "' to '" << context->dst_format << "'"; + // BiasAdd itself only needs NCHW/NHWC to determine whether C dim is the + // second or the last dim. Therefore, we use the original 4D data format in + // the context to update the node. For the input/output tensor, the + // corresponding 4D or 5D data format is needed. TF_RETURN_IF_ERROR(UpdateNode(context, node)); + ScopedDataFormatUpgrader data_format_upgrader(context, rank); + TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose)); + TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose)); + return context->graph_view->GetMutationBuilder()->Apply(); +} + +Status BiasAddGradTransposer::TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) { + DCHECK(IsBiasAddGrad(*node->node())); + const int rank = GetFaninPortRank(*node, 0); + if (rank != 4 && rank != 5) { + return Status::OK(); + } + if (!ShouldProcess(*context, *node)) { + return Status::OK(); + } + VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName() + << "' with op '" << node->GetOp() << "' from data format '" + << context->src_format << "' to '" << context->dst_format << "'"; + // BiasAddGrad itself only needs NCHW/NHWC to determine whether C dim is the + // second or the last dim. Therefore, we use the original 4D data format in + // the context to update the node. For the input tensor, the corresponding 4D + // or 5D data format is needed. + TF_RETURN_IF_ERROR(UpdateNode(context, node)); + ScopedDataFormatUpgrader data_format_upgrader(context, rank); TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose)); // No need to update output shape, as it is always of shape 1-D with size the // feature dimension of `out_backprop`, regardless of whether NCHW or NHWC is @@ -839,7 +872,12 @@ Status Conv2DBackpropInputTransposer::TransposeNode( Status Conv3DTransposer::TransposeNode(TransposeContext* context, utils::MutableNodeView* node) { DCHECK(IsConv3D(*node->node())); - if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 5)) { + const int rank = GetFanoutPortRank(*node, 0); + if (rank != 5) { + return Status::OK(); + } + ScopedDataFormatUpgrader data_format_upgrader(context, rank); + if (!ShouldProcess(*context, *node)) { return Status::OK(); } VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName() @@ -854,7 +892,12 @@ Status Conv3DTransposer::TransposeNode(TransposeContext* context, Status Conv3DBackpropFilterTransposer::TransposeNode( TransposeContext* context, utils::MutableNodeView* node) { DCHECK(IsConv3DBackpropFilterV2(*node->node())); - if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 5)) { + const int rank = GetFanoutPortRank(*node, 0); + if (rank != 5) { + return Status::OK(); + } + ScopedDataFormatUpgrader data_format_upgrader(context, rank); + if (!ShouldProcess(*context, *node)) { return Status::OK(); } VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName() @@ -872,7 +915,12 @@ Status Conv3DBackpropFilterTransposer::TransposeNode( Status Conv3DBackpropInputTransposer::TransposeNode( TransposeContext* context, utils::MutableNodeView* node) { DCHECK(IsConv3DBackpropInputV2(*node->node())); - if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 5)) { + const int rank = GetFanoutPortRank(*node, 0); + if (rank != 5) { + return Status::OK(); + } + ScopedDataFormatUpgrader data_format_upgrader(context, rank); + if (!ShouldProcess(*context, *node)) { return Status::OK(); } VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName() @@ -1081,8 +1129,9 @@ bool LayoutAgnosticOpTransposer::IsAfterDstToSrcTransform( return false; } -std::vector<int> LayoutAgnosticOpTransposer::GetVariadic4DFaninPorts( - const TransposeContext& context, const utils::MutableNodeView& node) const { +std::vector<int> LayoutAgnosticOpTransposer::GetVariadicNDFaninPorts( + const TransposeContext& context, const utils::MutableNodeView& node, + int rank) const { std::vector<int> ports; const int num_regular_fanins = node.NumRegularFanins(); ports.reserve(num_regular_fanins); @@ -1090,7 +1139,7 @@ std::vector<int> LayoutAgnosticOpTransposer::GetVariadic4DFaninPorts( const auto& regular_fanin = node.GetRegularFanin(i); auto* regular_fanin_node = regular_fanin.node_view(); int regular_fanin_port = regular_fanin.index(); - if (IsFanoutPortRankN(*regular_fanin_node, regular_fanin_port, 4) && + if ((IsFanoutPortRankN(*regular_fanin_node, regular_fanin_port, rank)) && ((IsAfterDstToSrcTransform(context, *regular_fanin_node) && IsLayoutAgnosticOp(*regular_fanin_node->node())) || IsLayoutOptimizerAddedDstToSrcTranspose(context, @@ -1124,10 +1173,18 @@ Status DefaultLayoutAgnosticOpTransposer::TransposeNode( Status AddNTransposer::TransposeNode(TransposeContext* context, utils::MutableNodeView* node) { DCHECK(IsAddN(*node->node())); - if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4) || + const int rank = GetFanoutPortRank(*node, 0); + if (rank != 4 && rank != 5) { + return Status::OK(); + } + ScopedDataFormatUpgrader data_format_upgrader(context, rank); + if (!ShouldProcess(*context, *node) || !IsAfterDstToSrcTransform(*context, *node)) { return Status::OK(); } + VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName() + << "' with op '" << node->GetOp() << "' from data format '" + << context->src_format << "' to '" << context->dst_format << "'"; TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, GetDataFaninPorts(*node), node, kOpTranspose)); TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose)); @@ -1284,7 +1341,12 @@ Status BinaryOpTransposer::TransposeNode(TransposeContext* context, Status ConcatOpTransposer::TransposeNode(TransposeContext* context, utils::MutableNodeView* node) { DCHECK(IsConcat(*node->node())); - if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4) || + const int rank = GetFanoutPortRank(*node, 0); + if (rank != 4 && rank != 5) { + return Status::OK(); + } + ScopedDataFormatUpgrader data_format_upgrader(context, rank); + if (!ShouldProcess(*context, *node) || !IsAfterDstToSrcTransform(*context, *node)) { return Status::OK(); } @@ -1297,6 +1359,9 @@ Status ConcatOpTransposer::TransposeNode(TransposeContext* context, axis_node = n_attr->i(); } } + VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName() + << "' with op '" << node->GetOp() << "' from data format '" + << context->src_format << "' to '" << context->dst_format << "'"; TF_RETURN_IF_ERROR( UpdateFaninEdgesWithOp(context, {axis_node}, node, kOpDataFormatDimMap)); TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose)); @@ -1320,14 +1385,33 @@ Status FillOpTransposer::TransposeNode(TransposeContext* context, Status IdentityNTransposer::TransposeNode(TransposeContext* context, utils::MutableNodeView* node) { DCHECK(IsIdentityN(*node->node())); - const auto ports = GetVariadic4DFaninPorts(*context, *node); - if (!ShouldProcess(*context, *node) || ports.empty()) { + const auto ports_4d = GetVariadicNDFaninPorts(*context, *node, 4); + + // Temporarily upgrade the context to obtain the number of 5D fanin ports. + std::vector<int> ports_5d; + { + ScopedDataFormatUpgrader data_format_upgrader(context, 5); + ports_5d = GetVariadicNDFaninPorts(*context, *node, 5); + } + + if (!ShouldProcess(*context, *node)) { return Status::OK(); } - TF_RETURN_IF_ERROR( - UpdateFaninEdgesWithOp(context, ports, node, kOpTranspose)); - TF_RETURN_IF_ERROR( - UpdateFanoutEdgesWithOp(context, ports, node, kOpTranspose)); + + if (!ports_4d.empty()) { + TF_RETURN_IF_ERROR( + UpdateFaninEdgesWithOp(context, ports_4d, node, kOpTranspose)); + TF_RETURN_IF_ERROR( + UpdateFanoutEdgesWithOp(context, ports_4d, node, kOpTranspose)); + } + + if (!ports_5d.empty()) { + ScopedDataFormatUpgrader data_format_upgrader(context, 5); + TF_RETURN_IF_ERROR( + UpdateFaninEdgesWithOp(context, ports_5d, node, kOpTranspose)); + TF_RETURN_IF_ERROR( + UpdateFanoutEdgesWithOp(context, ports_5d, node, kOpTranspose)); + } return context->graph_view->GetMutationBuilder()->Apply(); } @@ -1519,10 +1603,18 @@ Status SelectTransposer::TransposeNode(TransposeContext* context, Status ShapeTransposer::TransposeNode(TransposeContext* context, utils::MutableNodeView* node) { DCHECK(IsShape(*node->node())); - if (!ShouldProcess(*context, *node) || !IsFaninPortRankN(*node, 0, 4) || + const int rank = GetFaninPortRank(*node, 0); + if (rank != 4 && rank != 5) { + return Status::OK(); + } + ScopedDataFormatUpgrader data_format_upgrader(context, rank); + if (!ShouldProcess(*context, *node) || !IsAfterDstToSrcTransform(*context, *node)) { return Status::OK(); } + VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName() + << "' with op '" << node->GetOp() << "' from data format '" + << context->src_format << "' to '" << context->dst_format << "'"; TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose)); TF_RETURN_IF_ERROR( UpdateFanoutEdgesWithOp(context, {0}, node, kOpDataFormatVecPermute)); @@ -1532,10 +1624,20 @@ Status ShapeTransposer::TransposeNode(TransposeContext* context, Status ShapeNTransposer::TransposeNode(TransposeContext* context, utils::MutableNodeView* node) { DCHECK(IsShapeN(*node->node())); - const auto ports = GetVariadic4DFaninPorts(*context, *node); + // ShapeN requires all input tensors to have the same dimensions. Therefore, + // we simply use the 0th fanin port. + const int rank = GetFaninPortRank(*node, 0); + if (rank != 4 && rank != 5) { + return Status::OK(); + } + ScopedDataFormatUpgrader data_format_upgrader(context, rank); + const auto ports = GetVariadicNDFaninPorts(*context, *node, rank); if (!ShouldProcess(*context, *node) || ports.empty()) { return Status::OK(); } + VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName() + << "' with op '" << node->GetOp() << "' from data format '" + << context->src_format << "' to '" << context->dst_format << "'"; TF_RETURN_IF_ERROR( UpdateFaninEdgesWithOp(context, ports, node, kOpTranspose)); TF_RETURN_IF_ERROR( @@ -1546,11 +1648,19 @@ Status ShapeNTransposer::TransposeNode(TransposeContext* context, Status SliceTransposer::TransposeNode(TransposeContext* context, utils::MutableNodeView* node) { DCHECK(IsSlice(*node->node())); - if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4) || - !IsFaninPortsDimsNIfConst(*node, {1, 2}, {4}) || + const int rank = GetFanoutPortRank(*node, 0); + if (rank != 4 && rank != 5) { + return Status::OK(); + } + ScopedDataFormatUpgrader data_format_upgrader(context, rank); + if (!ShouldProcess(*context, *node) || + !IsFaninPortsDimsNIfConst(*node, {1, 2}, {rank}) || !IsAfterDstToSrcTransform(*context, *node)) { return Status::OK(); } + VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName() + << "' with op '" << node->GetOp() << "' from data format '" + << context->src_format << "' to '" << context->dst_format << "'"; TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose)); TF_RETURN_IF_ERROR( UpdateFaninEdgesWithOp(context, {1, 2}, node, kOpDataFormatVecPermute)); @@ -1839,18 +1949,17 @@ string GetDeviceName(const VirtualPlacer* virtual_placer, const NodeDef& node) { bool IsDefaultLayoutSensitiveOp(const NodeDef& node) { static absl::flat_hash_set<string>* default_layout_sensitive_ops = new absl::flat_hash_set<std::string>( - {"AvgPool", "BiasAdd", "Conv2D", "DepthwiseConv2dNative", - "DepthToSpace", "FusedBatchNorm", "FusedBatchNormV2", - "FusedBatchNormV3", "FusedConv2DBiasActivation", "MaxPool", - "SpaceToDepth"}); + {"AvgPool", "Conv2D", "DepthwiseConv2dNative", "DepthToSpace", + "FusedBatchNorm", "FusedBatchNormV2", "FusedBatchNormV3", + "FusedConv2DBiasActivation", "MaxPool", "SpaceToDepth"}); return default_layout_sensitive_ops->find(node.op()) != default_layout_sensitive_ops->end(); } bool IsLayoutSensitiveOp(const NodeDef& node) { return IsDefaultLayoutSensitiveOp(node) || IsAvgPoolGrad(node) || - IsBiasAddGrad(node) || IsConv2DBackpropFilter(node) || - IsConv2DBackpropInput(node) || + IsBiasAdd(node) || IsBiasAddGrad(node) || + IsConv2DBackpropFilter(node) || IsConv2DBackpropInput(node) || IsDepthwiseConv2dNativeBackpropFilter(node) || IsDepthwiseConv2dNativeBackpropInput(node) || IsFusedBatchNormEx(node) || IsFusedBatchNormGrad(node) || diff --git a/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.h b/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.h index bfc67c0633d..1d8991c05dc 100644 --- a/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.h +++ b/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.h @@ -210,6 +210,14 @@ class DefaultLayoutSensitiveOpTransposer : public LayoutSensitiveOpTransposer { utils::MutableNodeView* node) override; }; +class BiasAddTransposer : public LayoutSensitiveOpTransposer { + public: + explicit BiasAddTransposer() : LayoutSensitiveOpTransposer() {} + + Status TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) override; +}; + class AvgPoolGradTransposer : public LayoutSensitiveOpTransposer { public: explicit AvgPoolGradTransposer() : LayoutSensitiveOpTransposer() {} @@ -319,9 +327,9 @@ class LayoutAgnosticOpTransposer : public Transposer { bool IsAfterDstToSrcTransform(const TransposeContext& context, const utils::MutableNodeView& node) const; - std::vector<int> GetVariadic4DFaninPorts( - const TransposeContext& context, - const utils::MutableNodeView& node) const; + std::vector<int> GetVariadicNDFaninPorts(const TransposeContext& context, + const utils::MutableNodeView& node, + int rank) const; }; class DefaultLayoutAgnosticOpTransposer : public LayoutAgnosticOpTransposer { diff --git a/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer_factory.cc b/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer_factory.cc index 15bbc08079c..8d4f55a71e8 100644 --- a/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer_factory.cc +++ b/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer_factory.cc @@ -27,6 +27,9 @@ std::shared_ptr<Transposer> TransposerFactory::GetTransposer( return GetOrCreateIfNotFound<DefaultLayoutSensitiveOpTransposer>( "DefaultLayoutSensitiveOp"); } + if (IsBiasAdd(node)) { + return GetOrCreateIfNotFound<BiasAddTransposer>("BiasAdd"); + } if (IsAvgPoolGrad(node)) { return GetOrCreateIfNotFound<AvgPoolGradTransposer>("AvgPoolGrad"); } diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 67fcd2174bf..c0a33232f8c 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -48,7 +48,6 @@ load( ) load( "//third_party/mkl:build_defs.bzl", - "if_mkl_ml", "mkl_deps", ) load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") @@ -3241,7 +3240,6 @@ cc_library( deps = [ ":aggregate_ops", ":argmax_op", - ":batch_matmul_op", ":betainc_op", ":bincount_op", ":bucketize_op", @@ -3337,14 +3335,27 @@ tf_kernel_library( tf_kernel_library( name = "batch_matmul_op", + deps = [":matmul_op"], +) + +tf_kernel_library( + name = "matmul_op", # <prefix>*impl.h are excluded by default from the CPU build, add explicitly. - hdrs = ["batch_matmul_op_impl.h"], - prefix = "batch_matmul_op", - deps = MATH_DEPS + [":eigen_contraction_kernel"] + if_mkl_ml([ - "//third_party/mkl:intel_binary_blob", - ]) + if_cuda_or_rocm([ - "//tensorflow/core/kernels:gpu_utils", - ]), + hdrs = ["matmul_op_impl.h"], + defines = select({ + ":xsmm": ["TENSORFLOW_USE_LIBXSMM"], + "//conditions:default": [], + }), + prefix = "matmul_op", + deps = MATH_DEPS + [ + ":eigen_contraction_kernel", + ":fused_eigen_output_kernels", + ] + select({ + ":xsmm": ["@libxsmm_archive//:xsmm_avx"], + "//conditions:default": [], + }) + mkl_deps() + if_cuda([ + "//tensorflow/core/platform/default/build_config:cublas_plugin", + ]) + if_cuda_or_rocm([":gpu_utils"]), ) tf_kernel_library( @@ -3406,28 +3417,6 @@ tf_kernel_library( ]), ) -tf_kernel_library( - name = "matmul_op", - srcs = [ - "matmul_op.cc", - "matmul_op_fused.cc", - ], - hdrs = ["matmul_op.h"], - defines = select({ - ":xsmm": ["TENSORFLOW_USE_LIBXSMM"], - "//conditions:default": [], - }), - deps = MATH_DEPS + [ - ":eigen_contraction_kernel", - ":fused_eigen_output_kernels", - ] + select({ - ":xsmm": ["@libxsmm_archive//:xsmm_avx"], - "//conditions:default": [], - }) + mkl_deps() + if_cuda([ - "//tensorflow/core/platform/default/build_config:cublas_plugin", - ]) + if_cuda_or_rocm([":gpu_utils"]), -) - tf_kernel_library( name = "reduction_ops", gpu_srcs = ["reduction_gpu_kernels.cu.h"], @@ -3620,25 +3609,6 @@ tf_cuda_cc_test( ], ) -tf_cuda_cc_test( - name = "batch_matmul_op_test", - size = "small", - srcs = ["batch_matmul_op_test.cc"], - deps = [ - ":batch_matmul_op", - ":broadcast_to_op", - ":ops_testutil", - ":ops_util", - "//tensorflow/core:core_cpu", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core:testlib", - ], -) - tf_cuda_cc_test( name = "scan_ops_test", size = "small", @@ -5868,8 +5838,8 @@ filegroup( "identity_op.h", "immutable_constant_op.cc", "immutable_constant_op.h", - "matmul_op.cc", - "matmul_op.h", + "matmul_op_impl.h", + "matmul_op_real.cc", "no_op.cc", "no_op.h", "one_hot_op.cc", @@ -5948,7 +5918,6 @@ filegroup( srcs = [ "argmax_op.h", "avgpooling_op.h", - "batch_matmul_op_impl.h", "batch_norm_op.h", "bincount_op.h", "broadcast_to_op.h", @@ -6039,7 +6008,6 @@ filegroup( ":android_extended_ops_headers", "argmax_op.cc", "avgpooling_op.cc", - "batch_matmul_op_real.cc", "batch_norm_op.cc", "bcast_ops.cc", "check_numerics_op.cc", @@ -6480,6 +6448,7 @@ cc_library( "//tensorflow/core/platform:strong_hash", "//third_party/eigen3", "//third_party/fft2d:fft2d_headers", + "//third_party/icu/data:conversion_data", "@com_google_absl//absl/base", "@com_google_protobuf//:protobuf", "@fft2d", @@ -7431,7 +7400,6 @@ test_suite( "manual", # Avoid redundancy when using wildcard test patterns. ], tests = [ - ":batch_matmul_op_test", ":batch_norm_op_test", ":broadcast_to_op_test", ":cast_op_test", diff --git a/tensorflow/core/kernels/batch_matmul_op_test.cc b/tensorflow/core/kernels/batch_matmul_op_test.cc deleted file mode 100644 index 0c04a82818f..00000000000 --- a/tensorflow/core/kernels/batch_matmul_op_test.cc +++ /dev/null @@ -1,257 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/framework/types.pb.h" -#include "tensorflow/core/graph/graph.h" -#include "tensorflow/core/graph/node_builder.h" -#include "tensorflow/core/graph/testlib.h" -#include "tensorflow/core/kernels/broadcast_to_op.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/platform/test_benchmark.h" - -namespace tensorflow { -namespace { - -Node* BroadcastTo(Graph* g, Node* input, Node* shape) { - Node* ret; - TF_CHECK_OK(NodeBuilder(g->NewName("n"), "BroadcastTo") - .Input(input) - .Input(shape) - .Finalize(g, &ret)); - return ret; -} - -Node* BatchMatmulV2(Graph* g, Node* in0, Node* in1, bool adj_x, bool adj_y) { - Node* ret; - TF_CHECK_OK(NodeBuilder(g->NewName("n"), "BatchMatMulV2") - .Input(in0) - .Input(in1) - .Attr("adj_x", adj_x) - .Attr("adj_y", adj_y) - .Finalize(g, &ret)); - return ret; -} - -template <typename T> -static Graph* BatchMatmul(int b, int m, int k, int n, bool adjoint_a, - bool adjoint_b, DataType type) { - Graph* g = new Graph(OpRegistry::Global()); - Tensor in0(type, adjoint_a ? TensorShape({b, k, m}) : TensorShape({b, m, k})); - in0.flat<T>().setRandom(); - Tensor in1(type, adjoint_b ? TensorShape({b, n, k}) : TensorShape({b, k, n})); - in1.flat<T>().setRandom(); - test::graph::BatchMatmul(g, test::graph::Constant(g, in0), - test::graph::Constant(g, in1), adjoint_a, adjoint_b); - return g; -} - -template <typename T> -static Graph* BatchMatmulWithBroadcast(int b0, int b1, int m, int k, int n, - bool manual_broadcast, DataType type) { - Graph* g = new Graph(OpRegistry::Global()); - Tensor in0(type, TensorShape({b0, m, k})); - in0.flat<T>().setRandom(); - Tensor in1(type, TensorShape({b1, k, n})); - in1.flat<T>().setRandom(); - - Tensor broadcasted_in0_shape(DT_INT64, TensorShape({3})); - Tensor broadcasted_in1_shape(DT_INT64, TensorShape({3})); - - Node* in0_node = nullptr; - Node* in1_node = nullptr; - if (manual_broadcast) { - for (int i = 0; i < 3; ++i) { - auto vec0 = broadcasted_in0_shape.vec<int64>(); - auto vec1 = broadcasted_in1_shape.vec<int64>(); - vec0(i) = (i == 0 ? std::max(b0, b1) : in0.shape().dim_size(i)); - vec1(i) = (i == 0 ? std::max(b0, b1) : in1.shape().dim_size(i)); - } - in0_node = BroadcastTo(g, test::graph::Constant(g, in0), - test::graph::Constant(g, broadcasted_in0_shape)); - in1_node = BroadcastTo(g, test::graph::Constant(g, in1), - test::graph::Constant(g, broadcasted_in1_shape)); - } else { - in0_node = test::graph::Constant(g, in0); - in1_node = test::graph::Constant(g, in1); - } - - BatchMatmulV2(g, in0_node, in1_node, false, false); - return g; -} - -#define BM_BatchMatmulDev(B, M, K, N, TA, TB, T, TFTYPE, DEVICE) \ - static void \ - BM_BatchMatmul##_##B##_##M##_##K##_##N##_##TA##_##TB##_##TFTYPE##_##DEVICE( \ - int iters) { \ - testing::UseRealTime(); \ - testing::ItemsProcessed(static_cast<int64>(iters) * B * M * K * N * 2); \ - test::Benchmark(#DEVICE, BatchMatmul<T>(B, M, K, N, TA, TB, TFTYPE)) \ - .Run(iters); \ - } \ - BENCHMARK( \ - BM_BatchMatmul##_##B##_##M##_##K##_##N##_##TA##_##TB##_##TFTYPE##_##DEVICE); - -#define BM_BatchMatmul(B, M, K, N, TA, TB) \ - BM_BatchMatmulDev(B, M, K, N, TA, TB, float, DT_FLOAT, cpu); -// BM_BatchMatmulDev(B, M, K, N, TA, TB, std::complex<float>, DT_COMPLEX64, -// cpu); -// BM_BatchMatmulDev(B, M, K, N, TA, TB, float, DT_FLOAT, gpu); -/* Uncomment to enable benchmarks for double & complex types: */ -// BM_BatchMatmulDev(B, M, K, N, TA, TB, std::complex<float>, DT_COMPLEX64, -// gpu); -// BM_BatchMatmulDev(M, K, N, TA, TB, double, DT_DOUBLE, cpu); \ -// BM_BatchMatmulDev(M, K, N, TA, TB, std::complex<double>, DT_COMPLEX128, cpu); -// \ -// BM_BatchMatmulDev(M, K, N, TA, TB, double, DT_DOUBLE, gpu); \ -// BM_BatchMatmulDev(M, K, N, TA, TB, std::complex<double>, DT_COMPLEX128, gpu); - -// Macro arguments names: --------------------------------------------------- // -// B1: batch size of LHS -// B2: batch size of RHS -// M: outer dimension of LHS -// K: inner dimensions of LHS and RHS -// N: outer dimension of RHS -// MB: boolean indicating whether to use manual broadcasting -// T: C++ type of scalars (e.g. float, std::complex) -// TT: TensorFlow type of scalars (e.g. DT_FLOAT, DT_COMPLEX128 -// D: Device (e.g. cpu, gpu) -#define BM_BatchMatmulBCastDev(B1, B2, M, K, N, MB, T, TT, D) \ - static void \ - BM_BatchMatmulBCast##_##B1##_##B2##_##M##_##K##_##N##_##MB##_##TT##_##D( \ - int iters) { \ - testing::UseRealTime(); \ - testing::ItemsProcessed(static_cast<int64>(iters) * std::max(B1, B2) * M * \ - K * N * 2); \ - test::Benchmark(#D, BatchMatmulWithBroadcast<T>(B1, B2, M, K, N, MB, TT)) \ - .Run(iters); \ - } \ - BENCHMARK( \ - BM_BatchMatmulBCast##_##B1##_##B2##_##M##_##K##_##N##_##MB##_##TT##_##D); - -#define BM_BatchMatmulBCast(B1, B2, M, K, N, MB) \ - BM_BatchMatmulBCastDev(B1, B2, M, K, N, MB, float, DT_FLOAT, cpu); - -// Typical fully connected layers -BM_BatchMatmulBCast(1, 128, 1, 1024, 1024, true); -BM_BatchMatmulBCast(1, 128, 1, 1024, 1024, false); -BM_BatchMatmulBCast(128, 1, 1, 1024, 1024, true); -BM_BatchMatmulBCast(128, 1, 1, 1024, 1024, false); -BM_BatchMatmulBCast(1, 128, 128, 1024, 1024, true); -BM_BatchMatmulBCast(1, 128, 128, 1024, 1024, false); -BM_BatchMatmulBCast(128, 1, 128, 1024, 1024, true); -BM_BatchMatmulBCast(128, 1, 128, 1024, 1024, false); - -// Square matmul. -BM_BatchMatmulBCast(1, 128, 512, 512, 512, true); -BM_BatchMatmulBCast(1, 128, 512, 512, 512, false); -BM_BatchMatmulBCast(128, 1, 512, 512, 512, true); -BM_BatchMatmulBCast(128, 1, 512, 512, 512, false); -BM_BatchMatmulBCast(1, 128, 1024, 1024, 1024, true); -BM_BatchMatmulBCast(1, 128, 1024, 1024, 1024, false); -BM_BatchMatmulBCast(128, 1, 1024, 1024, 1024, true); -BM_BatchMatmulBCast(128, 1, 1024, 1024, 1024, false); - -// Matrix-vector multiplies. -BM_BatchMatmulBCast(1, 128, 10000, 200, 1, true); -BM_BatchMatmulBCast(1, 128, 10000, 200, 1, false); -BM_BatchMatmulBCast(128, 1, 10000, 200, 1, true); -BM_BatchMatmulBCast(128, 1, 10000, 200, 1, false); - -// Vector-matrix multiplies. -BM_BatchMatmulBCast(1, 128, 1, 200, 10000, true); -BM_BatchMatmulBCast(1, 128, 1, 200, 10000, false); -BM_BatchMatmulBCast(128, 1, 1, 200, 10000, true); -BM_BatchMatmulBCast(128, 1, 1, 200, 10000, false); - -// Typical fully connected layers -BM_BatchMatmul(1, 1, 1024, 1024, false, false); -BM_BatchMatmul(1, 8, 1024, 1024, false, false); -BM_BatchMatmul(1, 16, 1024, 1024, false, false); -BM_BatchMatmul(1, 128, 1024, 1024, false, false); -BM_BatchMatmul(2, 1, 1024, 1024, false, false); -BM_BatchMatmul(2, 8, 1024, 1024, false, false); -BM_BatchMatmul(2, 16, 1024, 1024, false, false); -BM_BatchMatmul(2, 128, 1024, 1024, false, false); -BM_BatchMatmul(8, 1, 1024, 1024, false, false); -BM_BatchMatmul(8, 8, 1024, 1024, false, false); -BM_BatchMatmul(8, 16, 1024, 1024, false, false); -BM_BatchMatmul(8, 128, 1024, 1024, false, false); -BM_BatchMatmul(32, 1, 1024, 1024, false, false); -BM_BatchMatmul(32, 8, 1024, 1024, false, false); -BM_BatchMatmul(32, 16, 1024, 1024, false, false); -BM_BatchMatmul(32, 128, 1024, 1024, false, false); - -// Square matmul. -BM_BatchMatmul(1, 32, 32, 32, false, false); -BM_BatchMatmul(1, 128, 128, 128, false, false); -BM_BatchMatmul(1, 256, 256, 256, false, false); -BM_BatchMatmul(1, 1024, 1024, 1024, false, false); -BM_BatchMatmul(1, 2048, 2048, 2048, false, false); -BM_BatchMatmul(2, 32, 32, 32, false, false); -BM_BatchMatmul(2, 128, 128, 128, false, false); -BM_BatchMatmul(2, 256, 256, 256, false, false); -BM_BatchMatmul(2, 1024, 1024, 1024, false, false); -BM_BatchMatmul(2, 2048, 2048, 2048, false, false); -BM_BatchMatmul(4, 32, 32, 32, false, false); -BM_BatchMatmul(4, 128, 128, 128, false, false); -BM_BatchMatmul(4, 256, 256, 256, false, false); -BM_BatchMatmul(4, 1024, 1024, 1024, false, false); -BM_BatchMatmul(4, 2048, 2048, 2048, false, false); -BM_BatchMatmul(8, 32, 32, 32, false, false); -BM_BatchMatmul(8, 128, 128, 128, false, false); -BM_BatchMatmul(8, 256, 256, 256, false, false); -BM_BatchMatmul(8, 1024, 1024, 1024, false, false); -BM_BatchMatmul(8, 2048, 2048, 2048, false, false); -BM_BatchMatmul(32, 32, 32, 32, false, false); -BM_BatchMatmul(32, 128, 128, 128, false, false); -BM_BatchMatmul(32, 256, 256, 256, false, false); -BM_BatchMatmul(32, 1024, 1024, 1024, false, false); -BM_BatchMatmul(32, 2048, 2048, 2048, false, false); - -// Matrix-vector multiplies. -BM_BatchMatmul(1, 10000, 200, 1, false, false); -BM_BatchMatmul(8, 10000, 200, 1, false, false); -BM_BatchMatmul(32, 10000, 200, 1, false, false); -BM_BatchMatmul(1, 10000, 200, 1, true, false); -BM_BatchMatmul(8, 10000, 200, 1, true, false); -BM_BatchMatmul(32, 10000, 200, 1, true, false); -BM_BatchMatmul(1, 10000, 200, 1, false, true); -BM_BatchMatmul(8, 10000, 200, 1, false, true); -BM_BatchMatmul(32, 10000, 200, 1, false, true); -BM_BatchMatmul(1, 10000, 200, 1, true, true); -BM_BatchMatmul(8, 10000, 200, 1, true, true); -BM_BatchMatmul(32, 10000, 200, 1, true, true); - -// Vector-matrix multiplies. -BM_BatchMatmul(1, 1, 200, 10000, false, false); -BM_BatchMatmul(8, 1, 200, 10000, false, false); -BM_BatchMatmul(32, 1, 200, 10000, false, false); -BM_BatchMatmul(1, 1, 200, 10000, true, false); -BM_BatchMatmul(8, 1, 200, 10000, true, false); -BM_BatchMatmul(32, 1, 200, 10000, true, false); -BM_BatchMatmul(1, 1, 200, 10000, false, true); -BM_BatchMatmul(8, 1, 200, 10000, false, true); -BM_BatchMatmul(32, 1, 200, 10000, false, true); -BM_BatchMatmul(1, 1, 200, 10000, true, true); -BM_BatchMatmul(8, 1, 200, 10000, true, true); -BM_BatchMatmul(32, 1, 200, 10000, true, true); - -} // namespace -} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/cache_dataset_ops.cc b/tensorflow/core/kernels/data/cache_dataset_ops.cc index c9883f9c938..3d0a51404ba 100644 --- a/tensorflow/core/kernels/data/cache_dataset_ops.cc +++ b/tensorflow/core/kernels/data/cache_dataset_ops.cc @@ -38,6 +38,8 @@ namespace data { /* static */ constexpr const char* const CacheDatasetOp::kOutputTypes; /* static */ constexpr const char* const CacheDatasetOp::kOutputShapes; +namespace { + constexpr char kKeyStrFormat[] = "%%%zuzu_%%%zuzu"; constexpr char kPaddingSizeStrFormat[] = "%zu"; constexpr char kFileDatasetPrefix[] = "File"; @@ -57,6 +59,14 @@ constexpr char kCacheCompleted[] = "cache_completed"; constexpr char kIndex[] = "index"; constexpr char kImpl[] = "Impl"; constexpr char kCacheDataset[] = "CacheDataset"; +constexpr char kIncompleteCacheErrorMessage[] = + "The calling iterator did not fully read the dataset being cached. In " + "order to avoid unexpected truncation of the dataset, the partially cached " + "contents of the dataset will be discarded. This can happen if you have " + "an input pipeline similar to `dataset.cache().take(k).repeat()`. You " + "should use `dataset.take(k).cache().repeat()` instead."; + +} // namespace class CacheDatasetOp::FileDatasetBase : public DatasetBase { public: @@ -220,6 +230,7 @@ class CacheDatasetOp::FileDatasetBase : public DatasetBase { ~FileWriterIterator() override { if (!dataset()->env_->FileExists(MetaFilename(filename_)).ok()) { + LOG(WARNING) << kIncompleteCacheErrorMessage; std::vector<string> cache_files; Status s = dataset()->env_->GetMatchingPaths( strings::StrCat(filename_, "*"), &cache_files); @@ -754,13 +765,7 @@ class CacheDatasetOp::MemoryDatasetBase : public DatasetBase { ~MemoryWriterIterator() override { mutex_lock l(mu_); if (!temp_cache_.empty() && !cache_->IsCompleted()) { - LOG(WARNING) - << "The calling iterator did not fully read the dataset being " - "cached. In order to avoid unexpected truncation of the " - "dataset, the partially cached contents of the dataset " - "will be discarded. This can happen if you have an input " - "pipeline similar to `dataset.cache().take(k).repeat()`. " - "You should use `dataset.take(k).cache().repeat()` instead."; + LOG(WARNING) << kIncompleteCacheErrorMessage; cache_->Reset(); } } diff --git a/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc b/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc index bdea5724911..c7fac733b32 100644 --- a/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc @@ -482,7 +482,10 @@ class DataServiceDatasetOp::Dataset : public DatasetBase { VLOG(1) << "Failed to get element from worker " << task_to_process->address << ": " << s; task_to_process->in_use = false; - status_ = s; + status_ = Status( + s.code(), + absl::StrCat("Failed to get element from worker ", + task_to_process->address, ": ", s.error_message())); get_next_cv_.notify_all(); return; } diff --git a/tensorflow/core/kernels/dynamic_partition_op_test.cc b/tensorflow/core/kernels/dynamic_partition_op_test.cc index ac34c4ff09f..ba9cde22845 100644 --- a/tensorflow/core/kernels/dynamic_partition_op_test.cc +++ b/tensorflow/core/kernels/dynamic_partition_op_test.cc @@ -188,15 +188,19 @@ static Graph* DynamicPartition(int num_partitions, int dim) { return g; } -#define BM_DYNAMIC_PARTITION(DEVICE, T, num) \ - static void BM_##DEVICE##_dynpart_##T##_##num(int iters, int dim) { \ - const int64 items = ((128 << 20) / sizeof(T)); \ - const int64 tot = static_cast<int64>(iters) * items; \ - testing::ItemsProcessed(tot); \ - testing::UseRealTime(); \ - test::Benchmark(#DEVICE, DynamicPartition<T>(num, dim)).Run(iters); \ - } \ - BENCHMARK(BM_##DEVICE##_dynpart_##T##_##num)->Arg(1)->Arg(256) +#define BM_DYNAMIC_PARTITION(DEVICE, T, num) \ + static void BM_##DEVICE##_dynpart_##T##_##num( \ + ::testing::benchmark::State& state) { \ + const int dim = state.range(0); \ + \ + const int64 items = ((128 << 20) / sizeof(T)); \ + test::Benchmark(#DEVICE, DynamicPartition<T>(num, dim), \ + /*old_benchmark_api=*/false) \ + .Run(state); \ + const int64 tot = static_cast<int64>(state.iterations()) * items; \ + state.SetItemsProcessed(tot); \ + } \ + BENCHMARK(BM_##DEVICE##_dynpart_##T##_##num)->UseRealTime()->Arg(1)->Arg(256) BM_DYNAMIC_PARTITION(cpu, float, 2); BM_DYNAMIC_PARTITION(cpu, float, 100); diff --git a/tensorflow/core/kernels/eigen_spatial_convolutions_test.cc b/tensorflow/core/kernels/eigen_spatial_convolutions_test.cc index ed4b65cd398..6f16df351f5 100644 --- a/tensorflow/core/kernels/eigen_spatial_convolutions_test.cc +++ b/tensorflow/core/kernels/eigen_spatial_convolutions_test.cc @@ -1376,7 +1376,7 @@ TEST(EigenSpatialConvolutionsTest, SpatialConvContractionMapper) { } template <typename T> -static void PackRhsHelper(int iters, +static void PackRhsHelper(::testing::benchmark::State& state, /* Input dimensions: */ int input_batches, int input_cols, int input_rows, int input_depth, @@ -1393,9 +1393,6 @@ static void PackRhsHelper(int iters, // Set random seed for benchmark repeatability. srand(12345); - tensorflow::testing::UseRealTime(); - tensorflow::testing::StopTiming(); - using Dimensions = Eigen::DSizes<Eigen::Index, 4>; // Default Eigen::Tensor layout is column major, so we configure dimensions @@ -1547,8 +1544,7 @@ static void PackRhsHelper(int iters, return (idx / packet_size) * packet_size; }; - tensorflow::testing::StartTiming(); - for (int i = 0; i < iters; ++i) { + for (auto s : state) { int input_idx = num_inputs == 1 ? 1 : internal::random<int>(0, num_inputs - 1); @@ -1571,15 +1567,15 @@ static void PackRhsHelper(int iters, input_mappers[input_idx].getSubMapper(depth_offset, col_offset); pack_rhs(packed.data() + packed_offset, sub_mapper, depth, cols); } - tensorflow::testing::StopTiming(); - tensorflow::testing::SetLabel( + + state.SetLabel( absl::StrCat("patch: ", patch_rows, "x", patch_cols, " D", patch_depth, "; num_patches=", num_patches, " patch_size=", patch_size, " num_inputs=", num_inputs, " padding=", padding)); } template <typename T> -static void PackLhsHelper(int iters, +static void PackLhsHelper(::testing::benchmark::State& state, /* Input dimensions: */ int input_depth, /* Filter (kernel) dimensions: */ @@ -1592,9 +1588,6 @@ static void PackLhsHelper(int iters, eigen_assert(block_rows <= filter_count); eigen_assert(block_cols <= input_depth * filter_rows * filter_cols); - tensorflow::testing::UseRealTime(); - tensorflow::testing::StopTiming(); - using Dimensions = Eigen::DSizes<Eigen::Index, 4>; // Default Eigen::Tensor layout is column major, so we configure dimensions @@ -1716,8 +1709,7 @@ static void PackLhsHelper(int iters, const Index max_row = filter_count; const Index max_col = filter_rows * filter_cols * input_depth; - tensorflow::testing::StartTiming(); - for (int i = 0; i < iters; ++i) { + for (auto s : state) { int filter_idx = num_filters == 1 ? 1 : internal::random<int>(0, num_filters - 1); @@ -1743,8 +1735,7 @@ static void PackLhsHelper(int iters, pack_lhs(packed.data() + packed_offset, sub_mapper, cols, rows); #endif } - tensorflow::testing::StopTiming(); - tensorflow::testing::SetLabel(absl::StrCat( + state.SetLabel(absl::StrCat( "filter: count=", filter_count, " dims=", filter_rows, "x", filter_cols, "; input: depth=", input_depth, "; num_filers=", num_filters)); } @@ -1777,12 +1768,14 @@ static void PackLhsHelper(int iters, #define BM_PackRhs(T, N, H, W, C, FC, FH, FW, PAD, SH, SW, ISH, ISW, BR, BC) \ static void BM_RHS_NAME(PackRhs, T, N, H, W, C, FC, FH, FW, PAD, SH, SW, \ - ISH, ISW, BR, BC)(int iters) { \ - PackRhsHelper<T>(iters, N, H, W, C, FC, FH, FW, PADDING_##PAD, SH, SW, \ + ISH, ISW, BR, \ + BC)(::testing::benchmark::State & state) { \ + PackRhsHelper<T>(state, N, H, W, C, FC, FH, FW, PADDING_##PAD, SH, SW, \ ISH, ISW, BR, BC); \ } \ BENCHMARK(BM_RHS_NAME(PackRhs, T, N, H, W, C, FC, FH, FW, PAD, SH, SW, ISH, \ - ISW, BR, BC)) + ISW, BR, BC)) \ + ->UseRealTime() // Number of input channel (input depth) it equal to the number of patch // channels (patch depth). @@ -2019,11 +2012,12 @@ BM_PackRhs(/*type*/ qint8, // #define BM_LHS_NAME(prefix, T, C, FC, FH, FW, BR, BC) \ BM_CONCAT(BM_##prefix##_##T##_##C##_FC##FC##_##FH##x##FW, _B##BR##x##BC) -#define BM_PackLhs(T, C, FC, FH, FW, BR, BC) \ - static void BM_LHS_NAME(PackLhs, T, C, FC, FH, FW, BR, BC)(int iters) { \ - PackLhsHelper<T>(iters, C, FC, FH, FW, BR, BC); \ - } \ - BENCHMARK(BM_LHS_NAME(PackLhs, T, C, FC, FH, FW, BR, BC)) +#define BM_PackLhs(T, C, FC, FH, FW, BR, BC) \ + static void BM_LHS_NAME(PackLhs, T, C, FC, FH, FW, BR, \ + BC)(::testing::benchmark::State & state) { \ + PackLhsHelper<T>(state, C, FC, FH, FW, BR, BC); \ + } \ + BENCHMARK(BM_LHS_NAME(PackLhs, T, C, FC, FH, FW, BR, BC))->UseRealTime() // Number of input channel (input depth) it equal to the number of patch // channels (patch depth). diff --git a/tensorflow/core/kernels/example_parsing_ops_test.cc b/tensorflow/core/kernels/example_parsing_ops_test.cc index 7d30d00b266..8c933eff704 100644 --- a/tensorflow/core/kernels/example_parsing_ops_test.cc +++ b/tensorflow/core/kernels/example_parsing_ops_test.cc @@ -349,15 +349,16 @@ typedef BenchmarkOptions<ExampleStore<FloatFiller>, kRagged> RaggedFloat; // B == batch_size, K == num_keys. F == feature_size. // K must be one of 10, 100, 1000 #define BM_ParseExample(TYPE, B, K, F) \ - static void BM_ParseExample##_##TYPE##_##B##_##K##_##F(int iters) { \ + static void BM_ParseExample##_##TYPE##_##B##_##K##_##F( \ + ::testing::benchmark::State& state) { \ int64 items_per_iter = static_cast<int64>(B) * K * F; \ - testing::UseRealTime(); \ - testing::ItemsProcessed(static_cast<int64>(iters) * items_per_iter); \ test::Benchmark("cpu", ParseExample<TYPE>(B, K, F), nullptr, nullptr, \ - nullptr, "SINGLE_THREADED_EXECUTOR") \ - .Run(iters); \ + nullptr, "SINGLE_THREADED_EXECUTOR", false) \ + .Run(state); \ + state.SetItemsProcessed(static_cast<int64>(state.iterations()) * \ + items_per_iter); \ } \ - BENCHMARK(BM_ParseExample##_##TYPE##_##B##_##K##_##F); + BENCHMARK(BM_ParseExample##_##TYPE##_##B##_##K##_##F)->UseRealTime(); #define BM_AllParseExample(Type) \ BM_ParseExample(Type, 1, 10, 1); \ @@ -385,15 +386,17 @@ BM_AllParseExample(VarLenDenseFloat); // K must be one of 10, 100, 1000 // B=0 indicates that a scalar input should be used (instead of a vector). #define BM_ParseExampleV2(TYPE, B, K, F) \ - static void BM_ParseExampleV2##_##TYPE##_##B##_##K##_##F(int iters) { \ + static void BM_ParseExampleV2##_##TYPE##_##B##_##K##_##F( \ + ::testing::benchmark::State& state) { \ int64 items_per_iter = static_cast<int64>(std::max(B, 1)) * K * F; \ - testing::UseRealTime(); \ - testing::ItemsProcessed(static_cast<int64>(iters) * items_per_iter); \ test::Benchmark("cpu", ParseExampleV2<TYPE>(B, K, F), nullptr, nullptr, \ - nullptr, "SINGLE_THREADED_EXECUTOR") \ - .Run(iters); \ + nullptr, "SINGLE_THREADED_EXECUTOR", \ + /*old_benchmark_api=*/false) \ + .Run(state); \ + state.SetItemsProcessed(static_cast<int64>(state.iterations()) * \ + items_per_iter); \ } \ - BENCHMARK(BM_ParseExampleV2##_##TYPE##_##B##_##K##_##F); + BENCHMARK(BM_ParseExampleV2##_##TYPE##_##B##_##K##_##F)->UseRealTime(); #define BM_AllParseExampleV2(Type) \ /* Vector Inputs */ \ @@ -437,15 +440,17 @@ BM_AllParseExampleV2(RaggedFloat); // K == num_keys. F == feature_size. // K must be one of 10, 100, 1000 #define BM_ParseSingleExample(TYPE, K, F) \ - static void BM_ParseSingleExample##_##TYPE##_1_##K##_##F(int iters) { \ + void BM_ParseSingleExample##_##TYPE##_1_##K##_##F( \ + ::testing::benchmark::State& state) { \ int64 items_per_iter = K * F; \ - testing::UseRealTime(); \ - testing::ItemsProcessed(static_cast<int64>(iters) * items_per_iter); \ test::Benchmark("cpu", ParseSingleExample<TYPE>(K, F), nullptr, nullptr, \ - nullptr, "SINGLE_THREADED_EXECUTOR") \ - .Run(iters); \ + nullptr, "SINGLE_THREADED_EXECUTOR", \ + /*old_benchmark_api=*/false) \ + .Run(state); \ + state.SetItemsProcessed(static_cast<int64>(state.iterations()) * \ + items_per_iter); \ } \ - BENCHMARK(BM_ParseSingleExample##_##TYPE##_1_##K##_##F); + BENCHMARK(BM_ParseSingleExample##_##TYPE##_1_##K##_##F)->UseRealTime(); #define BM_AllParseSingleExample(Type) \ BM_ParseSingleExample(Type, 10, 1); \ diff --git a/tensorflow/core/kernels/gather_nd_op_test.cc b/tensorflow/core/kernels/gather_nd_op_test.cc index b0b5c958b5a..130345b68f6 100644 --- a/tensorflow/core/kernels/gather_nd_op_test.cc +++ b/tensorflow/core/kernels/gather_nd_op_test.cc @@ -132,18 +132,22 @@ static Graph* GatherNd(int dim) { return g; } -#define BM_GATHER_ND(DEVICE, INDEX) \ - static void BM_##DEVICE##_gather_nd_##INDEX(int iters, int dim) { \ - const int64 tot = static_cast<int64>(iters) * kLookups * 4; \ - testing::ItemsProcessed(tot); \ - testing::BytesProcessed(tot * sizeof(float)); \ - testing::UseRealTime(); \ - test::Benchmark(#DEVICE, GatherNd<INDEX>(dim)).Run(iters); \ - } \ - BENCHMARK(BM_##DEVICE##_gather_nd_##INDEX) \ - ->Arg(10) \ - ->Arg(100) \ - ->Arg(1000) \ +#define BM_GATHER_ND(DEVICE, INDEX) \ + static void BM_##DEVICE##_gather_nd_##INDEX( \ + ::testing::benchmark::State& state) { \ + const int dim = state.range(0); \ + test::Benchmark(#DEVICE, GatherNd<INDEX>(dim), \ + /*old_benchmark_api=*/false) \ + .Run(state); \ + const int64 tot = static_cast<int64>(state.iterations()) * kLookups * 4; \ + state.SetItemsProcessed(tot); \ + state.SetBytesProcessed(tot * sizeof(float)); \ + } \ + BENCHMARK(BM_##DEVICE##_gather_nd_##INDEX) \ + ->UseRealTime() \ + ->Arg(10) \ + ->Arg(100) \ + ->Arg(1000) \ ->Arg(10000) BM_GATHER_ND(cpu, int32); diff --git a/tensorflow/core/kernels/gather_op_test.cc b/tensorflow/core/kernels/gather_op_test.cc index e4c77881ea8..f2d96e9475f 100644 --- a/tensorflow/core/kernels/gather_op_test.cc +++ b/tensorflow/core/kernels/gather_op_test.cc @@ -222,21 +222,24 @@ static Graph* Gather(int dim) { return g; } -#define BM_GATHER(DEVICE, INDEX) \ - static void BM_##DEVICE##_gather_##INDEX(int iters, int dim) { \ - const int64 tot = static_cast<int64>(iters) * kLookups * dim; \ - testing::ItemsProcessed(tot); \ - testing::BytesProcessed(tot * sizeof(float)); \ - testing::UseRealTime(); \ - test::Benchmark(#DEVICE, Gather<INDEX>(dim)).Run(iters); \ - } \ - BENCHMARK(BM_##DEVICE##_gather_##INDEX) \ - ->Arg(1) \ - ->Arg(10) \ - ->Arg(20) \ - ->Arg(64) \ - ->Arg(100) \ - ->Arg(200) \ +#define BM_GATHER(DEVICE, INDEX) \ + static void BM_##DEVICE##_gather_##INDEX( \ + ::testing::benchmark::State& state) { \ + const int dim = state.range(0); \ + test::Benchmark(#DEVICE, Gather<INDEX>(dim), /*old_benchmark_api=*/false) \ + .Run(state); \ + const int64 tot = static_cast<int64>(state.iterations()) * kLookups * dim; \ + state.SetItemsProcessed(tot); \ + state.SetBytesProcessed(tot * sizeof(float)); \ + } \ + BENCHMARK(BM_##DEVICE##_gather_##INDEX) \ + ->UseRealTime() \ + ->Arg(1) \ + ->Arg(10) \ + ->Arg(20) \ + ->Arg(64) \ + ->Arg(100) \ + ->Arg(200) \ ->Arg(1000) BM_GATHER(cpu, int32); diff --git a/tensorflow/core/kernels/in_topk_op_test.cc b/tensorflow/core/kernels/in_topk_op_test.cc index 9e4da735c5a..75476a6323d 100644 --- a/tensorflow/core/kernels/in_topk_op_test.cc +++ b/tensorflow/core/kernels/in_topk_op_test.cc @@ -65,13 +65,16 @@ static Graph* InTopK(int num_targets, int num_classes, T top_k) { #define BM_NAME(T, TARGETS, CLASSES, K, DEVICE) \ BM_InTopK##_##T##_##TARGETS##_##CLASSES##_##K##_##DEVICE -#define BM_InTopK(T, TARGETS, CLASSES, K, DEVICE) \ - static void BM_NAME(T, TARGETS, CLASSES, K, DEVICE)(int iters) { \ - testing::UseRealTime(); \ - testing::ItemsProcessed(static_cast<int64>(iters) * TARGETS * CLASSES); \ - test::Benchmark(#DEVICE, InTopK<T>(TARGETS, CLASSES, K)).Run(iters); \ - } \ - BENCHMARK(BM_NAME(T, TARGETS, CLASSES, K, DEVICE)); +#define BM_InTopK(T, TARGETS, CLASSES, K, DEVICE) \ + static void BM_NAME(T, TARGETS, CLASSES, K, \ + DEVICE)(::testing::benchmark::State & state) { \ + test::Benchmark(#DEVICE, InTopK<T>(TARGETS, CLASSES, K), \ + /*old_benchmark_api=*/false) \ + .Run(state); \ + state.SetItemsProcessed(static_cast<int64>(state.iterations()) * TARGETS * \ + CLASSES); \ + } \ + BENCHMARK(BM_NAME(T, TARGETS, CLASSES, K, DEVICE))->UseRealTime(); BM_InTopK(int64, 64, 1000, 10, cpu); BM_InTopK(int64, 64, 10000, 10, cpu); diff --git a/tensorflow/core/kernels/linalg/einsum_op_impl.h b/tensorflow/core/kernels/linalg/einsum_op_impl.h index b9b2d1f0eae..61e509f331b 100644 --- a/tensorflow/core/kernels/linalg/einsum_op_impl.h +++ b/tensorflow/core/kernels/linalg/einsum_op_impl.h @@ -30,9 +30,9 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_types.h" -#include "tensorflow/core/kernels/batch_matmul_op_impl.h" #include "tensorflow/core/kernels/fill_functor.h" #include "tensorflow/core/kernels/linalg/einsum_op.h" +#include "tensorflow/core/kernels/matmul_op_impl.h" #include "tensorflow/core/kernels/reduction_ops_common.h" #include "tensorflow/core/kernels/transpose_functor.h" #include "tensorflow/core/lib/core/errors.h" diff --git a/tensorflow/core/kernels/lrn_op_test.cc b/tensorflow/core/kernels/lrn_op_test.cc index 68aa3428399..3b8e7a43b2d 100644 --- a/tensorflow/core/kernels/lrn_op_test.cc +++ b/tensorflow/core/kernels/lrn_op_test.cc @@ -199,7 +199,7 @@ TCASE(T3, 128, 4, 3, 2.0f, 1.0f, 1.0f) #undef TCASE -static Graph* BM_LRNGrad(int batches, int rows, int cols, int depth, +static Graph* MakeRNGrad(int batches, int rows, int cols, int depth, int depth_radius) { Graph* g = new Graph(OpRegistry::Global()); Tensor grads(DT_FLOAT, TensorShape({batches, rows, cols, depth})); @@ -223,12 +223,15 @@ static Graph* BM_LRNGrad(int batches, int rows, int cols, int depth, return g; } -#define BM_LRNGradDev(DEVICE, B, R, C, D, DR) \ - static void BM_LRNGrad_##DEVICE##_##B##_##R##_##C##_##D##_##DR(int iters) { \ - testing::ItemsProcessed(static_cast<int64>(iters) * B * R * C * D * DR * \ - 4); \ - test::Benchmark(#DEVICE, BM_LRNGrad(B, R, C, D, DR)).Run(iters); \ - } \ +#define BM_LRNGradDev(DEVICE, B, R, C, D, DR) \ + static void BM_LRNGrad_##DEVICE##_##B##_##R##_##C##_##D##_##DR( \ + ::testing::benchmark::State& state) { \ + test::Benchmark(#DEVICE, MakeRNGrad(B, R, C, D, DR), \ + /*old_benchmark_api*/ false) \ + .Run(state); \ + state.SetItemsProcessed(static_cast<int64>(state.iterations()) * B * R * \ + C * D * DR * 4); \ + } \ BENCHMARK(BM_LRNGrad_##DEVICE##_##B##_##R##_##C##_##D##_##DR) BM_LRNGradDev(cpu, 128, 12, 12, 64, 4); diff --git a/tensorflow/core/kernels/matmul_op.cc b/tensorflow/core/kernels/matmul_op.cc deleted file mode 100644 index 3b57f093e23..00000000000 --- a/tensorflow/core/kernels/matmul_op.cc +++ /dev/null @@ -1,567 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// See docs in ../ops/math_ops.cc. - -#define EIGEN_USE_THREADS - -#include "tensorflow/core/kernels/matmul_op.h" - -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/register_types.h" -#include "tensorflow/core/kernels/fill_functor.h" -#include "tensorflow/core/util/matmul_autotune.h" -#if GOOGLE_CUDA -#include "third_party/gpus/cuda/include/cuda.h" -#endif -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM -#include "tensorflow/core/kernels/gpu_utils.h" -#include "tensorflow/core/platform/stream_executor.h" -#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM - -namespace tensorflow { - -typedef Eigen::ThreadPoolDevice CPUDevice; -typedef Eigen::GpuDevice GPUDevice; - -template <typename Device, typename T, bool USE_CUBLAS> -struct LaunchMatMul; - -namespace { -// Converts a TensorFlow Tensor to an Eigen Matrix. -template <typename T> -Eigen::Map< - const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> -ToEigenMatrix(const Tensor& tensor) { - auto matrix = tensor.matrix<T>(); - return Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>::Map( - matrix.data(), matrix.dimension(0), matrix.dimension(1)); -} - -// Converts a TensorFlow Tensor to an Eigen Vector. -template <typename T> -Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, 1>> ToEigenVector(Tensor* tensor) { - auto v = tensor->flat<T>(); - return Eigen::Matrix<T, Eigen::Dynamic, 1>::Map(v.data(), v.dimension(0)); -} -template <typename T> -Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, 1>> ToEigenVector( - const Tensor& tensor) { - auto v = tensor.flat<T>(); - return Eigen::Matrix<T, Eigen::Dynamic, 1>::Map(v.data(), v.dimension(0)); -} -} // namespace - -// If either side can be represented as a vector, do an explicit vector -// matrix multiply and return true; else return false. -// -// Note: this uses plain Eigen and not Eigen Tensor because it is more -// efficient. -template <typename T> -bool ExplicitVectorMatrixOptimization( - const Tensor& a, const Tensor& b, - const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair, - Tensor* out) { - if (out->dim_size(0) == 1) { - if (dim_pair[0].second == 0) { - // Note: this case is optimized in Eigen Tensors. - return false; - } else { - auto out_v = ToEigenVector<T>(out); - auto a_v = ToEigenVector<T>(a); - auto b_m = ToEigenMatrix<T>(b); - out_v.noalias() = b_m * a_v; - } - return true; - } else if (out->dim_size(1) == 1) { - auto out_v = ToEigenVector<T>(out); - auto a_m = ToEigenMatrix<T>(a); - auto b_v = ToEigenVector<T>(b); - if (dim_pair[0].first == 0) { - out_v.noalias() = a_m.transpose() * b_v; - } else { - out_v.noalias() = a_m * b_v; - } - return true; - } - return false; -} -// Half is not supported. -template <> -bool ExplicitVectorMatrixOptimization<Eigen::half>( - const Tensor& a, const Tensor& b, - const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair, - Tensor* out) { - return false; -} - -template <typename Device, typename T> -struct LaunchMatMulBase { -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM - typedef se::blas::AlgorithmType AlgorithmType; -#else - typedef int64 AlgorithmType; -#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM - - static void launch( - OpKernelContext* ctx, const Tensor& a, const Tensor& b, - const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair, - std::vector<AlgorithmType>* algorithms, bool use_autotune, Tensor* out) { - // An explicit vector-matrix multiply is much better optimized than an - // implicit one and this is a bottleneck during non-batched inference. - bool was_vector = ExplicitVectorMatrixOptimization<T>(a, b, dim_pair, out); - if (!was_vector) { - functor::MatMulFunctor<Device, T>()(ctx->eigen_device<Device>(), - out->matrix<T>(), a.matrix<T>(), - b.matrix<T>(), dim_pair); - } - } - - static void GetBlasGemmAlgorithm(OpKernelConstruction* ctx, - std::vector<int64>* algorithms, - bool* algorithm_set_flag) {} -}; -// On CPUs, we ignore USE_CUBLAS -template <typename T> -struct LaunchMatMulCPU : LaunchMatMulBase<CPUDevice, T> {}; - -template <typename T, bool USE_CUBLAS> -struct LaunchMatMul<CPUDevice, T, USE_CUBLAS> : public LaunchMatMulCPU<T> {}; - - -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM - -namespace { - -template <typename T> -struct LaunchBlasGemv { - static void Compute(OpKernelContext* ctx, se::Stream* stream, bool trans, - uint64 m, uint64 n, const se::DeviceMemory<T>& a, - const se::DeviceMemory<T>& b, se::DeviceMemory<T>* c, - se::blas::ProfileResult* output_profile) { - const auto blas_trans = trans ? se::blas::Transpose::kTranspose - : se::blas::Transpose::kNoTranspose; - if (output_profile == nullptr) { - bool blas_launch_status = - stream - ->ThenBlasGemv(blas_trans, m, n, static_cast<T>(1.0), a, m, b, 1, - static_cast<T>(0.0), c, 1) - .ok(); - if (!blas_launch_status) { - ctx->SetStatus( - errors::Internal("Blas GEMV launch failed: m=", m, ", n=", n)); - } - } else { - bool blas_launch_status = - stream - ->ThenBlasGemvWithProfiling(blas_trans, m, n, static_cast<T>(1.0), - a, m, b, 1, static_cast<T>(0.0), c, 1, - output_profile) - .ok(); - if (!blas_launch_status) { - ctx->SetStatus(errors::Internal( - "Blas GEMV with profiling launch failed: m=", m, ", n=", n)); - } - } - } - - static bool IsSupported() { return true; } -}; - -template <> -void LaunchBlasGemv<Eigen::half>::Compute( - OpKernelContext* ctx, se::Stream* stream, bool trans, uint64 m, uint64 n, - const se::DeviceMemory<Eigen::half>& a, - const se::DeviceMemory<Eigen::half>& b, se::DeviceMemory<Eigen::half>* c, - se::blas::ProfileResult* output_profile) { - ctx->SetStatus(errors::Internal( - "Blas GEMV launch failed: GEMV is not implemented for float16.")); -} - -template <> -bool LaunchBlasGemv<Eigen::half>::IsSupported() { - return false; -} - -template <typename T> -bool ShouldUseGemv(uint64 n) { - return (LaunchBlasGemv<T>::IsSupported() && n == 1); -} - -} // namespace - -bool GetCublasAutotuneComputationType(const DataType& dtype, - se::blas::ComputationType* compute_type) { - using se::blas::ComputationType; - switch (dtype) { - case DT_HALF: - case DT_BFLOAT16: - static bool use_f32_for_f16_computation = - MatmulDoFP32ComputationFP16Input(); - if (use_f32_for_f16_computation) { - *compute_type = ComputationType::kF32; - } else { - *compute_type = ComputationType::kF16; - } - return false; - case DT_FLOAT: - *compute_type = ComputationType::kF32; - return true; - case DT_DOUBLE: - *compute_type = ComputationType::kF64; - return true; - default: - // Unsupported compute_type, return false. - return false; - } -} - -// A dummy type to group matmul autotune results together. -struct MatmulAutoTuneGroup { - static string name() { return "Matmul"; } -}; -typedef AutoTuneSingleton<MatmulAutoTuneGroup, MatmulParameters, - se::blas::AlgorithmConfig> - AutoTuneMatmul; - -template <typename T> -struct LaunchMatMul<GPUDevice, T, true /* USE_CUBLAS */> { - static void launch( - OpKernelContext* ctx, const Tensor& a, const Tensor& b, - const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair, - std::vector<int64>* algorithms, bool use_autotune, Tensor* out) { - using se::blas::AlgorithmConfig; - using se::blas::ComputationType; - using se::blas::kDefaultAlgorithm; - using se::blas::kDefaultBlasGemm; - using se::blas::kDefaultBlasGemv; - using se::blas::kNoAlgorithm; - using se::blas::ProfileResult; - using se::blas::Transpose; - Transpose trans[] = {Transpose::kNoTranspose, Transpose::kTranspose}; - const uint64 m = a.dim_size(1 - dim_pair[0].first); - const uint64 k = a.dim_size(dim_pair[0].first); - const uint64 n = b.dim_size(1 - dim_pair[0].second); - bool transpose_a = dim_pair[0].first == 0; - bool transpose_b = dim_pair[0].second == 1; - auto blas_transpose_a = trans[transpose_a]; - auto blas_transpose_b = trans[transpose_b]; - - auto* stream = ctx->op_device_context()->stream(); - OP_REQUIRES(ctx, stream, errors::Internal("No GPU stream available.")); - - auto a_ptr = AsDeviceMemory(a.template flat<T>().data(), - a.template flat<T>().size()); - auto b_ptr = AsDeviceMemory(b.template flat<T>().data(), - b.template flat<T>().size()); - auto c_ptr = AsDeviceMemory(out->template flat<T>().data(), - out->template flat<T>().size()); - auto alpha = static_cast<T>(1.0); - auto beta = static_cast<T>(0.0); - - int device_id = stream->parent()->device_ordinal(); - DataType dtype = a.dtype(); - MatmulParameters matmul_parameters = { - transpose_a, transpose_b, m, n, k, dtype, device_id, - }; - AlgorithmConfig algorithm_config(kNoAlgorithm); - - ComputationType computation_type; - bool compute_type_supported = - GetCublasAutotuneComputationType(dtype, &computation_type); - if (use_autotune && compute_type_supported && !algorithms->empty()) { - ProfileResult best_result; - // TODO(yangzihao): Unify this code with conv autotuning. - if (!AutoTuneMatmul::GetInstance()->Find(matmul_parameters, - &algorithm_config)) { - ProfileResult profile_result; - for (auto profile_algorithm : (*algorithms)) { - // Cublas does - // C = A x B - // where A, B and C are assumed to be in column major. - // We want the output to be in row-major, so we can compute - // C' = B' x A' (' stands for transpose) - bool cublas_launch_status = - stream - ->ThenBlasGemmWithAlgorithm( - blas_transpose_b, blas_transpose_a, n, m, k, alpha, b_ptr, - transpose_b ? k : n, a_ptr, transpose_a ? m : k, beta, - &c_ptr, n, computation_type, profile_algorithm, - &profile_result) - .ok(); - if (cublas_launch_status) { - if (profile_result.is_valid()) { - if (profile_result.elapsed_time_in_ms() < - best_result.elapsed_time_in_ms()) { - best_result = profile_result; - } - } - } - } - // Try BlasGemmWithProfiling - bool cublas_launch_status = - stream - ->ThenBlasGemmWithProfiling( - blas_transpose_b, blas_transpose_a, n, m, k, 1.0, b_ptr, - transpose_b ? k : n, a_ptr, transpose_a ? m : k, 0.0, - &c_ptr, n, &profile_result) - .ok(); - if (cublas_launch_status) { - if (profile_result.is_valid()) { - if (profile_result.elapsed_time_in_ms() < - best_result.elapsed_time_in_ms()) { - best_result = profile_result; - } - } - } - // Try BlasGemvWithProfiling - if (ShouldUseGemv<T>(n)) { - LaunchBlasGemv<T>::Compute(ctx, stream, !transpose_a, - transpose_a ? m : k, transpose_a ? k : m, - a_ptr, b_ptr, &c_ptr, &profile_result); - if (profile_result.is_valid()) { - if (profile_result.elapsed_time_in_ms() < - best_result.elapsed_time_in_ms()) { - best_result = profile_result; - } - } - } - } - // We make sure that each matmul parameter set only gets one pass of - // autotune. If the best result is found, assign it to algorithm_type - // and insert it to autotune map. If all internal kernels of - // cublasGemmEx() returns invalid results, we add kNoAlgorithm to the - // autotune map. - if (best_result.is_valid()) { - algorithm_config.set_algorithm(best_result.algorithm()); - } - AutoTuneMatmul::GetInstance()->Insert(matmul_parameters, - algorithm_config); - if (algorithm_config.algorithm() != kNoAlgorithm && - algorithm_config.algorithm() != kDefaultBlasGemm && - algorithm_config.algorithm() != kDefaultBlasGemv) { - bool cublas_launch_status = - stream - ->ThenBlasGemmWithAlgorithm( - blas_transpose_b, blas_transpose_a, n, m, k, alpha, b_ptr, - transpose_b ? k : n, a_ptr, transpose_a ? m : k, beta, - &c_ptr, n, computation_type, algorithm_config.algorithm(), - nullptr) - .ok(); - if (!cublas_launch_status) { - ctx->SetStatus(errors::Internal( - "Blas GEMM with algorithm launch failed : a.shape=(", - a.dim_size(0), ", ", a.dim_size(1), "), b.shape=(", b.dim_size(0), - ", ", b.dim_size(1), "), m=", m, ", n=", n, ", k=", k)); - } - } - } - // For the following case, we use normal BlasGemm(): - // 1) We didn't set the use_autotune flag; - // 2) compute type does not support autotune; - // 3) no algorithm is found; - // 4) all internal kernels in autotune return invalid results. - // For the following case, we use normal BlasGemv(): - // 1) We didn't set the use_autotune flag but LaunchBlasGemv is supported - // and n == 1. - // 2) We set the use_autotune flag and it picked up BlasGemv() and set the - // algorithm_config.algorithm() to be kDefaultBlasGemv. - if (!use_autotune || !compute_type_supported || algorithms->empty() || - algorithm_config.algorithm() == kNoAlgorithm || - algorithm_config.algorithm() == kDefaultBlasGemm || - algorithm_config.algorithm() == kDefaultBlasGemv) { - if (algorithm_config.algorithm() == kDefaultBlasGemv || - ShouldUseGemv<T>(n)) { - // This is a matrix*vector multiply so use GEMV to compute A * b. - // Here we are multiplying in the natural order, so we have to flip - // the transposition flag to compensate for the tensor being stored - // row-major. - // TODO(yangzihao): Add Gemv as an autotuning option too. - LaunchBlasGemv<T>::Compute(ctx, stream, !transpose_a, - transpose_a ? m : k, transpose_a ? k : m, - a_ptr, b_ptr, &c_ptr, nullptr); - } else { - // Use C' = B' x A' (' stands for transpose) - bool blas_launch_status = - stream - ->ThenBlasGemm(blas_transpose_b, blas_transpose_a, n, m, k, - 1.0f, b_ptr, transpose_b ? k : n, a_ptr, - transpose_a ? m : k, 0.0f, &c_ptr, n) - .ok(); - if (!blas_launch_status) { - ctx->SetStatus(errors::Internal( - "Blas GEMM launch failed : a.shape=(", a.dim_size(0), ", ", - a.dim_size(1), "), b.shape=(", b.dim_size(0), ", ", b.dim_size(1), - "), m=", m, ", n=", n, ", k=", k)); - } - } - } - } - - static void GetBlasGemmAlgorithm(OpKernelConstruction* ctx, - std::vector<int64>* algorithms, - bool* algorithm_set_flag) { - if (*algorithm_set_flag == false) { - auto* stream = ctx->device()->tensorflow_gpu_device_info()->stream; - stream->parent()->GetBlasGemmAlgorithms(algorithms); - *algorithm_set_flag = true; - } - } -}; - -#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM - -template <typename Device, typename T, bool USE_CUBLAS> -class MatMulOp : public OpKernel { - public: - explicit MatMulOp(OpKernelConstruction* ctx) - : OpKernel(ctx), algorithms_set_already_(false) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_a", &transpose_a_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_b", &transpose_b_)); - - LaunchMatMul<Device, T, USE_CUBLAS>::GetBlasGemmAlgorithm( - ctx, &algorithms_, &algorithms_set_already_); - use_autotune_ = MatmulAutotuneEnable(); - } - - void Compute(OpKernelContext* ctx) override { - const Tensor& a = ctx->input(0); - const Tensor& b = ctx->input(1); - - // Check that the dimensions of the two matrices are valid. - OP_REQUIRES( - ctx, TensorShapeUtils::IsMatrix(a.shape()), - errors::InvalidArgument("In[0] is not a matrix. Instead it has shape ", - a.shape().DebugString())); - OP_REQUIRES( - ctx, TensorShapeUtils::IsMatrix(b.shape()), - errors::InvalidArgument("In[1] is not a matrix. Instead it has shape ", - b.shape().DebugString())); - Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair; - dim_pair[0].first = transpose_a_ ? 0 : 1; - dim_pair[0].second = transpose_b_ ? 1 : 0; - - OP_REQUIRES( - ctx, a.dim_size(dim_pair[0].first) == b.dim_size(dim_pair[0].second), - errors::InvalidArgument( - "Matrix size-incompatible: In[0]: ", a.shape().DebugString(), - ", In[1]: ", b.shape().DebugString())); - int a_dim_remaining = 1 - dim_pair[0].first; - int b_dim_remaining = 1 - dim_pair[0].second; - TensorShape out_shape( - {a.dim_size(a_dim_remaining), b.dim_size(b_dim_remaining)}); - Tensor* out = nullptr; - OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out)); - - if (out->NumElements() == 0) { - // If a has shape [0, x] or b has shape [x, 0], the output shape - // is a 0-element matrix, so there is nothing to do. - return; - } - - if (a.NumElements() == 0 && b.NumElements() == 0) { - // If a has shape [x, 0] and b has shape [0, y], the - // output shape is [x, y] where x and y are non-zero, so we fill - // the output with zeros. - functor::SetZeroFunctor<Device, T> f; - f(ctx->eigen_device<Device>(), out->flat<T>()); - return; - } - - if (std::is_same<T, bfloat16>::value) { - bool is_cpu = std::is_same<Device, CPUDevice>::value; - OP_REQUIRES(ctx, is_cpu, - errors::Internal("bfloat16 matmul is not supported by GPU")); - Tensor a_float, b_float, out_float; - OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, a.shape(), &a_float)); - OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, b.shape(), &b_float)); - OP_REQUIRES_OK(ctx, - ctx->allocate_temp(DT_FLOAT, out->shape(), &out_float)); - - // TODO: Avoid extra copy to make bfloat16 matmul efficient on CPU. - BFloat16ToFloat(a.flat<bfloat16>().data(), a_float.flat<float>().data(), - a.NumElements()); - BFloat16ToFloat(b.flat<bfloat16>().data(), b_float.flat<float>().data(), - b.NumElements()); - - LaunchMatMul<Device, float, USE_CUBLAS>::launch( - ctx, a_float, b_float, dim_pair, &algorithms_, use_autotune_, - &out_float); - FloatToBFloat16(out_float.flat<float>().data(), - out->flat<bfloat16>().data(), out->NumElements()); - } else { - LaunchMatMul<Device, T, USE_CUBLAS>::launch( - ctx, a, b, dim_pair, &algorithms_, use_autotune_, out); - } - } - - private: - std::vector<int64> algorithms_; - bool algorithms_set_already_; - bool use_autotune_; - bool transpose_a_; - bool transpose_b_; -}; - -namespace functor { - -// Partial specialization MatMulFunctor<Device=CPUDevice, T>. -template <typename T> -struct MatMulFunctor<CPUDevice, T> { - void operator()( - const CPUDevice& d, typename MatMulTypes<T>::out_type out, - typename MatMulTypes<T>::in_type in0, - typename MatMulTypes<T>::in_type in1, - const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair) { - MatMul<CPUDevice>(d, out, in0, in1, dim_pair); - } -}; - - -} // end namespace functor - -#define REGISTER_CPU_EIGEN(T) \ - REGISTER_KERNEL_BUILDER( \ - Name("MatMul").Device(DEVICE_CPU).TypeConstraint<T>("T").Label("eigen"), \ - MatMulOp<CPUDevice, T, false /* cublas, ignored for CPU */>); - -#define REGISTER_CPU(T) \ - REGISTER_KERNEL_BUILDER( \ - Name("MatMul").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ - MatMulOp<CPUDevice, T, false /* cublas, ignored for CPU */>); \ - REGISTER_CPU_EIGEN(T); - -#define REGISTER_GPU(T) \ - REGISTER_KERNEL_BUILDER( \ - Name("MatMul").Device(DEVICE_GPU).TypeConstraint<T>("T"), \ - MatMulOp<GPUDevice, T, true /* cublas, true by default */>); \ - REGISTER_KERNEL_BUILDER(Name("MatMul") \ - .Device(DEVICE_GPU) \ - .TypeConstraint<T>("T") \ - .Label("cublas"), \ - MatMulOp<GPUDevice, T, true /* cublas */>) - -TF_CALL_int32(REGISTER_CPU); -TF_CALL_int64(REGISTER_CPU); -TF_CALL_FLOAT_TYPES(REGISTER_CPU); -TF_CALL_COMPLEX_TYPES(REGISTER_CPU); - -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM -TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU); -TF_CALL_COMPLEX_TYPES(REGISTER_GPU); -#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM - -} // namespace tensorflow diff --git a/tensorflow/core/kernels/batch_matmul_op_complex.cc b/tensorflow/core/kernels/matmul_op_complex.cc similarity index 94% rename from tensorflow/core/kernels/batch_matmul_op_complex.cc rename to tensorflow/core/kernels/matmul_op_complex.cc index bc36b95d6a1..daec0220d10 100644 --- a/tensorflow/core/kernels/batch_matmul_op_complex.cc +++ b/tensorflow/core/kernels/matmul_op_complex.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/kernels/batch_matmul_op_impl.h" +#include "tensorflow/core/kernels/matmul_op_impl.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/batch_matmul_op_impl.h b/tensorflow/core/kernels/matmul_op_impl.h similarity index 88% rename from tensorflow/core/kernels/batch_matmul_op_impl.h rename to tensorflow/core/kernels/matmul_op_impl.h index d6cc980633f..4e29be5eb11 100644 --- a/tensorflow/core/kernels/batch_matmul_op_impl.h +++ b/tensorflow/core/kernels/matmul_op_impl.h @@ -15,8 +15,8 @@ limitations under the License. // See docs in ../ops/math_ops.cc. -#ifndef TENSORFLOW_CORE_KERNELS_BATCH_MATMUL_OP_IMPL_H_ -#define TENSORFLOW_CORE_KERNELS_BATCH_MATMUL_OP_IMPL_H_ +#ifndef TENSORFLOW_CORE_KERNELS_MATMUL_OP_IMPL_H_ +#define TENSORFLOW_CORE_KERNELS_MATMUL_OP_IMPL_H_ #define EIGEN_USE_THREADS @@ -633,10 +633,21 @@ struct LaunchBatchMatMul<GPUDevice, Eigen::half> { template <typename Device, typename Scalar> class BaseBatchMatMulOp : public OpKernel { public: - explicit BaseBatchMatMulOp(OpKernelConstruction* context) + explicit BaseBatchMatMulOp(OpKernelConstruction* context, + bool is_legacy_matmul) : OpKernel(context) { - OP_REQUIRES_OK(context, context->GetAttr("adj_x", &adj_x_)); - OP_REQUIRES_OK(context, context->GetAttr("adj_y", &adj_y_)); + if (is_legacy_matmul) { + // The old MatMul kernel has "transpose_a/transpose_b" attributes. + OP_REQUIRES_OK(context, context->GetAttr("transpose_a", &trans_x_)); + OP_REQUIRES_OK(context, context->GetAttr("transpose_b", &trans_y_)); + adj_x_ = false; + adj_y_ = false; + } else { + OP_REQUIRES_OK(context, context->GetAttr("adj_x", &adj_x_)); + OP_REQUIRES_OK(context, context->GetAttr("adj_y", &adj_y_)); + trans_x_ = false; + trans_y_ = false; + } } ~BaseBatchMatMulOp() override {} @@ -672,8 +683,8 @@ class BaseBatchMatMulOp : public OpKernel { in1_reshaped.CopyFrom(in1, TensorShape({bcast.y_batch_size(), d2, d3})), errors::Internal("Failed to reshape In[1] from ", in1.shape().DebugString())); - if (adj_x_) std::swap(d0, d1); - if (adj_y_) std::swap(d2, d3); + if (adj_x_ || trans_x_) std::swap(d0, d1); + if (adj_y_ || trans_y_) std::swap(d2, d3); OP_REQUIRES(ctx, d1 == d2, errors::InvalidArgument( "In[0] mismatch In[1] shape: ", d1, " vs. ", d2, ": ", @@ -696,9 +707,36 @@ class BaseBatchMatMulOp : public OpKernel { out_reshaped.CopyFrom(*out, TensorShape({batch_size, d0, d3})), errors::Internal("Failed to reshape output from ", out->shape().DebugString())); - LaunchBatchMatMul<Device, Scalar>::Launch( - ctx, in0_reshaped, in1_reshaped, adj_x_, adj_y_, /*trans_x=*/false, - /*trans_y=*/false, bcast, &out_reshaped); + if (std::is_same<Scalar, bfloat16>::value) { + bool is_cpu = std::is_same<Device, CPUDevice>::value; + OP_REQUIRES(ctx, is_cpu, + errors::Internal("bfloat16 matmul is not supported by GPU")); + Tensor in0_reshaped_float, in1_reshaped_float, out_reshaped_float; + OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, in0_reshaped.shape(), + &in0_reshaped_float)); + OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, in1_reshaped.shape(), + &in1_reshaped_float)); + OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, out_reshaped.shape(), + &out_reshaped_float)); + + // TODO: Avoid extra copy to make bfloat16 matmul efficient on CPU. + BFloat16ToFloat(in0_reshaped.flat<bfloat16>().data(), + in0_reshaped_float.flat<float>().data(), + in0_reshaped.NumElements()); + BFloat16ToFloat(in1_reshaped.flat<bfloat16>().data(), + in1_reshaped_float.flat<float>().data(), + in1_reshaped.NumElements()); + + LaunchBatchMatMul<Device, float>::Launch( + ctx, in0_reshaped_float, in1_reshaped_float, adj_x_, adj_y_, trans_x_, + trans_y_, bcast, &out_reshaped_float); + FloatToBFloat16(out_reshaped_float.flat<float>().data(), + out_reshaped.flat<bfloat16>().data(), out->NumElements()); + } else { + LaunchBatchMatMul<Device, Scalar>::Launch(ctx, in0_reshaped, in1_reshaped, + adj_x_, adj_y_, trans_x_, + trans_y_, bcast, &out_reshaped); + } } protected: @@ -706,16 +744,19 @@ class BaseBatchMatMulOp : public OpKernel { const Tensor& in1) = 0; private: + // TODO(171979567) Make the ops take both adj and transpose attributes. bool adj_x_; bool adj_y_; + bool trans_x_; + bool trans_y_; }; // BatchMatMul Op implementation which disallows broadcasting. -template <typename Device, typename Scalar> +template <typename Device, typename Scalar, bool is_legacy_matmul = false> class BatchMatMulOp : public BaseBatchMatMulOp<Device, Scalar> { public: explicit BatchMatMulOp(OpKernelConstruction* context) - : BaseBatchMatMulOp<Device, Scalar>(context) {} + : BaseBatchMatMulOp<Device, Scalar>(context, is_legacy_matmul) {} ~BatchMatMulOp() override {} @@ -729,15 +770,21 @@ class BatchMatMulOp : public BaseBatchMatMulOp<Device, Scalar> { in0.shape().DebugString(), " vs. ", in1.shape().DebugString())); const int ndims = in0.dims(); - OP_REQUIRES( - ctx, ndims >= 2, - errors::InvalidArgument("In[0] and In[1] ndims must be >= 2: ", ndims)); - for (int i = 0; i < ndims - 2; ++i) { - OP_REQUIRES(ctx, in0.dim_size(i) == in1.dim_size(i), + if (is_legacy_matmul) { + OP_REQUIRES(ctx, ndims == 2, errors::InvalidArgument( - "In[0].dim(", i, ") and In[1].dim(", i, - ") must be the same: ", in0.shape().DebugString(), " vs ", - in1.shape().DebugString())); + "In[0] and In[1] ndims must be == 2: ", ndims)); + } else { + OP_REQUIRES(ctx, ndims >= 2, + errors::InvalidArgument( + "In[0] and In[1] ndims must be >= 2: ", ndims)); + for (int i = 0; i < ndims - 2; ++i) { + OP_REQUIRES(ctx, in0.dim_size(i) == in1.dim_size(i), + errors::InvalidArgument( + "In[0].dim(", i, ") and In[1].dim(", i, + ") must be the same: ", in0.shape().DebugString(), + " vs ", in1.shape().DebugString())); + } } } }; @@ -747,7 +794,8 @@ template <typename Device, typename Scalar> class BatchMatMulV2Op : public BaseBatchMatMulOp<Device, Scalar> { public: explicit BatchMatMulV2Op(OpKernelConstruction* context) - : BaseBatchMatMulOp<Device, Scalar>(context) {} + : BaseBatchMatMulOp<Device, Scalar>(context, + /* is_legacy_matmul= */ false) {} ~BatchMatMulV2Op() override {} @@ -771,7 +819,10 @@ class BatchMatMulV2Op : public BaseBatchMatMulOp<Device, Scalar> { BatchMatMulOp<CPUDevice, TYPE>); \ REGISTER_KERNEL_BUILDER( \ Name("BatchMatMulV2").Device(DEVICE_CPU).TypeConstraint<TYPE>("T"), \ - BatchMatMulV2Op<CPUDevice, TYPE>) + BatchMatMulV2Op<CPUDevice, TYPE>); \ + REGISTER_KERNEL_BUILDER( \ + Name("MatMul").Device(DEVICE_CPU).TypeConstraint<TYPE>("T"), \ + BatchMatMulOp<CPUDevice, TYPE, /* is_legacy_matmul=*/true>) #define REGISTER_BATCH_MATMUL_GPU(TYPE) \ REGISTER_KERNEL_BUILDER( \ @@ -779,8 +830,11 @@ class BatchMatMulV2Op : public BaseBatchMatMulOp<Device, Scalar> { BatchMatMulOp<GPUDevice, TYPE>); \ REGISTER_KERNEL_BUILDER( \ Name("BatchMatMulV2").Device(DEVICE_GPU).TypeConstraint<TYPE>("T"), \ - BatchMatMulV2Op<GPUDevice, TYPE>) + BatchMatMulV2Op<GPUDevice, TYPE>); \ + REGISTER_KERNEL_BUILDER( \ + Name("MatMul").Device(DEVICE_GPU).TypeConstraint<TYPE>("T"), \ + BatchMatMulOp<GPUDevice, TYPE, /* is_legacy_matmul=*/true>) } // namespace tensorflow -#endif // TENSORFLOW_CORE_KERNELS_BATCH_MATMUL_OP_IMPL_H_ +#endif // TENSORFLOW_CORE_KERNELS_MATMUL_OP_IMPL_H_ diff --git a/tensorflow/core/kernels/batch_matmul_op_real.cc b/tensorflow/core/kernels/matmul_op_real.cc similarity index 76% rename from tensorflow/core/kernels/batch_matmul_op_real.cc rename to tensorflow/core/kernels/matmul_op_real.cc index 30ec13e6b4d..34d4b8c57b4 100644 --- a/tensorflow/core/kernels/batch_matmul_op_real.cc +++ b/tensorflow/core/kernels/matmul_op_real.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/kernels/batch_matmul_op_impl.h" +#include "tensorflow/core/kernels/matmul_op_impl.h" #if GOOGLE_CUDA #include "third_party/gpus/cuda/include/cuda.h" @@ -21,17 +21,13 @@ limitations under the License. namespace tensorflow { -TF_CALL_float(REGISTER_BATCH_MATMUL_CPU); -TF_CALL_double(REGISTER_BATCH_MATMUL_CPU); -TF_CALL_half(REGISTER_BATCH_MATMUL_CPU); +TF_CALL_FLOAT_TYPES(REGISTER_BATCH_MATMUL_CPU); TF_CALL_int16(REGISTER_BATCH_MATMUL_CPU); TF_CALL_int32(REGISTER_BATCH_MATMUL_CPU); TF_CALL_int64(REGISTER_BATCH_MATMUL_CPU); #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM -TF_CALL_float(REGISTER_BATCH_MATMUL_GPU); -TF_CALL_double(REGISTER_BATCH_MATMUL_GPU); -TF_CALL_half(REGISTER_BATCH_MATMUL_GPU); +TF_CALL_GPU_NUMBER_TYPES(REGISTER_BATCH_MATMUL_GPU); #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM } // namespace tensorflow diff --git a/tensorflow/core/kernels/matmul_op_test.cc b/tensorflow/core/kernels/matmul_op_test.cc index aa4c8efb640..4f986e34acd 100644 --- a/tensorflow/core/kernels/matmul_op_test.cc +++ b/tensorflow/core/kernels/matmul_op_test.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/core/public/session.h" namespace tensorflow { +namespace { template <typename T> class FusedMatMulOpTest : public OpsTestBase { @@ -459,4 +460,230 @@ BM_Matmul(2000, 1, 2000, true, false); BM_Matmul(2000, 1, 2000, false, true); BM_Matmul(2000, 1, 2000, true, true); -} // end namespace tensorflow +// Benchmarks for batched matmul with broadcasting. +Node* BroadcastTo(Graph* g, Node* input, Node* shape) { + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), "BroadcastTo") + .Input(input) + .Input(shape) + .Finalize(g, &ret)); + return ret; +} + +Node* BatchMatmulV2(Graph* g, Node* in0, Node* in1, bool adj_x, bool adj_y) { + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), "BatchMatMulV2") + .Input(in0) + .Input(in1) + .Attr("adj_x", adj_x) + .Attr("adj_y", adj_y) + .Finalize(g, &ret)); + return ret; +} + +template <typename T> +static Graph* BatchMatmul(int b, int m, int k, int n, bool adjoint_a, + bool adjoint_b, DataType type) { + Graph* g = new Graph(OpRegistry::Global()); + Tensor in0(type, adjoint_a ? TensorShape({b, k, m}) : TensorShape({b, m, k})); + in0.flat<T>().setRandom(); + Tensor in1(type, adjoint_b ? TensorShape({b, n, k}) : TensorShape({b, k, n})); + in1.flat<T>().setRandom(); + test::graph::BatchMatmul(g, test::graph::Constant(g, in0), + test::graph::Constant(g, in1), adjoint_a, adjoint_b); + return g; +} + +template <typename T> +static Graph* BatchMatmulWithBroadcast(int b0, int b1, int m, int k, int n, + bool manual_broadcast, DataType type) { + Graph* g = new Graph(OpRegistry::Global()); + Tensor in0(type, TensorShape({b0, m, k})); + in0.flat<T>().setRandom(); + Tensor in1(type, TensorShape({b1, k, n})); + in1.flat<T>().setRandom(); + + Tensor broadcasted_in0_shape(DT_INT64, TensorShape({3})); + Tensor broadcasted_in1_shape(DT_INT64, TensorShape({3})); + + Node* in0_node = nullptr; + Node* in1_node = nullptr; + if (manual_broadcast) { + for (int i = 0; i < 3; ++i) { + auto vec0 = broadcasted_in0_shape.vec<int64>(); + auto vec1 = broadcasted_in1_shape.vec<int64>(); + vec0(i) = (i == 0 ? std::max(b0, b1) : in0.shape().dim_size(i)); + vec1(i) = (i == 0 ? std::max(b0, b1) : in1.shape().dim_size(i)); + } + in0_node = BroadcastTo(g, test::graph::Constant(g, in0), + test::graph::Constant(g, broadcasted_in0_shape)); + in1_node = BroadcastTo(g, test::graph::Constant(g, in1), + test::graph::Constant(g, broadcasted_in1_shape)); + } else { + in0_node = test::graph::Constant(g, in0); + in1_node = test::graph::Constant(g, in1); + } + + BatchMatmulV2(g, in0_node, in1_node, false, false); + return g; +} + +#define BM_BatchMatmulDev(B, M, K, N, TA, TB, T, TFTYPE, DEVICE) \ + static void \ + BM_BatchMatmul##_##B##_##M##_##K##_##N##_##TA##_##TB##_##TFTYPE##_##DEVICE( \ + int iters) { \ + testing::UseRealTime(); \ + testing::ItemsProcessed(static_cast<int64>(iters) * B * M * K * N * 2); \ + test::Benchmark(#DEVICE, BatchMatmul<T>(B, M, K, N, TA, TB, TFTYPE)) \ + .Run(iters); \ + } \ + BENCHMARK( \ + BM_BatchMatmul##_##B##_##M##_##K##_##N##_##TA##_##TB##_##TFTYPE##_##DEVICE); + +#define BM_BatchMatmul(B, M, K, N, TA, TB) \ + BM_BatchMatmulDev(B, M, K, N, TA, TB, float, DT_FLOAT, cpu); +// BM_BatchMatmulDev(B, M, K, N, TA, TB, std::complex<float>, DT_COMPLEX64, +// cpu); +// BM_BatchMatmulDev(B, M, K, N, TA, TB, float, DT_FLOAT, gpu); +/* Uncomment to enable benchmarks for double & complex types: */ +// BM_BatchMatmulDev(B, M, K, N, TA, TB, std::complex<float>, DT_COMPLEX64, +// gpu); +// BM_BatchMatmulDev(M, K, N, TA, TB, double, DT_DOUBLE, cpu); \ +// BM_BatchMatmulDev(M, K, N, TA, TB, std::complex<double>, DT_COMPLEX128, cpu); +// \ +// BM_BatchMatmulDev(M, K, N, TA, TB, double, DT_DOUBLE, gpu); \ +// BM_BatchMatmulDev(M, K, N, TA, TB, std::complex<double>, DT_COMPLEX128, gpu); + +// Macro arguments names: --------------------------------------------------- // +// B1: batch size of LHS +// B2: batch size of RHS +// M: outer dimension of LHS +// K: inner dimensions of LHS and RHS +// N: outer dimension of RHS +// MB: boolean indicating whether to use manual broadcasting +// T: C++ type of scalars (e.g. float, std::complex) +// TT: TensorFlow type of scalars (e.g. DT_FLOAT, DT_COMPLEX128 +// D: Device (e.g. cpu, gpu) +#define BM_BatchMatmulBCastDev(B1, B2, M, K, N, MB, T, TT, D) \ + static void \ + BM_BatchMatmulBCast##_##B1##_##B2##_##M##_##K##_##N##_##MB##_##TT##_##D( \ + int iters) { \ + testing::UseRealTime(); \ + testing::ItemsProcessed(static_cast<int64>(iters) * std::max(B1, B2) * M * \ + K * N * 2); \ + test::Benchmark(#D, BatchMatmulWithBroadcast<T>(B1, B2, M, K, N, MB, TT)) \ + .Run(iters); \ + } \ + BENCHMARK( \ + BM_BatchMatmulBCast##_##B1##_##B2##_##M##_##K##_##N##_##MB##_##TT##_##D); + +#define BM_BatchMatmulBCast(B1, B2, M, K, N, MB) \ + BM_BatchMatmulBCastDev(B1, B2, M, K, N, MB, float, DT_FLOAT, cpu); + +// Typical fully connected layers +BM_BatchMatmulBCast(1, 128, 1, 1024, 1024, true); +BM_BatchMatmulBCast(1, 128, 1, 1024, 1024, false); +BM_BatchMatmulBCast(128, 1, 1, 1024, 1024, true); +BM_BatchMatmulBCast(128, 1, 1, 1024, 1024, false); +BM_BatchMatmulBCast(1, 128, 128, 1024, 1024, true); +BM_BatchMatmulBCast(1, 128, 128, 1024, 1024, false); +BM_BatchMatmulBCast(128, 1, 128, 1024, 1024, true); +BM_BatchMatmulBCast(128, 1, 128, 1024, 1024, false); + +// Square matmul. +BM_BatchMatmulBCast(1, 128, 512, 512, 512, true); +BM_BatchMatmulBCast(1, 128, 512, 512, 512, false); +BM_BatchMatmulBCast(128, 1, 512, 512, 512, true); +BM_BatchMatmulBCast(128, 1, 512, 512, 512, false); +BM_BatchMatmulBCast(1, 128, 1024, 1024, 1024, true); +BM_BatchMatmulBCast(1, 128, 1024, 1024, 1024, false); +BM_BatchMatmulBCast(128, 1, 1024, 1024, 1024, true); +BM_BatchMatmulBCast(128, 1, 1024, 1024, 1024, false); + +// Matrix-vector multiplies. +BM_BatchMatmulBCast(1, 128, 10000, 200, 1, true); +BM_BatchMatmulBCast(1, 128, 10000, 200, 1, false); +BM_BatchMatmulBCast(128, 1, 10000, 200, 1, true); +BM_BatchMatmulBCast(128, 1, 10000, 200, 1, false); + +// Vector-matrix multiplies. +BM_BatchMatmulBCast(1, 128, 1, 200, 10000, true); +BM_BatchMatmulBCast(1, 128, 1, 200, 10000, false); +BM_BatchMatmulBCast(128, 1, 1, 200, 10000, true); +BM_BatchMatmulBCast(128, 1, 1, 200, 10000, false); + +// Typical fully connected layers +BM_BatchMatmul(1, 1, 1024, 1024, false, false); +BM_BatchMatmul(1, 8, 1024, 1024, false, false); +BM_BatchMatmul(1, 16, 1024, 1024, false, false); +BM_BatchMatmul(1, 128, 1024, 1024, false, false); +BM_BatchMatmul(2, 1, 1024, 1024, false, false); +BM_BatchMatmul(2, 8, 1024, 1024, false, false); +BM_BatchMatmul(2, 16, 1024, 1024, false, false); +BM_BatchMatmul(2, 128, 1024, 1024, false, false); +BM_BatchMatmul(8, 1, 1024, 1024, false, false); +BM_BatchMatmul(8, 8, 1024, 1024, false, false); +BM_BatchMatmul(8, 16, 1024, 1024, false, false); +BM_BatchMatmul(8, 128, 1024, 1024, false, false); +BM_BatchMatmul(32, 1, 1024, 1024, false, false); +BM_BatchMatmul(32, 8, 1024, 1024, false, false); +BM_BatchMatmul(32, 16, 1024, 1024, false, false); +BM_BatchMatmul(32, 128, 1024, 1024, false, false); + +// Square matmul. +BM_BatchMatmul(1, 32, 32, 32, false, false); +BM_BatchMatmul(1, 128, 128, 128, false, false); +BM_BatchMatmul(1, 256, 256, 256, false, false); +BM_BatchMatmul(1, 1024, 1024, 1024, false, false); +BM_BatchMatmul(1, 2048, 2048, 2048, false, false); +BM_BatchMatmul(2, 32, 32, 32, false, false); +BM_BatchMatmul(2, 128, 128, 128, false, false); +BM_BatchMatmul(2, 256, 256, 256, false, false); +BM_BatchMatmul(2, 1024, 1024, 1024, false, false); +BM_BatchMatmul(2, 2048, 2048, 2048, false, false); +BM_BatchMatmul(4, 32, 32, 32, false, false); +BM_BatchMatmul(4, 128, 128, 128, false, false); +BM_BatchMatmul(4, 256, 256, 256, false, false); +BM_BatchMatmul(4, 1024, 1024, 1024, false, false); +BM_BatchMatmul(4, 2048, 2048, 2048, false, false); +BM_BatchMatmul(8, 32, 32, 32, false, false); +BM_BatchMatmul(8, 128, 128, 128, false, false); +BM_BatchMatmul(8, 256, 256, 256, false, false); +BM_BatchMatmul(8, 1024, 1024, 1024, false, false); +BM_BatchMatmul(8, 2048, 2048, 2048, false, false); +BM_BatchMatmul(32, 32, 32, 32, false, false); +BM_BatchMatmul(32, 128, 128, 128, false, false); +BM_BatchMatmul(32, 256, 256, 256, false, false); +BM_BatchMatmul(32, 1024, 1024, 1024, false, false); +BM_BatchMatmul(32, 2048, 2048, 2048, false, false); + +// Matrix-vector multiplies. +BM_BatchMatmul(1, 10000, 200, 1, false, false); +BM_BatchMatmul(8, 10000, 200, 1, false, false); +BM_BatchMatmul(32, 10000, 200, 1, false, false); +BM_BatchMatmul(1, 10000, 200, 1, true, false); +BM_BatchMatmul(8, 10000, 200, 1, true, false); +BM_BatchMatmul(32, 10000, 200, 1, true, false); +BM_BatchMatmul(1, 10000, 200, 1, false, true); +BM_BatchMatmul(8, 10000, 200, 1, false, true); +BM_BatchMatmul(32, 10000, 200, 1, false, true); +BM_BatchMatmul(1, 10000, 200, 1, true, true); +BM_BatchMatmul(8, 10000, 200, 1, true, true); +BM_BatchMatmul(32, 10000, 200, 1, true, true); + +// Vector-matrix multiplies. +BM_BatchMatmul(1, 1, 200, 10000, false, false); +BM_BatchMatmul(8, 1, 200, 10000, false, false); +BM_BatchMatmul(32, 1, 200, 10000, false, false); +BM_BatchMatmul(1, 1, 200, 10000, true, false); +BM_BatchMatmul(8, 1, 200, 10000, true, false); +BM_BatchMatmul(32, 1, 200, 10000, true, false); +BM_BatchMatmul(1, 1, 200, 10000, false, true); +BM_BatchMatmul(8, 1, 200, 10000, false, true); +BM_BatchMatmul(32, 1, 200, 10000, false, true); +BM_BatchMatmul(1, 1, 200, 10000, true, true); +BM_BatchMatmul(8, 1, 200, 10000, true, true); +BM_BatchMatmul(32, 1, 200, 10000, true, true); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/kernels/mkl/mkl_batch_matmul_op.cc b/tensorflow/core/kernels/mkl/mkl_batch_matmul_op.cc index fbba4116e3b..ec3f526bd9d 100644 --- a/tensorflow/core/kernels/mkl/mkl_batch_matmul_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_batch_matmul_op.cc @@ -33,8 +33,8 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/type_traits.h" #include "tensorflow/core/framework/types.h" -#include "tensorflow/core/kernels/batch_matmul_op_impl.h" #include "tensorflow/core/kernels/fill_functor.h" +#include "tensorflow/core/kernels/matmul_op_impl.h" #include "tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/isinf.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/isinf.mlir.tmpl index b8477acca4e..4be6db3aebc 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/isinf.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/isinf.mlir.tmpl @@ -1,5 +1,5 @@ func @Isinf_elem_type(%arg0: tensor<*xelem_type>) - -> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} { + -> tensor<*xi1> attributes {tf_entry, llvm.emit_c_interface} { %0 = "tf.IsInf"(%arg0) : (tensor<*xelem_type>) -> tensor<*xi1> return %0 : tensor<*xi1> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/isnan.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/isnan.mlir.tmpl index ac0c09d22d4..5610068935a 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/isnan.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/isnan.mlir.tmpl @@ -1,5 +1,5 @@ func @Isnan_elem_type(%arg0: tensor<*xelem_type>) - -> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} { + -> tensor<*xi1> attributes {tf_entry, llvm.emit_c_interface} { %0 = "tf.IsNan"(%arg0) : (tensor<*xelem_type>) -> tensor<*xi1> return %0 : tensor<*xi1> } diff --git a/tensorflow/core/kernels/multinomial_op_test.cc b/tensorflow/core/kernels/multinomial_op_test.cc index 25326ac5ecf..e1cc9d7dcd3 100644 --- a/tensorflow/core/kernels/multinomial_op_test.cc +++ b/tensorflow/core/kernels/multinomial_op_test.cc @@ -40,11 +40,15 @@ static Graph* Multinomial(int batch_size, int num_classes, int num_samples) { return g; } -#define BM_MultinomialDev(DEVICE, B, C, S) \ - static void BM_Multinomial_##DEVICE##_##B##_##C##_##S(int iters) { \ - test::Benchmark(#DEVICE, Multinomial(B, C, S)).Run(iters); \ - testing::ItemsProcessed(static_cast<int64>(B) * C * S * iters); \ - } \ +#define BM_MultinomialDev(DEVICE, B, C, S) \ + static void BM_Multinomial_##DEVICE##_##B##_##C##_##S( \ + ::testing::benchmark::State& state) { \ + test::Benchmark(#DEVICE, Multinomial(B, C, S), \ + /*old_benchmark_api*/ false) \ + .Run(state); \ + state.SetItemsProcessed(static_cast<int64>(B) * C * S * \ + state.iterations()); \ + } \ BENCHMARK(BM_Multinomial_##DEVICE##_##B##_##C##_##S); #define BM_MultinomialBCS(B, C, S) \ diff --git a/tensorflow/core/kernels/nn_ops_test.cc b/tensorflow/core/kernels/nn_ops_test.cc index ced97481ca9..bff83abc4aa 100644 --- a/tensorflow/core/kernels/nn_ops_test.cc +++ b/tensorflow/core/kernels/nn_ops_test.cc @@ -103,18 +103,18 @@ enum CONV_OP { } // namespace -static void BM_ConvFloat(int iters, int batch, int rows, int cols, int in_depth, - int out_depth, int filter_rows, int filter_cols, - CONV_OP op, int num_threads, int stride, - Padding padding, bool use_gpu, DataType data_type, +static void BM_ConvFloat(::testing::benchmark::State& state, int batch, + int rows, int cols, int in_depth, int out_depth, + int filter_rows, int filter_cols, CONV_OP op, + int num_threads, int stride, Padding padding, + bool use_gpu, DataType data_type, const string& label) { - testing::StopTiming(); if (!IsGoogleCudaEnabled() && use_gpu) { - testing::SetLabel( + state.SetLabel( strings::StrCat("Skipping GPU test (no --config=cuda): ", label)); return; } - testing::SetLabel(label); + state.SetLabel(label); // Set the number of threads SessionOptions options; @@ -221,10 +221,10 @@ static void BM_ConvFloat(int iters, int batch, int rows, int cols, int in_depth, TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph, g)); string device = use_gpu ? "gpu" : "cpu"; - testing::UseRealTime(); - testing::StartTiming(); - test::Benchmark(device, g, &options).Run(iters); - testing::ItemsProcessed(num_ops * iters); + test::Benchmark(device, g, &options, nullptr, nullptr, "", + /*old_benchmark_api*/ false) + .Run(state); + state.SetItemsProcessed(num_ops * state.iterations()); } // BS: batch_size @@ -235,48 +235,52 @@ static void BM_ConvFloat(int iters, int batch, int rows, int cols, int in_depth, // KR: kernel_rows // KC: kernel_cols #define BM_ConvFloatFwd(BS, R, C, ID, OD, KR, KC, STR, PAD, LABEL) \ - static void BM_ConvFloatFwdCPU1_##LABEL(int iters) { \ - BM_ConvFloat(iters, BS, R, C, ID, OD, KR, KC, CONV_OP_FORWARD, 1, STR, \ + static void BM_ConvFloatFwdCPU1_##LABEL( \ + ::testing::benchmark::State& state) { \ + BM_ConvFloat(state, BS, R, C, ID, OD, KR, KC, CONV_OP_FORWARD, 1, STR, \ PAD, false, DT_FLOAT, \ strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", OD, "_", \ KR, "_", KC, "_", STR, "_", PAD, "_f_cpu1")); \ } \ - static void BM_ConvFloatFwdCPU4_##LABEL(int iters) { \ - BM_ConvFloat(iters, BS, R, C, ID, OD, KR, KC, CONV_OP_FORWARD, 4, STR, \ + static void BM_ConvFloatFwdCPU4_##LABEL( \ + ::testing::benchmark::State& state) { \ + BM_ConvFloat(state, BS, R, C, ID, OD, KR, KC, CONV_OP_FORWARD, 4, STR, \ PAD, false, DT_FLOAT, \ strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", OD, "_", \ KR, "_", KC, "_", STR, "_", PAD, "_f_cpu4")); \ } \ - static void BM_ConvFloatFusedCPU1_##LABEL(int iters) { \ - BM_ConvFloat(iters, BS, R, C, ID, OD, KR, KC, CONV_OP_FUSED, 1, STR, PAD, \ + static void BM_ConvFloatFusedCPU1_##LABEL( \ + ::testing::benchmark::State& state) { \ + BM_ConvFloat(state, BS, R, C, ID, OD, KR, KC, CONV_OP_FUSED, 1, STR, PAD, \ false, DT_FLOAT, \ strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", OD, "_", \ KR, "_", KC, "_", STR, "_", PAD, "_f_cpu1")); \ } \ - static void BM_ConvFloatFusedCPU4_##LABEL(int iters) { \ - BM_ConvFloat(iters, BS, R, C, ID, OD, KR, KC, CONV_OP_FUSED, 4, STR, PAD, \ + static void BM_ConvFloatFusedCPU4_##LABEL( \ + ::testing::benchmark::State& state) { \ + BM_ConvFloat(state, BS, R, C, ID, OD, KR, KC, CONV_OP_FUSED, 4, STR, PAD, \ false, DT_FLOAT, \ strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", OD, "_", \ KR, "_", KC, "_", STR, "_", PAD, "_f_cpu4")); \ } \ - static void BM_ConvFloatFwdGPU_##LABEL(int iters) { \ - BM_ConvFloat(iters, BS, R, C, ID, OD, KR, KC, CONV_OP_FORWARD, 1, STR, \ + static void BM_ConvFloatFwdGPU_##LABEL(::testing::benchmark::State& state) { \ + BM_ConvFloat(state, BS, R, C, ID, OD, KR, KC, CONV_OP_FORWARD, 1, STR, \ PAD, true, DT_FLOAT, \ strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", OD, "_", \ KR, "_", KC, "_", STR, "_", PAD, "_f_gpu")); \ } \ - static void BM_ConvHalfFwdGPU_##LABEL(int iters) { \ - BM_ConvFloat(iters, BS, R, C, ID, OD, KR, KC, CONV_OP_FORWARD, 1, STR, \ + static void BM_ConvHalfFwdGPU_##LABEL(::testing::benchmark::State& state) { \ + BM_ConvFloat(state, BS, R, C, ID, OD, KR, KC, CONV_OP_FORWARD, 1, STR, \ PAD, true, DT_HALF, \ strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", OD, "_", \ KR, "_", KC, "_", STR, "_", PAD, "_h_gpu")); \ } \ - BENCHMARK(BM_ConvFloatFwdCPU1_##LABEL); \ - BENCHMARK(BM_ConvFloatFwdCPU4_##LABEL); \ - BENCHMARK(BM_ConvFloatFusedCPU1_##LABEL); \ - BENCHMARK(BM_ConvFloatFusedCPU4_##LABEL); \ - BENCHMARK(BM_ConvFloatFwdGPU_##LABEL); \ - BENCHMARK(BM_ConvHalfFwdGPU_##LABEL) + BENCHMARK(BM_ConvFloatFwdCPU1_##LABEL)->UseRealTime(); \ + BENCHMARK(BM_ConvFloatFwdCPU4_##LABEL)->UseRealTime(); \ + BENCHMARK(BM_ConvFloatFusedCPU1_##LABEL)->UseRealTime(); \ + BENCHMARK(BM_ConvFloatFusedCPU4_##LABEL)->UseRealTime(); \ + BENCHMARK(BM_ConvFloatFwdGPU_##LABEL)->UseRealTime(); \ + BENCHMARK(BM_ConvHalfFwdGPU_##LABEL)->UseRealTime() BM_ConvFloatFwd(32, 5, 5, 1248, 128, 1, 1, 1, SAME, conv0); BM_ConvFloatFwd(32, 8, 8, 384, 384, 1, 3, 1, SAME, conv1); @@ -334,63 +338,70 @@ BM_ConvFloatFwd(32, 73, 73, 64, 192, 3, 3, 1, VALID, conv52); BM_ConvFloatFwd(32, 73, 73, 64, 64, 1, 1, 1, VALID, conv53); BM_ConvFloatFwd(32, 147, 147, 24, 64, 1, 1, 1, VALID, conv54); -#define BM_ConvFloatBkInAndFilter(BS, R, C, ID, OD, KR, KC, STR, PAD, LABEL) \ - static void BM_ConvFloatBkInCPU1_##LABEL(int iters) { \ - BM_ConvFloat(iters, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_INPUT, 1, \ - STR, PAD, false, DT_FLOAT, \ - strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", OD, "_", \ - KR, "_", KC, "_", STR, "_", PAD, "_cpu1")); \ - } \ - static void BM_ConvFloatBkInCPU4_##LABEL(int iters) { \ - BM_ConvFloat(iters, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_INPUT, 4, \ - STR, PAD, false, DT_FLOAT, \ - strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", OD, "_", \ - KR, "_", KC, "_", STR, "_", PAD, "_cpu4")); \ - } \ - static void BM_ConvFloatBkInGPU_##LABEL(int iters) { \ - BM_ConvFloat(iters, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_INPUT, 1, \ - STR, PAD, true, DT_FLOAT, \ - strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", OD, "_", \ - KR, "_", KC, "_", STR, "_", PAD, "_gpu")); \ - } \ - static void BM_ConvFloatBkFilterCPU1_##LABEL(int iters) { \ - BM_ConvFloat(iters, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_FILTER, 1, \ - STR, PAD, false, DT_FLOAT, \ - strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", OD, "_", \ - KR, "_", KC, "_", STR, "_", PAD, "_cpu1")); \ - } \ - static void BM_ConvFloatBkFilterCPU4_##LABEL(int iters) { \ - BM_ConvFloat(iters, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_FILTER, 4, \ - STR, PAD, false, DT_FLOAT, \ - strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", OD, "_", \ - KR, "_", KC, "_", STR, "_", PAD, "_cpu4")); \ - } \ - static void BM_ConvFloatBkFilterGPU_##LABEL(int iters) { \ - BM_ConvFloat(iters, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_FILTER, 1, \ - STR, PAD, true, DT_FLOAT, \ - strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", OD, "_", \ - KR, "_", KC, "_", STR, "_", PAD, "_gpu")); \ - } \ - static void BM_ConvHalfBkInGPU_##LABEL(int iters) { \ - BM_ConvFloat(iters, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_INPUT, 1, \ - STR, PAD, true, DT_HALF, \ - strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", OD, "_", \ - KR, "_", KC, "_", STR, "_", PAD, "_gpu")); \ - } \ - static void BM_ConvHalfBkFilterGPU_##LABEL(int iters) { \ - BM_ConvFloat(iters, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_FILTER, 1, \ - STR, PAD, true, DT_HALF, \ - strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", OD, "_", \ - KR, "_", KC, "_", STR, "_", PAD, "_gpu")); \ - } \ - BENCHMARK(BM_ConvFloatBkInCPU1_##LABEL); \ - BENCHMARK(BM_ConvFloatBkInCPU4_##LABEL); \ - BENCHMARK(BM_ConvFloatBkInGPU_##LABEL); \ - BENCHMARK(BM_ConvFloatBkFilterCPU1_##LABEL); \ - BENCHMARK(BM_ConvFloatBkFilterCPU4_##LABEL); \ - BENCHMARK(BM_ConvFloatBkFilterGPU_##LABEL); \ - BENCHMARK(BM_ConvHalfBkInGPU_##LABEL); \ - BENCHMARK(BM_ConvHalfBkFilterGPU_##LABEL) +#define BM_ConvFloatBkInAndFilter(BS, R, C, ID, OD, KR, KC, STR, PAD, LABEL) \ + static void BM_ConvFloatBkInCPU1_##LABEL( \ + ::testing::benchmark::State& state) { \ + BM_ConvFloat(state, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_INPUT, 1, \ + STR, PAD, false, DT_FLOAT, \ + strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", OD, "_", \ + KR, "_", KC, "_", STR, "_", PAD, "_cpu1")); \ + } \ + static void BM_ConvFloatBkInCPU4_##LABEL( \ + ::testing::benchmark::State& state) { \ + BM_ConvFloat(state, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_INPUT, 4, \ + STR, PAD, false, DT_FLOAT, \ + strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", OD, "_", \ + KR, "_", KC, "_", STR, "_", PAD, "_cpu4")); \ + } \ + static void BM_ConvFloatBkInGPU_##LABEL( \ + ::testing::benchmark::State& state) { \ + BM_ConvFloat(state, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_INPUT, 1, \ + STR, PAD, true, DT_FLOAT, \ + strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", OD, "_", \ + KR, "_", KC, "_", STR, "_", PAD, "_gpu")); \ + } \ + static void BM_ConvFloatBkFilterCPU1_##LABEL( \ + ::testing::benchmark::State& state) { \ + BM_ConvFloat(state, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_FILTER, 1, \ + STR, PAD, false, DT_FLOAT, \ + strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", OD, "_", \ + KR, "_", KC, "_", STR, "_", PAD, "_cpu1")); \ + } \ + static void BM_ConvFloatBkFilterCPU4_##LABEL( \ + ::testing::benchmark::State& state) { \ + BM_ConvFloat(state, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_FILTER, 4, \ + STR, PAD, false, DT_FLOAT, \ + strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", OD, "_", \ + KR, "_", KC, "_", STR, "_", PAD, "_cpu4")); \ + } \ + static void BM_ConvFloatBkFilterGPU_##LABEL( \ + ::testing::benchmark::State& state) { \ + BM_ConvFloat(state, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_FILTER, 1, \ + STR, PAD, true, DT_FLOAT, \ + strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", OD, "_", \ + KR, "_", KC, "_", STR, "_", PAD, "_gpu")); \ + } \ + static void BM_ConvHalfBkInGPU_##LABEL(::testing::benchmark::State& state) { \ + BM_ConvFloat(state, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_INPUT, 1, \ + STR, PAD, true, DT_HALF, \ + strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", OD, "_", \ + KR, "_", KC, "_", STR, "_", PAD, "_gpu")); \ + } \ + static void BM_ConvHalfBkFilterGPU_##LABEL( \ + ::testing::benchmark::State& state) { \ + BM_ConvFloat(state, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_FILTER, 1, \ + STR, PAD, true, DT_HALF, \ + strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", OD, "_", \ + KR, "_", KC, "_", STR, "_", PAD, "_gpu")); \ + } \ + BENCHMARK(BM_ConvFloatBkInCPU1_##LABEL)->UseRealTime(); \ + BENCHMARK(BM_ConvFloatBkInCPU4_##LABEL)->UseRealTime(); \ + BENCHMARK(BM_ConvFloatBkInGPU_##LABEL)->UseRealTime(); \ + BENCHMARK(BM_ConvFloatBkFilterCPU1_##LABEL)->UseRealTime(); \ + BENCHMARK(BM_ConvFloatBkFilterCPU4_##LABEL)->UseRealTime(); \ + BENCHMARK(BM_ConvFloatBkFilterGPU_##LABEL)->UseRealTime(); \ + BENCHMARK(BM_ConvHalfBkInGPU_##LABEL)->UseRealTime(); \ + BENCHMARK(BM_ConvHalfBkFilterGPU_##LABEL)->UseRealTime() // Benchmarks from the inception model @@ -453,8 +464,8 @@ BM_ConvFloatBkInAndFilter(32, 147, 147, 24, 64, 1, 1, 1, VALID, conv54); #define BM_ConvFloatBkFCPU(BS, R, C, ID, OD, KR, KC, TH, LABEL) \ static void \ BM_ConvFloatBkFCPU_##BS##_##R##_##C##_##ID##_##OD##_##KR##_##KC##_##TH( \ - int iters) { \ - BM_ConvFloat(iters, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_FILTER, TH, \ + ::testing::benchmark::State& state) { \ + BM_ConvFloat(state, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_FILTER, TH, \ 1, VALID, false, DT_FLOAT, LABEL); \ } \ BENCHMARK( \ @@ -469,17 +480,19 @@ BM_ConvFloatBkFCPU(128, 13, 13, 384, 384, 3, 3, 4, "convnet-layer5"); #define BM_ConvFloatBkFGPU(BS, R, C, ID, OD, KR, KC, LABEL) \ static void BM_ConvFloatBkFGPU_##BS##_##R##_##C##_##ID##_##OD##_##KR##_##KC( \ - int iters) { \ - BM_ConvFloat(iters, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_FILTER, 1, \ + ::testing::benchmark::State& state) { \ + BM_ConvFloat(state, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_FILTER, 1, \ 1, VALID, true, DT_FLOAT, LABEL); \ } \ static void BM_ConvHalfBkFGPU_##BS##_##R##_##C##_##ID##_##OD##_##KR##_##KC( \ - int iters) { \ - BM_ConvFloat(iters, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_FILTER, 1, \ + ::testing::benchmark::State& state) { \ + BM_ConvFloat(state, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_FILTER, 1, \ 1, VALID, true, DT_HALF, LABEL); \ } \ - BENCHMARK(BM_ConvFloatBkFGPU_##BS##_##R##_##C##_##ID##_##OD##_##KR##_##KC); \ - BENCHMARK(BM_ConvHalfBkFGPU_##BS##_##R##_##C##_##ID##_##OD##_##KR##_##KC) + BENCHMARK(BM_ConvFloatBkFGPU_##BS##_##R##_##C##_##ID##_##OD##_##KR##_##KC) \ + ->UseRealTime(); \ + BENCHMARK(BM_ConvHalfBkFGPU_##BS##_##R##_##C##_##ID##_##OD##_##KR##_##KC) \ + ->UseRealTime() // Benchmarks from https://github.com/soumith/convnet-benchmarks BM_ConvFloatBkFGPU(128, 128, 128, 3, 96, 11, 11, "convnet-layer1"); @@ -498,19 +511,19 @@ enum DEPTHWISE_CONV_OP { } // namespace -static void BM_ConvFloatDepthwise(int iters, int batch, int rows, int cols, - int in_depth, int depth_multiplier, - int out_depth, int filter_rows, - int filter_cols, DEPTHWISE_CONV_OP op, - int num_threads, int stride, Padding padding, - bool use_gpu, const string& label) { - testing::StopTiming(); +static void BM_ConvFloatDepthwise(::testing::benchmark::State& state, int batch, + int rows, int cols, int in_depth, + int depth_multiplier, int out_depth, + int filter_rows, int filter_cols, + DEPTHWISE_CONV_OP op, int num_threads, + int stride, Padding padding, bool use_gpu, + const string& label) { if (!IsGoogleCudaEnabled() && use_gpu) { - testing::SetLabel( + state.SetLabel( strings::StrCat("Skipping GPU test (no --config=cuda): ", label)); return; } - testing::SetLabel(label); + state.SetLabel(label); // Set the number of threads SessionOptions options; @@ -603,10 +616,10 @@ static void BM_ConvFloatDepthwise(int iters, int batch, int rows, int cols, TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph, g)); string device = use_gpu ? "gpu" : "cpu"; - testing::UseRealTime(); - testing::StartTiming(); - test::Benchmark(device, g, &options).Run(iters); - testing::ItemsProcessed(num_ops * iters); + test::Benchmark(device, g, &options, nullptr, nullptr, "", + /*old_benchmark_api=*/false) + .Run(state); + state.SetItemsProcessed(num_ops * state.iterations()); } // BS: batch_size @@ -622,30 +635,33 @@ static void BM_ConvFloatDepthwise(int iters, int batch, int rows, int cols, #define BM_ConvFloatDepthwiseFwd(BS, R, C, ID, DM, OD, KR, KC, STR, PAD, \ LABEL) \ - static void BM_ConvFloatDepthwiseFwdCPU1_##LABEL(int iters) { \ + static void BM_ConvFloatDepthwiseFwdCPU1_##LABEL( \ + ::testing::benchmark::State& state) { \ BM_ConvFloatDepthwise( \ - iters, BS, R, C, ID, DM, OD, KR, KC, DEPTHWISE_CONV_OP_FWD, 1, STR, \ + state, BS, R, C, ID, DM, OD, KR, KC, DEPTHWISE_CONV_OP_FWD, 1, STR, \ PAD, false, \ strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", DM, "_", OD, "_", \ KR, "_", KC, "_", STR, "_", PAD, "_cpu1")); \ } \ - static void BM_ConvFloatDepthwiseFwdCPU4_##LABEL(int iters) { \ + static void BM_ConvFloatDepthwiseFwdCPU4_##LABEL( \ + ::testing::benchmark::State& state) { \ BM_ConvFloatDepthwise( \ - iters, BS, R, C, ID, DM, OD, KR, KC, DEPTHWISE_CONV_OP_FWD, 4, STR, \ + state, BS, R, C, ID, DM, OD, KR, KC, DEPTHWISE_CONV_OP_FWD, 4, STR, \ PAD, false, \ strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", DM, "_", OD, "_", \ KR, "_", KC, "_", STR, "_", PAD, "_cpu4")); \ } \ - static void BM_ConvFloatDepthwiseFwdGPU_##LABEL(int iters) { \ + static void BM_ConvFloatDepthwiseFwdGPU_##LABEL( \ + ::testing::benchmark::State& state) { \ BM_ConvFloatDepthwise( \ - iters, BS, R, C, ID, DM, OD, KR, KC, DEPTHWISE_CONV_OP_FWD, 1, STR, \ + state, BS, R, C, ID, DM, OD, KR, KC, DEPTHWISE_CONV_OP_FWD, 1, STR, \ PAD, true, \ strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", DM, "_", OD, "_", \ KR, "_", KC, "_", STR, "_", PAD, "_gpu")); \ } \ - BENCHMARK(BM_ConvFloatDepthwiseFwdCPU1_##LABEL); \ - BENCHMARK(BM_ConvFloatDepthwiseFwdCPU4_##LABEL); \ - BENCHMARK(BM_ConvFloatDepthwiseFwdGPU_##LABEL); + BENCHMARK(BM_ConvFloatDepthwiseFwdCPU1_##LABEL)->UseRealTime(); \ + BENCHMARK(BM_ConvFloatDepthwiseFwdCPU4_##LABEL)->UseRealTime(); \ + BENCHMARK(BM_ConvFloatDepthwiseFwdGPU_##LABEL)->UseRealTime(); // The configurations below are mostly from mobilenet models. BM_ConvFloatDepthwiseFwd(32, 112, 112, 3, 8, 24, 3, 3, 1, SAME, conv0); @@ -662,53 +678,59 @@ BM_ConvFloatDepthwiseFwd(1, 100, 100, 72, 1, 72, 3, 3, 1, SAME, conv9); BM_ConvFloatDepthwiseFwd(1, 100, 100, 72, 1, 72, 5, 5, 1, SAME, conv10); #define BM_ConvFloatDepthwiseBk(BS, R, C, ID, DM, OD, KR, KC, STR, PAD, LABEL) \ - static void BM_ConvFloatDepthwiseBkInCPU1_##LABEL(int iters) { \ + static void BM_ConvFloatDepthwiseBkInCPU1_##LABEL( \ + ::testing::benchmark::State& state) { \ BM_ConvFloatDepthwise( \ - iters, BS, R, C, ID, DM, OD, KR, KC, DEPTHWISE_CONV_OP_BACKPROP_INPUT, \ + state, BS, R, C, ID, DM, OD, KR, KC, DEPTHWISE_CONV_OP_BACKPROP_INPUT, \ 1, STR, PAD, false, \ strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", DM, "_", OD, "_", \ KR, "_", KC, "_", STR, "_", PAD, "_cpu1")); \ } \ - static void BM_ConvFloatDepthwiseBkInCPU4_##LABEL(int iters) { \ + static void BM_ConvFloatDepthwiseBkInCPU4_##LABEL( \ + ::testing::benchmark::State& state) { \ BM_ConvFloatDepthwise( \ - iters, BS, R, C, ID, DM, OD, KR, KC, DEPTHWISE_CONV_OP_BACKPROP_INPUT, \ + state, BS, R, C, ID, DM, OD, KR, KC, DEPTHWISE_CONV_OP_BACKPROP_INPUT, \ 4, STR, PAD, false, \ strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", DM, "_", OD, "_", \ KR, "_", KC, "_", STR, "_", PAD, "_cpu4")); \ } \ - static void BM_ConvFloatDepthwiseBkInGPU_##LABEL(int iters) { \ + static void BM_ConvFloatDepthwiseBkInGPU_##LABEL( \ + ::testing::benchmark::State& state) { \ BM_ConvFloatDepthwise( \ - iters, BS, R, C, ID, DM, OD, KR, KC, DEPTHWISE_CONV_OP_BACKPROP_INPUT, \ + state, BS, R, C, ID, DM, OD, KR, KC, DEPTHWISE_CONV_OP_BACKPROP_INPUT, \ 4, STR, PAD, true, \ strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", DM, "_", OD, "_", \ KR, "_", KC, "_", STR, "_", PAD, "_gpu")); \ } \ - static void BM_ConvFloatDepthwiseBkFilterCPU1_##LABEL(int iters) { \ + static void BM_ConvFloatDepthwiseBkFilterCPU1_##LABEL( \ + ::testing::benchmark::State& state) { \ BM_ConvFloatDepthwise( \ - iters, BS, R, C, ID, DM, OD, KR, KC, \ + state, BS, R, C, ID, DM, OD, KR, KC, \ DEPTHWISE_CONV_OP_BACKPROP_FILTER, 1, STR, PAD, false, \ strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", DM, "_", OD, "_", \ KR, "_", KC, "_", STR, "_", PAD, "_cpu1")); \ } \ - static void BM_ConvFloatDepthwiseBkFilterCPU4_##LABEL(int iters) { \ + static void BM_ConvFloatDepthwiseBkFilterCPU4_##LABEL( \ + ::testing::benchmark::State& state) { \ BM_ConvFloatDepthwise( \ - iters, BS, R, C, ID, DM, OD, KR, KC, \ + state, BS, R, C, ID, DM, OD, KR, KC, \ DEPTHWISE_CONV_OP_BACKPROP_FILTER, 4, STR, PAD, false, \ strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", DM, "_", OD, "_", \ KR, "_", KC, "_", STR, "_", PAD, "_cpu4")); \ } \ - static void BM_ConvFloatDepthwiseBkFilterGPU_##LABEL(int iters) { \ + static void BM_ConvFloatDepthwiseBkFilterGPU_##LABEL( \ + ::testing::benchmark::State& state) { \ BM_ConvFloatDepthwise( \ - iters, BS, R, C, ID, DM, OD, KR, KC, \ + state, BS, R, C, ID, DM, OD, KR, KC, \ DEPTHWISE_CONV_OP_BACKPROP_FILTER, 4, STR, PAD, true, \ strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", DM, "_", OD, "_", \ KR, "_", KC, "_", STR, "_", PAD, "_gpu")); \ } \ - BENCHMARK(BM_ConvFloatDepthwiseBkInCPU1_##LABEL); \ - BENCHMARK(BM_ConvFloatDepthwiseBkInCPU4_##LABEL); \ - BENCHMARK(BM_ConvFloatDepthwiseBkInGPU_##LABEL); \ - BENCHMARK(BM_ConvFloatDepthwiseBkFilterCPU1_##LABEL); \ - BENCHMARK(BM_ConvFloatDepthwiseBkFilterCPU4_##LABEL); \ + BENCHMARK(BM_ConvFloatDepthwiseBkInCPU1_##LABEL)->UseRealTime(); \ + BENCHMARK(BM_ConvFloatDepthwiseBkInCPU4_##LABEL)->UseRealTime(); \ + BENCHMARK(BM_ConvFloatDepthwiseBkInGPU_##LABEL)->UseRealTime(); \ + BENCHMARK(BM_ConvFloatDepthwiseBkFilterCPU1_##LABEL)->UseRealTime(); \ + BENCHMARK(BM_ConvFloatDepthwiseBkFilterCPU4_##LABEL)->UseRealTime(); \ BENCHMARK(BM_ConvFloatDepthwiseBkFilterGPU_##LABEL) // The configurations below are mostly from mobilenet models. @@ -732,10 +754,9 @@ BM_ConvFloatDepthwiseBk(32, 112, 112, 8, 3, 24, 3, 3, 1, SAME, conv12); BM_ConvFloatDepthwiseBk(32, 112, 112, 12, 2, 24, 3, 3, 1, SAME, conv13); BM_ConvFloatDepthwiseBk(32, 112, 112, 24, 1, 24, 3, 3, 1, SAME, conv14); -static void BM_LRNFloat(int iters, int depth, int cols, int rows, - int batch_size, int range, int num_threads, +static void BM_LRNFloat(::testing::benchmark::State& state, int depth, int cols, + int rows, int batch_size, int range, int num_threads, const string& label) { - tensorflow::testing::StopTiming(); std::unique_ptr<Device> device( DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0")); @@ -778,26 +799,24 @@ static void BM_LRNFloat(int iters, int depth, int cols, int rows, std::unique_ptr<OpKernelContext> context(new OpKernelContext(¶ms)); op->Compute(context.get()); - testing::UseRealTime(); - tensorflow::testing::StartTiming(); - for (int i = 0; i < iters; ++i) { + for (auto s : state) { delete context->release_output(0).tensor; op->Compute(context.get()); } - tensorflow::testing::StopTiming(); - testing::ItemsProcessed(context->mutable_output(0)->NumElements() * iters * - (2 * range + 1) * 2); - testing::SetLabel(label); + state.SetItemsProcessed(context->mutable_output(0)->NumElements() * + state.iterations() * (2 * range + 1) * 2); + state.SetLabel(label); } #define BM_LRNFloatFwdCPU(DEPTH, COLS, ROWS, BATCH, RANGE, THREADS, LABEL) \ static void \ BM_LRNFloat_##DEPTH##_##COLS##_##ROWS##_##BATCH##_##RANGE##_##THREADS( \ - int iters) { \ - BM_LRNFloat(iters, DEPTH, COLS, ROWS, BATCH, RANGE, THREADS, LABEL); \ + ::testing::benchmark::State& state) { \ + BM_LRNFloat(state, DEPTH, COLS, ROWS, BATCH, RANGE, THREADS, LABEL); \ } \ BENCHMARK( \ - BM_LRNFloat_##DEPTH##_##COLS##_##ROWS##_##BATCH##_##RANGE##_##THREADS) + BM_LRNFloat_##DEPTH##_##COLS##_##ROWS##_##BATCH##_##RANGE##_##THREADS) \ + ->UseRealTime() // clang-format off // DEPTH, COLS, ROWS, BATCH, RANGE, THREADS, LABEL @@ -815,10 +834,10 @@ BM_LRNFloatFwdCPU(192, 56, 56, 32, 5, 8, "lrn 8 threads"); /* AvgPooling Op */ -static void BM_AvgPool(int iters, int batch_size, int rows, int cols, int depth, - int kernel_rows, int kernel_cols, int stride, - Padding padding, int num_threads, const string& label) { - tensorflow::testing::StopTiming(); +static void BM_AvgPool(::testing::benchmark::State& state, int batch_size, + int rows, int cols, int depth, int kernel_rows, + int kernel_cols, int stride, Padding padding, + int num_threads, const string& label) { std::unique_ptr<Device> device( DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0")); @@ -860,16 +879,13 @@ static void BM_AvgPool(int iters, int batch_size, int rows, int cols, int depth, new OpKernelContext(¶ms)); op->Compute(avgpool_context.get()); - testing::UseRealTime(); - tensorflow::testing::StartTiming(); - for (int i = 0; i < iters; ++i) { + for (auto s : state) { delete avgpool_context->release_output(0).tensor; op->Compute(avgpool_context.get()); } - tensorflow::testing::StopTiming(); - testing::ItemsProcessed(avgpool_context->mutable_output(0)->NumElements() * - iters); - testing::SetLabel(label); + state.SetItemsProcessed(avgpool_context->mutable_output(0)->NumElements() * + state.iterations()); + state.SetLabel(label); } // BS: batch_size @@ -883,11 +899,12 @@ static void BM_AvgPool(int iters, int batch_size, int rows, int cols, int depth, #define BM_AvgPoolFwdCPU(BS, IR, IC, ND, KR, KC, ST, PT, TH, LABEL) \ static void \ BM_AvgPool_##BS##_##IR##_##IC##_##ND##_##KR##_##KC##_##ST##_##PT##_##TH( \ - int iters) { \ - BM_AvgPool(iters, BS, IR, IC, ND, KR, KC, ST, PT, TH, LABEL); \ + ::testing::benchmark::State& state) { \ + BM_AvgPool(state, BS, IR, IC, ND, KR, KC, ST, PT, TH, LABEL); \ } \ BENCHMARK( \ - BM_AvgPool_##BS##_##IR##_##IC##_##ND##_##KR##_##KC##_##ST##_##PT##_##TH) + BM_AvgPool_##BS##_##IR##_##IC##_##ND##_##KR##_##KC##_##ST##_##PT##_##TH) \ + ->UseRealTime() // Labels are taken from the 2014-July-24 version of imagenet BM_AvgPoolFwdCPU(32, 112, 112, 64, 3, 3, 2, VALID, 1, "avgpool0_VALID"); @@ -907,11 +924,10 @@ BM_AvgPoolFwdCPU(32, 56, 56, 192, 3, 3, 2, SAME, 4, "avgpool1_SAME"); BM_AvgPoolFwdCPU(32, 28, 28, 352, 3, 3, 2, SAME, 4, "avgpool4_SAME"); BM_AvgPoolFwdCPU(32, 14, 14, 576, 3, 3, 2, SAME, 4, "avgpool10_SAME"); -static void BM_AvgPoolBk(int iters, int batch_size, int rows, int cols, - int depth, int kernel_rows, int kernel_cols, - int stride, Padding padding, int num_threads, - const string& label) { - tensorflow::testing::StopTiming(); +static void BM_AvgPoolBk(::testing::benchmark::State& state, int batch_size, + int rows, int cols, int depth, int kernel_rows, + int kernel_cols, int stride, Padding padding, + int num_threads, const string& label) { std::unique_ptr<Device> device( DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0")); @@ -966,16 +982,13 @@ static void BM_AvgPoolBk(int iters, int batch_size, int rows, int cols, new OpKernelContext(¶ms)); op->Compute(avgpool_context.get()); - testing::UseRealTime(); - tensorflow::testing::StartTiming(); - for (int i = 0; i < iters; ++i) { + for (auto s : state) { delete avgpool_context->release_output(0).tensor; op->Compute(avgpool_context.get()); } - tensorflow::testing::StopTiming(); - testing::ItemsProcessed(avgpool_context->mutable_output(0)->NumElements() * - iters); - testing::SetLabel(label); + state.SetItemsProcessed(avgpool_context->mutable_output(0)->NumElements() * + state.iterations()); + state.SetLabel(label); } // BS: batch_size @@ -987,14 +1000,17 @@ static void BM_AvgPoolBk(int iters, int batch_size, int rows, int cols, // ST: stride. We use the same stride for both directions. // PT: padding // The resulted symbol is too long. Need to use two macros to fit in 80-chars +// NOLINTBEGIN #define BM_AvgPoolBkCPU(BS, IR, IC, ND, KR, KC, ST, PT, TH, LABEL) \ static void \ BM_AvgPoolBk_##BS##_##IR##_##IC##_##ND##_##KR##_##KC##_##ST##_##PT##_##TH( \ - int iters) { \ - BM_AvgPoolBk(iters, BS, IR, IC, ND, KR, KC, ST, PT, TH, LABEL); \ + ::testing::benchmark::State& state) { \ + BM_AvgPoolBk(state, BS, IR, IC, ND, KR, KC, ST, PT, TH, LABEL); \ } \ BENCHMARK( \ - BM_AvgPoolBk_##BS##_##IR##_##IC##_##ND##_##KR##_##KC##_##ST##_##PT##_##TH) + BM_AvgPoolBk_##BS##_##IR##_##IC##_##ND##_##KR##_##KC##_##ST##_##PT##_##TH) \ + ->UseRealTime() +// NOLINTEND // Shapes taken from the 2015/05/16 inception model BM_AvgPoolBkCPU(32, 35, 35, 192, 3, 3, 1, SAME, 1, "avgpool_grad0_SAME"); @@ -1010,10 +1026,10 @@ BM_AvgPoolBkCPU(32, 8, 8, 2048, 8, 8, 1, VALID, 1, "avgpool_grad8_VALID"); /* MaxPooling Op */ -static void BM_MaxPool(int iters, int batch_size, int rows, int cols, int depth, - int kernel_rows, int kernel_cols, int stride, - Padding padding, int num_threads, const string& label) { - tensorflow::testing::StopTiming(); +static void BM_MaxPool(::testing::benchmark::State& state, int batch_size, + int rows, int cols, int depth, int kernel_rows, + int kernel_cols, int stride, Padding padding, + int num_threads, const string& label) { SessionOptions options; options.config.set_intra_op_parallelism_threads(num_threads); @@ -1057,16 +1073,13 @@ static void BM_MaxPool(int iters, int batch_size, int rows, int cols, int depth, new OpKernelContext(¶ms)); op->Compute(maxpool_context.get()); - testing::UseRealTime(); - tensorflow::testing::StartTiming(); - for (int i = 0; i < iters; ++i) { + for (auto s : state) { delete maxpool_context->release_output(0).tensor; op->Compute(maxpool_context.get()); } - tensorflow::testing::StopTiming(); - testing::ItemsProcessed(maxpool_context->mutable_output(0)->NumElements() * - iters); - testing::SetLabel(label); + state.SetItemsProcessed(maxpool_context->mutable_output(0)->NumElements() * + state.iterations()); + state.SetLabel(label); } // BS: batch_size @@ -1080,11 +1093,12 @@ static void BM_MaxPool(int iters, int batch_size, int rows, int cols, int depth, #define BM_MaxPoolFwdCPU(BS, IR, IC, ND, KR, KC, ST, PT, TH, LABEL) \ static void \ BM_MaxPool_##BS##_##IR##_##IC##_##ND##_##KR##_##KC##_##ST##_##PT##_##TH( \ - int iters) { \ - BM_MaxPool(iters, BS, IR, IC, ND, KR, KC, ST, PT, TH, LABEL); \ + ::testing::benchmark::State& state) { \ + BM_MaxPool(state, BS, IR, IC, ND, KR, KC, ST, PT, TH, LABEL); \ } \ BENCHMARK( \ - BM_MaxPool_##BS##_##IR##_##IC##_##ND##_##KR##_##KC##_##ST##_##PT##_##TH) + BM_MaxPool_##BS##_##IR##_##IC##_##ND##_##KR##_##KC##_##ST##_##PT##_##TH) \ + ->UseRealTime() // Labels are taken from the 2014-July-24 version of imagenet /* TODO XXX @@ -1106,10 +1120,10 @@ BM_MaxPoolFwdCPU(32, 56, 56, 192, 3, 3, 2, SAME, 4, "maxpool1_SAME"); BM_MaxPoolFwdCPU(32, 28, 28, 352, 3, 3, 2, SAME, 4, "maxpool4_SAME"); BM_MaxPoolFwdCPU(32, 14, 14, 576, 3, 3, 2, SAME, 4, "maxpool10_SAME"); -static void BM_MaxPoolBk(int iters, int batch_size, int rows, int cols, - int depth, int kernel_rows, int kernel_cols, - int stride, Padding padding, int num_threads, - bool use_gpu, const string& label) { +static void BM_MaxPoolBk(::testing::benchmark::State& state, int batch_size, + int rows, int cols, int depth, int kernel_rows, + int kernel_cols, int stride, Padding padding, + int num_threads, bool use_gpu, const string& label) { auto root = Scope::NewRootScope().ExitOnError(); int64 out_height, out_width, pad_rows, pad_cols; @@ -1138,11 +1152,11 @@ static void BM_MaxPoolBk(int iters, int batch_size, int rows, int cols, Graph* g = new Graph(OpRegistry::Global()); TF_CHECK_OK(root.ToGraph(g)); string device = use_gpu ? "gpu" : "cpu"; - testing::UseRealTime(); - test::Benchmark(device, g).Run(iters); + test::Benchmark(device, g, /*old_benchmark_api*/ false).Run(state); - testing::ItemsProcessed(batch_size * rows * cols * depth * iters); - testing::SetLabel(label); + state.SetItemsProcessed(batch_size * rows * cols * depth * + state.iterations()); + state.SetLabel(label); } // BS: batch_size @@ -1159,23 +1173,23 @@ static void BM_MaxPoolBk(int iters, int batch_size, int rows, int cols, static void \ BM_MaxPoolBk_GPU_##BS##_##IR##_##IC##_##ND##_##KR##_##KC##_##ST##_ \ ##PT##_##TH( \ - int iters) { \ - BM_MaxPoolBk(iters, BS, IR, IC, ND, KR, KC, ST, PT, TH, true, LABEL); \ + ::testing::benchmark::State& state) { \ + BM_MaxPoolBk(state, BS, IR, IC, ND, KR, KC, ST, PT, TH, true, LABEL); \ } \ BENCHMARK( \ BM_MaxPoolBk_GPU_##BS##_##IR##_##IC##_##ND##_##KR##_##KC##_##ST##_ \ - ##PT##_##TH) \ + ##PT##_##TH)->UseRealTime() #define BM_MaxPoolBkCPU(BS, IR, IC, ND, KR, KC, ST, PT, TH, LABEL) \ static void \ BM_MaxPoolBk_CPU_##BS##_##IR##_##IC##_##ND##_##KR##_##KC##_##ST##_ \ ##PT##_##TH( \ - int iters) { \ - BM_MaxPoolBk(iters, BS, IR, IC, ND, KR, KC, ST, PT, TH, false, LABEL); \ + ::testing::benchmark::State& state) { \ + BM_MaxPoolBk(state, BS, IR, IC, ND, KR, KC, ST, PT, TH, false, LABEL); \ } \ BENCHMARK( \ BM_MaxPoolBk_CPU_##BS##_##IR##_##IC##_##ND##_##KR##_##KC##_##ST##_ \ - ##PT##_##TH) + ##PT##_##TH)->UseRealTime() // clang-format on // Shapes taken from the 2015/05/16 inception model @@ -1195,9 +1209,9 @@ BM_MaxPoolBkCPU(32, 8, 8, 2048, 3, 3, 2, VALID, 1, "maxpool_grad4_VALID"); Relu Op Run benchmark with: */ -static void BM_ReluFloat(int iters, int batch_size, int rows, int cols, - int depth, int num_threads, const string& label) { - tensorflow::testing::StopTiming(); +static void BM_ReluFloat(::testing::benchmark::State& state, int batch_size, + int rows, int cols, int depth, int num_threads, + const string& label) { std::unique_ptr<Device> device( DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0")); @@ -1233,27 +1247,25 @@ static void BM_ReluFloat(int iters, int batch_size, int rows, int cols, std::unique_ptr<OpKernelContext> relu_context(new OpKernelContext(¶ms)); op->Compute(relu_context.get()); - testing::UseRealTime(); - tensorflow::testing::StartTiming(); - for (int i = 0; i < iters; ++i) { + for (auto s : state) { delete relu_context->release_output(0).tensor; op->Compute(relu_context.get()); } - tensorflow::testing::StopTiming(); - testing::ItemsProcessed(relu_context->mutable_output(0)->NumElements() * - iters); - testing::SetLabel(label); + state.SetItemsProcessed(relu_context->mutable_output(0)->NumElements() * + state.iterations()); + state.SetLabel(label); } // BS: batch_size // IR: input_rows // IC: input_cols // ND: node_depth -#define BM_Relu(BS, IR, IC, ND, TH, LABEL) \ - static void BM_ReluFloat_##BS##_##IR##_##IC##_##ND##_##TH(int iters) { \ - BM_ReluFloat(iters, BS, IR, IC, ND, TH, LABEL); \ - } \ - BENCHMARK(BM_ReluFloat_##BS##_##IR##_##IC##_##ND##_##TH) +#define BM_Relu(BS, IR, IC, ND, TH, LABEL) \ + static void BM_ReluFloat_##BS##_##IR##_##IC##_##ND##_##TH( \ + ::testing::benchmark::State& state) { \ + BM_ReluFloat(state, BS, IR, IC, ND, TH, LABEL); \ + } \ + BENCHMARK(BM_ReluFloat_##BS##_##IR##_##IC##_##ND##_##TH)->UseRealTime() BM_Relu(32, 112, 112, 64, 1, "relu0"); BM_Relu(32, 56, 56, 192, 1, "relu1"); @@ -1268,9 +1280,9 @@ BM_Relu(32, 14, 14, 576, 4, "relu10"); Softplus Op Run benchmark with: */ -static void BM_SoftplusFloat(int iters, int batch_size, int rows, int cols, - int depth, int num_threads, const string& label) { - tensorflow::testing::StopTiming(); +static void BM_SoftplusFloat(::testing::benchmark::State& state, int batch_size, + int rows, int cols, int depth, int num_threads, + const string& label) { std::unique_ptr<Device> device( DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0")); @@ -1307,27 +1319,25 @@ static void BM_SoftplusFloat(int iters, int batch_size, int rows, int cols, new OpKernelContext(¶ms)); op->Compute(softplus_context.get()); - testing::UseRealTime(); - tensorflow::testing::StartTiming(); - for (int i = 0; i < iters; ++i) { + for (auto s : state) { delete softplus_context->release_output(0).tensor; op->Compute(softplus_context.get()); } - tensorflow::testing::StopTiming(); - testing::ItemsProcessed(softplus_context->mutable_output(0)->NumElements() * - iters); - testing::SetLabel(label); + state.SetItemsProcessed(softplus_context->mutable_output(0)->NumElements() * + state.iterations()); + state.SetLabel(label); } // BS: batch_size // IR: input_rows // IC: input_cols // ND: node_depth -#define BM_Softplus(BS, IR, IC, ND, TH, LABEL) \ - static void BM_SoftplusFloat_##BS##_##IR##_##IC##_##ND##_##TH(int iters) { \ - BM_SoftplusFloat(iters, BS, IR, IC, ND, TH, LABEL); \ - } \ - BENCHMARK(BM_SoftplusFloat_##BS##_##IR##_##IC##_##ND##_##TH) +#define BM_Softplus(BS, IR, IC, ND, TH, LABEL) \ + static void BM_SoftplusFloat_##BS##_##IR##_##IC##_##ND##_##TH( \ + ::testing::benchmark::State& state) { \ + BM_SoftplusFloat(state, BS, IR, IC, ND, TH, LABEL); \ + } \ + BENCHMARK(BM_SoftplusFloat_##BS##_##IR##_##IC##_##ND##_##TH)->UseRealTime() BM_Softplus(32, 112, 112, 64, 1, "softplus0"); BM_Softplus(32, 56, 56, 192, 1, "softplus1"); @@ -1338,7 +1348,8 @@ BM_Softplus(32, 56, 56, 192, 4, "softplus1"); BM_Softplus(32, 28, 28, 352, 4, "softplus4"); BM_Softplus(32, 14, 14, 576, 4, "softplus10"); -static void BM_ImageNetSoftmaxFwd(int iters, int batch_size, int node_depth, +static void BM_ImageNetSoftmaxFwd(::testing::benchmark::State& state, + int batch_size, int node_depth, int num_threads, bool use_gpu, const string& label) { auto root = Scope::NewRootScope().ExitOnError(); @@ -1359,19 +1370,21 @@ static void BM_ImageNetSoftmaxFwd(int iters, int batch_size, int node_depth, opts.config.mutable_graph_options() ->mutable_optimizer_options() ->set_opt_level(OptimizerOptions::L0); - testing::UseRealTime(); - test::Benchmark(device, g, &opts).Run(iters); - testing::ItemsProcessed(batch_size * node_depth * iters); - testing::SetLabel(label); + test::Benchmark(device, g, &opts, nullptr, nullptr, "", + /*old_benchmark_api*/ false) + .Run(state); + state.SetItemsProcessed(batch_size * node_depth * state.iterations()); + state.SetLabel(label); } -#define BM_ImageNetSoftmaxFwd(BATCH_SIZE, NODE_DEPTH, TH, GPU, LABEL) \ - static void \ - BM_ImageNetSoftmaxFwd_##BATCH_SIZE##_##NODE_DEPTH##_##TH##_##GPU( \ - int iters) { \ - BM_ImageNetSoftmaxFwd(iters, BATCH_SIZE, NODE_DEPTH, TH, GPU, LABEL); \ - } \ - BENCHMARK(BM_ImageNetSoftmaxFwd_##BATCH_SIZE##_##NODE_DEPTH##_##TH##_##GPU) +#define BM_ImageNetSoftmaxFwd(BATCH_SIZE, NODE_DEPTH, TH, GPU, LABEL) \ + static void \ + BM_ImageNetSoftmaxFwd_##BATCH_SIZE##_##NODE_DEPTH##_##TH##_##GPU( \ + ::testing::benchmark::State& state) { \ + BM_ImageNetSoftmaxFwd(state, BATCH_SIZE, NODE_DEPTH, TH, GPU, LABEL); \ + } \ + BENCHMARK(BM_ImageNetSoftmaxFwd_##BATCH_SIZE##_##NODE_DEPTH##_##TH##_##GPU) \ + ->UseRealTime() // Labels are taken from the 2014-July-24 version of imagenet BM_ImageNetSoftmaxFwd(32, 1008, 1, false, "softmax32"); @@ -1383,9 +1396,8 @@ BM_ImageNetSoftmaxFwd(128, 1008, 1, true, "softmax128"); BM_ImageNetSoftmaxFwd(8192, 1024, 1, true, "softmax32"); BM_ImageNetSoftmaxFwd(8192, 32768, 1, true, "softmax128"); -static void BM_TopK(int iters, int rows, int cols, int k, int num_threads, - bool use_gpu, const string& label) { - testing::StopTiming(); +static void BM_TopK(::testing::benchmark::State& state, int rows, int cols, + int k, int num_threads, bool use_gpu, const string& label) { auto root = Scope::NewRootScope().ExitOnError(); Tensor input(DT_FLOAT, TensorShape({rows, cols})); @@ -1407,28 +1419,30 @@ static void BM_TopK(int iters, int rows, int cols, int k, int num_threads, opts.config.mutable_graph_options() ->mutable_optimizer_options() ->set_opt_level(OptimizerOptions::L0); - testing::UseRealTime(); - testing::StartTiming(); - test::Benchmark(device, g, &opts).Run(iters); - testing::ItemsProcessed(rows * cols * iters); - testing::SetLabel(label); + test::Benchmark(device, g, &opts, nullptr, nullptr, "", + /*old_benchmark_api=*/false) + .Run(state); + state.SetItemsProcessed(rows * cols * state.iterations()); + state.SetLabel(label); } // IR: input_rows // IC: input_cols // IK: k // TH: number of threads -#define BM_TopKGPU(IR, IC, IK, TH, LABEL) \ - static void BM_TopK_GPU_##IR##_##IC##_##IK##_##TH(int iters) { \ - BM_TopK(iters, IR, IC, IK, TH, true, LABEL); \ - } \ - BENCHMARK(BM_TopK_GPU_##IR##_##IC##_##IK##_##TH) +#define BM_TopKGPU(IR, IC, IK, TH, LABEL) \ + static void BM_TopK_GPU_##IR##_##IC##_##IK##_##TH( \ + ::testing::benchmark::State& state) { \ + BM_TopK(state, IR, IC, IK, TH, true, LABEL); \ + } \ + BENCHMARK(BM_TopK_GPU_##IR##_##IC##_##IK##_##TH)->UseRealTime() -#define BM_TopKCPU(IR, IC, IK, TH, LABEL) \ - static void BM_TopK_CPU_##IR##_##IC##_##IK##_##TH(int iters) { \ - BM_TopK(iters, IR, IC, IK, TH, false, LABEL); \ - } \ - BENCHMARK(BM_TopK_CPU_##IR##_##IC##_##IK##_##TH) +#define BM_TopKCPU(IR, IC, IK, TH, LABEL) \ + static void BM_TopK_CPU_##IR##_##IC##_##IK##_##TH( \ + ::testing::benchmark::State& state) { \ + BM_TopK(state, IR, IC, IK, TH, false, LABEL); \ + } \ + BENCHMARK(BM_TopK_CPU_##IR##_##IC##_##IK##_##TH)->UseRealTime() // clang-format on diff --git a/tensorflow/core/kernels/one_hot_op_test.cc b/tensorflow/core/kernels/one_hot_op_test.cc index 95a9ea11a06..bf50c62fc07 100644 --- a/tensorflow/core/kernels/one_hot_op_test.cc +++ b/tensorflow/core/kernels/one_hot_op_test.cc @@ -56,9 +56,13 @@ static Graph* OneHot(int batch_size, int num_classes, int axis) { } #define BM_OneHot(BATCH, CLASS, AXIS, DEVICE) \ - static void BM_OneHot##_##BATCH##_##CLASS##_##AXIS##_##DEVICE(int iters) { \ - testing::ItemsProcessed(static_cast<int64>(iters) * BATCH * CLASS); \ - test::Benchmark(#DEVICE, OneHot(BATCH, CLASS, AXIS)).Run(iters); \ + static void BM_OneHot##_##BATCH##_##CLASS##_##AXIS##_##DEVICE( \ + ::testing::benchmark::State& state) { \ + test::Benchmark(#DEVICE, OneHot(BATCH, CLASS, AXIS), \ + /*old_benchmark_api*/ false) \ + .Run(state); \ + state.SetItemsProcessed(static_cast<int64>(state.iterations()) * BATCH * \ + CLASS); \ } \ BENCHMARK(BM_OneHot##_##BATCH##_##CLASS##_##AXIS##_##DEVICE); diff --git a/tensorflow/core/kernels/parameterized_truncated_normal_op_test.cc b/tensorflow/core/kernels/parameterized_truncated_normal_op_test.cc index 07f2f75ca5a..4180cfba0d3 100644 --- a/tensorflow/core/kernels/parameterized_truncated_normal_op_test.cc +++ b/tensorflow/core/kernels/parameterized_truncated_normal_op_test.cc @@ -107,25 +107,34 @@ static Graph* PTruncatedNormalOneTail(int num_batches, int samples_per_batch) { return g; } -#define BM_PTruncatedNormalDev(DEVICE, B, S) \ - static void BM_PTruncatedNormal_##DEVICE##_##B##_##S(int iters) { \ - test::Benchmark(#DEVICE, PTruncatedNormal(B, S)).Run(iters); \ - testing::ItemsProcessed(static_cast<int64>(B) * S * iters); \ - } \ +#define BM_PTruncatedNormalDev(DEVICE, B, S) \ + static void BM_PTruncatedNormal_##DEVICE##_##B##_##S( \ + ::testing::benchmark::State& state) { \ + test::Benchmark(#DEVICE, PTruncatedNormal(B, S), \ + /*old_benchmark_api*/ false) \ + .Run(state); \ + state.SetItemsProcessed(static_cast<int64>(B) * S * state.iterations()); \ + } \ BENCHMARK(BM_PTruncatedNormal_##DEVICE##_##B##_##S); -#define BM_PTruncatedNormalDev_2SD(DEVICE, B, S) \ - static void BM_PTruncatedNormal_2SD_##DEVICE##_##B##_##S(int iters) { \ - test::Benchmark(#DEVICE, PTruncatedNormal2SD(B, S)).Run(iters); \ - testing::ItemsProcessed(static_cast<int64>(B) * S * iters); \ - } \ +#define BM_PTruncatedNormalDev_2SD(DEVICE, B, S) \ + static void BM_PTruncatedNormal_2SD_##DEVICE##_##B##_##S( \ + ::testing::benchmark::State& state) { \ + test::Benchmark(#DEVICE, PTruncatedNormal2SD(B, S), \ + /*old_benchmark_api*/ false) \ + .Run(state); \ + state.SetItemsProcessed(static_cast<int64>(B) * S * state.iterations()); \ + } \ BENCHMARK(BM_PTruncatedNormal_2SD_##DEVICE##_##B##_##S); -#define BM_PTruncatedNormalDev_OneTail(DEVICE, B, S) \ - static void BM_PTruncatedNormal_OneTail_##DEVICE##_##B##_##S(int iters) { \ - test::Benchmark(#DEVICE, PTruncatedNormalOneTail(B, S)).Run(iters); \ - testing::ItemsProcessed(static_cast<int64>(B) * S * iters); \ - } \ +#define BM_PTruncatedNormalDev_OneTail(DEVICE, B, S) \ + static void BM_PTruncatedNormal_OneTail_##DEVICE##_##B##_##S( \ + ::testing::benchmark::State& state) { \ + test::Benchmark(#DEVICE, PTruncatedNormalOneTail(B, S), \ + /*old_benchmark_api*/ false) \ + .Run(state); \ + state.SetItemsProcessed(static_cast<int64>(B) * S * state.iterations()); \ + } \ BENCHMARK(BM_PTruncatedNormal_OneTail_##DEVICE##_##B##_##S); BM_PTruncatedNormalDev(cpu, 1000, 1000); diff --git a/tensorflow/core/kernels/quantize_and_dequantize_op_test.cc b/tensorflow/core/kernels/quantize_and_dequantize_op_test.cc index 596ab13590a..a685c1ad0f8 100644 --- a/tensorflow/core/kernels/quantize_and_dequantize_op_test.cc +++ b/tensorflow/core/kernels/quantize_and_dequantize_op_test.cc @@ -759,15 +759,16 @@ TEST_F(QuantizeAndDequantizeTest, Invalid_range_given_V3) { << s; } -#define BM_SIMPLE_QUAN_DEQUAN(DEVICE) \ - static void BM_SIMPLE_QUAN_DEQUAN_##DEVICE(int iters) { \ - auto root = Scope::NewRootScope().ExitOnError(); \ - ops::QuantizeAndDequantizeV2(root, -3.5, -3.5, -3.5); \ - TF_CHECK_OK(root.status()); \ - Graph* g = new Graph(OpRegistry::Global()); \ - TF_CHECK_OK(root.ToGraph(g)); \ - test::Benchmark(#DEVICE, g).Run(iters); \ - } \ +#define BM_SIMPLE_QUAN_DEQUAN(DEVICE) \ + static void BM_SIMPLE_QUAN_DEQUAN_##DEVICE( \ + ::testing::benchmark::State& state) { \ + auto root = Scope::NewRootScope().ExitOnError(); \ + ops::QuantizeAndDequantizeV2(root, -3.5, -3.5, -3.5); \ + TF_CHECK_OK(root.status()); \ + Graph* g = new Graph(OpRegistry::Global()); \ + TF_CHECK_OK(root.ToGraph(g)); \ + test::Benchmark(#DEVICE, g, /*old_benchmark_api*/ false).Run(state); \ + } \ BENCHMARK(BM_SIMPLE_QUAN_DEQUAN_##DEVICE); BM_SIMPLE_QUAN_DEQUAN(cpu); diff --git a/tensorflow/core/kernels/quantized_concat_op_test.cc b/tensorflow/core/kernels/quantized_concat_op_test.cc index 2b7fd248e9e..09cb7f00bfd 100644 --- a/tensorflow/core/kernels/quantized_concat_op_test.cc +++ b/tensorflow/core/kernels/quantized_concat_op_test.cc @@ -248,9 +248,8 @@ void QuantizedConcatTest::TestSecondDim8Bit(float first_min, float first_max, // If <same_limits> is true, then both concatenated dimensions have the same // quantized range; otherwise, they are set to different values. template <typename T> -static void ConcatHelper(int iters, int concat_dimension, bool same_limits, - int dim2) { - testing::StopTiming(); +static void ConcatHelper(::testing::benchmark::State& state, + int concat_dimension, bool same_limits, int dim2) { Graph* g = new Graph(OpRegistry::Global()); DataType dt = DataTypeToEnum<T>::v(); @@ -278,61 +277,111 @@ static void ConcatHelper(int iters, int concat_dimension, bool same_limits, .Attr("T", dt) .Finalize(g, &node)); - testing::BytesProcessed(static_cast<int64>(iters) * + test::Benchmark("cpu", g, /*old_benchmark_api*/ false).Run(state); + state.SetBytesProcessed(static_cast<int64>(state.iterations()) * ((kDim1 * dim2) + (kDim1 * dim2)) * sizeof(T)); - testing::StartTiming(); - test::Benchmark("cpu", g).Run(iters); - testing::UseRealTime(); } -static void BM_QConcatDim0SameLimitQInt32(int iters, int dim2) { - ConcatHelper<qint32>(iters, 0 /* concat_dimension */, true /* same_limits */, +static void BM_QConcatDim0SameLimitQInt32(::testing::benchmark::State& state) { + const int dim2 = state.range(0); + + ConcatHelper<qint32>(state, 0 /* concat_dimension */, true /* same_limits */, dim2); } -static void BM_QConcatDim1SameLimitQInt32(int iters, int dim2) { - ConcatHelper<qint32>(iters, 1 /* concat_dimension */, true /* same_limits */, +static void BM_QConcatDim1SameLimitQInt32(::testing::benchmark::State& state) { + const int dim2 = state.range(0); + + ConcatHelper<qint32>(state, 1 /* concat_dimension */, true /* same_limits */, dim2); } -static void BM_QConcatDim0DifferLimitQInt32(int iters, int dim2) { - ConcatHelper<qint32>(iters, 0 /* concat_dimension */, false /* same_limits */, +static void BM_QConcatDim0DifferLimitQInt32( + ::testing::benchmark::State& state) { + const int dim2 = state.range(0); + + ConcatHelper<qint32>(state, 0 /* concat_dimension */, false /* same_limits */, dim2); } -static void BM_QConcatDim1DifferLimitQInt32(int iters, int dim2) { - ConcatHelper<qint32>(iters, 1 /* concat_dimension */, false /* same_limits */, +static void BM_QConcatDim1DifferLimitQInt32( + ::testing::benchmark::State& state) { + const int dim2 = state.range(0); + + ConcatHelper<qint32>(state, 1 /* concat_dimension */, false /* same_limits */, dim2); } -BENCHMARK(BM_QConcatDim0SameLimitQInt32)->Arg(1000)->Arg(20000)->Arg(100000); -BENCHMARK(BM_QConcatDim1SameLimitQInt32)->Arg(1000)->Arg(20000)->Arg(100000); -BENCHMARK(BM_QConcatDim0DifferLimitQInt32)->Arg(1000)->Arg(20000)->Arg(100000); -BENCHMARK(BM_QConcatDim1DifferLimitQInt32)->Arg(1000)->Arg(20000)->Arg(100000); +BENCHMARK(BM_QConcatDim0SameLimitQInt32) + ->UseRealTime() + ->Arg(1000) + ->Arg(20000) + ->Arg(100000); +BENCHMARK(BM_QConcatDim1SameLimitQInt32) + ->UseRealTime() + ->Arg(1000) + ->Arg(20000) + ->Arg(100000); +BENCHMARK(BM_QConcatDim0DifferLimitQInt32) + ->UseRealTime() + ->Arg(1000) + ->Arg(20000) + ->Arg(100000); +BENCHMARK(BM_QConcatDim1DifferLimitQInt32) + ->UseRealTime() + ->Arg(1000) + ->Arg(20000) + ->Arg(100000); -static void BM_QConcatDim0SameLimitQUint8(int iters, int dim2) { - ConcatHelper<qint32>(iters, 0 /* concat_dimension */, true /* same_limits */, +static void BM_QConcatDim0SameLimitQUint8(::testing::benchmark::State& state) { + const int dim2 = state.range(0); + + ConcatHelper<qint32>(state, 0 /* concat_dimension */, true /* same_limits */, dim2); } -static void BM_QConcatDim1SameLimitQUint8(int iters, int dim2) { - ConcatHelper<qint32>(iters, 1 /* concat_dimension */, true /* same_limits */, +static void BM_QConcatDim1SameLimitQUint8(::testing::benchmark::State& state) { + const int dim2 = state.range(0); + + ConcatHelper<qint32>(state, 1 /* concat_dimension */, true /* same_limits */, dim2); } -static void BM_QConcatDim0DifferLimitQUint8(int iters, int dim2) { - ConcatHelper<qint32>(iters, 0 /* concat_dimension */, false /* same_limits */, +static void BM_QConcatDim0DifferLimitQUint8( + ::testing::benchmark::State& state) { + const int dim2 = state.range(0); + + ConcatHelper<qint32>(state, 0 /* concat_dimension */, false /* same_limits */, dim2); } -static void BM_QConcatDim1DifferLimitQUint8(int iters, int dim2) { - ConcatHelper<qint32>(iters, 1 /* concat_dimension */, false /* same_limits */, +static void BM_QConcatDim1DifferLimitQUint8( + ::testing::benchmark::State& state) { + const int dim2 = state.range(0); + + ConcatHelper<qint32>(state, 1 /* concat_dimension */, false /* same_limits */, dim2); } -BENCHMARK(BM_QConcatDim0SameLimitQUint8)->Arg(1000)->Arg(20000)->Arg(100000); -BENCHMARK(BM_QConcatDim1SameLimitQUint8)->Arg(1000)->Arg(20000)->Arg(100000); -BENCHMARK(BM_QConcatDim0DifferLimitQUint8)->Arg(1000)->Arg(20000)->Arg(100000); -BENCHMARK(BM_QConcatDim1DifferLimitQUint8)->Arg(1000)->Arg(20000)->Arg(100000); +BENCHMARK(BM_QConcatDim0SameLimitQUint8) + ->UseRealTime() + ->Arg(1000) + ->Arg(20000) + ->Arg(100000); +BENCHMARK(BM_QConcatDim1SameLimitQUint8) + ->UseRealTime() + ->Arg(1000) + ->Arg(20000) + ->Arg(100000); +BENCHMARK(BM_QConcatDim0DifferLimitQUint8) + ->UseRealTime() + ->Arg(1000) + ->Arg(20000) + ->Arg(100000); +BENCHMARK(BM_QConcatDim1DifferLimitQUint8) + ->UseRealTime() + ->Arg(1000) + ->Arg(20000) + ->Arg(100000); } // namespace tensorflow diff --git a/tensorflow/core/kernels/random_binomial_op_test.cc b/tensorflow/core/kernels/random_binomial_op_test.cc index 9f8f47ef853..d3d090a47f3 100644 --- a/tensorflow/core/kernels/random_binomial_op_test.cc +++ b/tensorflow/core/kernels/random_binomial_op_test.cc @@ -67,32 +67,44 @@ static Graph* RandomBinomialRejComplement(int num_batches, return RandomBinomialGraph(100., 0.2, num_batches, samples_per_batch); } -#define BM_RandomBinomialInv(DEVICE, B, S) \ - static void BM_RandomBinomialInv_##DEVICE##_##B##_##S(int iters) { \ - test::Benchmark(#DEVICE, RandomBinomialInv(B, S)).Run(iters); \ - testing::ItemsProcessed(static_cast<int64>(B) * S * iters); \ - } \ +#define BM_RandomBinomialInv(DEVICE, B, S) \ + static void BM_RandomBinomialInv_##DEVICE##_##B##_##S( \ + ::testing::benchmark::State& state) { \ + test::Benchmark(#DEVICE, RandomBinomialInv(B, S), \ + /*old_benchmark_api*/ false) \ + .Run(state); \ + state.SetItemsProcessed(static_cast<int64>(B) * S * state.iterations()); \ + } \ BENCHMARK(BM_RandomBinomialInv_##DEVICE##_##B##_##S); -#define BM_RandomBinomialRej(DEVICE, B, S) \ - static void BM_RandomBinomialRej_##DEVICE##_##B##_##S(int iters) { \ - test::Benchmark(#DEVICE, RandomBinomialRej(B, S)).Run(iters); \ - testing::ItemsProcessed(static_cast<int64>(B) * S * iters); \ - } \ +#define BM_RandomBinomialRej(DEVICE, B, S) \ + static void BM_RandomBinomialRej_##DEVICE##_##B##_##S( \ + ::testing::benchmark::State& state) { \ + test::Benchmark(#DEVICE, RandomBinomialRej(B, S), \ + /*old_benchmark_api*/ false) \ + .Run(state); \ + state.SetItemsProcessed(static_cast<int64>(B) * S * state.iterations()); \ + } \ BENCHMARK(BM_RandomBinomialRej_##DEVICE##_##B##_##S); -#define BM_RandomBinomialInvComplement(DEVICE, B, S) \ - static void BM_RandomBinomialInvComplement_##DEVICE##_##B##_##S(int iters) { \ - test::Benchmark(#DEVICE, RandomBinomialInvComplement(B, S)).Run(iters); \ - testing::ItemsProcessed(static_cast<int64>(B) * S * iters); \ - } \ +#define BM_RandomBinomialInvComplement(DEVICE, B, S) \ + static void BM_RandomBinomialInvComplement_##DEVICE##_##B##_##S( \ + ::testing::benchmark::State& state) { \ + test::Benchmark(#DEVICE, RandomBinomialInvComplement(B, S), \ + /*old_benchmark_api*/ false) \ + .Run(state); \ + state.SetItemsProcessed(static_cast<int64>(B) * S * state.iterations()); \ + } \ BENCHMARK(BM_RandomBinomialInvComplement_##DEVICE##_##B##_##S); -#define BM_RandomBinomialRejComplement(DEVICE, B, S) \ - static void BM_RandomBinomialRejComplement_##DEVICE##_##B##_##S(int iters) { \ - test::Benchmark(#DEVICE, RandomBinomialRejComplement(B, S)).Run(iters); \ - testing::ItemsProcessed(static_cast<int64>(B) * S * iters); \ - } \ +#define BM_RandomBinomialRejComplement(DEVICE, B, S) \ + static void BM_RandomBinomialRejComplement_##DEVICE##_##B##_##S( \ + ::testing::benchmark::State& state) { \ + test::Benchmark(#DEVICE, RandomBinomialRejComplement(B, S), \ + /*old_benchmark_api*/ false) \ + .Run(state); \ + state.SetItemsProcessed(static_cast<int64>(B) * S * state.iterations()); \ + } \ BENCHMARK(BM_RandomBinomialRejComplement_##DEVICE##_##B##_##S); BM_RandomBinomialInv(cpu, 1000, 1000); diff --git a/tensorflow/core/kernels/random_op_test.cc b/tensorflow/core/kernels/random_op_test.cc index 47d94ad9028..e32ec11c9b3 100644 --- a/tensorflow/core/kernels/random_op_test.cc +++ b/tensorflow/core/kernels/random_op_test.cc @@ -58,11 +58,14 @@ Graph* TruncatedNormal(int64 n) { return g; } -#define BM_RNG(DEVICE, RNG) \ - void BM_##DEVICE##_##RNG(int iters, int arg) { \ - testing::ItemsProcessed(static_cast<int64>(iters) * arg); \ - test::Benchmark(#DEVICE, RNG(arg)).Run(iters); \ - } \ +#define BM_RNG(DEVICE, RNG) \ + void BM_##DEVICE##_##RNG(::testing::benchmark::State& state) { \ + const int arg = state.range(0); \ + \ + test::Benchmark(#DEVICE, RNG(arg), /*old_benchmark_api*/ false) \ + .Run(state); \ + state.SetItemsProcessed(static_cast<int64>(state.iterations()) * arg); \ + } \ BENCHMARK(BM_##DEVICE##_##RNG)->Range(1 << 20, 8 << 20); BM_RNG(cpu, RandomUniform); @@ -84,60 +87,48 @@ Tensor VecAlphas(int64 n) { return alphas; } -void BM_cpu_RandomGamma(int iters, int nsamp, int nalpha) { - testing::ItemsProcessed(static_cast<int64>(iters) * nsamp * nalpha); +void BM_cpu_RandomGamma(::testing::benchmark::State& state) { + const int nsamp = state.range(0); + const int nalpha = state.range(1); + Graph* g = new Graph(OpRegistry::Global()); test::graph::RandomGamma(g, test::graph::Constant(g, VecShape(nsamp)), test::graph::Constant(g, VecAlphas(nalpha))); - test::Benchmark("cpu", g).Run(iters); + test::Benchmark("cpu", g, /*old_benchmark_api*/ false).Run(state); + state.SetItemsProcessed(static_cast<int64>(state.iterations()) * nsamp * + nalpha); } BENCHMARK(BM_cpu_RandomGamma)->RangePair(1 << 14, 4 << 15, 2, 50); -void BM_PhiloxRandom(int iters) { +void BM_PhiloxRandom(::testing::benchmark::State& state) { // Fill 2M random numbers int count = 2 << 20; - - testing::ItemsProcessed(static_cast<int64>(iters) * count); - random::PhiloxRandom gen(0x12345); - int val = 1; - for (int i = 0; i < iters; ++i) { + for (auto s : state) { for (int j = 0; j < count; j += 4) { /// each invocation of gen() returns 128-bit samples auto samples = gen(); - - // use the result trivially so the compiler does not optimize it away - val ^= samples[0] ^ samples[1] ^ samples[2] ^ samples[3]; + tensorflow::testing::DoNotOptimize(samples); } } - - // A anchor point to make sure the compiler does not cut corners - CHECK(val) << val; + state.SetItemsProcessed(static_cast<int64>(state.iterations()) * count); } BENCHMARK(BM_PhiloxRandom); -void BM_StdMTRandom(int iters) { +void BM_StdMTRandom(::testing::benchmark::State& state) { // Fill 2M random numbers int count = 2 << 20; - - testing::ItemsProcessed(static_cast<int64>(iters) * count); - std::mt19937 gen(0x12345); - uint_fast32_t val = 1; - for (int i = 0; i < iters; ++i) { + for (auto s : state) { for (int j = 0; j < count; ++j) { /// each invocation of gen() returns 32-bit sample uint_fast32_t sample = gen(); - - // use the result trivially so the compiler does not optimize it away - val ^= sample; + tensorflow::testing::DoNotOptimize(sample); } } - - // A anchor point to make sure the compiler does not cut corners - CHECK(val) << val; + state.SetItemsProcessed(static_cast<int64>(state.iterations()) * count); } BENCHMARK(BM_StdMTRandom); diff --git a/tensorflow/core/kernels/reduction_ops_test.cc b/tensorflow/core/kernels/reduction_ops_test.cc index 359d7dbeca5..90666a77de6 100644 --- a/tensorflow/core/kernels/reduction_ops_test.cc +++ b/tensorflow/core/kernels/reduction_ops_test.cc @@ -84,108 +84,167 @@ static Graph* ThreeDXZReduce(const string& reduce, int num_y, int num_z) { // Creates a bench which reduces a 3D tensor with total "num" floats // into a scalar on a "device". Runs the bench for "iters" times. template <typename T> -static void ReduceToScalar(int iters, const string& device, - const string& reduce, int num_x, int num_y) { - testing::ItemsProcessed(static_cast<int64>(iters) * num_x * num_y); - testing::BytesProcessed(static_cast<int64>(iters) * num_x * num_y * - sizeof(T)); - test::Benchmark(device, ToScalar<T>(reduce, num_x, num_y)).Run(iters); +static void ReduceToScalar(::testing::benchmark::State& state, + const string& device, const string& reduce, + int num_x, int num_y) { + test::Benchmark(device, ToScalar<T>(reduce, num_x, num_y), + /*old_benchmark_api*/ false) + .Run(state); + state.SetItemsProcessed(static_cast<int64>(state.iterations()) * num_x * + num_y); + state.SetBytesProcessed(static_cast<int64>(state.iterations()) * num_x * + num_y * sizeof(T)); } -static void DoRowReduce(int iters, const string& device, const string& reduce, - int num_x, int num_y) { - testing::ItemsProcessed(static_cast<int64>(iters) * num_x * num_y); - testing::BytesProcessed(static_cast<int64>(iters) * num_x * num_y * - sizeof(float)); - test::Benchmark(device, RowReduce(reduce, num_x, num_y)).Run(iters); +static void DoRowReduce(::testing::benchmark::State& state, + const string& device, const string& reduce, int num_x, + int num_y) { + test::Benchmark(device, RowReduce(reduce, num_x, num_y), + /*old_benchmark_api*/ false) + .Run(state); + state.SetItemsProcessed(static_cast<int64>(state.iterations()) * num_x * + num_y); + state.SetBytesProcessed(static_cast<int64>(state.iterations()) * num_x * + num_y * sizeof(float)); } -static void DoColReduce(int iters, const string& device, const string& reduce, - int num_x, int num_y) { - testing::ItemsProcessed(static_cast<int64>(iters) * num_x * num_y); - testing::BytesProcessed(static_cast<int64>(iters) * num_x * num_y * - sizeof(float)); - test::Benchmark(device, ColReduce(reduce, num_x, num_y)).Run(iters); +static void DoColReduce(::testing::benchmark::State& state, + const string& device, const string& reduce, int num_x, + int num_y) { + test::Benchmark(device, ColReduce(reduce, num_x, num_y), + /*old_benchmark_api*/ false) + .Run(state); + state.SetItemsProcessed(static_cast<int64>(state.iterations()) * num_x * + num_y); + state.SetBytesProcessed(static_cast<int64>(state.iterations()) * num_x * + num_y * sizeof(float)); } -static void Do3DYReduce(int iters, const string& device, const string& reduce, - int num_x, int num_y) { - testing::ItemsProcessed(static_cast<int64>(iters) * num_x * num_y); - testing::BytesProcessed(static_cast<int64>(iters) * num_x * num_y * - sizeof(float)); - test::Benchmark(device, ThreeDYReduce(reduce, num_x, num_y)).Run(iters); +static void Do3DYReduce(::testing::benchmark::State& state, + const string& device, const string& reduce, int num_x, + int num_y) { + test::Benchmark(device, ThreeDYReduce(reduce, num_x, num_y), + /*old_benchmark_api*/ false) + .Run(state); + state.SetItemsProcessed(static_cast<int64>(state.iterations()) * num_x * + num_y); + state.SetBytesProcessed(static_cast<int64>(state.iterations()) * num_x * + num_y * sizeof(float)); } -static void Do3DXZReduce(int iters, const string& device, const string& reduce, - int num_x, int num_y) { - testing::ItemsProcessed(static_cast<int64>(iters) * num_x * num_y); - testing::BytesProcessed(static_cast<int64>(iters) * num_x * num_y * - sizeof(float)); - test::Benchmark(device, ThreeDXZReduce(reduce, num_x, num_y)).Run(iters); +static void Do3DXZReduce(::testing::benchmark::State& state, + const string& device, const string& reduce, int num_x, + int num_y) { + test::Benchmark(device, ThreeDXZReduce(reduce, num_x, num_y), + /*old_benchmark_api*/ false) + .Run(state); + state.SetItemsProcessed(static_cast<int64>(state.iterations()) * num_x * + num_y); + state.SetBytesProcessed(static_cast<int64>(state.iterations()) * num_x * + num_y * sizeof(float)); } -static void BM_Sum2DToScalarGPU(int iters, int num_x, int num_y) { - ReduceToScalar<float>(iters, "gpu", "Sum", num_x, num_y); +static void BM_Sum2DToScalarGPU(::testing::benchmark::State& state) { + const int num_x = state.range(0); + const int num_y = state.range(1); + + ReduceToScalar<float>(state, "gpu", "Sum", num_x, num_y); } BENCHMARK(BM_Sum2DToScalarGPU)->RangePair(1, 8192, 1, 8192); -static void BM_Sum2DToScalarGPUComplex(int iters, int num_x, int num_y) { - ReduceToScalar<std::complex<float>>(iters, "gpu", "Sum", num_x, num_y); +static void BM_Sum2DToScalarGPUComplex(::testing::benchmark::State& state) { + const int num_x = state.range(0); + const int num_y = state.range(1); + + ReduceToScalar<std::complex<float>>(state, "gpu", "Sum", num_x, num_y); } BENCHMARK(BM_Sum2DToScalarGPUComplex)->RangePair(1, 8192, 1, 8192); -static void BM_Sum2DToScalarGPUHalf(int iters, int num_x, int num_y) { - ReduceToScalar<Eigen::half>(iters, "gpu", "Sum", num_x, num_y); +static void BM_Sum2DToScalarGPUHalf(::testing::benchmark::State& state) { + const int num_x = state.range(0); + const int num_y = state.range(1); + + ReduceToScalar<Eigen::half>(state, "gpu", "Sum", num_x, num_y); } BENCHMARK(BM_Sum2DToScalarGPUHalf)->RangePair(1, 8192, 1, 8192); -static void BM_Sum2DRowReduceGPU(int iters, int num_x, int num_y) { - DoRowReduce(iters, "gpu", "Sum", num_x, num_y); +static void BM_Sum2DRowReduceGPU(::testing::benchmark::State& state) { + const int num_x = state.range(0); + const int num_y = state.range(1); + + DoRowReduce(state, "gpu", "Sum", num_x, num_y); } BENCHMARK(BM_Sum2DRowReduceGPU)->RangePair(1, 8192, 1, 8192); -static void BM_Sum2DColumnReduceGPU(int iters, int num_x, int num_y) { - DoColReduce(iters, "gpu", "Sum", num_x, num_y); +static void BM_Sum2DColumnReduceGPU(::testing::benchmark::State& state) { + const int num_x = state.range(0); + const int num_y = state.range(1); + + DoColReduce(state, "gpu", "Sum", num_x, num_y); } BENCHMARK(BM_Sum2DColumnReduceGPU)->RangePair(1, 8192, 1, 8192); -static void BM_Sum3DYReduceGPU(int iters, int num_x, int num_y) { - Do3DYReduce(iters, "gpu", "Sum", num_x, num_y); +static void BM_Sum3DYReduceGPU(::testing::benchmark::State& state) { + const int num_x = state.range(0); + const int num_y = state.range(1); + + Do3DYReduce(state, "gpu", "Sum", num_x, num_y); } BENCHMARK(BM_Sum3DYReduceGPU)->RangePair(64, 4096, 64, 4096); -static void BM_Sum3DXZReduceGPU(int iters, int num_x, int num_y) { - Do3DXZReduce(iters, "gpu", "Sum", num_x, num_y); +static void BM_Sum3DXZReduceGPU(::testing::benchmark::State& state) { + const int num_x = state.range(0); + const int num_y = state.range(1); + + Do3DXZReduce(state, "gpu", "Sum", num_x, num_y); } BENCHMARK(BM_Sum3DXZReduceGPU)->RangePair(64, 4096, 64, 4096); -static void BM_Mean2DToScalarGPU(int iters, int num_x, int num_y) { - ReduceToScalar<float>(iters, "gpu", "Mean", num_x, num_y); +static void BM_Mean2DToScalarGPU(::testing::benchmark::State& state) { + const int num_x = state.range(0); + const int num_y = state.range(1); + + ReduceToScalar<float>(state, "gpu", "Mean", num_x, num_y); } BENCHMARK(BM_Mean2DToScalarGPU)->RangePair(2048, 8192, 2048, 8192); -static void BM_EuclideanNorm2DToScalarGPU(int iters, int num_x, int num_y) { - ReduceToScalar<float>(iters, "gpu", "EuclideanNorm", num_x, num_y); +static void BM_EuclideanNorm2DToScalarGPU(::testing::benchmark::State& state) { + const int num_x = state.range(0); + const int num_y = state.range(1); + + ReduceToScalar<float>(state, "gpu", "EuclideanNorm", num_x, num_y); } BENCHMARK(BM_EuclideanNorm2DToScalarGPU)->RangePair(2048, 8192, 2048, 8192); -static void BM_Max2DToScalarGPU(int iters, int num_x, int num_y) { - ReduceToScalar<float>(iters, "gpu", "Max", num_x, num_y); +static void BM_Max2DToScalarGPU(::testing::benchmark::State& state) { + const int num_x = state.range(0); + const int num_y = state.range(1); + + ReduceToScalar<float>(state, "gpu", "Max", num_x, num_y); } BENCHMARK(BM_Max2DToScalarGPU)->RangePair(2048, 8192, 2048, 8192); -static void BM_Min2DToScalarGPU(int iters, int num_x, int num_y) { - ReduceToScalar<float>(iters, "gpu", "Min", num_x, num_y); +static void BM_Min2DToScalarGPU(::testing::benchmark::State& state) { + const int num_x = state.range(0); + const int num_y = state.range(1); + + ReduceToScalar<float>(state, "gpu", "Min", num_x, num_y); } BENCHMARK(BM_Min2DToScalarGPU)->RangePair(2048, 8192, 2048, 8192); -static void BM_Min2DToScalarGPUHalf(int iters, int num_x, int num_y) { - ReduceToScalar<Eigen::half>(iters, "gpu", "Min", num_x, num_y); +static void BM_Min2DToScalarGPUHalf(::testing::benchmark::State& state) { + const int num_x = state.range(0); + const int num_y = state.range(1); + + ReduceToScalar<Eigen::half>(state, "gpu", "Min", num_x, num_y); } BENCHMARK(BM_Min2DToScalarGPUHalf)->RangePair(2048, 8192, 2048, 8192); -static void BM_Bool2DToScalarGPU(int iters, int num_x, int num_y) { - ReduceToScalar<bool>(iters, "gpu", "All", num_x, num_y); +static void BM_Bool2DToScalarGPU(::testing::benchmark::State& state) { + const int num_x = state.range(0); + const int num_y = state.range(1); + + ReduceToScalar<bool>(state, "gpu", "All", num_x, num_y); } BENCHMARK(BM_Bool2DToScalarGPU)->RangePair(2048, 8192, 2048, 8192); diff --git a/tensorflow/core/kernels/regex_replace_op_test.cc b/tensorflow/core/kernels/regex_replace_op_test.cc index b9e960efecc..7c537b6dbde 100644 --- a/tensorflow/core/kernels/regex_replace_op_test.cc +++ b/tensorflow/core/kernels/regex_replace_op_test.cc @@ -84,17 +84,17 @@ Graph* SetupRegexReplaceGraph(const Tensor& input, const string& input_pattern, return g; } -void BM_RegexReplace(int iters, int batch_size) { - testing::StopTiming(); - testing::ItemsProcessed(static_cast<int64>(iters)); - testing::UseRealTime(); +static void BM_RegexReplace(::testing::benchmark::State& state) { + const int batch_size = state.range(0); + Tensor input = GetTestTensor(batch_size); Graph* g = SetupRegexReplaceGraph(input, kRegExPattern, kRewrite); - testing::StartTiming(); - test::Benchmark("cpu", g).Run(iters); + test::Benchmark("cpu", g, /*old_benchmark_api*/ false).Run(state); + state.SetItemsProcessed(static_cast<int64>(state.iterations())); } BENCHMARK(BM_RegexReplace) + ->UseRealTime() ->Arg(1) ->Arg(8) ->Arg(16) @@ -115,17 +115,17 @@ Graph* SetupStaticGraph(const Tensor& input, const string& input_pattern, .Finalize(g, nullptr /* node */)); return g; } -void BM_StaticRegexReplace(int iters, int batch_size) { - testing::StopTiming(); - testing::ItemsProcessed(static_cast<int64>(iters)); - testing::UseRealTime(); +static void BM_StaticRegexReplace(::testing::benchmark::State& state) { + const int batch_size = state.range(0); + Tensor input = GetTestTensor(batch_size); Graph* g = SetupStaticGraph(input, kRegExPattern, kRewrite); - testing::StartTiming(); - test::Benchmark("cpu", g).Run(iters); + test::Benchmark("cpu", g, /*old_benchmark_api*/ false).Run(state); + state.SetItemsProcessed(static_cast<int64>(state.iterations())); } BENCHMARK(BM_StaticRegexReplace) + ->UseRealTime() ->Arg(1) ->Arg(8) ->Arg(16) diff --git a/tensorflow/core/kernels/requantization_range_op_test.cc b/tensorflow/core/kernels/requantization_range_op_test.cc index dd04da373d8..a9740dd31d7 100644 --- a/tensorflow/core/kernels/requantization_range_op_test.cc +++ b/tensorflow/core/kernels/requantization_range_op_test.cc @@ -67,56 +67,29 @@ TEST_F(RequantizationRangeTest, HandCrafted) { test::ExpectTensorEqual<float>(expected_max, *GetOutput(1)); } -static void BM_RequantizationRange(int iters, int size) { - testing::StopTiming(); - testing::UseRealTime(); - testing::ItemsProcessed(static_cast<int64>(iters) * size); - testing::ItemsProcessed(static_cast<int64>(iters) * size * 4); +static void BM_RequantizationRange(::testing::benchmark::State& state) { + const int size = state.range(0); Tensor quantized_tensor(DT_QINT32, TensorShape({1, size})); test::FillFn<qint32>(&quantized_tensor, [](int n) { return qint32(n); }); qint32 actual_min; qint32 actual_max; - testing::StartTiming(); - for (int iter = 0; iter < iters; ++iter) { + for (auto s : state) { CalculateUsedRange(quantized_tensor, &actual_min, &actual_max); } + state.SetItemsProcessed(static_cast<int64>(state.iterations()) * size); + state.SetBytesProcessed(static_cast<int64>(state.iterations()) * size * 4); } -static void BM_RequantizationRange100(int iters) { - BM_RequantizationRange(100, iters); -} -BENCHMARK(BM_RequantizationRange100); - -static void BM_RequantizationRange1000(int iters) { - BM_RequantizationRange(1000, iters); -} -BENCHMARK(BM_RequantizationRange1000); - -static void BM_RequantizationRange10000(int iters) { - BM_RequantizationRange(10000, iters); -} -BENCHMARK(BM_RequantizationRange10000); - -static void BM_RequantizationRange100000(int iters) { - BM_RequantizationRange(100000, iters); -} -BENCHMARK(BM_RequantizationRange100000); - -static void BM_RequantizationRange1000000(int iters) { - BM_RequantizationRange(1000000, iters); -} -BENCHMARK(BM_RequantizationRange1000000); - -static void BM_RequantizationRange10000000(int iters) { - BM_RequantizationRange(10000000, iters); -} -BENCHMARK(BM_RequantizationRange10000000); - -static void BM_RequantizationRange100000000(int iters) { - BM_RequantizationRange(100000000, iters); -} -BENCHMARK(BM_RequantizationRange100000000); +BENCHMARK(BM_RequantizationRange) + ->UseRealTime() + ->Arg(100) + ->Arg(1000) + ->Arg(10000) + ->Arg(100000) + ->Arg(1000000) + ->Arg(10000000) + ->Arg(100000000); } // end namespace tensorflow diff --git a/tensorflow/core/kernels/reverse_op_test.cc b/tensorflow/core/kernels/reverse_op_test.cc index 62d7d294597..d34e97ea2c2 100644 --- a/tensorflow/core/kernels/reverse_op_test.cc +++ b/tensorflow/core/kernels/reverse_op_test.cc @@ -197,148 +197,187 @@ static Graph* Reverse(const TensorShape& shape, int reverse_axis) { } template <typename T> -static void RunReverseRowsBenchmark(int iters, int outer_dim, int middle_dim, +static void RunReverseRowsBenchmark(::testing::benchmark::State& state, + int outer_dim, int middle_dim, int intra_threads, int channels) { SessionOptions opts = GetOptions(intra_threads); TensorShape shape{outer_dim, middle_dim, channels}; - const int64 num_items = static_cast<int64>(iters) * shape.num_elements(); - testing::ItemsProcessed(num_items); - testing::BytesProcessed(num_items * sizeof(T)); - testing::UseRealTime(); - test::Benchmark("cpu", Reverse<T>(shape, 1), &opts).Run(iters); + test::Benchmark("cpu", Reverse<T>(shape, 1), &opts, nullptr, nullptr, "", + /*old_benchmark_api*/ false) + .Run(state); + const int64 num_items = + static_cast<int64>(state.iterations()) * shape.num_elements(); + state.SetItemsProcessed(num_items); + state.SetBytesProcessed(num_items * sizeof(T)); } -static void BM_ReverseRowsOf1Channel_1T_float(int iters, int outer_dim, - int middle_dim) { - RunReverseRowsBenchmark<float>(iters, outer_dim, middle_dim, +void BM_ReverseRowsOf1Channel_1T_float(::testing::benchmark::State& state) { + const int outer_dim = state.range(0); + const int middle_dim = state.range(1); + + RunReverseRowsBenchmark<float>(state, outer_dim, middle_dim, 1 /* intra_threads */, 1 /* channels */); } BENCHMARK(BM_ReverseRowsOf1Channel_1T_float) + ->UseRealTime() ->ArgPair(288, 288) ->ArgPair(1024, 1024) ->ArgPair(10 * 1024, 1024); -static void BM_ReverseRowsOf1Channel_1T_uint8(int iters, int outer_dim, - int middle_dim) { - RunReverseRowsBenchmark<uint8>(iters, outer_dim, middle_dim, +void BM_ReverseRowsOf1Channel_1T_uint8(::testing::benchmark::State& state) { + const int outer_dim = state.range(0); + const int middle_dim = state.range(1); + + RunReverseRowsBenchmark<uint8>(state, outer_dim, middle_dim, 1 /* intra_threads */, 1 /* channels */); } BENCHMARK(BM_ReverseRowsOf1Channel_1T_uint8) + ->UseRealTime() ->ArgPair(288, 288) ->ArgPair(1024, 1024) ->ArgPair(10 * 1024, 1024); -static void BM_ReverseRowsOf1Channel_4T_float(int iters, int outer_dim, - int middle_dim) { - RunReverseRowsBenchmark<float>(iters, outer_dim, middle_dim, +void BM_ReverseRowsOf1Channel_4T_float(::testing::benchmark::State& state) { + const int outer_dim = state.range(0); + const int middle_dim = state.range(1); + + RunReverseRowsBenchmark<float>(state, outer_dim, middle_dim, 4 /* intra_threads */, 1 /* channels */); } BENCHMARK(BM_ReverseRowsOf1Channel_4T_float) + ->UseRealTime() ->ArgPair(288, 288) ->ArgPair(1024, 1024) ->ArgPair(10 * 1024, 1024); -static void BM_ReverseRowsOf1Channel_4T_uint8(int iters, int outer_dim, - int middle_dim) { - RunReverseRowsBenchmark<uint8>(iters, outer_dim, middle_dim, +void BM_ReverseRowsOf1Channel_4T_uint8(::testing::benchmark::State& state) { + const int outer_dim = state.range(0); + const int middle_dim = state.range(1); + + RunReverseRowsBenchmark<uint8>(state, outer_dim, middle_dim, 4 /* intra_threads */, 1 /* channels */); } BENCHMARK(BM_ReverseRowsOf1Channel_4T_uint8) + ->UseRealTime() ->ArgPair(288, 288) ->ArgPair(1024, 1024) ->ArgPair(10 * 1024, 1024); -static void BM_ReverseRowsOf3Channels_1T_float(int iters, int outer_dim, - int middle_dim) { - RunReverseRowsBenchmark<float>(iters, outer_dim, middle_dim, +void BM_ReverseRowsOf3Channels_1T_float(::testing::benchmark::State& state) { + const int outer_dim = state.range(0); + const int middle_dim = state.range(1); + + RunReverseRowsBenchmark<float>(state, outer_dim, middle_dim, 1 /* intra_threads */, 3 /* channels */); } BENCHMARK(BM_ReverseRowsOf3Channels_1T_float) + ->UseRealTime() ->ArgPair(288, 288) ->ArgPair(30, 30) ->ArgPair(1024, 1024) ->ArgPair(10 * 1024, 1024); -static void BM_ReverseRowsOf3Channels_1T_uint8(int iters, int outer_dim, - int middle_dim) { - RunReverseRowsBenchmark<uint8>(iters, outer_dim, middle_dim, +void BM_ReverseRowsOf3Channels_1T_uint8(::testing::benchmark::State& state) { + const int outer_dim = state.range(0); + const int middle_dim = state.range(1); + + RunReverseRowsBenchmark<uint8>(state, outer_dim, middle_dim, 1 /* intra_threads */, 3 /* channels */); } BENCHMARK(BM_ReverseRowsOf3Channels_1T_uint8) + ->UseRealTime() ->ArgPair(288, 288) ->ArgPair(30, 30) ->ArgPair(1024, 1024) ->ArgPair(10 * 1024, 1024); -static void BM_ReverseRowsOf3Channels_4T_float(int iters, int outer_dim, - int middle_dim) { - RunReverseRowsBenchmark<float>(iters, outer_dim, middle_dim, +void BM_ReverseRowsOf3Channels_4T_float(::testing::benchmark::State& state) { + const int outer_dim = state.range(0); + const int middle_dim = state.range(1); + + RunReverseRowsBenchmark<float>(state, outer_dim, middle_dim, 4 /* intra_threads */, 3 /* channels */); } BENCHMARK(BM_ReverseRowsOf3Channels_4T_float) + ->UseRealTime() ->ArgPair(288, 288) ->ArgPair(30, 30) ->ArgPair(1024, 1024) ->ArgPair(10 * 1024, 1024); -static void BM_ReverseRowsOf3Channels_4T_uint8(int iters, int outer_dim, - int middle_dim) { - RunReverseRowsBenchmark<uint8>(iters, outer_dim, middle_dim, +void BM_ReverseRowsOf3Channels_4T_uint8(::testing::benchmark::State& state) { + const int outer_dim = state.range(0); + const int middle_dim = state.range(1); + + RunReverseRowsBenchmark<uint8>(state, outer_dim, middle_dim, 4 /* intra_threads */, 3 /* channels */); } BENCHMARK(BM_ReverseRowsOf3Channels_4T_uint8) + ->UseRealTime() ->ArgPair(288, 288) ->ArgPair(30, 30) ->ArgPair(1024, 1024) ->ArgPair(10 * 1024, 1024); -static void BM_ReverseRowsOf4Channels_1T_float(int iters, int outer_dim, - int middle_dim) { - RunReverseRowsBenchmark<float>(iters, outer_dim, middle_dim, +void BM_ReverseRowsOf4Channels_1T_float(::testing::benchmark::State& state) { + const int outer_dim = state.range(0); + const int middle_dim = state.range(1); + + RunReverseRowsBenchmark<float>(state, outer_dim, middle_dim, 1 /* intra_threads */, 4 /* channels */); } BENCHMARK(BM_ReverseRowsOf4Channels_1T_float) + ->UseRealTime() ->ArgPair(288, 288) ->ArgPair(1024, 1024) ->ArgPair(10 * 1024, 1024); -static void BM_ReverseRowsOf4Channels_1T_uint8(int iters, int outer_dim, - int middle_dim) { - RunReverseRowsBenchmark<uint8>(iters, outer_dim, middle_dim, +void BM_ReverseRowsOf4Channels_1T_uint8(::testing::benchmark::State& state) { + const int outer_dim = state.range(0); + const int middle_dim = state.range(1); + + RunReverseRowsBenchmark<uint8>(state, outer_dim, middle_dim, 1 /* intra_threads */, 4 /* channels */); } BENCHMARK(BM_ReverseRowsOf4Channels_1T_uint8) + ->UseRealTime() ->ArgPair(288, 288) ->ArgPair(1024, 1024) ->ArgPair(10 * 1024, 1024); -static void BM_ReverseRowsOf4Channels_4T_float(int iters, int outer_dim, - int middle_dim) { - RunReverseRowsBenchmark<float>(iters, outer_dim, middle_dim, +void BM_ReverseRowsOf4Channels_4T_float(::testing::benchmark::State& state) { + const int outer_dim = state.range(0); + const int middle_dim = state.range(1); + + RunReverseRowsBenchmark<float>(state, outer_dim, middle_dim, 4 /* intra_threads */, 4 /* channels */); } BENCHMARK(BM_ReverseRowsOf4Channels_4T_float) + ->UseRealTime() ->ArgPair(288, 288) ->ArgPair(1024, 1024) ->ArgPair(10 * 1024, 1024); -static void BM_ReverseRowsOf4Channels_4T_uint8(int iters, int outer_dim, - int middle_dim) { - RunReverseRowsBenchmark<uint8>(iters, outer_dim, middle_dim, +void BM_ReverseRowsOf4Channels_4T_uint8(::testing::benchmark::State& state) { + const int outer_dim = state.range(0); + const int middle_dim = state.range(1); + + RunReverseRowsBenchmark<uint8>(state, outer_dim, middle_dim, 4 /* intra_threads */, 4 /* channels */); } BENCHMARK(BM_ReverseRowsOf4Channels_4T_uint8) + ->UseRealTime() ->ArgPair(288, 288) ->ArgPair(1024, 1024) ->ArgPair(10 * 1024, 1024); diff --git a/tensorflow/core/kernels/roll_op_test.cc b/tensorflow/core/kernels/roll_op_test.cc index 3ee66906139..6e0b638c79d 100644 --- a/tensorflow/core/kernels/roll_op_test.cc +++ b/tensorflow/core/kernels/roll_op_test.cc @@ -450,34 +450,44 @@ static Graph* RollGraph(const TensorShape& shape, int isd) { return g; } -#define BM_ROLL_OUTER(DEVICE) \ - static void BM_##DEVICE##_roll_outer(int iters, int rows, int columns) { \ - TensorShape shape{rows, columns}; \ - const int64 num_items = static_cast<int64>(iters) * shape.num_elements(); \ - testing::ItemsProcessed(num_items); \ - testing::BytesProcessed(num_items * sizeof(float)); \ - testing::UseRealTime(); \ - test::Benchmark(#DEVICE, RollGraph(shape, 0)).Run(iters); \ - } \ - BENCHMARK(BM_##DEVICE##_roll_outer) \ - ->ArgPair(256, 256) \ - ->ArgPair(512, 512) \ - ->ArgPair(1024, 1024) \ +#define BM_ROLL_OUTER(DEVICE) \ + static void BM_##DEVICE##_roll_outer(::testing::benchmark::State& state) { \ + const int rows = state.range(0); \ + const int columns = state.range(1); \ + \ + TensorShape shape{rows, columns}; \ + test::Benchmark(#DEVICE, RollGraph(shape, 0), /*old_benchmark_api*/ false) \ + .Run(state); \ + const int64 num_items = \ + static_cast<int64>(state.iterations()) * shape.num_elements(); \ + state.SetItemsProcessed(num_items); \ + state.SetBytesProcessed(num_items * sizeof(float)); \ + } \ + BENCHMARK(BM_##DEVICE##_roll_outer) \ + ->UseRealTime() \ + ->ArgPair(256, 256) \ + ->ArgPair(512, 512) \ + ->ArgPair(1024, 1024) \ ->ArgPair(2048, 2048) -#define BM_ROLL_ALL(DEVICE) \ - static void BM_##DEVICE##_roll_all(int iters, int rows, int columns) { \ - TensorShape shape{rows, columns}; \ - const int64 num_items = static_cast<int64>(iters) * shape.num_elements(); \ - testing::ItemsProcessed(num_items); \ - testing::BytesProcessed(num_items * sizeof(float)); \ - testing::UseRealTime(); \ - test::Benchmark(#DEVICE, RollGraph(shape, 1)).Run(iters); \ - } \ - BENCHMARK(BM_##DEVICE##_roll_all) \ - ->ArgPair(256, 256) \ - ->ArgPair(512, 512) \ - ->ArgPair(1024, 1024) \ +#define BM_ROLL_ALL(DEVICE) \ + static void BM_##DEVICE##_roll_all(::testing::benchmark::State& state) { \ + const int rows = state.range(0); \ + const int columns = state.range(1); \ + \ + TensorShape shape{rows, columns}; \ + test::Benchmark(#DEVICE, RollGraph(shape, 1), /*old_benchmark_api*/ false) \ + .Run(state); \ + const int64 num_items = \ + static_cast<int64>(state.iterations()) * shape.num_elements(); \ + state.SetItemsProcessed(num_items); \ + state.SetBytesProcessed(num_items * sizeof(float)); \ + } \ + BENCHMARK(BM_##DEVICE##_roll_all) \ + ->UseRealTime() \ + ->ArgPair(256, 256) \ + ->ArgPair(512, 512) \ + ->ArgPair(1024, 1024) \ ->ArgPair(2048, 2048) BM_ROLL_OUTER(cpu); diff --git a/tensorflow/core/kernels/save_op_test.cc b/tensorflow/core/kernels/save_op_test.cc index 1f6d8257bdd..b46609ef193 100644 --- a/tensorflow/core/kernels/save_op_test.cc +++ b/tensorflow/core/kernels/save_op_test.cc @@ -663,8 +663,8 @@ TEST_F(SaveOpSlices2Test, TwoSlices) { // Benchmark-related code below. -static void BM_LargeTensorWrite(int iters, int num_elements) { - testing::StopTiming(); +void BM_LargeTensorWrite(::testing::benchmark::State& state) { + const int num_elements = state.range(0); // 4 * num_elements bytes total , since sizeof(float) == 4. Tensor tensor(DT_FLOAT, TensorShape({num_elements})); @@ -689,8 +689,9 @@ static void BM_LargeTensorWrite(int iters, int num_elements) { VLOG(1) << "Save op's output path: " << temp_filename; VLOG(1) << "# nodes in Graph: " << g->num_nodes(); - testing::StartTiming(); - test::Benchmark("cpu", g, &session_options).Run(iters); + test::Benchmark("cpu", g, &session_options, nullptr, nullptr, "", + /*old_benchmark_api*/ false) + .Run(state); } BENCHMARK(BM_LargeTensorWrite)->Arg((1 << 30) / 4 /* 1GB float tensor */); diff --git a/tensorflow/core/kernels/scan_ops_test.cc b/tensorflow/core/kernels/scan_ops_test.cc index 588b606a99b..88cb351eb53 100644 --- a/tensorflow/core/kernels/scan_ops_test.cc +++ b/tensorflow/core/kernels/scan_ops_test.cc @@ -67,79 +67,120 @@ static Graph* ThreeDYCumsum(int num_y, int num_z, bool reverse = false) { } template <typename T> -static void LargeOneDimensional(int iters, const string& device, int num_x, +static void LargeOneDimensional(::testing::benchmark::State& state, + const string& device, int num_x, bool reverse = false) { - testing::ItemsProcessed(static_cast<int64>(iters) * num_x); - testing::BytesProcessed(static_cast<int64>(iters) * num_x * sizeof(T)); - test::Benchmark(device, LargeOneDCumsum<T>(num_x, reverse)).Run(iters); + test::Benchmark(device, LargeOneDCumsum<T>(num_x, reverse), + /*old_benchmark_api*/ false) + .Run(state); + state.SetItemsProcessed(static_cast<int64>(state.iterations()) * num_x); + state.SetBytesProcessed(static_cast<int64>(state.iterations()) * num_x * + sizeof(T)); } -static void DoRowCumsum(int iters, const string& device, int num_x, int num_y, +static void DoRowCumsum(::testing::benchmark::State& state, + const string& device, int num_x, int num_y, bool reverse = false) { - testing::ItemsProcessed(static_cast<int64>(iters) * num_x * num_y); - testing::BytesProcessed(static_cast<int64>(iters) * num_x * num_y * - sizeof(float)); - test::Benchmark(device, RowCumsum(num_x, num_y, reverse)).Run(iters); + test::Benchmark(device, RowCumsum(num_x, num_y, reverse), + /*old_benchmark_api*/ false) + .Run(state); + state.SetItemsProcessed(static_cast<int64>(state.iterations()) * num_x * + num_y); + state.SetBytesProcessed(static_cast<int64>(state.iterations()) * num_x * + num_y * sizeof(float)); } -static void DoColCumsum(int iters, const string& device, int num_x, int num_y, +static void DoColCumsum(::testing::benchmark::State& state, + const string& device, int num_x, int num_y, bool reverse = false) { - testing::ItemsProcessed(static_cast<int64>(iters) * num_x * num_y); - testing::BytesProcessed(static_cast<int64>(iters) * num_x * num_y * - sizeof(float)); - test::Benchmark(device, ColCumsum(num_x, num_y, reverse)).Run(iters); + test::Benchmark(device, ColCumsum(num_x, num_y, reverse), + /*old_benchmark_api*/ false) + .Run(state); + state.SetItemsProcessed(static_cast<int64>(state.iterations()) * num_x * + num_y); + state.SetBytesProcessed(static_cast<int64>(state.iterations()) * num_x * + num_y * sizeof(float)); } -static void Do3DYCumsum(int iters, const string& device, int num_x, int num_y, +static void Do3DYCumsum(::testing::benchmark::State& state, + const string& device, int num_x, int num_y, bool reverse = false) { - testing::ItemsProcessed(static_cast<int64>(iters) * num_x * num_y); - testing::BytesProcessed(static_cast<int64>(iters) * num_x * num_y * - sizeof(float)); - test::Benchmark(device, ThreeDYCumsum(num_x, num_y, reverse)).Run(iters); + test::Benchmark(device, ThreeDYCumsum(num_x, num_y, reverse), + /*old_benchmark_api*/ false) + .Run(state); + state.SetItemsProcessed(static_cast<int64>(state.iterations()) * num_x * + num_y); + state.SetBytesProcessed(static_cast<int64>(state.iterations()) * num_x * + num_y * sizeof(float)); } -static void BM_OneDCumsumGPU(int iters, int num_x) { - LargeOneDimensional<float>(iters, "gpu", num_x); +static void BM_OneDCumsumGPU(::testing::benchmark::State& state) { + const int num_x = state.range(0); + + LargeOneDimensional<float>(state, "gpu", num_x); } BENCHMARK(BM_OneDCumsumGPU)->Range(1, 1 << 21); -static void BM_OneDCumsumGPUHalf(int iters, int num_x) { - LargeOneDimensional<Eigen::half>(iters, "gpu", num_x); +static void BM_OneDCumsumGPUHalf(::testing::benchmark::State& state) { + const int num_x = state.range(0); + + LargeOneDimensional<Eigen::half>(state, "gpu", num_x); } BENCHMARK(BM_OneDCumsumGPUHalf)->Range(1, 1 << 21); -static void BM_Sum2DRowCumsumGPU(int iters, int num_x, int num_y) { - DoRowCumsum(iters, "gpu", num_x, num_y); +static void BM_Sum2DRowCumsumGPU(::testing::benchmark::State& state) { + const int num_x = state.range(0); + const int num_y = state.range(1); + + DoRowCumsum(state, "gpu", num_x, num_y); } BENCHMARK(BM_Sum2DRowCumsumGPU)->RangePair(1, 8192, 1, 8192); -static void BM_Sum2DColumnCumsumGPU(int iters, int num_x, int num_y) { - DoColCumsum(iters, "gpu", num_x, num_y); +static void BM_Sum2DColumnCumsumGPU(::testing::benchmark::State& state) { + const int num_x = state.range(0); + const int num_y = state.range(1); + + DoColCumsum(state, "gpu", num_x, num_y); } BENCHMARK(BM_Sum2DColumnCumsumGPU)->RangePair(1, 8192, 1, 8192); -static void BM_Sum3DYCumsumGPU(int iters, int num_x, int num_y) { - Do3DYCumsum(iters, "gpu", num_x, num_y); +static void BM_Sum3DYCumsumGPU(::testing::benchmark::State& state) { + const int num_x = state.range(0); + const int num_y = state.range(1); + + Do3DYCumsum(state, "gpu", num_x, num_y); } BENCHMARK(BM_Sum3DYCumsumGPU)->RangePair(64, 4096, 64, 4096); -static void BM_OneDCumsumGPU_reverse(int iters, int num_x) { - LargeOneDimensional<float>(iters, "gpu", num_x, true); +static void BM_OneDCumsumGPU_reverse(::testing::benchmark::State& state) { + const int num_x = state.range(0); + + LargeOneDimensional<float>(state, "gpu", num_x, true); } BENCHMARK(BM_OneDCumsumGPU_reverse)->Range(1, 1 << 21); -static void BM_Sum2DRowCumsumGPU_reverse(int iters, int num_x, int num_y) { - DoRowCumsum(iters, "gpu", num_x, num_y, true); +static void BM_Sum2DRowCumsumGPU_reverse(::testing::benchmark::State& state) { + const int num_x = state.range(0); + const int num_y = state.range(1); + + DoRowCumsum(state, "gpu", num_x, num_y, true); } BENCHMARK(BM_Sum2DRowCumsumGPU_reverse)->RangePair(1, 8192, 1, 8192); -static void BM_Sum2DColumnCumsumGPU_reverse(int iters, int num_x, int num_y) { - DoColCumsum(iters, "gpu", num_x, num_y, true); +static void BM_Sum2DColumnCumsumGPU_reverse( + ::testing::benchmark::State& state) { + const int num_x = state.range(0); + const int num_y = state.range(1); + + DoColCumsum(state, "gpu", num_x, num_y, true); } BENCHMARK(BM_Sum2DColumnCumsumGPU_reverse)->RangePair(1, 8192, 1, 8192); -static void BM_Sum3DYCumsumGPU_reverse(int iters, int num_x, int num_y) { - Do3DYCumsum(iters, "gpu", num_x, num_y, true); +static void BM_Sum3DYCumsumGPU_reverse(::testing::benchmark::State& state) { + const int num_x = state.range(0); + const int num_y = state.range(1); + + Do3DYCumsum(state, "gpu", num_x, num_y, true); } BENCHMARK(BM_Sum3DYCumsumGPU_reverse)->RangePair(32, 2048, 32, 2048); diff --git a/tensorflow/core/kernels/scatter_nd_op_test.cc b/tensorflow/core/kernels/scatter_nd_op_test.cc index 9c31bed784f..b7837e11e73 100644 --- a/tensorflow/core/kernels/scatter_nd_op_test.cc +++ b/tensorflow/core/kernels/scatter_nd_op_test.cc @@ -254,8 +254,8 @@ class ScatterNdUpdateBM : public ScatterNdUpdateOpTest { }; template <typename Index> -static void BM_ScatterNdHelper(int iters, int embedding_size, const char* op) { - testing::StopTiming(); +void BM_ScatterNdHelper(::testing::benchmark::State& state, int embedding_size, + const char* op) { const int kRows = 10000000 / embedding_size; std::vector<float> values; values.reserve(kRows); @@ -280,27 +280,33 @@ static void BM_ScatterNdHelper(int iters, int embedding_size, const char* op) { bm.AddInputFromArray<Index>(TensorShape({kNumUpdates}), indices); bm.AddInputFromArray<float>(TensorShape({kNumUpdates, embedding_size}), updates); - testing::ItemsProcessed((static_cast<int64>(kNumUpdates) * embedding_size) * - iters); - testing::StartTiming(); - while (iters-- > 0) { + for (auto i : state) { Status s = bm.RunOpKernel(); } - testing::StopTiming(); + state.SetItemsProcessed((static_cast<int64>(kNumUpdates) * embedding_size) * + state.iterations()); } -static void BM_ScatterNdUpdateInt32(int iters, int embedding_size) { - BM_ScatterNdHelper<int32>(iters, embedding_size, "ScatterNdUpdate"); +void BM_ScatterNdUpdateInt32(::testing::benchmark::State& state) { + const int embedding_size = state.range(0); + + BM_ScatterNdHelper<int32>(state, embedding_size, "ScatterNdUpdate"); } -static void BM_ScatterNdUpdateInt64(int iters, int embedding_size) { - BM_ScatterNdHelper<int64>(iters, embedding_size, "ScatterNdUpdate"); +void BM_ScatterNdUpdateInt64(::testing::benchmark::State& state) { + const int embedding_size = state.range(0); + + BM_ScatterNdHelper<int64>(state, embedding_size, "ScatterNdUpdate"); } -static void BM_ScatterNdAddInt32(int iters, int embedding_size) { - BM_ScatterNdHelper<int32>(iters, embedding_size, "ScatterNdAdd"); +void BM_ScatterNdAddInt32(::testing::benchmark::State& state) { + const int embedding_size = state.range(0); + + BM_ScatterNdHelper<int32>(state, embedding_size, "ScatterNdAdd"); } -static void BM_ScatterNdAddInt64(int iters, int embedding_size) { - BM_ScatterNdHelper<int64>(iters, embedding_size, "ScatterNdAdd"); +void BM_ScatterNdAddInt64(::testing::benchmark::State& state) { + const int embedding_size = state.range(0); + + BM_ScatterNdHelper<int64>(state, embedding_size, "ScatterNdAdd"); } BENCHMARK(BM_ScatterNdUpdateInt32) diff --git a/tensorflow/core/kernels/scatter_op_test.cc b/tensorflow/core/kernels/scatter_op_test.cc index e52f6e74dd5..7febb0e1cb7 100644 --- a/tensorflow/core/kernels/scatter_op_test.cc +++ b/tensorflow/core/kernels/scatter_op_test.cc @@ -280,9 +280,8 @@ class ScatterUpdateBM : public ScatterUpdateOpTest { }; template <typename Index> -static void BM_ScatterHelper(int iters, int embedding_size, const char* op, - bool big_num_updates = false) { - testing::StopTiming(); +void BM_ScatterHelper(::testing::benchmark::State& state, int embedding_size, + const char* op, bool big_num_updates = false) { const int kRows = 10000000 / embedding_size; std::vector<float> values; values.reserve(kRows); @@ -307,59 +306,83 @@ static void BM_ScatterHelper(int iters, int embedding_size, const char* op, bm.AddInputFromArray<Index>(TensorShape({kNumUpdates}), indices); bm.AddInputFromArray<float>(TensorShape({kNumUpdates, embedding_size}), updates); - testing::ItemsProcessed((static_cast<int64>(kNumUpdates) * embedding_size) * - iters); - testing::StartTiming(); - while (iters-- > 0) { + for (auto i : state) { Status s = bm.RunOpKernel(); } - testing::StopTiming(); + state.SetItemsProcessed((static_cast<int64>(kNumUpdates) * embedding_size) * + state.iterations()); } -static void BM_ScatterUpdateInt32(int iters, int embedding_size) { - BM_ScatterHelper<int32>(iters, embedding_size, "ScatterUpdate"); +void BM_ScatterUpdateInt32(::testing::benchmark::State& state) { + const int embedding_size = state.range(0); + + BM_ScatterHelper<int32>(state, embedding_size, "ScatterUpdate"); } -static void BM_ScatterUpdateInt64(int iters, int embedding_size) { - BM_ScatterHelper<int64>(iters, embedding_size, "ScatterUpdate"); +void BM_ScatterUpdateInt64(::testing::benchmark::State& state) { + const int embedding_size = state.range(0); + + BM_ScatterHelper<int64>(state, embedding_size, "ScatterUpdate"); } -static void BM_ScatterAddInt32(int iters, int embedding_size) { - BM_ScatterHelper<int32>(iters, embedding_size, "ScatterAdd"); +void BM_ScatterAddInt32(::testing::benchmark::State& state) { + const int embedding_size = state.range(0); + + BM_ScatterHelper<int32>(state, embedding_size, "ScatterAdd"); } -static void BM_ScatterAddInt32Large(int iters, int embedding_size) { - BM_ScatterHelper<int32>(iters, embedding_size, "ScatterAdd", true); +void BM_ScatterAddInt32Large(::testing::benchmark::State& state) { + const int embedding_size = state.range(0); + + BM_ScatterHelper<int32>(state, embedding_size, "ScatterAdd", true); } -static void BM_ScatterAddInt64(int iters, int embedding_size) { - BM_ScatterHelper<int64>(iters, embedding_size, "ScatterAdd"); +void BM_ScatterAddInt64(::testing::benchmark::State& state) { + const int embedding_size = state.range(0); + + BM_ScatterHelper<int64>(state, embedding_size, "ScatterAdd"); } -static void BM_ScatterMulInt32(int iters, int embedding_size) { - BM_ScatterHelper<int32>(iters, embedding_size, "ScatterMul"); +void BM_ScatterMulInt32(::testing::benchmark::State& state) { + const int embedding_size = state.range(0); + + BM_ScatterHelper<int32>(state, embedding_size, "ScatterMul"); } -static void BM_ScatterMulInt64(int iters, int embedding_size) { - BM_ScatterHelper<int64>(iters, embedding_size, "ScatterMul"); +void BM_ScatterMulInt64(::testing::benchmark::State& state) { + const int embedding_size = state.range(0); + + BM_ScatterHelper<int64>(state, embedding_size, "ScatterMul"); } -static void BM_ScatterDivInt32(int iters, int embedding_size) { - BM_ScatterHelper<int32>(iters, embedding_size, "ScatterDiv"); +void BM_ScatterDivInt32(::testing::benchmark::State& state) { + const int embedding_size = state.range(0); + + BM_ScatterHelper<int32>(state, embedding_size, "ScatterDiv"); } -static void BM_ScatterDivInt64(int iters, int embedding_size) { - BM_ScatterHelper<int64>(iters, embedding_size, "ScatterDiv"); +void BM_ScatterDivInt64(::testing::benchmark::State& state) { + const int embedding_size = state.range(0); + + BM_ScatterHelper<int64>(state, embedding_size, "ScatterDiv"); } -static void BM_ScatterMinInt32(int iters, int embedding_size) { - BM_ScatterHelper<int32>(iters, embedding_size, "ScatterMin"); +void BM_ScatterMinInt32(::testing::benchmark::State& state) { + const int embedding_size = state.range(0); + + BM_ScatterHelper<int32>(state, embedding_size, "ScatterMin"); } -static void BM_ScatterMinInt64(int iters, int embedding_size) { - BM_ScatterHelper<int64>(iters, embedding_size, "ScatterMin"); +void BM_ScatterMinInt64(::testing::benchmark::State& state) { + const int embedding_size = state.range(0); + + BM_ScatterHelper<int64>(state, embedding_size, "ScatterMin"); } -static void BM_ScatterMaxInt32(int iters, int embedding_size) { - BM_ScatterHelper<int32>(iters, embedding_size, "ScatterMax"); +void BM_ScatterMaxInt32(::testing::benchmark::State& state) { + const int embedding_size = state.range(0); + + BM_ScatterHelper<int32>(state, embedding_size, "ScatterMax"); } -static void BM_ScatterMaxInt64(int iters, int embedding_size) { - BM_ScatterHelper<int64>(iters, embedding_size, "ScatterMax"); +void BM_ScatterMaxInt64(::testing::benchmark::State& state) { + const int embedding_size = state.range(0); + + BM_ScatterHelper<int64>(state, embedding_size, "ScatterMax"); } BENCHMARK(BM_ScatterUpdateInt32) diff --git a/tensorflow/core/kernels/segment_reduction_ops_test.cc b/tensorflow/core/kernels/segment_reduction_ops_test.cc index 8d7b70878b7..ca8c3db3d42 100644 --- a/tensorflow/core/kernels/segment_reduction_ops_test.cc +++ b/tensorflow/core/kernels/segment_reduction_ops_test.cc @@ -39,10 +39,9 @@ limitations under the License. namespace tensorflow { template <typename Index> -static void BM_SegmentReduction(int iters, const string& reduction, - Index num_rows, Index num_cols, - Index segment_size) { - testing::StopTiming(); +static void BM_SegmentReduction(::testing::benchmark::State& state, + const string& reduction, Index num_rows, + Index num_cols, Index segment_size) { std::unique_ptr<Device> device( DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0")); @@ -81,24 +80,25 @@ static void BM_SegmentReduction(int iters, const string& reduction, reduction_op->Compute(reduction_context.get()); TF_CHECK_OK(reduction_context->status()); - testing::StartTiming(); - for (int i = 0; i < iters; ++i) { + for (auto s : state) { delete reduction_context->release_output(0).tensor; reduction_op->Compute(reduction_context.get()); } int64 bytes_per_iter = static_cast<int64>(num_rows * num_cols * sizeof(float)); - testing::BytesProcessed(bytes_per_iter * iters); + state.SetBytesProcessed(bytes_per_iter * state.iterations()); } -#define BM_Reduce(O, R, C, S) \ - static void BM_Reduce_##O##_##R##_##C##_##S##_int32(int iters) { \ - BM_SegmentReduction<int32>(iters, #O, R, C, S); \ - } \ - static void BM_Reduce_##O##_##R##_##C##_##S##_int64(int iters) { \ - BM_SegmentReduction<int64>(iters, #O, R, C, S); \ - } \ - BENCHMARK(BM_Reduce_##O##_##R##_##C##_##S##_int32); \ +#define BM_Reduce(O, R, C, S) \ + static void BM_Reduce_##O##_##R##_##C##_##S##_int32( \ + ::testing::benchmark::State & state) { \ + BM_SegmentReduction<int32>(state, #O, R, C, S); \ + } \ + static void BM_Reduce_##O##_##R##_##C##_##S##_int64( \ + ::testing::benchmark::State & state) { \ + BM_SegmentReduction<int64>(state, #O, R, C, S); \ + } \ + BENCHMARK(BM_Reduce_##O##_##R##_##C##_##S##_int32); \ BENCHMARK(BM_Reduce_##O##_##R##_##C##_##S##_int64); #define BM_Reduce_Arg(R, C, S) \ @@ -113,8 +113,8 @@ BM_Reduce_Arg(64, 32, 2); BM_Reduce_Arg(4096, 32, 2); BM_Reduce_Arg(4096, 128, 2); -static void SparseSegmentMeanGradHelper(int iters, float uniqueness, int size) { - testing::StopTiming(); +static void SparseSegmentMeanGradHelper(::testing::benchmark::State& state, + float uniqueness, int size) { Graph* g = new Graph(OpRegistry::Global()); CHECK_LE(uniqueness, 1.0); CHECK_GT(uniqueness, 0.0); @@ -148,22 +148,24 @@ static void SparseSegmentMeanGradHelper(int iters, float uniqueness, int size) { .Attr("T", DT_FLOAT) .Finalize(g, &node)); - testing::UseRealTime(); - testing::BytesProcessed(static_cast<int64>(iters) * (kDim1 * kDim2) * - sizeof(float)); - testing::StartTiming(); - test::Benchmark("cpu", g).Run(iters); + test::Benchmark("cpu", g, /*old_benchmark_api*/ false).Run(state); + state.SetBytesProcessed(static_cast<int64>(state.iterations()) * + (kDim1 * kDim2) * sizeof(float)); } -static void BM_SparseSegmentMeanGrad_Low(int iters, int size) { - return SparseSegmentMeanGradHelper(iters, 1.0, size); +static void BM_SparseSegmentMeanGrad_Low(::testing::benchmark::State& state) { + const int size = state.range(0); + + return SparseSegmentMeanGradHelper(state, 1.0, size); } -static void BM_SparseSegmentMeanGrad_High(int iters, int size) { - return SparseSegmentMeanGradHelper(iters, 0.01, size); +static void BM_SparseSegmentMeanGrad_High(::testing::benchmark::State& state) { + const int size = state.range(0); + + return SparseSegmentMeanGradHelper(state, 0.01, size); } -BENCHMARK(BM_SparseSegmentMeanGrad_Low)->Arg(1000)->Arg(100000); -BENCHMARK(BM_SparseSegmentMeanGrad_High)->Arg(1000)->Arg(100000); +BENCHMARK(BM_SparseSegmentMeanGrad_Low)->UseRealTime()->Arg(1000)->Arg(100000); +BENCHMARK(BM_SparseSegmentMeanGrad_High)->UseRealTime()->Arg(1000)->Arg(100000); } // namespace tensorflow diff --git a/tensorflow/core/kernels/sendrecv_ops_test.cc b/tensorflow/core/kernels/sendrecv_ops_test.cc index 092a29f2f3c..347f7d933d0 100644 --- a/tensorflow/core/kernels/sendrecv_ops_test.cc +++ b/tensorflow/core/kernels/sendrecv_ops_test.cc @@ -54,21 +54,21 @@ static Graph* Recv() { return g; } -static void BM_Send(int iters) { - testing::UseRealTime(); - testing::ItemsProcessed(static_cast<int64>(iters)); - test::Benchmark("cpu", Send(), nullptr, nullptr, new DummyRendezvous) - .Run(iters); +void BM_Send(::testing::benchmark::State& state) { + test::Benchmark("cpu", Send(), nullptr, nullptr, new DummyRendezvous, "", + /*old_benchmark_api*/ false) + .Run(state); + state.SetItemsProcessed(static_cast<int64>(state.iterations())); } -BENCHMARK(BM_Send); +BENCHMARK(BM_Send)->UseRealTime(); -static void BM_Recv(int iters) { - testing::UseRealTime(); - testing::ItemsProcessed(static_cast<int64>(iters)); - test::Benchmark("cpu", Recv(), nullptr, nullptr, new DummyRendezvous) - .Run(iters); +void BM_Recv(::testing::benchmark::State& state) { + test::Benchmark("cpu", Recv(), nullptr, nullptr, new DummyRendezvous, "", + /*old_benchmark_api*/ false) + .Run(state); + state.SetItemsProcessed(static_cast<int64>(state.iterations())); } -BENCHMARK(BM_Recv); +BENCHMARK(BM_Recv)->UseRealTime(); } // namespace } // namespace tensorflow diff --git a/tensorflow/core/kernels/slice_op_test.cc b/tensorflow/core/kernels/slice_op_test.cc index f589a09c4fc..aeb96566da6 100644 --- a/tensorflow/core/kernels/slice_op_test.cc +++ b/tensorflow/core/kernels/slice_op_test.cc @@ -37,8 +37,8 @@ namespace { // For the benchmark, we set up two 2-dimensional tensors, each kDim1 x 'dim' // in size, and concat them together along "concat_dimension" template <typename T> -static void SliceHelper(int iters, int size) { - testing::StopTiming(); +static void SliceHelper(::testing::benchmark::State& state) { + const int size = state.range(0); Graph* g = new Graph(OpRegistry::Global()); DataType dt = DataTypeToEnum<T>::v(); int kDim = 100; @@ -65,26 +65,24 @@ static void SliceHelper(int iters, int size) { .Finalize(g, &node)); FixupSourceAndSinkEdges(g); - testing::BytesProcessed(static_cast<int64>(iters) * kDim * size * sizeof(T)); - testing::StartTiming(); test::Benchmark("cpu", g, nullptr, nullptr, nullptr, - "SINGLE_THREADED_EXECUTOR") - .Run(iters); - - testing::UseRealTime(); + "SINGLE_THREADED_EXECUTOR", /*old_benchmark_api*/ false) + .Run(state); + state.SetBytesProcessed(static_cast<int64>(state.iterations()) * kDim * size * + sizeof(T)); } -static void BM_SliceFloat(int iters, int dim2) { - SliceHelper<float>(iters, dim2); +void BM_SliceFloat(::testing::benchmark::State& state) { + SliceHelper<float>(state); } -BENCHMARK(BM_SliceFloat)->Arg(100)->Arg(1000)->Arg(10000); +BENCHMARK(BM_SliceFloat)->UseRealTime()->Arg(100)->Arg(1000)->Arg(10000); -static void BM_SliceBFloat16(int iters, int dim2) { - SliceHelper<bfloat16>(iters, dim2); +void BM_SliceBFloat16(::testing::benchmark::State& state) { + SliceHelper<bfloat16>(state); } -BENCHMARK(BM_SliceBFloat16)->Arg(100)->Arg(1000)->Arg(10000); +BENCHMARK(BM_SliceBFloat16)->UseRealTime()->Arg(100)->Arg(1000)->Arg(10000); } // namespace } // namespace tensorflow diff --git a/tensorflow/core/kernels/sparse_dense_binary_op_shared_test.cc b/tensorflow/core/kernels/sparse_dense_binary_op_shared_test.cc index e3e9a27f316..4f6c20921ed 100644 --- a/tensorflow/core/kernels/sparse_dense_binary_op_shared_test.cc +++ b/tensorflow/core/kernels/sparse_dense_binary_op_shared_test.cc @@ -276,15 +276,18 @@ static ST MakeSparseTensor(Graph* g, int B, int M, int N, int nnz_inner) { // [8, 4, N{nnz}] cmul [8, 4, N] #define BM_SparseMatCMulDenseMatArgs(N, NNZ_INNER) \ - static void BM_SparseMatCMulDenseMat_##N##_##NNZ_INNER(int iters) { \ + static void BM_SparseMatCMulDenseMat_##N##_##NNZ_INNER( \ + ::testing::benchmark::State& state) { \ Graph* g = new Graph(OpRegistry::Global()); \ Node* dense = MakeTensor(g, 8, 4, N); \ ST sp = MakeSparseTensor(g, 8, 4, N, NNZ_INNER); \ \ - testing::ItemsProcessed(static_cast<int64>(iters * 8 * 4 * N * 2)); \ test::Benchmark( \ - "cpu", SparseMatCMulDenseMat(g, sp.indices, sp.vals, sp.shape, dense)) \ - .Run(iters); \ + "cpu", SparseMatCMulDenseMat(g, sp.indices, sp.vals, sp.shape, dense), \ + /*old_benchmark_api*/ false) \ + .Run(state); \ + state.SetItemsProcessed( \ + static_cast<int64>(state.iterations() * 8 * 4 * N * 2)); \ } \ BENCHMARK(BM_SparseMatCMulDenseMat_##N##_##NNZ_INNER) diff --git a/tensorflow/core/kernels/sparse_to_dense_op_test.cc b/tensorflow/core/kernels/sparse_to_dense_op_test.cc index 84e1e09c219..a1f22e355ec 100644 --- a/tensorflow/core/kernels/sparse_to_dense_op_test.cc +++ b/tensorflow/core/kernels/sparse_to_dense_op_test.cc @@ -198,9 +198,11 @@ TEST_F(SparseToDenseTest, ThreeD_MultValues) { } // namespace -static void BM_SparseToDense(int iters, int NDIM, int N) { +static void BM_SparseToDense(::testing::benchmark::State& state) { + const int NDIM = state.range(0); + const int N = state.range(1); + // TODO(zhifengc): Switch to use kernel_benchmark_testlib.h - tensorflow::testing::StopTiming(); const int IndexDim = (NDIM == 1) ? 0 : 1; @@ -253,18 +255,15 @@ static void BM_SparseToDense(int iters, int NDIM, int N) { std::unique_ptr<OpKernelContext> sparse_context(new OpKernelContext(¶ms)); op->Compute(sparse_context.get()); - tensorflow::testing::StartTiming(); - for (int i = 0; i < iters; ++i) { + for (auto s : state) { delete sparse_context->release_output(0).tensor; op->Compute(sparse_context.get()); TF_ASSERT_OK(sparse_context->status()); } - tensorflow::testing::StopTiming(); // processing input, mainly int64 bytes_per_iter = static_cast<int64>((N + N * NDIM) * sizeof(float)); - - tensorflow::testing::BytesProcessed(bytes_per_iter * iters); + state.SetBytesProcessed(bytes_per_iter * state.iterations()); } BENCHMARK(BM_SparseToDense) diff --git a/tensorflow/core/kernels/sparse_xent_op_test.cc b/tensorflow/core/kernels/sparse_xent_op_test.cc index 3b252d77d0a..85a5cd3befc 100644 --- a/tensorflow/core/kernels/sparse_xent_op_test.cc +++ b/tensorflow/core/kernels/sparse_xent_op_test.cc @@ -41,11 +41,15 @@ static Graph* SparseXent(int batch_size, int num_classes) { return g; } -#define BM_SparseXentDev(BATCH, CLASS, DEVICE) \ - static void BM_SparseXent##_##BATCH##_##CLASS##_##DEVICE(int iters) { \ - testing::ItemsProcessed(static_cast<int64>(iters) * BATCH * CLASS); \ - test::Benchmark(#DEVICE, SparseXent(BATCH, CLASS)).Run(iters); \ - } \ +#define BM_SparseXentDev(BATCH, CLASS, DEVICE) \ + static void BM_SparseXent##_##BATCH##_##CLASS##_##DEVICE( \ + ::testing::benchmark::State& state) { \ + test::Benchmark(#DEVICE, SparseXent(BATCH, CLASS), \ + /*old_benchmark_api*/ false) \ + .Run(state); \ + state.SetItemsProcessed(static_cast<int64>(state.iterations()) * BATCH * \ + CLASS); \ + } \ BENCHMARK(BM_SparseXent##_##BATCH##_##CLASS##_##DEVICE); /// The representative tests for ptb_word on GPU diff --git a/tensorflow/core/kernels/split_op_test.cc b/tensorflow/core/kernels/split_op_test.cc index ac25b6a710e..2617f36fb2e 100644 --- a/tensorflow/core/kernels/split_op_test.cc +++ b/tensorflow/core/kernels/split_op_test.cc @@ -44,38 +44,34 @@ static Graph* MakeGraph(int split_dim, int num_split, } #define BM_SPLIT_1D(num_split, chunk_size) \ - static void BM_Split_1d_##num_split##_##chunk_size(int iters) { \ - testing::StopTiming(); \ - testing::ItemsProcessed(static_cast<int64>(iters) * num_split * \ - chunk_size); \ + static void BM_Split_1d_##num_split##_##chunk_size( \ + ::testing::benchmark::State& state) { \ auto label = \ strings::Printf("1-D %d chunks of %d each", num_split, chunk_size); \ - testing::SetLabel(label); \ - testing::UseRealTime(); \ + state.SetLabel(label); \ auto g = MakeGraph(/* split_dim = */ 0, num_split, {chunk_size}); \ - testing::StartTiming(); \ - test::Benchmark("cpu", g).Run(iters); \ + test::Benchmark("cpu", g, /*old_benchmark_api*/ false).Run(state); \ + state.SetItemsProcessed(static_cast<int64>(state.iterations()) * \ + num_split * chunk_size); \ } \ - BENCHMARK(BM_Split_1d_##num_split##_##chunk_size); + BENCHMARK(BM_Split_1d_##num_split##_##chunk_size)->UseRealTime(); #define BM_SPLIT_2D(split_dim, num_split, chunk_size0, chunk_size1) \ static void \ BM_Split_2d_##split_dim##_##num_split##_##chunk_size0##_##chunk_size1( \ - int iters) { \ - testing::StopTiming(); \ - testing::ItemsProcessed(static_cast<int64>(iters) * num_split * \ - chunk_size0 * chunk_size1); \ + ::testing::benchmark::State& state) { \ auto label = \ strings::Printf("2-D %d chunks in dim %d of (%d * %d) each", \ num_split, split_dim, chunk_size0, chunk_size1); \ - testing::SetLabel(label); \ - testing::UseRealTime(); \ + state.SetLabel(label); \ auto g = MakeGraph(split_dim, num_split, {chunk_size0, chunk_size1}); \ - testing::StartTiming(); \ - test::Benchmark("cpu", g).Run(iters); \ + test::Benchmark("cpu", g, /*old_benchmark_api*/ false).Run(state); \ + state.SetItemsProcessed(static_cast<int64>(state.iterations()) * \ + num_split * chunk_size0 * chunk_size1); \ } \ BENCHMARK( \ - BM_Split_2d_##split_dim##_##num_split##_##chunk_size0##_##chunk_size1); + BM_Split_2d_##split_dim##_##num_split##_##chunk_size0##_##chunk_size1) \ + ->UseRealTime(); BM_SPLIT_1D(5, 1); BM_SPLIT_1D(262144, 1); diff --git a/tensorflow/core/kernels/split_v_op_test.cc b/tensorflow/core/kernels/split_v_op_test.cc index ea2bdd8c3b1..3ffaae4e0fb 100644 --- a/tensorflow/core/kernels/split_v_op_test.cc +++ b/tensorflow/core/kernels/split_v_op_test.cc @@ -73,43 +73,40 @@ static Graph* MakeGraph(int split_dim, const std::vector<int64>& size_splits, } #define BM_SPLITV_1D(num_split, total_size) \ - static void BM_SplitV_1d_##num_split##_##total_size(int iters) { \ - testing::StopTiming(); \ - testing::ItemsProcessed(static_cast<int64>(iters) * total_size); \ + static void BM_SplitV_1d_##num_split##_##total_size( \ + ::testing::benchmark::State& state) { \ auto label = \ strings::Printf("1-D %d chunks totaling %d", num_split, total_size); \ - testing::SetLabel(label); \ - testing::UseRealTime(); \ + state.SetLabel(label); \ auto g = MakeGraph(/* split_dim = */ 0, \ GenerateRandomIntsWithSum(total_size, num_split), \ {total_size}); \ - testing::StartTiming(); \ - test::Benchmark("cpu", g).Run(iters); \ + test::Benchmark("cpu", g, /*old_benchmark_api*/ false).Run(state); \ + state.SetItemsProcessed(static_cast<int64>(state.iterations()) * \ + total_size); \ } \ - BENCHMARK(BM_SplitV_1d_##num_split##_##total_size); + BENCHMARK(BM_SplitV_1d_##num_split##_##total_size)->UseRealTime(); #define BM_SPLITV_2D(split_dim, num_split, total_size0, total_size1) \ static void \ BM_SplitV_2d_##split_dim##_##num_split##_##total_size0##_##total_size1( \ - int iters) { \ - testing::StopTiming(); \ + ::testing::benchmark::State& state) { \ std::vector<int64> total_size_vec{total_size0, total_size1}; \ - testing::ItemsProcessed(static_cast<int64>(iters) * total_size0 * \ - total_size1); \ auto label = \ strings::Printf("2-D %d chunks in dim %d totaling (%d * %d)", \ num_split, split_dim, total_size0, total_size1); \ - testing::SetLabel(label); \ - testing::UseRealTime(); \ + state.SetLabel(label); \ auto g = MakeGraph( \ split_dim, \ GenerateRandomIntsWithSum(total_size_vec[split_dim], num_split), \ {total_size0, total_size1}); \ - testing::StartTiming(); \ - test::Benchmark("cpu", g).Run(iters); \ + test::Benchmark("cpu", g, /*old_benchmark_api*/ false).Run(state); \ + state.SetItemsProcessed(static_cast<int64>(state.iterations()) * \ + total_size0 * total_size1); \ } \ BENCHMARK( \ - BM_SplitV_2d_##split_dim##_##num_split##_##total_size0##_##total_size1); + BM_SplitV_2d_##split_dim##_##num_split##_##total_size0##_##total_size1) \ + ->UseRealTime(); BM_SPLITV_1D(5, 20); BM_SPLITV_1D(262144, 1000000); diff --git a/tensorflow/core/kernels/strided_slice_op_test.cc b/tensorflow/core/kernels/strided_slice_op_test.cc index 281ca0f58fe..78f0e47c31e 100644 --- a/tensorflow/core/kernels/strided_slice_op_test.cc +++ b/tensorflow/core/kernels/strided_slice_op_test.cc @@ -38,8 +38,8 @@ namespace { // For the benchmark, we set up two 2-dimensional tensors, each kDim1 x 'dim' // in size, and concat them together along "concat_dimension" template <typename T> -static void SliceHelper(int iters, int size) { - testing::StopTiming(); +static void SliceHelper(::testing::benchmark::State& state) { + const int size = state.range(0); Graph* g = new Graph(OpRegistry::Global()); DataType dt = DataTypeToEnum<T>::v(); int kDim = 100; @@ -70,32 +70,30 @@ static void SliceHelper(int iters, int size) { .Attr("T", dt) .Finalize(g, &node)); - testing::BytesProcessed(static_cast<int64>(iters) * kDim * size * sizeof(T)); - testing::StartTiming(); - test::Benchmark("cpu", g).Run(iters); - testing::UseRealTime(); + test::Benchmark("cpu", g, /*old_benchmark_api*/ false).Run(state); + state.SetBytesProcessed(static_cast<int64>(state.iterations()) * kDim * size * + sizeof(T)); } -static void BM_SliceFloat(int iters, int dim2) { - SliceHelper<float>(iters, dim2); +void BM_SliceFloat(::testing::benchmark::State& state) { + SliceHelper<float>(state); } -BENCHMARK(BM_SliceFloat)->Arg(100)->Arg(1000)->Arg(10000); +BENCHMARK(BM_SliceFloat)->UseRealTime()->Arg(100)->Arg(1000)->Arg(10000); -static void BM_SliceComplex64(int iters, int dim2) { - SliceHelper<std::complex<float>>(iters, dim2); +void BM_SliceComplex64(::testing::benchmark::State& state) { + SliceHelper<std::complex<float>>(state); } -BENCHMARK(BM_SliceComplex64)->Arg(100)->Arg(1000)->Arg(10000); +BENCHMARK(BM_SliceComplex64)->UseRealTime()->Arg(100)->Arg(1000)->Arg(10000); -static void BM_SliceBFloat16(int iters, int dim2) { - SliceHelper<bfloat16>(iters, dim2); +void BM_SliceBFloat16(::testing::benchmark::State& state) { + SliceHelper<bfloat16>(state); } -BENCHMARK(BM_SliceBFloat16)->Arg(100)->Arg(1000)->Arg(10000); +BENCHMARK(BM_SliceBFloat16)->UseRealTime()->Arg(100)->Arg(1000)->Arg(10000); -static void BM_ValidateStridedSliceOp(int iters) { - testing::StopTiming(); +void BM_ValidateStridedSliceOp(::testing::benchmark::State& state) { int kDim = 100; int kMaxSize = 15000; int size = 100; @@ -104,8 +102,7 @@ static void BM_ValidateStridedSliceOp(int iters) { Tensor strides = test::AsTensor<int32>({1, 1}); TensorShape input_shape({2 * kDim, kMaxSize}); - testing::StartTiming(); - for (int i = 0; i < iters; ++i) { + for (auto s : state) { TensorShape processing_shape, final_shape; bool is_identity = true, slice_dim0 = true, is_simple_slice = true; gtl::InlinedVector<int64, 4> begin_out, end_out, strides_out; diff --git a/tensorflow/core/kernels/string_split_op_test.cc b/tensorflow/core/kernels/string_split_op_test.cc index 4494cf9dcf3..2aed21db4af 100644 --- a/tensorflow/core/kernels/string_split_op_test.cc +++ b/tensorflow/core/kernels/string_split_op_test.cc @@ -76,17 +76,17 @@ Graph* SetupStringSplitGraph(const Tensor& input) { return g; } -void BM_StringSplit(int iters, int batch_size) { - testing::StopTiming(); - testing::ItemsProcessed(static_cast<int64>(iters)); - testing::UseRealTime(); +static void BM_StringSplit(::testing::benchmark::State& state) { + const int batch_size = state.range(0); + Tensor input = GetTestTensor(batch_size); Graph* g = SetupStringSplitGraph(input); - testing::StartTiming(); - test::Benchmark("cpu", g).Run(iters); + test::Benchmark("cpu", g, /*old_benchmark_api*/ false).Run(state); + state.SetItemsProcessed(static_cast<int64>(state.iterations())); } BENCHMARK(BM_StringSplit) + ->UseRealTime() ->Arg(1) ->Arg(8) ->Arg(16) @@ -107,17 +107,17 @@ Graph* SetupStringSplitV2Graph(const Tensor& input) { return g; } -void BM_StringSplitV2(int iters, int batch_size) { - testing::StopTiming(); - testing::ItemsProcessed(static_cast<int64>(iters)); - testing::UseRealTime(); +static void BM_StringSplitV2(::testing::benchmark::State& state) { + const int batch_size = state.range(0); + Tensor input = GetTestTensor(batch_size); Graph* g = SetupStringSplitV2Graph(input); - testing::StartTiming(); - test::Benchmark("cpu", g).Run(iters); + test::Benchmark("cpu", g, /*old_benchmark_api*/ false).Run(state); + state.SetItemsProcessed(static_cast<int64>(state.iterations())); } BENCHMARK(BM_StringSplitV2) + ->UseRealTime() ->Arg(1) ->Arg(8) ->Arg(16) diff --git a/tensorflow/core/kernels/substr_op_test.cc b/tensorflow/core/kernels/substr_op_test.cc index 3aebfe3a212..02ac6503cae 100644 --- a/tensorflow/core/kernels/substr_op_test.cc +++ b/tensorflow/core/kernels/substr_op_test.cc @@ -149,27 +149,26 @@ Graph* SetupSubstrGraph(const Tensor& input, const int32 pos, const int32 len, return g; } -void BM_SubstrByte(int iters, int batch_size) { - testing::StopTiming(); - testing::ItemsProcessed(static_cast<int64>(iters)); - testing::UseRealTime(); +static void BM_SubstrByte(::testing::benchmark::State& state) { + const int batch_size = state.range(0); + Tensor input = GetTestTensor(batch_size); Graph* g = SetupSubstrGraph(input, 3, 30, kByteUnit); - testing::StartTiming(); - test::Benchmark("cpu", g).Run(iters); + test::Benchmark("cpu", g, /*old_benchmark_api*/ false).Run(state); + state.SetItemsProcessed(state.iterations()); } -void BM_SubstrUTF8(int iters, int batch_size) { - testing::StopTiming(); - testing::ItemsProcessed(static_cast<int64>(iters)); - testing::UseRealTime(); +static void BM_SubstrUTF8(::testing::benchmark::State& state) { + const int batch_size = state.range(0); + Tensor input = GetTestUTF8Tensor(batch_size); Graph* g = SetupSubstrGraph(input, 3, 30, kUTF8Unit); - testing::StartTiming(); - test::Benchmark("cpu", g).Run(iters); + test::Benchmark("cpu", g, /*old_benchmark_api*/ false).Run(state); + state.SetItemsProcessed(state.iterations()); } BENCHMARK(BM_SubstrByte) + ->UseRealTime() ->Arg(1) ->Arg(8) ->Arg(16) @@ -178,6 +177,7 @@ BENCHMARK(BM_SubstrByte) ->Arg(128) ->Arg(256); BENCHMARK(BM_SubstrUTF8) + ->UseRealTime() ->Arg(1) ->Arg(8) ->Arg(16) diff --git a/tensorflow/core/kernels/training_ops_test.cc b/tensorflow/core/kernels/training_ops_test.cc index a92a7b29984..364fc84c507 100644 --- a/tensorflow/core/kernels/training_ops_test.cc +++ b/tensorflow/core/kernels/training_ops_test.cc @@ -103,14 +103,18 @@ static void SGD(int32 n, Graph** init_g, Graph** train_g) { } } -static void BM_SGD(int iters, int params) { - const int64 tot = static_cast<int64>(iters) * params; - testing::ItemsProcessed(tot); - testing::BytesProcessed(tot * sizeof(float)); +static void BM_SGD(::testing::benchmark::State& state) { + const int params = state.range(0); + Graph* init; Graph* train; SGD(params, &init, &train); - test::Benchmark("cpu", train, GetOptions(), init).Run(iters); + test::Benchmark("cpu", train, GetOptions(), init, nullptr, "", + /*old_benchmark_api*/ false) + .Run(state); + const int64 tot = static_cast<int64>(state.iterations()) * params; + state.SetItemsProcessed(tot); + state.SetBytesProcessed(tot * sizeof(float)); } BENCHMARK(BM_SGD)->Arg(128 << 10)->Arg(256 << 10); @@ -135,14 +139,18 @@ static void Adagrad(int32 n, Graph** init_g, Graph** train_g) { } } -static void BM_Adagrad(int iters, int params) { - const int64 tot = static_cast<int64>(iters) * params; - testing::ItemsProcessed(tot); - testing::BytesProcessed(tot * sizeof(float)); +static void BM_Adagrad(::testing::benchmark::State& state) { + const int params = state.range(0); + Graph* init; Graph* train; Adagrad(params, &init, &train); - test::Benchmark("cpu", train, GetOptions(), init).Run(iters); + test::Benchmark("cpu", train, GetOptions(), init, nullptr, "", + /*old_benchmark_api*/ false) + .Run(state); + const int64 tot = static_cast<int64>(state.iterations()) * params; + state.SetItemsProcessed(tot); + state.SetBytesProcessed(tot * sizeof(float)); } BENCHMARK(BM_Adagrad)->Arg(128 << 10)->Arg(256 << 10); @@ -168,17 +176,22 @@ static void SparseAdagrad(int32 m, int32 n, Graph** init_g, Graph** train_g) { *train_g = g; } } -static void BM_SparseAdagrad(int iters, int m, int n) { - const int64 tot = static_cast<int64>(iters) * m * n; - testing::UseRealTime(); - testing::ItemsProcessed(tot); - testing::BytesProcessed(tot * sizeof(float)); +static void BM_SparseAdagrad(::testing::benchmark::State& state) { + const int m = state.range(0); + const int n = state.range(1); + Graph* init; Graph* train; SparseAdagrad(m, n, &init, &train); - test::Benchmark("cpu", train, GetMultiThreadedOptions(), init).Run(iters); + test::Benchmark("cpu", train, GetMultiThreadedOptions(), init, nullptr, "", + /*old_benchmark_api*/ false) + .Run(state); + const int64 tot = static_cast<int64>(state.iterations()) * m * n; + state.SetItemsProcessed(tot); + state.SetBytesProcessed(tot * sizeof(float)); } BENCHMARK(BM_SparseAdagrad) + ->UseRealTime() ->ArgPair(128, 1 << 10) ->ArgPair(128, 4 << 10) ->ArgPair(128, 8 << 10) @@ -208,14 +221,18 @@ static void Momentum(int32 n, Graph** init_g, Graph** train_g) { } } -static void BM_Momentum(int iters, int params) { - const int64 tot = static_cast<int64>(iters) * params; - testing::ItemsProcessed(tot); - testing::BytesProcessed(tot * sizeof(float)); +static void BM_Momentum(::testing::benchmark::State& state) { + const int params = state.range(0); + Graph* init; Graph* train; Momentum(params, &init, &train); - test::Benchmark("cpu", train, GetOptions(), init).Run(iters); + test::Benchmark("cpu", train, GetOptions(), init, nullptr, "", + /*old_benchmark_api*/ false) + .Run(state); + const int64 tot = static_cast<int64>(state.iterations()) * params; + state.SetItemsProcessed(tot); + state.SetBytesProcessed(tot * sizeof(float)); } BENCHMARK(BM_Momentum)->Arg(128 << 10)->Arg(256 << 10); @@ -251,19 +268,26 @@ static void Adam(int32 n, Graph** init_g, Graph** train_g) { } } -static void BM_Adam(int iters, int params, int is_multi_threaded) { - const int64 tot = static_cast<int64>(iters) * params; - testing::ItemsProcessed(tot); - testing::BytesProcessed(tot * sizeof(float)); +static void BM_Adam(::testing::benchmark::State& state) { + const int params = state.range(0); + const int is_multi_threaded = state.range(1); + Graph* init; Graph* train; Adam(params, &init, &train); if (is_multi_threaded) { // Use max thread number if test performance. - test::Benchmark("cpu", train, nullptr, init).Run(iters); + test::Benchmark("cpu", train, nullptr, init, nullptr, "", + /*old_benchmark_api*/ false) + .Run(state); } else { - test::Benchmark("cpu", train, GetOptions(), init).Run(iters); + test::Benchmark("cpu", train, GetOptions(), init, nullptr, "", + /*old_benchmark_api*/ false) + .Run(state); } + const int64 tot = static_cast<int64>(state.iterations()) * params; + state.SetItemsProcessed(tot); + state.SetBytesProcessed(tot * sizeof(float)); } BENCHMARK(BM_Adam)->ArgPair(128 << 10, 0)->ArgPair(256 << 10, 0); BENCHMARK(BM_Adam)->ArgPair(256 << 5, 1)->ArgPair(256 << 16, 1); @@ -297,14 +321,18 @@ static void RMSProp(int32 n, Graph** init_g, Graph** train_g) { } } -static void BM_RMSProp(int iters, int params) { - const int64 tot = static_cast<int64>(iters) * params; - testing::ItemsProcessed(tot); - testing::BytesProcessed(tot * sizeof(float)); +static void BM_RMSProp(::testing::benchmark::State& state) { + const int params = state.range(0); + Graph* init; Graph* train; RMSProp(params, &init, &train); - test::Benchmark("cpu", train, GetOptions(), init).Run(iters); + test::Benchmark("cpu", train, GetOptions(), init, nullptr, "", + /*old_benhcmark_api*/ false) + .Run(state); + const int64 tot = static_cast<int64>(state.iterations()) * params; + state.SetItemsProcessed(tot); + state.SetBytesProcessed(tot * sizeof(float)); } BENCHMARK(BM_RMSProp)->Arg(128 << 10)->Arg(256 << 10); @@ -334,14 +362,18 @@ static void AddSign(int32 n, Graph** init_g, Graph** train_g) { } } -static void BM_AddSign(int iters, int params) { - const int64 tot = static_cast<int64>(iters) * params; - testing::ItemsProcessed(tot); - testing::BytesProcessed(tot * sizeof(float)); +static void BM_AddSign(::testing::benchmark::State& state) { + const int params = state.range(0); + Graph* init; Graph* train; AddSign(params, &init, &train); - test::Benchmark("cpu", train, GetOptions(), init).Run(iters); + test::Benchmark("cpu", train, GetOptions(), init, nullptr, "", + /*old_benhcmark_api*/ false) + .Run(state); + const int64 tot = static_cast<int64>(state.iterations()) * params; + state.SetItemsProcessed(tot); + state.SetBytesProcessed(tot * sizeof(float)); } BENCHMARK(BM_AddSign)->Arg(128 << 10)->Arg(256 << 10); @@ -371,14 +403,19 @@ static void PowerSign(int32 n, Graph** init_g, Graph** train_g) { } } -static void BM_PowerSign(int iters, int params) { - const int64 tot = static_cast<int64>(iters) * params; - testing::ItemsProcessed(tot); - testing::BytesProcessed(tot * sizeof(float)); +static void BM_PowerSign(::testing::benchmark::State& state) { + const int params = state.range(0); + Graph* init; Graph* train; PowerSign(params, &init, &train); - test::Benchmark("cpu", train, GetOptions(), init).Run(iters); + test::Benchmark("cpu", train, GetOptions(), init, nullptr, "", + /*old_benhcmark_api*/ false) + .Run(state); + + const int64 tot = static_cast<int64>(state.iterations()) * params; + state.SetItemsProcessed(tot); + state.SetBytesProcessed(tot * sizeof(float)); } BENCHMARK(BM_PowerSign)->Arg(128 << 10)->Arg(256 << 10); diff --git a/tensorflow/core/kernels/unary_ops_composition_test.cc b/tensorflow/core/kernels/unary_ops_composition_test.cc index 807dc56e3e7..3110f435038 100644 --- a/tensorflow/core/kernels/unary_ops_composition_test.cc +++ b/tensorflow/core/kernels/unary_ops_composition_test.cc @@ -108,11 +108,15 @@ static Graph* UnaryOpsChain(int tensor_size, int repeat_graph, return g; } -#define BM_UnaryOpsChain(N, R, F, type) \ - static void BM_UnaryOpsChain##_##type##_##N##_##R##_##F(int iters) { \ - testing::ItemsProcessed(static_cast<int64>(iters) * N * R * F); \ - test::Benchmark(#type, UnaryOpsChain(N, R, F)).Run(iters); \ - } \ +#define BM_UnaryOpsChain(N, R, F, type) \ + static void BM_UnaryOpsChain##_##type##_##N##_##R##_##F( \ + ::testing::benchmark::State& state) { \ + test::Benchmark(#type, UnaryOpsChain(N, R, F), \ + /*old_benchmark_api*/ false) \ + .Run(state); \ + state.SetItemsProcessed(static_cast<int64>(state.iterations()) * N * R * \ + F); \ + } \ BENCHMARK(BM_UnaryOpsChain##_##type##_##N##_##R##_##F); // Unary ops fused together. @@ -140,11 +144,15 @@ static Graph* UnaryOpsCompo(int tensor_size, int repeat_graph, return g; } -#define BM_UnaryOpsCompo(N, R, F, type) \ - static void BM_UnaryOpsCompo##_##type##_##N##_##R##_##F(int iters) { \ - testing::ItemsProcessed(static_cast<int64>(iters) * N * R * F); \ - test::Benchmark(#type, UnaryOpsCompo(N, R, F)).Run(iters); \ - } \ +#define BM_UnaryOpsCompo(N, R, F, type) \ + static void BM_UnaryOpsCompo##_##type##_##N##_##R##_##F( \ + ::testing::benchmark::State& state) { \ + test::Benchmark(#type, UnaryOpsCompo(N, R, F), \ + /*old_benchmark_api*/ false) \ + .Run(state); \ + state.SetItemsProcessed(static_cast<int64>(state.iterations()) * N * R * \ + F); \ + } \ BENCHMARK(BM_UnaryOpsCompo##_##type##_##N##_##R##_##F); // BenchmarkName(tensor_size, repeat_graph, num_ops, type) diff --git a/tensorflow/core/kernels/unique_op_test.cc b/tensorflow/core/kernels/unique_op_test.cc index a0249d9bc4c..590bd7f8c39 100644 --- a/tensorflow/core/kernels/unique_op_test.cc +++ b/tensorflow/core/kernels/unique_op_test.cc @@ -64,8 +64,10 @@ TensorProto GetRandomInt32TensorProtoWithRepeat(int dim, int repeat, return tensor_proto; } -static void BM_Unique_INT32(int iters, int dim, int max_int) { - testing::StopTiming(); +void BM_Unique_INT32(::testing::benchmark::State& state) { + const int dim = state.range(0); + const int max_int = state.range(1); + Graph* g = new Graph(OpRegistry::Global()); Tensor input(DT_INT32, TensorShape({dim})); @@ -78,16 +80,17 @@ static void BM_Unique_INT32(int iters, int dim, int max_int) { .Finalize(g, &node)); FixupSourceAndSinkEdges(g); - testing::BytesProcessed(static_cast<int64>(iters) * dim * sizeof(int32)); - testing::UseRealTime(); - testing::StartTiming(); test::Benchmark("cpu", g, nullptr, nullptr, nullptr, - "SINGLE_THREADED_EXECUTOR") - .Run(iters); + "SINGLE_THREADED_EXECUTOR", /*old_benchmark_api*/ false) + .Run(state); + state.SetBytesProcessed(static_cast<int64>(state.iterations()) * dim * + sizeof(int32)); } -static void BM_Unique_INT32_Repeat(int iters, int dim, int max_int) { - testing::StopTiming(); +void BM_Unique_INT32_Repeat(::testing::benchmark::State& state) { + const int dim = state.range(0); + const int max_int = state.range(1); + Graph* g = new Graph(OpRegistry::Global()); Tensor input(DT_INT32, TensorShape({dim * 200})); @@ -101,13 +104,11 @@ static void BM_Unique_INT32_Repeat(int iters, int dim, int max_int) { .Finalize(g, &node)); FixupSourceAndSinkEdges(g); - testing::BytesProcessed(static_cast<int64>(iters) * dim * 200 * - sizeof(int32)); - testing::UseRealTime(); - testing::StartTiming(); test::Benchmark("cpu", g, nullptr, nullptr, nullptr, - "SINGLE_THREADED_EXECUTOR") - .Run(iters); + "SINGLE_THREADED_EXECUTOR", /*old_benchmark_api*/ false) + .Run(state); + state.SetBytesProcessed(static_cast<int64>(state.iterations()) * dim * 200 * + sizeof(int32)); } TensorProto GetRandomStringsTensorProto(int dim, int max_str_len) { @@ -127,8 +128,9 @@ TensorProto GetRandomStringsTensorProto(int dim, int max_str_len) { return tensor_proto; } -static void BM_Unique_STRING(int iters, int dim) { - testing::StopTiming(); +void BM_Unique_STRING(::testing::benchmark::State& state) { + const int dim = state.range(0); + Graph* g = new Graph(OpRegistry::Global()); Tensor input(DT_STRING, TensorShape({dim})); @@ -140,16 +142,15 @@ static void BM_Unique_STRING(int iters, int dim) { .Attr("T", DT_STRING) .Finalize(g, &node)); FixupSourceAndSinkEdges(g); - - testing::BytesProcessed(static_cast<int64>(iters) * dim * sizeof(tstring)); - testing::UseRealTime(); - testing::StartTiming(); test::Benchmark("cpu", g, nullptr, nullptr, nullptr, - "SINGLE_THREADED_EXECUTOR") - .Run(iters); + "SINGLE_THREADED_EXECUTOR", /*old_benchmark_api*/ false) + .Run(state); + state.SetBytesProcessed(static_cast<int64>(state.iterations()) * dim * + sizeof(tstring)); } BENCHMARK(BM_Unique_INT32) + ->UseRealTime() ->ArgPair(32, 1024 * 1024) ->ArgPair(256, 1024 * 1024) ->ArgPair(1024, 1024 * 1024) @@ -168,6 +169,7 @@ BENCHMARK(BM_Unique_INT32) ->ArgPair(4 * 1024 * 1024, 64 * 1024 * 1024); BENCHMARK(BM_Unique_INT32_Repeat) + ->UseRealTime() ->ArgPair(32, 1024 * 1024) ->ArgPair(256, 1024 * 1024) ->ArgPair(1024, 1024 * 1024) @@ -192,6 +194,7 @@ BENCHMARK(BM_Unique_INT32_Repeat) ->ArgPair(1024 * 1024, 64 * 1024 * 1024); BENCHMARK(BM_Unique_STRING) + ->UseRealTime() ->Arg(32) ->Arg(256) ->Arg(1024) diff --git a/tensorflow/core/kernels/variable_ops_test.cc b/tensorflow/core/kernels/variable_ops_test.cc index 7a615788cc9..0a814aab1db 100644 --- a/tensorflow/core/kernels/variable_ops_test.cc +++ b/tensorflow/core/kernels/variable_ops_test.cc @@ -28,8 +28,8 @@ namespace { // Benchmark to simulate the overhead in training and serving workloads from too // many threads grabbing the ResourceMgr lock at the same time because of the // variable and queue ops. -void ManyManyVariablesHelper(int threads, int variables, int iters) { - testing::StopTiming(); +void ManyManyVariablesHelper(int threads, int variables, + ::testing::benchmark::State& state) { Graph g(OpRegistry::Global()); std::vector<string> targets; for (int i = 0; i < variables; ++i) { @@ -50,16 +50,16 @@ void ManyManyVariablesHelper(int threads, int variables, int iters) { Session* sess = NewSession(opts); TF_CHECK_OK(sess->Create(gd)); TF_CHECK_OK(sess->Run({}, {}, targets, nullptr)); - testing::StartTiming(); - for (int i = 0; i < iters; ++i) { + for (auto s : state) { TF_CHECK_OK(sess->Run({}, {}, targets, nullptr)); } - testing::StopTiming(); delete sess; } -void BM_ManyManyVariablesManyThreads(int iters, int threads) { - ManyManyVariablesHelper(threads, 1000, iters); +void BM_ManyManyVariablesManyThreads(::testing::benchmark::State& state) { + const int threads = state.range(0); + + ManyManyVariablesHelper(threads, 1000, state); } BENCHMARK(BM_ManyManyVariablesManyThreads)->Arg(50); diff --git a/tensorflow/core/kernels/xent_op_test.cc b/tensorflow/core/kernels/xent_op_test.cc index b844979adfa..ec87e85e810 100644 --- a/tensorflow/core/kernels/xent_op_test.cc +++ b/tensorflow/core/kernels/xent_op_test.cc @@ -33,11 +33,14 @@ static Graph* Xent(int batch_size, int num_classes) { return g; } -#define BM_XentDev(BATCH, CLASS, DEVICE) \ - static void BM_Xent##_##BATCH##_##CLASS##_##DEVICE(int iters) { \ - testing::ItemsProcessed(static_cast<int64>(iters) * BATCH * CLASS); \ - test::Benchmark(#DEVICE, Xent(BATCH, CLASS)).Run(iters); \ - } \ +#define BM_XentDev(BATCH, CLASS, DEVICE) \ + static void BM_Xent##_##BATCH##_##CLASS##_##DEVICE( \ + ::testing::benchmark::State& state) { \ + test::Benchmark(#DEVICE, Xent(BATCH, CLASS), /*old_benchmark_api*/ false) \ + .Run(state); \ + state.SetItemsProcessed(static_cast<int64>(state.iterations()) * BATCH * \ + CLASS); \ + } \ BENCHMARK(BM_Xent##_##BATCH##_##CLASS##_##DEVICE); /// The representative tests for ptb_word on GPU diff --git a/tensorflow/core/ops/collective_ops.cc b/tensorflow/core/ops/collective_ops.cc index ecaab00c91a..9033c7154ff 100644 --- a/tensorflow/core/ops/collective_ops.cc +++ b/tensorflow/core/ops/collective_ops.cc @@ -111,10 +111,12 @@ REGISTER_OP("CollectiveReduceV2") .Input("group_size: int32") .Input("group_key: int32") .Input("instance_key: int32") + .Input("ordering_token: Nordering_token * resource") .Attr("merge_op: {'Min', 'Max', 'Mul', 'Add'}") .Attr("final_op: {'Id', 'Div'}") .Attr("communication_hint: string = 'auto'") .Attr("timeout_seconds: float = 0") + .Attr("Nordering_token: int >= 0 = 0") .SetIsStateful() .SetShapeFn(shape_inference::UnchangedShape); @@ -125,8 +127,10 @@ REGISTER_OP("CollectiveGatherV2") .Input("group_size: int32") .Input("group_key: int32") .Input("instance_key: int32") + .Input("ordering_token: Nordering_token * resource") .Attr("communication_hint: string = 'auto'") .Attr("timeout_seconds: float = 0") + .Attr("Nordering_token: int >= 0 = 0") .SetIsStateful() .SetShapeFn([](shape_inference::InferenceContext* c) { // Scalar input is not supported. diff --git a/tensorflow/core/ops/compat/ops_history_v2/CollectiveGatherV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/CollectiveGatherV2.pbtxt index 8a081e34d34..4d473ac82e2 100644 --- a/tensorflow/core/ops/compat/ops_history_v2/CollectiveGatherV2.pbtxt +++ b/tensorflow/core/ops/compat/ops_history_v2/CollectiveGatherV2.pbtxt @@ -49,3 +49,67 @@ op { } is_stateful: true } +op { + name: "CollectiveGatherV2" + input_arg { + name: "input" + type_attr: "T" + } + input_arg { + name: "group_size" + type: DT_INT32 + } + input_arg { + name: "group_key" + type: DT_INT32 + } + input_arg { + name: "instance_key" + type: DT_INT32 + } + input_arg { + name: "ordering_token" + type: DT_RESOURCE + number_attr: "Nordering_token" + } + output_arg { + name: "data" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_HALF + type: DT_DOUBLE + type: DT_INT32 + type: DT_INT64 + } + } + } + attr { + name: "communication_hint" + type: "string" + default_value { + s: "auto" + } + } + attr { + name: "timeout_seconds" + type: "float" + default_value { + f: 0 + } + } + attr { + name: "Nordering_token" + type: "int" + default_value { + i: 0 + } + has_minimum: true + } + is_stateful: true +} diff --git a/tensorflow/core/ops/compat/ops_history_v2/CollectiveReduceV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/CollectiveReduceV2.pbtxt index b2751cc59e8..bdb99f807b6 100644 --- a/tensorflow/core/ops/compat/ops_history_v2/CollectiveReduceV2.pbtxt +++ b/tensorflow/core/ops/compat/ops_history_v2/CollectiveReduceV2.pbtxt @@ -137,3 +137,89 @@ op { } is_stateful: true } +op { + name: "CollectiveReduceV2" + input_arg { + name: "input" + type_attr: "T" + } + input_arg { + name: "group_size" + type: DT_INT32 + } + input_arg { + name: "group_key" + type: DT_INT32 + } + input_arg { + name: "instance_key" + type: DT_INT32 + } + input_arg { + name: "ordering_token" + type: DT_RESOURCE + number_attr: "Nordering_token" + } + output_arg { + name: "data" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_HALF + type: DT_DOUBLE + type: DT_INT32 + type: DT_INT64 + } + } + } + attr { + name: "merge_op" + type: "string" + allowed_values { + list { + s: "Min" + s: "Max" + s: "Mul" + s: "Add" + } + } + } + attr { + name: "final_op" + type: "string" + allowed_values { + list { + s: "Id" + s: "Div" + } + } + } + attr { + name: "communication_hint" + type: "string" + default_value { + s: "auto" + } + } + attr { + name: "timeout_seconds" + type: "float" + default_value { + f: 0 + } + } + attr { + name: "Nordering_token" + type: "int" + default_value { + i: 0 + } + has_minimum: true + } + is_stateful: true +} diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index 9886e92b893..d7d9efead62 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -7569,6 +7569,11 @@ op { name: "instance_key" type: DT_INT32 } + input_arg { + name: "ordering_token" + type: DT_RESOURCE + number_attr: "Nordering_token" + } output_arg { name: "data" type_attr: "T" @@ -7600,6 +7605,14 @@ op { f: 0 } } + attr { + name: "Nordering_token" + type: "int" + default_value { + i: 0 + } + has_minimum: true + } is_stateful: true } op { @@ -7745,6 +7758,11 @@ op { name: "instance_key" type: DT_INT32 } + input_arg { + name: "ordering_token" + type: DT_RESOURCE + number_attr: "Nordering_token" + } output_arg { name: "data" type_attr: "T" @@ -7798,6 +7816,14 @@ op { f: 0 } } + attr { + name: "Nordering_token" + type: "int" + default_value { + i: 0 + } + has_minimum: true + } is_stateful: true } op { diff --git a/tensorflow/core/profiler/internal/cpu/host_tracer.cc b/tensorflow/core/profiler/internal/cpu/host_tracer.cc index ba882067463..e640d0763a6 100644 --- a/tensorflow/core/profiler/internal/cpu/host_tracer.cc +++ b/tensorflow/core/profiler/internal/cpu/host_tracer.cc @@ -82,11 +82,15 @@ Status HostTracer::Start() { if (recording_) { return errors::Internal("TraceMeRecorder already started"); } + + // All TraceMe captured should have a timestamp greater or equal to + // start_timestamp_ns_ to prevent timestamp underflow in XPlane. + // Therefore this have to be done before TraceMeRecorder::Start. + start_timestamp_ns_ = EnvTime::NowNanos(); recording_ = TraceMeRecorder::Start(host_trace_level_); if (!recording_) { return errors::Internal("Failed to start TraceMeRecorder"); } - start_timestamp_ns_ = EnvTime::NowNanos(); return Status::OK(); } diff --git a/tensorflow/core/profiler/internal/cpu/host_tracer_utils.cc b/tensorflow/core/profiler/internal/cpu/host_tracer_utils.cc index 4f12776f581..f3d128f4539 100644 --- a/tensorflow/core/profiler/internal/cpu/host_tracer_utils.cc +++ b/tensorflow/core/profiler/internal/cpu/host_tracer_utils.cc @@ -80,6 +80,7 @@ void ConvertCompleteEventsToXPlane(uint64 start_timestamp_ns, xline.ReserveEvents(thread.events.size()); for (const auto& event : thread.events) { if (!IsCompleteEvent(event)) continue; + if (event.start_time < start_timestamp_ns) continue; Annotation annotation = ParseAnnotation(event.name); XEventMetadata* xevent_metadata = xplane.GetOrCreateEventMetadata(annotation.name); diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index cf6112e3e7d..25899a7f9f9 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -108,7 +108,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 569 // Updated: 2020/10/29 +#define TF_GRAPH_DEF_VERSION 570 // Updated: 2020/10/30 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index 47a68aad644..08401cff702 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -19370,7 +19370,7 @@ func CollectiveReduceV2TimeoutSeconds(value float32) CollectiveReduceV2Attr { } // Mutually reduces multiple tensors of identical type and shape. -func CollectiveReduceV2(scope *Scope, input tf.Output, group_size tf.Output, group_key tf.Output, instance_key tf.Output, merge_op string, final_op string, optional ...CollectiveReduceV2Attr) (data tf.Output) { +func CollectiveReduceV2(scope *Scope, input tf.Output, group_size tf.Output, group_key tf.Output, instance_key tf.Output, ordering_token []tf.Output, merge_op string, final_op string, optional ...CollectiveReduceV2Attr) (data tf.Output) { if scope.Err() != nil { return } @@ -19381,7 +19381,7 @@ func CollectiveReduceV2(scope *Scope, input tf.Output, group_size tf.Output, gro opspec := tf.OpSpec{ Type: "CollectiveReduceV2", Input: []tf.Input{ - input, group_size, group_key, instance_key, + input, group_size, group_key, instance_key, tf.OutputList(ordering_token), }, Attrs: attrs, } @@ -50170,7 +50170,7 @@ func CollectiveGatherV2TimeoutSeconds(value float32) CollectiveGatherV2Attr { } // Mutually accumulates multiple tensors of identical type and shape. -func CollectiveGatherV2(scope *Scope, input tf.Output, group_size tf.Output, group_key tf.Output, instance_key tf.Output, optional ...CollectiveGatherV2Attr) (data tf.Output) { +func CollectiveGatherV2(scope *Scope, input tf.Output, group_size tf.Output, group_key tf.Output, instance_key tf.Output, ordering_token []tf.Output, optional ...CollectiveGatherV2Attr) (data tf.Output) { if scope.Err() != nil { return } @@ -50181,7 +50181,7 @@ func CollectiveGatherV2(scope *Scope, input tf.Output, group_size tf.Output, gro opspec := tf.OpSpec{ Type: "CollectiveGatherV2", Input: []tf.Input{ - input, group_size, group_key, instance_key, + input, group_size, group_key, instance_key, tf.OutputList(ordering_token), }, Attrs: attrs, } diff --git a/tensorflow/lite/builtin_ops.h b/tensorflow/lite/builtin_ops.h index 35952090e6c..71cefe62f92 100644 --- a/tensorflow/lite/builtin_ops.h +++ b/tensorflow/lite/builtin_ops.h @@ -156,6 +156,7 @@ typedef enum { kTfLiteBuiltinBatchMatmul = 126, kTfLiteBuiltinPlaceholderForGreaterOpCodes = 127, kTfLiteBuiltinCumsum = 128, + kTfLiteBuiltinCallOnce = 129, } TfLiteBuiltinOperator; #ifdef __cplusplus diff --git a/tensorflow/lite/c/builtin_op_data.h b/tensorflow/lite/c/builtin_op_data.h index a511e51b5bf..5452ef63748 100644 --- a/tensorflow/lite/c/builtin_op_data.h +++ b/tensorflow/lite/c/builtin_op_data.h @@ -470,6 +470,10 @@ typedef struct { bool reverse; } TfLiteCumsumParams; +typedef struct { + int init_subgraph_index; +} TfLiteCallOnceParams; + #ifdef __cplusplus } // extern "C" #endif // __cplusplus diff --git a/tensorflow/lite/core/api/flatbuffer_conversions.cc b/tensorflow/lite/core/api/flatbuffer_conversions.cc index ea381801505..dee2a990761 100644 --- a/tensorflow/lite/core/api/flatbuffer_conversions.cc +++ b/tensorflow/lite/core/api/flatbuffer_conversions.cc @@ -761,6 +761,16 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type, *builtin_data = params.release(); return kTfLiteOk; } + case BuiltinOperator_CALL_ONCE: { + auto params = safe_allocator.Allocate<TfLiteCallOnceParams>(); + TF_LITE_ENSURE(error_reporter, params != nullptr); + if (const auto* call_once_params = + op->builtin_options_as_CallOnceOptions()) { + params->init_subgraph_index = call_once_params->init_subgraph_index(); + } + *builtin_data = params.release(); + return kTfLiteOk; + } case BuiltinOperator_CUMSUM: { auto params = safe_allocator.Allocate<TfLiteCumsumParams>(); TF_LITE_ENSURE(error_reporter, params != nullptr); diff --git a/tensorflow/lite/delegates/flex/build_def.bzl b/tensorflow/lite/delegates/flex/build_def.bzl index 5826e1f83cd..53854463627 100644 --- a/tensorflow/lite/delegates/flex/build_def.bzl +++ b/tensorflow/lite/delegates/flex/build_def.bzl @@ -20,6 +20,7 @@ load( "tflite_jni_linkopts", ) load("@build_bazel_rules_android//android:rules.bzl", "android_library") +load("//tensorflow/lite:special_rules.bzl", "flex_portable_tensorflow_deps") def generate_flex_kernel_header( name, @@ -130,13 +131,7 @@ def tflite_flex_cc_library( clean_dep("//tensorflow/core/kernels:android_all_ops_textual_hdrs"), ], visibility = visibility, - deps = [ - "@com_google_absl//absl/strings:str_format", - "//third_party/fft2d:fft2d_headers", - "//third_party/eigen3", - "@com_google_absl//absl/types:optional", - "@gemmlowp", - "@icu//:common", + deps = flex_portable_tensorflow_deps() + [ clean_dep("//tensorflow/core:protos_all_cc"), clean_dep("//tensorflow/core:portable_tensorflow_lib_lite"), clean_dep("//tensorflow/core/platform:strong_hash"), diff --git a/tensorflow/lite/delegates/gpu/cl/BUILD b/tensorflow/lite/delegates/gpu/cl/BUILD index 598a94e2160..f1eca5ac044 100644 --- a/tensorflow/lite/delegates/gpu/cl/BUILD +++ b/tensorflow/lite/delegates/gpu/cl/BUILD @@ -52,7 +52,7 @@ cc_library( srcs = ["arguments.cc"], hdrs = ["arguments.h"], deps = [ - ":gpu_object", + ":gpu_object_desc", ":serialization_cc_fbs", "//tensorflow/lite/delegates/gpu/common:access_type", "//tensorflow/lite/delegates/gpu/common:data_type", @@ -371,10 +371,9 @@ cc_library( cc_library( name = "gpu_object", - srcs = ["gpu_object.cc"], hdrs = ["gpu_object.h"], deps = [ - ":cl_context", + ":gpu_object_desc", ":opencl_wrapper", ":serialization_cc_fbs", "//tensorflow/lite/delegates/gpu/common:access_type", @@ -383,6 +382,17 @@ cc_library( ], ) +cc_library( + name = "gpu_object_desc", + hdrs = ["gpu_object_desc.h"], + deps = [ + ":serialization_cc_fbs", + "//tensorflow/lite/delegates/gpu/common:access_type", + "//tensorflow/lite/delegates/gpu/common:data_type", + "//tensorflow/lite/delegates/gpu/common:status", + ], +) + cc_library( name = "inference_context", srcs = [ @@ -435,6 +445,7 @@ cc_library( srcs = ["linear_storage.cc"], hdrs = ["linear_storage.h"], deps = [ + ":cl_context", ":gpu_object", ":opencl_wrapper", ":tensor_type", @@ -567,9 +578,11 @@ cc_library( srcs = ["tensor_type.cc"], hdrs = ["tensor_type.h"], deps = [ - ":gpu_object", + ":gpu_object_desc", "//tensorflow/lite/delegates/gpu/common:data_type", "//tensorflow/lite/delegates/gpu/common:shape", + "//tensorflow/lite/delegates/gpu/common:tensor", + "//tensorflow/lite/delegates/gpu/common:util", "@com_google_absl//absl/strings", ], ) @@ -623,6 +636,7 @@ cc_library( srcs = ["util.cc"], hdrs = ["util.h"], deps = [ + ":gpu_object_desc", ":opencl_wrapper", "//tensorflow/lite/delegates/gpu/common:data_type", "//tensorflow/lite/delegates/gpu/common:status", diff --git a/tensorflow/lite/delegates/gpu/cl/arguments.h b/tensorflow/lite/delegates/gpu/cl/arguments.h index 3c4671212ec..66e2ce751b7 100644 --- a/tensorflow/lite/delegates/gpu/cl/arguments.h +++ b/tensorflow/lite/delegates/gpu/cl/arguments.h @@ -20,7 +20,7 @@ limitations under the License. #include <string> #include <vector> -#include "tensorflow/lite/delegates/gpu/cl/gpu_object.h" +#include "tensorflow/lite/delegates/gpu/cl/gpu_object_desc.h" #include "tensorflow/lite/delegates/gpu/cl/serialization_generated.h" #include "tensorflow/lite/delegates/gpu/common/access_type.h" #include "tensorflow/lite/delegates/gpu/common/status.h" @@ -36,8 +36,6 @@ class ArgumentsBinder { virtual absl::Status SetInt(const std::string& name, int value) = 0; virtual absl::Status SetFloat(const std::string& name, float value) = 0; virtual absl::Status SetHalf(const std::string& name, half value) = 0; - virtual absl::Status SetObjectRef(const std::string& name, - const GPUObject* object) = 0; virtual ~ArgumentsBinder() = default; }; diff --git a/tensorflow/lite/delegates/gpu/cl/cl_arguments.h b/tensorflow/lite/delegates/gpu/cl/cl_arguments.h index 5ce4b8dde23..4291771425e 100644 --- a/tensorflow/lite/delegates/gpu/cl/cl_arguments.h +++ b/tensorflow/lite/delegates/gpu/cl/cl_arguments.h @@ -58,8 +58,7 @@ class CLArguments : public ArgumentsBinder { absl::Status SetInt(const std::string& name, int value) override; absl::Status SetFloat(const std::string& name, float value) override; absl::Status SetHalf(const std::string& name, half value) override; - absl::Status SetObjectRef(const std::string& name, - const GPUObject* object) override; + absl::Status SetObjectRef(const std::string& name, const GPUObject* object); absl::Status Bind(cl_kernel kernel, int offset = 0); diff --git a/tensorflow/lite/delegates/gpu/cl/device_info.cc b/tensorflow/lite/delegates/gpu/cl/device_info.cc index 43d050e8371..1e90bd2d673 100644 --- a/tensorflow/lite/delegates/gpu/cl/device_info.cc +++ b/tensorflow/lite/delegates/gpu/cl/device_info.cc @@ -150,7 +150,13 @@ int AdrenoInfo::GetRegisterMemorySizePerComputeUnit() const { } else if (gpu_version >= 500 && gpu_version < 600) { return -1; // Adreno 5xx does not support it currently } else if (gpu_version >= 600 && gpu_version < 700) { - return gpu_version == 640 ? 128 * 144 * 16 : 128 * 96 * 16; + if (gpu_version == 640) { + return 128 * 144 * 16; + } else if (gpu_version == 650) { + return 128 * 64 * 16; + } else { + return 128 * 96 * 16; + } } else { return -1; // Adreno 7xx and higher does not exist yet } diff --git a/tensorflow/lite/delegates/gpu/cl/gpu_object.h b/tensorflow/lite/delegates/gpu/cl/gpu_object.h index 414e14256bd..d4f86de8bf2 100644 --- a/tensorflow/lite/delegates/gpu/cl/gpu_object.h +++ b/tensorflow/lite/delegates/gpu/cl/gpu_object.h @@ -21,7 +21,7 @@ limitations under the License. #include <string> #include <vector> -#include "tensorflow/lite/delegates/gpu/cl/cl_context.h" +#include "tensorflow/lite/delegates/gpu/cl/gpu_object_desc.h" #include "tensorflow/lite/delegates/gpu/cl/opencl_wrapper.h" #include "tensorflow/lite/delegates/gpu/cl/serialization_generated.h" #include "tensorflow/lite/delegates/gpu/common/access_type.h" @@ -32,78 +32,6 @@ namespace tflite { namespace gpu { namespace cl { -struct GPUImage2DDescriptor { - DataType data_type; - AccessType access_type; -}; - -struct GPUImage3DDescriptor { - DataType data_type; - AccessType access_type; -}; - -struct GPUImage2DArrayDescriptor { - DataType data_type; - AccessType access_type; -}; - -struct GPUImageBufferDescriptor { - DataType data_type; - AccessType access_type; -}; - -struct GPUCustomMemoryDescriptor { - std::string type_name; -}; - -enum class MemoryType { GLOBAL, CONSTANT, LOCAL }; - -std::string MemoryTypeToCLType(MemoryType type); - -struct GPUBufferDescriptor { - DataType data_type; - AccessType access_type; - int element_size; - MemoryType memory_type = MemoryType::GLOBAL; - std::vector<std::string> attributes; -}; - -struct GPUResources { - std::vector<std::string> ints; - std::vector<std::string> floats; - std::vector<std::pair<std::string, GPUBufferDescriptor>> buffers; - std::vector<std::pair<std::string, GPUImage2DDescriptor>> images2d; - std::vector<std::pair<std::string, GPUImage2DArrayDescriptor>> image2d_arrays; - std::vector<std::pair<std::string, GPUImage3DDescriptor>> images3d; - std::vector<std::pair<std::string, GPUImageBufferDescriptor>> image_buffers; - std::vector<std::pair<std::string, GPUCustomMemoryDescriptor>> - custom_memories; - - std::vector<std::string> GetNames() const { - std::vector<std::string> names = ints; - names.insert(names.end(), floats.begin(), floats.end()); - for (const auto& obj : buffers) { - names.push_back(obj.first); - } - for (const auto& obj : images2d) { - names.push_back(obj.first); - } - for (const auto& obj : image2d_arrays) { - names.push_back(obj.first); - } - for (const auto& obj : images3d) { - names.push_back(obj.first); - } - for (const auto& obj : image_buffers) { - names.push_back(obj.first); - } - for (const auto& obj : custom_memories) { - names.push_back(obj.first); - } - return names; - } -}; - struct GPUResourcesWithValue { std::vector<std::pair<std::string, int>> ints; std::vector<std::pair<std::string, float>> floats; @@ -115,56 +43,6 @@ struct GPUResourcesWithValue { std::vector<std::pair<std::string, cl_mem>> custom_memories; }; -class GPUObject; - -class GPUObjectDescriptor { - public: - GPUObjectDescriptor() = default; - GPUObjectDescriptor(const GPUObjectDescriptor&) = default; - GPUObjectDescriptor& operator=(const GPUObjectDescriptor&) = default; - GPUObjectDescriptor(GPUObjectDescriptor&& obj_desc) - : state_vars_(std::move(obj_desc.state_vars_)) {} - GPUObjectDescriptor& operator=(GPUObjectDescriptor&& obj_desc) { - if (this != &obj_desc) { - state_vars_ = std::move(obj_desc.state_vars_); - } - return *this; - } - virtual ~GPUObjectDescriptor() = default; - - void SetStateVar(const std::string& key, const std::string& value) const { - state_vars_[key] = value; - } - - virtual std::string PerformConstExpr(const std::string& const_expr) const { - return ""; - } - - virtual absl::Status PerformSelector( - const std::string& selector, const std::vector<std::string>& args, - const std::vector<std::string>& template_args, - std::string* result) const { - *result = ""; - return absl::OkStatus(); - } - virtual GPUResources GetGPUResources() const { return GPUResources(); } - - virtual void Release() {} - - void SetAccess(AccessType access_type) { access_type_ = access_type; } - AccessType GetAccess() const { return access_type_; } - - protected: - friend flatbuffers::Offset<data::GPUObjectDescriptor> Encode( - const GPUObjectDescriptor& desc, flatbuffers::FlatBufferBuilder* builder); - friend void Decode(const data::GPUObjectDescriptor* fb_obj, - GPUObjectDescriptor* obj); - mutable std::map<std::string, std::string> state_vars_; - AccessType access_type_; -}; - -using GPUObjectDescriptorPtr = std::unique_ptr<GPUObjectDescriptor>; - class GPUObject { public: GPUObject() = default; diff --git a/tensorflow/lite/delegates/gpu/cl/gpu_object_desc.h b/tensorflow/lite/delegates/gpu/cl/gpu_object_desc.h new file mode 100644 index 00000000000..e620487b06b --- /dev/null +++ b/tensorflow/lite/delegates/gpu/cl/gpu_object_desc.h @@ -0,0 +1,149 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_GPU_OBJECT_DESC_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_CL_GPU_OBJECT_DESC_H_ + +#include <map> +#include <memory> +#include <string> +#include <vector> + +#include "tensorflow/lite/delegates/gpu/cl/serialization_generated.h" +#include "tensorflow/lite/delegates/gpu/common/access_type.h" +#include "tensorflow/lite/delegates/gpu/common/data_type.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" + +namespace tflite { +namespace gpu { +namespace cl { + +struct GPUImage2DDescriptor { + DataType data_type; + AccessType access_type; +}; + +struct GPUImage3DDescriptor { + DataType data_type; + AccessType access_type; +}; + +struct GPUImage2DArrayDescriptor { + DataType data_type; + AccessType access_type; +}; + +struct GPUImageBufferDescriptor { + DataType data_type; + AccessType access_type; +}; + +struct GPUCustomMemoryDescriptor { + std::string type_name; +}; + +enum class MemoryType { GLOBAL, CONSTANT, LOCAL }; + +struct GPUBufferDescriptor { + DataType data_type; + AccessType access_type; + int element_size; + MemoryType memory_type = MemoryType::GLOBAL; + std::vector<std::string> attributes; +}; + +struct GPUResources { + std::vector<std::string> ints; + std::vector<std::string> floats; + std::vector<std::pair<std::string, GPUBufferDescriptor>> buffers; + std::vector<std::pair<std::string, GPUImage2DDescriptor>> images2d; + std::vector<std::pair<std::string, GPUImage2DArrayDescriptor>> image2d_arrays; + std::vector<std::pair<std::string, GPUImage3DDescriptor>> images3d; + std::vector<std::pair<std::string, GPUImageBufferDescriptor>> image_buffers; + std::vector<std::pair<std::string, GPUCustomMemoryDescriptor>> + custom_memories; + + std::vector<std::string> GetNames() const { + std::vector<std::string> names = ints; + names.insert(names.end(), floats.begin(), floats.end()); + for (const auto& obj : buffers) { + names.push_back(obj.first); + } + for (const auto& obj : images2d) { + names.push_back(obj.first); + } + for (const auto& obj : image2d_arrays) { + names.push_back(obj.first); + } + for (const auto& obj : images3d) { + names.push_back(obj.first); + } + for (const auto& obj : image_buffers) { + names.push_back(obj.first); + } + for (const auto& obj : custom_memories) { + names.push_back(obj.first); + } + return names; + } +}; + +class GPUObjectDescriptor { + public: + GPUObjectDescriptor() = default; + GPUObjectDescriptor(const GPUObjectDescriptor&) = default; + GPUObjectDescriptor& operator=(const GPUObjectDescriptor&) = default; + GPUObjectDescriptor(GPUObjectDescriptor&& obj_desc) = default; + GPUObjectDescriptor& operator=(GPUObjectDescriptor&& obj_desc) = default; + virtual ~GPUObjectDescriptor() = default; + + void SetStateVar(const std::string& key, const std::string& value) const { + state_vars_[key] = value; + } + + virtual std::string PerformConstExpr(const std::string& const_expr) const { + return ""; + } + + virtual absl::Status PerformSelector( + const std::string& selector, const std::vector<std::string>& args, + const std::vector<std::string>& template_args, + std::string* result) const { + *result = ""; + return absl::OkStatus(); + } + virtual GPUResources GetGPUResources() const { return GPUResources(); } + + virtual void Release() {} + + void SetAccess(AccessType access_type) { access_type_ = access_type; } + AccessType GetAccess() const { return access_type_; } + + protected: + friend flatbuffers::Offset<data::GPUObjectDescriptor> Encode( + const GPUObjectDescriptor& desc, flatbuffers::FlatBufferBuilder* builder); + friend void Decode(const data::GPUObjectDescriptor* fb_obj, + GPUObjectDescriptor* obj); + mutable std::map<std::string, std::string> state_vars_; + AccessType access_type_; +}; + +using GPUObjectDescriptorPtr = std::unique_ptr<GPUObjectDescriptor>; + +} // namespace cl +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_CL_GPU_OBJECT_DESC_H_ diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_4x4.cc b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_4x4.cc index 0f389361724..b92846cb794 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_4x4.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_4x4.cc @@ -80,10 +80,6 @@ std::string ConvolutionTransposed4x4::GenerateConvolutionTransposedCode( args_.AddInt("filter_offset"); - const auto src_tensor_type = op_def.src_tensors[0].storage_type; - const bool manual_clamp = src_tensor_type == TensorStorageType::BUFFER || - src_tensor_type == TensorStorageType::IMAGE_BUFFER; - const bool need_local_mem = weights_upload_type == ConvolutionTransposed4x4::WeightsUploadType::LOCAL_MEM_BY_THREADS || @@ -150,24 +146,42 @@ std::string ConvolutionTransposed4x4::GenerateConvolutionTransposedCode( ConvolutionTransposed4x4::WeightsUploadType::LOCAL_MEM_BY_THREADS) { c += " int local_id = (int)(get_local_id(1) * 8 + get_local_id(0));\n"; } - if (manual_clamp) { - const std::string prev_x = "X - " + pixel_stride; + const std::string prev_x = "X - " + pixel_stride; + if (!src_desc.SupportsZeroClamp(Axis::WIDTH)) { c += " bool in_x0 = " + prev_x + " >= 0 && " + prev_x + " < args.src_tensor.Width();\n"; c += " bool in_x1 = X >= 0 && X < args.src_tensor.Width();\n"; + } + if (!src_desc.SupportsZeroClamp(Axis::HEIGHT)) { c += " bool in_y0 = Y - 1 >= 0 && Y - 1 < args.src_tensor.Height();\n"; c += " bool in_y1 = Y >= 0 && Y < args.src_tensor.Height();\n"; - if (src_tensor_type == TensorStorageType::IMAGE_BUFFER) { - c += " int addr_0 = select(-1, (Y - 1) * args.src_tensor.Width() + " + - prev_x + ", (in_x0 && in_y0));\n"; - c += " int addr_1 = select(-1, (Y - 1) * args.src_tensor.Width() + X, " - "(in_x1 && " - "in_y0));\n"; - c += " int addr_2 = select(-1, Y * args.src_tensor.Width() + " + prev_x + - ", (in_x0 && in_y1));\n"; - c += " int addr_3 = select(-1, Y * args.src_tensor.Width() + X, (in_x1 " - "&& " - "in_y1));\n"; + } + auto generate_check = [&](int x, int y) { + std::string check; + const std::vector<Axis> axes{Axis::WIDTH, Axis::HEIGHT}; + const std::vector<std::string> names{"in_x" + std::to_string(x), + "in_y" + std::to_string(y)}; + for (int i = 0; i < axes.size(); ++i) { + const auto& axis = axes[i]; + if (src_desc.HasAxis(axis) && !src_desc.SupportsZeroClamp(axis)) { + if (!check.empty()) { + check += " && "; + } + check += names[i]; + } + } + return check; + }; + if (src_desc.IsLinear()) { + if (src_desc.ReturnsZeroForNegOneRead()) { + c += " args.src_tensor.GetAddress(addr_0, " + prev_x + ", Y - 1, 0);\n"; + c += " args.src_tensor.GetAddress(addr_1, X, Y - 1, 0);\n"; + c += " args.src_tensor.GetAddress(addr_2, " + prev_x + ", Y, 0);\n"; + c += " args.src_tensor.GetAddress(addr_3, X, Y, 0);\n"; + c += " addr_0 = select(-1, addr_0, (in_x0 && in_y0));\n"; + c += " addr_1 = select(-1, addr_1, (in_x1 && in_y0));\n"; + c += " addr_2 = select(-1, addr_2, (in_x0 && in_y1));\n"; + c += " addr_3 = select(-1, addr_3, (in_x1 && in_y1));\n"; c += " int dz_0 = select(0, args.src_tensor.SliceStride(), (in_x0 && " "in_y0));\n"; c += " int dz_1 = select(0, args.src_tensor.SliceStride(), (in_x1 && " @@ -176,25 +190,24 @@ std::string ConvolutionTransposed4x4::GenerateConvolutionTransposedCode( "in_y1));\n"; c += " int dz_3 = select(0, args.src_tensor.SliceStride(), (in_x1 && " "in_y1));\n"; - } - if (src_tensor_type == TensorStorageType::BUFFER) { + } else { c += " int xc0 = clamp(" + prev_x + ", 0, args.src_tensor.Width() - 1);\n"; c += " int xc1 = clamp(X, 0, args.src_tensor.Width() - 1);\n"; c += " int yc0 = clamp(Y - 1, 0, args.src_tensor.Height() - 1);\n"; c += " int yc1 = clamp(Y, 0, args.src_tensor.Height() - 1);\n"; - c += " int addr_0 = yc0 * args.src_tensor.Width() + xc0;\n"; - c += " int addr_1 = yc0 * args.src_tensor.Width() + xc1;\n"; - c += " int addr_2 = yc1 * args.src_tensor.Width() + xc0;\n"; - c += " int addr_3 = yc1 * args.src_tensor.Width() + xc1;\n"; + c += " args.src_tensor.GetAddress(addr_0, xc0, yc0, 0);\n"; + c += " args.src_tensor.GetAddress(addr_1, xc1, yc0, 0);\n"; + c += " args.src_tensor.GetAddress(addr_2, xc0, yc1, 0);\n"; + c += " args.src_tensor.GetAddress(addr_3, xc1, yc1, 0);\n"; c += " int dz = args.src_tensor.SliceStride();\n"; } } auto read_src = [&](int x, int y) { - if (manual_clamp) { + if (src_desc.IsLinear()) { const std::string id = std::to_string(y * 2 + x); const std::string addr = "addr_" + std::to_string(y * 2 + x); - if (src_tensor_type == TensorStorageType::IMAGE_BUFFER) { + if (src_desc.ReturnsZeroForNegOneRead()) { return "args.src_tensor.Read(" + addr + "); " + addr + " += dz_" + id + ";"; } else { @@ -203,8 +216,13 @@ std::string ConvolutionTransposed4x4::GenerateConvolutionTransposedCode( addr + " += dz;"; } } else { + std::string check = generate_check(x, y); + if (!check.empty()) { + check = " * (FLT)(" + check + ")"; + } return "args.src_tensor.Read(X + " + std::to_string(x - 1) + " * " + - pixel_stride + ", Y + " + std::to_string(y - 1) + ", s);"; + pixel_stride + ", Y + " + std::to_string(y - 1) + ", s)" + check + + ";"; } }; c += " for (int s = 0; s < args.src_tensor.Slices(); ++s) {\n"; diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/special/depthwise_conv_plus_1x1_conv.cc b/tensorflow/lite/delegates/gpu/cl/kernels/special/depthwise_conv_plus_1x1_conv.cc index f451d09d32d..52f197e4ca6 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/special/depthwise_conv_plus_1x1_conv.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/special/depthwise_conv_plus_1x1_conv.cc @@ -128,11 +128,6 @@ std::string GenerateCode(const OperationDef& op_def, result->args_.AddInt("padding_y", -dw_attr.padding.prepended.h); result->args_.AddInt("dilation_y", dw_attr.dilations.h); - const auto src_tensor_type = op_def.src_tensors[0].storage_type; - - const bool manual_clamp = src_tensor_type == TensorStorageType::BUFFER || - src_tensor_type == TensorStorageType::IMAGE_BUFFER; - std::string c = GetCommonDefines(op_def.precision); c += "__kernel void main_function(\n"; c += "$0) {\n"; @@ -160,29 +155,54 @@ std::string GenerateCode(const OperationDef& op_def, c += " int x_offseted = X * args.stride_x + args.padding_x;\n"; c += " int y_offseted = Y * args.stride_y + args.padding_y;\n"; c += " int x_c, y_c;\n"; - if (manual_clamp) { - c += " bool x_in, y_in;\n"; + + auto generate_check = [&]() { + std::string check; + const std::vector<Axis> axes{Axis::WIDTH, Axis::HEIGHT, Axis::DEPTH}; + const std::vector<std::string> names{"x_in", "y_in", "z_in"}; + for (int i = 0; i < axes.size(); ++i) { + const auto& axis = axes[i]; + if (src_desc.HasAxis(axis) && !src_desc.SupportsZeroClamp(axis)) { + if (!check.empty()) { + check += " && "; + } + check += names[i]; + } + } + return check; + }; + const std::string check = generate_check(); + if (!src_desc.SupportsZeroClamp(Axis::HEIGHT)) { + c += " bool y_in;\n"; } + if (!src_desc.SupportsZeroClamp(Axis::WIDTH)) { + c += " bool x_in;\n"; + } + + const std::string postfixes[] = {".x", ".xy", ".xyz", ""}; c += " FLT4 src;\n"; for (int ky = 0; ky < dw_attr.weights.shape.h; ++ky) { c += " y_c = y_offseted + " + std::to_string(ky) + " * args.dilation_y;\n"; - if (manual_clamp) { + if (!src_desc.SupportsZeroClamp(Axis::HEIGHT)) { c += " y_in = y_c >= 0 && y_c < args.src_tensor.Height();\n"; c += " y_c = clamp(y_c, 0, args.src_tensor.Height() - 1);\n"; } for (int kx = 0; kx < dw_attr.weights.shape.w; ++kx) { c += " x_c = x_offseted + " + std::to_string(kx) + " * args.dilation_x;\n"; - if (manual_clamp) { + if (!src_desc.SupportsZeroClamp(Axis::WIDTH)) { c += " x_in = x_c >= 0 && x_c < args.src_tensor.Width();\n"; c += " x_c = clamp(x_c, 0, args.src_tensor.Width() - 1);\n"; } for (int d = 0; d < intermediate_depth; ++d) { - std::string multiplier = manual_clamp ? "* (FLT)(x_in && y_in)" : ""; - c += " src = args.src_tensor.Read(x_c, y_c, " + std::to_string(d) + - ")" + multiplier + ";\n"; - c += " dw_res_" + std::to_string(d) + " += src * constants[" + - std::to_string(weights_counter++) + "];\n"; + const int src_ch_count = std::min(4, dw_attr.weights.shape.i - d * 4); + const std::string s_postfix = postfixes[src_ch_count - 1]; + std::string multiplier = check.empty() ? "" : " * (FLT)(" + check + ")"; + c += " src" + s_postfix + " = args.src_tensor.Read(x_c, y_c, " + + std::to_string(d) + ")" + s_postfix + multiplier + ";\n"; + c += " dw_res_" + std::to_string(d) + s_postfix + " += src" + + s_postfix + " * constants[" + std::to_string(weights_counter++) + + "]" + s_postfix + ";\n"; } } } diff --git a/tensorflow/lite/delegates/gpu/cl/linear_storage.h b/tensorflow/lite/delegates/gpu/cl/linear_storage.h index 1f68e9928c2..dcd947e9e08 100644 --- a/tensorflow/lite/delegates/gpu/cl/linear_storage.h +++ b/tensorflow/lite/delegates/gpu/cl/linear_storage.h @@ -21,6 +21,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/types/span.h" +#include "tensorflow/lite/delegates/gpu/cl/cl_context.h" #include "tensorflow/lite/delegates/gpu/cl/gpu_object.h" #include "tensorflow/lite/delegates/gpu/cl/opencl_wrapper.h" #include "tensorflow/lite/delegates/gpu/cl/tensor_type.h" diff --git a/tensorflow/lite/delegates/gpu/cl/tensor_type.cc b/tensorflow/lite/delegates/gpu/cl/tensor_type.cc index f31df43539e..d297e7cc53d 100644 --- a/tensorflow/lite/delegates/gpu/cl/tensor_type.cc +++ b/tensorflow/lite/delegates/gpu/cl/tensor_type.cc @@ -18,6 +18,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/substitute.h" #include "tensorflow/lite/delegates/gpu/common/shape.h" +#include "tensorflow/lite/delegates/gpu/common/util.h" namespace tflite { namespace gpu { diff --git a/tensorflow/lite/delegates/gpu/cl/tensor_type.h b/tensorflow/lite/delegates/gpu/cl/tensor_type.h index 5a59c3b0d96..d14d8ab5a8d 100644 --- a/tensorflow/lite/delegates/gpu/cl/tensor_type.h +++ b/tensorflow/lite/delegates/gpu/cl/tensor_type.h @@ -19,9 +19,10 @@ limitations under the License. #include <cstddef> #include <string> -#include "tensorflow/lite/delegates/gpu/cl/gpu_object.h" +#include "tensorflow/lite/delegates/gpu/cl/gpu_object_desc.h" #include "tensorflow/lite/delegates/gpu/common/data_type.h" #include "tensorflow/lite/delegates/gpu/common/shape.h" +#include "tensorflow/lite/delegates/gpu/common/tensor.h" namespace tflite { namespace gpu { diff --git a/tensorflow/lite/delegates/gpu/cl/util.cc b/tensorflow/lite/delegates/gpu/cl/util.cc index d0e65537519..901e2c30a65 100644 --- a/tensorflow/lite/delegates/gpu/cl/util.cc +++ b/tensorflow/lite/delegates/gpu/cl/util.cc @@ -241,6 +241,19 @@ absl::Status CreateRGBAImage2D(cl_context context, int width, int height, return absl::OkStatus(); } +std::string MemoryTypeToCLType(MemoryType type) { + switch (type) { + case MemoryType::GLOBAL: + return "__global"; + case MemoryType::CONSTANT: + return "__constant"; + break; + case MemoryType::LOCAL: + return "__local"; + } + return ""; +} + } // namespace cl } // namespace gpu } // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/cl/util.h b/tensorflow/lite/delegates/gpu/cl/util.h index 54a6c74a3ff..2eb5cc33a23 100644 --- a/tensorflow/lite/delegates/gpu/cl/util.h +++ b/tensorflow/lite/delegates/gpu/cl/util.h @@ -19,6 +19,7 @@ limitations under the License. #include <string> #include "absl/types/span.h" +#include "tensorflow/lite/delegates/gpu/cl/gpu_object_desc.h" #include "tensorflow/lite/delegates/gpu/cl/opencl_wrapper.h" #include "tensorflow/lite/delegates/gpu/common/data_type.h" #include "tensorflow/lite/delegates/gpu/common/status.h" @@ -57,6 +58,8 @@ absl::Status CreateRGBAImage2D(cl_context context, int width, int height, cl_channel_type channel_type, void* data, cl_mem* result); +std::string MemoryTypeToCLType(MemoryType type); + } // namespace cl } // namespace gpu } // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/common/data_type.cc b/tensorflow/lite/delegates/gpu/common/data_type.cc index 05a61f86f29..f393c877cd4 100644 --- a/tensorflow/lite/delegates/gpu/common/data_type.cc +++ b/tensorflow/lite/delegates/gpu/common/data_type.cc @@ -105,5 +105,36 @@ std::string ToCLDataType(DataType data_type, int vec_size) { return "undefined"; } +std::string ToMetalDataType(DataType data_type, int vec_size) { + const std::string postfix = vec_size == 1 ? "" : std::to_string(vec_size); + switch (data_type) { + case DataType::FLOAT16: + return "half" + postfix; + case DataType::FLOAT32: + return "float" + postfix; + case DataType::FLOAT64: + return "double" + postfix; + case DataType::INT16: + return "short" + postfix; + case DataType::INT32: + return "int" + postfix; + case DataType::INT64: + return "long" + postfix; + case DataType::INT8: + return "char" + postfix; + case DataType::UINT16: + return "ushort" + postfix; + case DataType::UINT32: + return "uint" + postfix; + case DataType::UINT64: + return "ulong" + postfix; + case DataType::UINT8: + return "uchar" + postfix; + case DataType::UNKNOWN: + return "unknown"; + } + return "undefined"; +} + } // namespace gpu } // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/common/data_type.h b/tensorflow/lite/delegates/gpu/common/data_type.h index 82d55ec9d4e..8ad3d635dd7 100644 --- a/tensorflow/lite/delegates/gpu/common/data_type.h +++ b/tensorflow/lite/delegates/gpu/common/data_type.h @@ -43,6 +43,8 @@ std::string ToString(DataType t); std::string ToCLDataType(DataType data_type, int vec_size = 1); +std::string ToMetalDataType(DataType data_type, int vec_size = 1); + } // namespace gpu } // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/metal/BUILD b/tensorflow/lite/delegates/gpu/metal/BUILD index 9f694b55cdb..cfefe53abb9 100644 --- a/tensorflow/lite/delegates/gpu/metal/BUILD +++ b/tensorflow/lite/delegates/gpu/metal/BUILD @@ -43,6 +43,7 @@ cc_library( srcs = ["arguments.cc"], hdrs = ["arguments.h"], deps = [ + ":gpu_object_desc", "//tensorflow/lite/delegates/gpu/common:status", ], ) @@ -183,6 +184,17 @@ objc_library( ], ) +cc_library( + name = "gpu_object_desc", + srcs = ["gpu_object_desc.cc"], + hdrs = ["gpu_object_desc.h"], + deps = [ + "//tensorflow/lite/delegates/gpu/common:access_type", + "//tensorflow/lite/delegates/gpu/common:data_type", + "//tensorflow/lite/delegates/gpu/common:status", + ], +) + objc_library( name = "inference_context", srcs = ["inference_context.mm"], @@ -232,6 +244,7 @@ objc_library( sdk_frameworks = ["Metal"], deps = [ ":arguments", + ":gpu_object_desc", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common:util", "@com_google_absl//absl/strings", diff --git a/tensorflow/lite/delegates/gpu/metal/arguments.cc b/tensorflow/lite/delegates/gpu/metal/arguments.cc index f4a308e59be..d9c9c3cc22b 100644 --- a/tensorflow/lite/delegates/gpu/metal/arguments.cc +++ b/tensorflow/lite/delegates/gpu/metal/arguments.cc @@ -50,6 +50,12 @@ void Arguments::AddInt(const std::string& name, int value) { int_values_[name].value = value; } +void Arguments::AddObject(const std::string& name, + GPUObjectDescriptorPtr&& descriptor_ptr) { + descriptor_ptr->SetAccess(AccessType::READ); + objects_[name] = {std::move(descriptor_ptr)}; +} + void Arguments::GetActiveArguments(const std::string& code) { for (auto& float_val : float_values_) { float_val.second.active = HasWord(kArgsPrefix + float_val.first, code); @@ -59,6 +65,12 @@ void Arguments::GetActiveArguments(const std::string& code) { } } +void Arguments::ReleaseCPURepresentation() { + for (auto& t : objects_) { + t.second->Release(); + } +} + } // namespace metal } // namespace gpu } // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/metal/arguments.h b/tensorflow/lite/delegates/gpu/metal/arguments.h index fbdcfef1358..47eca6dc783 100644 --- a/tensorflow/lite/delegates/gpu/metal/arguments.h +++ b/tensorflow/lite/delegates/gpu/metal/arguments.h @@ -19,6 +19,7 @@ limitations under the License. #include <string> #include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/metal/gpu_object_desc.h" namespace tflite { namespace gpu { @@ -36,6 +37,10 @@ class Arguments { void AddFloat(const std::string& name, float value = 0.0f); void AddInt(const std::string& name, int value = 0); + void AddObject(const std::string& name, + GPUObjectDescriptorPtr&& descriptor_ptr); + + void ReleaseCPURepresentation(); private: friend class MetalArguments; @@ -61,6 +66,8 @@ class Arguments { bool active = false; }; std::map<std::string, FloatValue> float_values_; + + std::map<std::string, GPUObjectDescriptorPtr> objects_; }; class ArgumentsSetter { diff --git a/tensorflow/lite/delegates/gpu/cl/gpu_object.cc b/tensorflow/lite/delegates/gpu/metal/gpu_object_desc.cc similarity index 79% rename from tensorflow/lite/delegates/gpu/cl/gpu_object.cc rename to tensorflow/lite/delegates/gpu/metal/gpu_object_desc.cc index 277d711ff63..09fd19aa355 100644 --- a/tensorflow/lite/delegates/gpu/cl/gpu_object.cc +++ b/tensorflow/lite/delegates/gpu/metal/gpu_object_desc.cc @@ -13,25 +13,25 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/delegates/gpu/cl/gpu_object.h" +#include "tensorflow/lite/delegates/gpu/metal/gpu_object_desc.h" namespace tflite { namespace gpu { -namespace cl { +namespace metal { -std::string MemoryTypeToCLType(MemoryType type) { +std::string MemoryTypeToMetalType(MemoryType type) { switch (type) { case MemoryType::GLOBAL: - return "__global"; + return "device"; case MemoryType::CONSTANT: - return "__constant"; + return "constant"; break; case MemoryType::LOCAL: - return "__local"; + return "threadgroup"; } return ""; } -} // namespace cl +} // namespace metal } // namespace gpu } // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/metal/gpu_object_desc.h b/tensorflow/lite/delegates/gpu/metal/gpu_object_desc.h new file mode 100644 index 00000000000..e33bbcde703 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/metal/gpu_object_desc.h @@ -0,0 +1,102 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_GPU_OBJECT_DESC_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_METAL_GPU_OBJECT_DESC_H_ + +#include <map> +#include <memory> +#include <string> +#include <vector> + +#include "tensorflow/lite/delegates/gpu/common/access_type.h" +#include "tensorflow/lite/delegates/gpu/common/data_type.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" + +namespace tflite { +namespace gpu { +namespace metal { + +enum class MemoryType { GLOBAL, CONSTANT, LOCAL }; + +std::string MemoryTypeToMetalType(MemoryType type); + +struct GPUBufferDescriptor { + DataType data_type; + AccessType access_type; + int element_size; + MemoryType memory_type = MemoryType::GLOBAL; + std::vector<std::string> attributes; +}; + +struct GPUResources { + std::vector<std::string> ints; + std::vector<std::string> floats; + std::vector<std::pair<std::string, GPUBufferDescriptor>> buffers; + + std::vector<std::string> GetNames() const { + std::vector<std::string> names = ints; + names.insert(names.end(), floats.begin(), floats.end()); + for (const auto& obj : buffers) { + names.push_back(obj.first); + } + return names; + } +}; + +class GPUObjectDescriptor { + public: + GPUObjectDescriptor() = default; + GPUObjectDescriptor(const GPUObjectDescriptor&) = default; + GPUObjectDescriptor& operator=(const GPUObjectDescriptor&) = default; + GPUObjectDescriptor(GPUObjectDescriptor&& obj_desc) = default; + GPUObjectDescriptor& operator=(GPUObjectDescriptor&& obj_desc) = default; + + virtual ~GPUObjectDescriptor() = default; + + void SetStateVar(const std::string& key, const std::string& value) const { + state_vars_[key] = value; + } + + virtual std::string PerformConstExpr(const std::string& const_expr) const { + return ""; + } + + virtual absl::Status PerformSelector( + const std::string& selector, const std::vector<std::string>& args, + const std::vector<std::string>& template_args, + std::string* result) const { + *result = ""; + return absl::OkStatus(); + } + virtual GPUResources GetGPUResources() const { return GPUResources(); } + + virtual void Release() {} + + void SetAccess(AccessType access_type) { access_type_ = access_type; } + AccessType GetAccess() const { return access_type_; } + + protected: + mutable std::map<std::string, std::string> state_vars_; + AccessType access_type_; +}; + +using GPUObjectDescriptorPtr = std::unique_ptr<GPUObjectDescriptor>; + +} // namespace metal +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_METAL_GPU_OBJECT_DESC_H_ diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/BUILD b/tensorflow/lite/delegates/gpu/metal/kernels/BUILD index c877d4eeb5c..a79627125ec 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/BUILD +++ b/tensorflow/lite/delegates/gpu/metal/kernels/BUILD @@ -370,6 +370,7 @@ ios_unit_test( minimum_os_version = "11.0", runner = tflite_ios_lab_runner("IOS_LATEST"), tags = [ + "no_mac", # TODO(b/171882133) "notap", "tflite_not_portable_android", ], diff --git a/tensorflow/lite/delegates/gpu/metal/metal_arguments.h b/tensorflow/lite/delegates/gpu/metal/metal_arguments.h index 496287c8ff0..df66bd915bb 100644 --- a/tensorflow/lite/delegates/gpu/metal/metal_arguments.h +++ b/tensorflow/lite/delegates/gpu/metal/metal_arguments.h @@ -23,11 +23,18 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/status.h" #include "tensorflow/lite/delegates/gpu/metal/arguments.h" +#include "tensorflow/lite/delegates/gpu/metal/gpu_object_desc.h" namespace tflite { namespace gpu { namespace metal { +struct GPUResourcesWithValue { + std::vector<std::pair<std::string, int>> ints; + std::vector<std::pair<std::string, float>> floats; + std::vector<std::pair<std::string, id<MTLBuffer>>> buffers; +}; + class MetalArguments : public ArgumentsSetter { public: MetalArguments() = default; @@ -46,6 +53,15 @@ class MetalArguments : public ArgumentsSetter { void Encode(id<MTLComputeCommandEncoder> encoder, int buffer_offset) const; private: + std::string GetListOfArgs(int buffer_offset); + + absl::Status SetGPUResources(const std::string& name, + const GPUResourcesWithValue& resources); + + void AddBuffer(const std::string& name, const GPUBufferDescriptor& desc); + + absl::Status SetBuffer(const std::string& name, id<MTLBuffer> handle); + static constexpr char kArgsPrefix[] = "args."; struct IntValue { int value; @@ -71,6 +87,12 @@ class MetalArguments : public ArgumentsSetter { }; std::map<std::string, FloatValue> float_values_; std::vector<uint8_t> const_data_; + + struct MetalBufferDescriptor { + GPUBufferDescriptor desc; + id<MTLBuffer> handle; + }; + std::map<std::string, MetalBufferDescriptor> buffers_; }; } // namespace metal diff --git a/tensorflow/lite/delegates/gpu/metal/metal_arguments.mm b/tensorflow/lite/delegates/gpu/metal/metal_arguments.mm index 0e9be1b1aeb..b82a5613c88 100644 --- a/tensorflow/lite/delegates/gpu/metal/metal_arguments.mm +++ b/tensorflow/lite/delegates/gpu/metal/metal_arguments.mm @@ -43,6 +43,13 @@ void ReplaceAllWords(const std::string& old_word, const std::string& new_word, position = str->find(old_word, position + new_word.size()); } } + +void AppendArgument(const std::string& arg, std::string* args) { + if (!args->empty()) { + absl::StrAppend(args, ",\n"); + } + absl::StrAppend(args, arg); +} } // namespace // Static @@ -51,7 +58,6 @@ constexpr char MetalArguments::kArgsPrefix[]; absl::Status MetalArguments::Init(int buffer_offset, Arguments* args, std::string* code) { args->GetActiveArguments(*code); std::string struct_desc = "struct uniforms_buffer {\n"; - std::string struct_decl; int pos = 0; for (auto& fvalue : args->float_values_) { auto& new_val = float_values_[fvalue.first]; @@ -76,7 +82,6 @@ absl::Status MetalArguments::Init(int buffer_offset, Arguments* args, std::strin } } if (pos != 0) { - struct_decl = "constant uniforms_buffer& U[[buffer(" + std::to_string(buffer_offset) + ")]],\n"; int aligned_pos = AlignByN(pos, 4); for (int i = pos; i < aligned_pos; i++) { struct_desc += " int dummy" + std::to_string(i - pos) + ";\n"; @@ -97,9 +102,8 @@ absl::Status MetalArguments::Init(int buffer_offset, Arguments* args, std::strin } } else { struct_desc = ""; - struct_decl = ""; } - *code = absl::Substitute(*code, struct_desc, struct_decl); + *code = absl::Substitute(*code, struct_desc, GetListOfArgs(buffer_offset)); return absl::OkStatus(); } @@ -136,6 +140,61 @@ void MetalArguments::Encode(id<MTLComputeCommandEncoder> encoder, int buffer_off } } +std::string MetalArguments::GetListOfArgs(int buffer_offset) { + std::string result; + for (auto& t : buffers_) { + std::string attributes; + for (const auto& attr : t.second.desc.attributes) { + attributes += absl::StrCat(" __attribute__((", attr, "))"); + } + AppendArgument( + absl::StrCat( + MemoryTypeToMetalType(t.second.desc.memory_type), " ", + ToMetalDataType(t.second.desc.data_type, t.second.desc.element_size), + "* ", t.first, "[[buffer(", buffer_offset, ")]]", attributes), + &result); + buffer_offset++; + } + if (!const_data_.empty()) { + AppendArgument( + absl::StrCat("constant uniforms_buffer& U[[buffer(", buffer_offset, ")]]"), + &result); + buffer_offset++; + } + if (!result.empty()) { + result += ",\n"; + } + return result; +} + +absl::Status MetalArguments::SetGPUResources( + const std::string& name, const GPUResourcesWithValue& resources) { + for (const auto& r : resources.ints) { + RETURN_IF_ERROR(SetInt(absl::StrCat(name, "_", r.first), r.second)); + } + for (const auto& r : resources.floats) { + RETURN_IF_ERROR(SetFloat(absl::StrCat(name, "_", r.first), r.second)); + } + for (const auto& r : resources.buffers) { + RETURN_IF_ERROR(SetBuffer(absl::StrCat(name, "_", r.first), r.second)); + } + return absl::OkStatus(); +} + +void MetalArguments::AddBuffer(const std::string& name, const GPUBufferDescriptor& desc) { + buffers_[name].desc = desc; +} + +absl::Status MetalArguments::SetBuffer(const std::string& name, id<MTLBuffer> handle) { + auto it = buffers_.find(name); + if (it == buffers_.end()) { + return absl::NotFoundError( + absl::StrCat("No buffer argument with name - ", name)); + } + it->second.handle = handle; + return absl::OkStatus(); +} + } // namespace metal } // namespace gpu } // namespace tflite diff --git a/tensorflow/lite/delegates/nnapi/BUILD b/tensorflow/lite/delegates/nnapi/BUILD index 7a34b0846f2..b3c6eff4741 100644 --- a/tensorflow/lite/delegates/nnapi/BUILD +++ b/tensorflow/lite/delegates/nnapi/BUILD @@ -189,6 +189,30 @@ cc_test( ], ) +cc_test( + name = "nnapi_delegate_nnapi_failure_handling_test", + size = "small", + srcs = [ + "nnapi_delegate_nnapi_failure_handling_test.cc", + ], + tags = [ + "no_mac", + "no_windows", + "tflite_not_portable_ios", + ], + deps = [ + ":nnapi_delegate", + ":nnapi_delegate_mock_test", + "//tensorflow/lite:framework", + "//tensorflow/lite:kernel_api", + "//tensorflow/lite/c:common", + "//tensorflow/lite/kernels:test_util", + "//tensorflow/lite/nnapi:nnapi_implementation", + "//tensorflow/lite/nnapi:nnapi_lib", + "@com_google_googletest//:gtest", + ], +) + cc_test( name = "nnapi_delegate_signed_quantization_test", size = "small", diff --git a/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc b/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc index 913e35cb9d9..ef384830d0a 100644 --- a/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc +++ b/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc @@ -4421,7 +4421,8 @@ TfLiteStatus NNAPIDelegateKernel::AddOpsAndTensors( AddDequantizeOperatorsWhereNeeded(context, reg->builtin_code, node, node_index, &builder, nnapi_errno); - builder.FinalizeAddOperation(nn_op_type, node_index); + TF_LITE_ENSURE_OK(context_, + builder.FinalizeAddOperation(nn_op_type, node_index)); } return kTfLiteOk; } diff --git a/tensorflow/lite/delegates/nnapi/nnapi_delegate_nnapi_failure_handling_test.cc b/tensorflow/lite/delegates/nnapi/nnapi_delegate_nnapi_failure_handling_test.cc new file mode 100644 index 00000000000..3f3d6229290 --- /dev/null +++ b/tensorflow/lite/delegates/nnapi/nnapi_delegate_nnapi_failure_handling_test.cc @@ -0,0 +1,142 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <sys/mman.h> + +#include <algorithm> +#include <array> +#include <cstdint> +#include <iterator> +#include <memory> +#include <numeric> +#include <ostream> +#include <unordered_set> +#include <vector> + +#include <gtest/gtest.h> +#include "tensorflow/lite/builtin_ops.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h" +#include "tensorflow/lite/delegates/nnapi/nnapi_delegate_mock_test.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" +#include "tensorflow/lite/nnapi/NeuralNetworksTypes.h" +#include "tensorflow/lite/nnapi/nnapi_implementation.h" + +namespace tflite { +namespace { + +struct NnApiFailureHandlingTest + : ::tflite::delegate::nnapi::NnApiDelegateMockTest {}; + +// This is a model with two ops: +// +// input1 ----> +// ADD -- +// input2 --> | +// --> +// SUB --> output +// input3 ----------------> +// +class AddSubOpsAcceleratedModel : public MultiOpModel { + public: + AddSubOpsAcceleratedModel(const TensorData& input1, const TensorData& input2, + const TensorData& input3, const TensorData& output, + ActivationFunctionType activation_type, + const NnApi* nnapi, + const std::string& accelerator_name, + bool allow_fp32_relax_to_fp16 = false) + : MultiOpModel() { + StatefulNnApiDelegate::Options options; + options.accelerator_name = accelerator_name.c_str(); + stateful_delegate_.reset(new StatefulNnApiDelegate(nnapi, options)); + SetDelegate(stateful_delegate_.get()); + Init(input1, input2, input3, output, activation_type, + allow_fp32_relax_to_fp16); + } + + int input1() { return input1_; } + int input2() { return input2_; } + int input3() { return input3_; } + + std::vector<float> GetOutput() { return ExtractVector<float>(output_); } + + protected: + int input1_; + int input2_; + int input3_; + int output_; + + private: + std::unique_ptr<StatefulNnApiDelegate> stateful_delegate_; + + // Performs initialization logic shared across all constructors. + void Init(const TensorData& input1, const TensorData& input2, + const TensorData& input3, const TensorData& output, + ActivationFunctionType activation_type, + bool allow_fp32_relax_to_fp16 = false) { + input1_ = AddInput(input1); + input2_ = AddInput(input2); + input3_ = AddInput(input3); + const int add_output = AddInnerTensor<float>(output); + output_ = AddOutput(output); + AddBuiltinOp(BuiltinOperator_ADD, BuiltinOptions_AddOptions, + CreateAddOptions(builder_, activation_type).Union(), + {input1_, input2_}, {add_output}); + AddBuiltinOp(BuiltinOperator_SUB, BuiltinOptions_SubOptions, + CreateSubOptions(builder_, activation_type).Union(), + {add_output, input3_}, {output_}); + BuildInterpreter({GetShape(input1_), GetShape(input2_), GetShape(input3_)}, + /*num_threads=*/-1, allow_fp32_relax_to_fp16, + /*apply_delegate=*/false); + ApplyDelegate(); + } +}; + +TEST_F(NnApiFailureHandlingTest, DelegateShouldFailImmediatelyIfUnableToAddOp) { + static int add_op_invocation_count = 0; + nnapi_mock_->SetNnapiSupportedDevice("test-device"); + + nnapi_mock_->StubAddOperationWith( + [](ANeuralNetworksModel* model, ANeuralNetworksOperationType type, + uint32_t inputCount, const uint32_t* inputs, uint32_t outputCount, + const uint32_t* outputs) -> int { + ++add_op_invocation_count; + return ANEURALNETWORKS_BAD_DATA; + }); + + AddSubOpsAcceleratedModel m( + {TensorType_FLOAT32, {1, 2, 2, 1}}, {TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {1, 2, 2, 1}}, {TensorType_FLOAT32, {}}, + ActivationFunctionType_NONE, nnapi_mock_->GetNnApi(), + /*accelerator_name=*/"test-device"); + std::vector<float> input1{-2.0, 0.2, 0.7, 0.9}; + std::vector<float> input2{0.1, 0.2, 0.3, 0.5}; + m.PopulateTensor<float>(m.input1(), input1); + m.PopulateTensor<float>(m.input2(), input2); + m.PopulateTensor<float>(m.input3(), input2); + m.Invoke(); + + EXPECT_EQ(add_op_invocation_count, 1); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/lite/examples/python/README.md b/tensorflow/lite/examples/python/README.md index 82b7ad690fc..8f870468d08 100644 --- a/tensorflow/lite/examples/python/README.md +++ b/tensorflow/lite/examples/python/README.md @@ -5,18 +5,16 @@ TensorFlow Lite model and use it to recognize objects in images. The Python script accepts arguments specifying the model to use, the corresponding labels file, and the image to process. -**Tip:** -If you're using a Raspberry Pi, instead try the [classify_picamera.py example]( -https://github.com/tensorflow/examples/tree/master/lite/examples/image_classification/raspberry_pi). - -Before you begin, -make sure you [have TensorFlow installed](https://www.tensorflow.org/install). +**Tip:** If you're using a Raspberry Pi, instead try the +[classify_picamera.py example](https://github.com/tensorflow/examples/tree/master/lite/examples/image_classification/raspberry_pi). +Before you begin, make sure you +[have TensorFlow installed](https://www.tensorflow.org/install). ## Download sample model and image -You can use any compatible model, but the following MobileNet v1 model offers -a good demonstration of a model trained to recognize 1,000 different objects. +You can use any compatible model, but the following MobileNet v1 model offers a +good demonstration of a model trained to recognize 1,000 different objects. ```sh # Get photo @@ -31,8 +29,6 @@ mv /tmp/mobilenet_v1_1.0_224/labels.txt /tmp/ ## Run the sample -Note: Instead use `python` if you're using Python 2.x. - ```sh python3 label_image.py \ --model_file /tmp/mobilenet_v1_1.0_224.tflite \ diff --git a/tensorflow/lite/experimental/ios/BUILD.apple b/tensorflow/lite/experimental/ios/BUILD.apple index 11868fe044d..cce0c4df883 100644 --- a/tensorflow/lite/experimental/ios/BUILD.apple +++ b/tensorflow/lite/experimental/ios/BUILD.apple @@ -71,6 +71,7 @@ tflite_ios_static_framework( # bazel build -c opt --config=ios --ios_multi_cpus=armv7,arm64,x86_64 //tensorflow/lite/experimental/ios:TensorFlowLiteSelectTfOps_framework ios_static_framework( name = "TensorFlowLiteSelectTfOps_framework", + avoid_deps = ["//tensorflow/lite/c:common"], bundle_name = "TensorFlowLiteSelectTfOps", minimum_os_version = TFL_MINIMUM_OS_VERSION, deps = [ diff --git a/tensorflow/lite/experimental/writer/option_writer_generator.cc b/tensorflow/lite/experimental/writer/option_writer_generator.cc index 14d7219f304..8efbd51b969 100644 --- a/tensorflow/lite/experimental/writer/option_writer_generator.cc +++ b/tensorflow/lite/experimental/writer/option_writer_generator.cc @@ -81,6 +81,7 @@ static const char* param_structs[] = {"TfLiteAddParams", "TfLiteReverseSequenceParams", "TfLiteWhileParams", "TfLiteCumsumParams", + "TfLiteCallOnceParams", nullptr}; } // namespace diff --git a/tensorflow/lite/interpreter_builder.cc b/tensorflow/lite/interpreter_builder.cc index 6eaf3eaadc8..4249c85238e 100644 --- a/tensorflow/lite/interpreter_builder.cc +++ b/tensorflow/lite/interpreter_builder.cc @@ -47,6 +47,20 @@ limitations under the License. #endif #endif +// TODO(b/139446230): Move to portable platform header. +#if defined(__ANDROID__) +#define TFLITE_IS_MOBILE_PLATFORM +#endif // defined(__ANDROID__) + +#if defined(__APPLE__) +#include "TargetConditionals.h" +#if TARGET_IPHONE_SIMULATOR +#define TFLITE_IS_MOBILE_PLATFORM +#elif TARGET_OS_IPHONE +#define TFLITE_IS_MOBILE_PLATFORM +#endif +#endif // defined(__APPLE__) + namespace tflite { namespace { @@ -129,9 +143,16 @@ const char* kEmptyTensorName = ""; // For flex delegate, see also the strong override in // lite/delegates/flex/delegate.cc. TFLITE_ATTRIBUTE_WEAK Interpreter::TfLiteDelegatePtr AcquireFlexDelegate() { -#if !defined(__ANDROID__) - // If _pywrap_tensorflow_internal.so is available, use - // TF_AcquireFlexDelegate() to initialize flex delegate. + auto acquire_flex_delegate_func = + reinterpret_cast<Interpreter::TfLiteDelegatePtr (*)()>( + SharedLibrary::GetSymbol("TF_AcquireFlexDelegate")); + if (acquire_flex_delegate_func) { + return acquire_flex_delegate_func(); + } + +#if !defined(TFLITE_IS_MOBILE_PLATFORM) + // Load TF_AcquireFlexDelegate() from _pywrap_tensorflow_internal.so if it is + // available. const char* filename_pywrap_tensorflow_internal = #if defined(_WIN32) "_pywrap_tensorflow_internal.pyd"; @@ -143,15 +164,16 @@ TFLITE_ATTRIBUTE_WEAK Interpreter::TfLiteDelegatePtr AcquireFlexDelegate() { void* lib_tf_internal = SharedLibrary::LoadLibrary(filename_pywrap_tensorflow_internal); if (lib_tf_internal) { - auto TF_AcquireFlexDelegate = + acquire_flex_delegate_func = reinterpret_cast<Interpreter::TfLiteDelegatePtr (*)()>( SharedLibrary::GetLibrarySymbol(lib_tf_internal, "TF_AcquireFlexDelegate")); - if (TF_AcquireFlexDelegate) { - return TF_AcquireFlexDelegate(); + if (acquire_flex_delegate_func) { + return acquire_flex_delegate_func(); } } -#endif +#endif // !defined(TFLITE_IS_MOBILE_PLATFORM) + return Interpreter::TfLiteDelegatePtr(nullptr, [](TfLiteDelegate*) {}); } diff --git a/tensorflow/lite/kernels/BUILD b/tensorflow/lite/kernels/BUILD index d27589d06cc..9cc5d0452ec 100644 --- a/tensorflow/lite/kernels/BUILD +++ b/tensorflow/lite/kernels/BUILD @@ -544,6 +544,7 @@ BUILTIN_KERNEL_SRCS = [ "batch_to_space_nd.cc", "bidirectional_sequence_lstm.cc", "bidirectional_sequence_rnn.cc", + "call_once.cc", "cast.cc", "ceil.cc", "comparisons.cc", @@ -2100,6 +2101,21 @@ cc_test( ], ) +cc_test( + name = "call_once_test", + size = "small", + srcs = ["call_once_test.cc"], + tags = ["tflite_not_portable_ios"], + deps = [ + ":kernel_util", + ":subgraph_test_util", + ":test_main", + ":variable_op_kernels", + "//tensorflow/lite:framework", + "@com_google_googletest//:gtest", + ], +) + cc_test( name = "if_test", size = "small", @@ -2225,6 +2241,7 @@ cc_library( ":builtin_ops", ":kernel_util", ":test_util", + ":variable_op_kernels", "//tensorflow/lite:builtin_op_data", "//tensorflow/lite:framework", "//tensorflow/lite/c:common", diff --git a/tensorflow/lite/kernels/builtin_op_kernels.h b/tensorflow/lite/kernels/builtin_op_kernels.h index b6e73c2d7a1..f1ba36e59b8 100644 --- a/tensorflow/lite/kernels/builtin_op_kernels.h +++ b/tensorflow/lite/kernels/builtin_op_kernels.h @@ -39,6 +39,7 @@ TfLiteRegistration* Register_BATCH_TO_SPACE_ND(); TfLiteRegistration* Register_BATCH_MATMUL(); TfLiteRegistration* Register_BIDIRECTIONAL_SEQUENCE_LSTM(); TfLiteRegistration* Register_BIDIRECTIONAL_SEQUENCE_RNN(); +TfLiteRegistration* Register_CALL_ONCE(); TfLiteRegistration* Register_CAST(); TfLiteRegistration* Register_CEIL(); TfLiteRegistration* Register_CONCATENATION(); diff --git a/tensorflow/lite/kernels/call_once.cc b/tensorflow/lite/kernels/call_once.cc new file mode 100644 index 00000000000..2e56f5d8511 --- /dev/null +++ b/tensorflow/lite/kernels/call_once.cc @@ -0,0 +1,109 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <stddef.h> + +#include <cstring> +#include <memory> +#include <vector> + +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/core/subgraph.h" +#include "tensorflow/lite/kernels/kernel_util.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace call_once_kernel { + +// CallOnce operator is a control flow op to invoke other subgraph in the graph +// in order to conduct the given graph's initialization tasks, for example, hash +// table initialization and variable initialization. +// +// This operator will invoke the subgraph for initialization in the first run +// and become no-op after the first run in an interpreter's life cycle. + +struct OpData { + // Subgraph index to be invoked once in a life cycle by this CallOnce op. + int init_subgraph_index; + // Boolean storage to store whether the subgraph for initialization is invoked + // successfully once in an interpreter's life cycle. + bool init_subgraph_invoked; +}; + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + auto* op_data = new OpData; + const auto* params = reinterpret_cast<const TfLiteCallOnceParams*>(buffer); + op_data->init_subgraph_index = params->init_subgraph_index; + op_data->init_subgraph_invoked = false; + return op_data; +} + +void Free(TfLiteContext* context, void* buffer) { + delete reinterpret_cast<OpData*>(buffer); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + const OpData* op_data = reinterpret_cast<OpData*>(node->user_data); + + // Return early if the initialization graph is already invoked. + if (op_data->init_subgraph_invoked) return kTfLiteOk; + + TF_LITE_ENSURE_EQ(context, node->inputs->size, 0); + TF_LITE_ENSURE_EQ(context, node->outputs->size, 0); + + Subgraph* this_subgraph = reinterpret_cast<Subgraph*>(context->impl_); + auto* subgraphs = this_subgraph->GetSubgraphs(); + TF_LITE_ENSURE(context, op_data->init_subgraph_index < subgraphs->size()); + + // Ensures that there are no input and output tensors in the subgraph. + Subgraph* init_subgraph = (*subgraphs)[op_data->init_subgraph_index].get(); + TF_LITE_ENSURE_EQ(context, init_subgraph->inputs().size(), 0); + TF_LITE_ENSURE_EQ(context, init_subgraph->outputs().size(), 0); + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + OpData* op_data = reinterpret_cast<OpData*>(node->user_data); + + // The initialization graph should be invoked once in a life cycle. + if (op_data->init_subgraph_invoked) return kTfLiteOk; + + Subgraph* this_subgraph = reinterpret_cast<Subgraph*>(context->impl_); + auto* subgraphs = this_subgraph->GetSubgraphs(); + Subgraph& init_subgraph = *(*subgraphs)[op_data->init_subgraph_index]; + + TF_LITE_ENSURE_OK(context, init_subgraph.AllocateTensors()); + TF_LITE_ENSURE_OK(context, init_subgraph.Invoke()); + TF_LITE_ENSURE_OK(context, init_subgraph.ReleaseNonPersistentMemory()); + + // Mark the invocation completed. + op_data->init_subgraph_invoked = true; + return kTfLiteOk; +} + +} // namespace call_once_kernel + +TfLiteRegistration* Register_CALL_ONCE() { + static TfLiteRegistration r = {call_once_kernel::Init, call_once_kernel::Free, + call_once_kernel::Prepare, + call_once_kernel::Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/lite/kernels/call_once_test.cc b/tensorflow/lite/kernels/call_once_test.cc new file mode 100644 index 00000000000..29917d60c61 --- /dev/null +++ b/tensorflow/lite/kernels/call_once_test.cc @@ -0,0 +1,87 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <stdint.h> + +#include <memory> +#include <vector> + +#include <gtest/gtest.h> +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/subgraph_test_util.h" + +namespace tflite { + +using subgraph_test_util::ControlFlowOpTest; + +namespace { + +class CallOnceTest : public ControlFlowOpTest { + protected: + void SetUp() override { + interpreter_->AddSubgraphs(1); + builder_->BuildCallOnceAndReadVariableSubgraph( + &interpreter_->primary_subgraph()); + builder_->BuildAssignRandomValueToVariableSubgraph( + interpreter_->subgraph(1)); + + ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk); + } +}; + +TEST_F(CallOnceTest, TestSimple) { + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + + TfLiteTensor* output = interpreter_->tensor(interpreter_->outputs()[0]); + ASSERT_EQ(output->dims->size, 1); + ASSERT_EQ(output->dims->data[0], 1); + ASSERT_EQ(output->type, kTfLiteInt32); + ASSERT_EQ(NumElements(output), 1); + + // The value of the variable must be non-zero, which will be assigned by the + // initialization subgraph. + EXPECT_GT(output->data.i32[0], 0); +} + +TEST_F(CallOnceTest, TestInvokeMultipleTimes) { + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + + TfLiteTensor* output = interpreter_->tensor(interpreter_->outputs()[0]); + ASSERT_EQ(output->dims->size, 1); + ASSERT_EQ(output->dims->data[0], 1); + ASSERT_EQ(output->type, kTfLiteInt32); + ASSERT_EQ(NumElements(output), 1); + + // The value of the variable must be non-zero, which will be assigned by the + // initialization subgraph. + int value = output->data.i32[0]; + EXPECT_GT(value, 0); + + for (int i = 0; i < 3; ++i) { + // Make sure that no more random value assignment in the initialization + // subgraph. + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + TfLiteTensor* output = interpreter_->tensor(interpreter_->outputs()[0]); + ASSERT_EQ(output->dims->size, 1); + ASSERT_EQ(output->dims->data[0], 1); + ASSERT_EQ(output->type, kTfLiteInt32); + ASSERT_EQ(NumElements(output), 1); + ASSERT_EQ(output->data.i32[0], value); + } +} + +} // namespace +} // namespace tflite diff --git a/tensorflow/lite/kernels/register.cc b/tensorflow/lite/kernels/register.cc index d2bb6dfd632..53d4c5c5e38 100644 --- a/tensorflow/lite/kernels/register.cc +++ b/tensorflow/lite/kernels/register.cc @@ -295,6 +295,8 @@ BuiltinOpResolver::BuiltinOpResolver() { /* min_version = */ 1, /* max_version = */ 3); AddBuiltin(BuiltinOperator_CUMSUM, Register_CUMSUM()); + AddBuiltin(BuiltinOperator_CALL_ONCE, + tflite::ops::builtin::Register_CALL_ONCE()); AddCustom("NumericVerify", tflite::ops::custom::Register_NUMERIC_VERIFY()); // TODO(andrewharp, ahentz): Move these somewhere more appropriate so that // custom ops aren't always included by default. diff --git a/tensorflow/lite/kernels/subgraph_test_util.cc b/tensorflow/lite/kernels/subgraph_test_util.cc index 8f1964ad10f..6cf3e89b8c1 100644 --- a/tensorflow/lite/kernels/subgraph_test_util.cc +++ b/tensorflow/lite/kernels/subgraph_test_util.cc @@ -19,6 +19,7 @@ limitations under the License. #include <stdint.h> #include <stdlib.h> +#include <random> #include <vector> #include <gtest/gtest.h> @@ -29,6 +30,48 @@ limitations under the License. #include "tensorflow/lite/kernels/kernel_util.h" namespace tflite { + +// Forward declaration for op kernels. +namespace ops { +namespace custom { + +TfLiteRegistration* Register_ASSIGN_VARIABLE(); +TfLiteRegistration* Register_READ_VARIABLE(); + +namespace random_int { + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 0); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + TfLiteTensor* output = GetOutput(context, node, 0); + TfLiteIntArray* outputSize = TfLiteIntArrayCreate(1); + outputSize->data[0] = 1; + // TODO(jaesung): Make output size be changeable depending on user's input to + // make it generic. + return context->ResizeTensor(context, output, outputSize); +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor& output = context->tensors[node->outputs->data[0]]; + + std::random_device rd; + std::uniform_int_distribution<int> dist(1, 32768); + output.data.i32[0] = dist(rd); + return kTfLiteOk; +} + +} // namespace random_int + +TfLiteRegistration* Register_RANDOM_INT() { + static TfLiteRegistration r = {nullptr, nullptr, random_int::Prepare, + random_int::Eval}; + return &r; +} + +} // namespace custom +} // namespace ops + namespace subgraph_test_util { namespace { @@ -328,6 +371,65 @@ void SubgraphBuilder::BuildWhileSubgraph(Subgraph* subgraph) { &node_index); } +void SubgraphBuilder::BuildAssignRandomValueToVariableSubgraph( + Subgraph* subgraph) { + const int kConstResourceId = 0; + const int kRandomValue = 1; + const int kTensorCount = 3; + + // Construct a graph like ths: + // %1 = random_int() + // variable_assign(%0, %1) + + int first_new_tensor_index; + ASSERT_EQ(subgraph->AddTensors(kTensorCount, &first_new_tensor_index), + kTfLiteOk); + ASSERT_EQ(subgraph->SetInputs({}), kTfLiteOk); + ASSERT_EQ(subgraph->SetOutputs({}), kTfLiteOk); + + SetupTensor(subgraph, kRandomValue, kTfLiteInt32); + CreateConstantInt32Tensor(subgraph, kConstResourceId, {1}, {1024}); + + int node_index; + subgraph->AddNodeWithParameters({}, {kRandomValue}, {}, nullptr, 0, nullptr, + ::tflite::ops::custom::Register_RANDOM_INT(), + &node_index); + subgraph->AddNodeWithParameters( + {kConstResourceId, kRandomValue}, {}, {}, nullptr, 0, nullptr, + ::tflite::ops::custom::Register_ASSIGN_VARIABLE(), &node_index); +} + +void SubgraphBuilder::BuildCallOnceAndReadVariableSubgraph(Subgraph* subgraph) { + const int kConstResourceId = 0; + const int kOutput = 1; + const int kTensorCount = 2; + + // Construct a graph like ths: + // Output: %1 + // %1 = read_variable(%0) + + int first_new_tensor_index; + ASSERT_EQ(subgraph->AddTensors(kTensorCount, &first_new_tensor_index), + kTfLiteOk); + ASSERT_EQ(subgraph->SetInputs({}), kTfLiteOk); + ASSERT_EQ(subgraph->SetOutputs({kOutput}), kTfLiteOk); + + SetupTensor(subgraph, kOutput, kTfLiteInt32); + CreateConstantInt32Tensor(subgraph, kConstResourceId, {1}, {1024}); + + TfLiteCallOnceParams* params = reinterpret_cast<TfLiteCallOnceParams*>( + malloc(sizeof(TfLiteCallOnceParams))); + params->init_subgraph_index = 1; + + int node_index; + subgraph->AddNodeWithParameters({}, {}, {}, nullptr, 0, params, + ::tflite::ops::builtin::Register_CALL_ONCE(), + &node_index); + subgraph->AddNodeWithParameters( + {kConstResourceId}, {kOutput}, {}, nullptr, 0, nullptr, + ::tflite::ops::custom::Register_READ_VARIABLE(), &node_index); +} + void SubgraphBuilder::CreateConstantInt32Tensor(Subgraph* subgraph, int tensor_index, const std::vector<int>& shape, diff --git a/tensorflow/lite/kernels/subgraph_test_util.h b/tensorflow/lite/kernels/subgraph_test_util.h index 7306f82344d..e2de12b5434 100644 --- a/tensorflow/lite/kernels/subgraph_test_util.h +++ b/tensorflow/lite/kernels/subgraph_test_util.h @@ -85,6 +85,14 @@ class SubgraphBuilder { // 2 inputs, 2 outputs. void BuildWhileSubgraph(Subgraph* subgraph); + // Build a subgraph that assigns a random value to a variable. + // No input/output. + void BuildAssignRandomValueToVariableSubgraph(Subgraph* graph); + + // Build a subgraph with CallOnce op and ReadVariable op. + // No input and 1 output. + void BuildCallOnceAndReadVariableSubgraph(Subgraph* graph); + private: void CreateConstantInt32Tensor(Subgraph* subgraph, int tensor_index, const std::vector<int>& shape, diff --git a/tensorflow/lite/micro/kernels/cmsis-nn/add.cc b/tensorflow/lite/micro/kernels/cmsis-nn/add.cc index 6db88839073..2816e118271 100644 --- a/tensorflow/lite/micro/kernels/cmsis-nn/add.cc +++ b/tensorflow/lite/micro/kernels/cmsis-nn/add.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/lite/kernels/internal/reference/add.h" -#include "cmsis/CMSIS/NN/Include/arm_nnfunctions.h" +#include "CMSIS/NN/Include/arm_nnfunctions.h" #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/kernels/internal/quantization_util.h" #include "tensorflow/lite/kernels/internal/reference/integer_ops/add.h" diff --git a/tensorflow/lite/micro/kernels/cmsis-nn/conv.cc b/tensorflow/lite/micro/kernels/cmsis-nn/conv.cc index 80a0a2ae748..65e94fcec05 100644 --- a/tensorflow/lite/micro/kernels/cmsis-nn/conv.cc +++ b/tensorflow/lite/micro/kernels/cmsis-nn/conv.cc @@ -15,8 +15,8 @@ limitations under the License. #include "tensorflow/lite/kernels/internal/reference/conv.h" -#include "cmsis/CMSIS/NN/Include/arm_nn_types.h" -#include "cmsis/CMSIS/NN/Include/arm_nnfunctions.h" +#include "CMSIS/NN/Include/arm_nn_types.h" +#include "CMSIS/NN/Include/arm_nnfunctions.h" #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/kernels/internal/common.h" diff --git a/tensorflow/lite/micro/kernels/cmsis-nn/depthwise_conv.cc b/tensorflow/lite/micro/kernels/cmsis-nn/depthwise_conv.cc index 3a59b71c985..7715dbe465d 100644 --- a/tensorflow/lite/micro/kernels/cmsis-nn/depthwise_conv.cc +++ b/tensorflow/lite/micro/kernels/cmsis-nn/depthwise_conv.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/lite/kernels/internal/reference/integer_ops/depthwise_conv.h" -#include "cmsis/CMSIS/NN/Include/arm_nnfunctions.h" +#include "CMSIS/NN/Include/arm_nnfunctions.h" #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/kernels/internal/common.h" diff --git a/tensorflow/lite/micro/kernels/cmsis-nn/fully_connected.cc b/tensorflow/lite/micro/kernels/cmsis-nn/fully_connected.cc index 9f901d436a1..11a0f0bdc23 100644 --- a/tensorflow/lite/micro/kernels/cmsis-nn/fully_connected.cc +++ b/tensorflow/lite/micro/kernels/cmsis-nn/fully_connected.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/lite/kernels/internal/reference/fully_connected.h" -#include "cmsis/CMSIS/NN/Include/arm_nnfunctions.h" +#include "CMSIS/NN/Include/arm_nnfunctions.h" #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/kernels/internal/common.h" diff --git a/tensorflow/lite/micro/kernels/cmsis-nn/mul.cc b/tensorflow/lite/micro/kernels/cmsis-nn/mul.cc index e7e23818f5e..20686500ac8 100644 --- a/tensorflow/lite/micro/kernels/cmsis-nn/mul.cc +++ b/tensorflow/lite/micro/kernels/cmsis-nn/mul.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/lite/kernels/internal/reference/mul.h" -#include "cmsis/CMSIS/NN/Include/arm_nnfunctions.h" +#include "CMSIS/NN/Include/arm_nnfunctions.h" #include "tensorflow/lite/kernels/internal/quantization_util.h" #include "tensorflow/lite/kernels/internal/reference/integer_ops/mul.h" #include "tensorflow/lite/kernels/internal/reference/process_broadcast_shapes.h" diff --git a/tensorflow/lite/micro/kernels/cmsis-nn/pooling.cc b/tensorflow/lite/micro/kernels/cmsis-nn/pooling.cc index cd2d799e734..e1ac2b595a3 100644 --- a/tensorflow/lite/micro/kernels/cmsis-nn/pooling.cc +++ b/tensorflow/lite/micro/kernels/cmsis-nn/pooling.cc @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/kernels/internal/reference/pooling.h" -#include "cmsis/CMSIS/NN/Include/arm_nnfunctions.h" +#include "CMSIS/NN/Include/arm_nnfunctions.h" #include "flatbuffers/base.h" // from @flatbuffers #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/kernels/internal/reference/integer_ops/pooling.h" diff --git a/tensorflow/lite/micro/kernels/cmsis-nn/softmax.cc b/tensorflow/lite/micro/kernels/cmsis-nn/softmax.cc index 60e1a9a88b0..9ca08abe862 100644 --- a/tensorflow/lite/micro/kernels/cmsis-nn/softmax.cc +++ b/tensorflow/lite/micro/kernels/cmsis-nn/softmax.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/lite/kernels/internal/reference/softmax.h" -#include "cmsis/CMSIS/NN/Include/arm_nnfunctions.h" +#include "CMSIS/NN/Include/arm_nnfunctions.h" #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" #include "tensorflow/lite/kernels/kernel_util.h" #include "tensorflow/lite/micro/kernels/kernel_util.h" diff --git a/tensorflow/lite/micro/kernels/cmsis-nn/svdf.cc b/tensorflow/lite/micro/kernels/cmsis-nn/svdf.cc index 16358e62e10..f4ee0c73ccf 100644 --- a/tensorflow/lite/micro/kernels/cmsis-nn/svdf.cc +++ b/tensorflow/lite/micro/kernels/cmsis-nn/svdf.cc @@ -16,8 +16,8 @@ limitations under the License. #include <cmath> #include <cstdint> -#include "cmsis/CMSIS/NN/Include/arm_nn_types.h" -#include "cmsis/CMSIS/NN/Include/arm_nnfunctions.h" +#include "CMSIS/NN/Include/arm_nn_types.h" +#include "CMSIS/NN/Include/arm_nnfunctions.h" #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/kernels/internal/common.h" diff --git a/tensorflow/lite/micro/tools/make/Makefile b/tensorflow/lite/micro/tools/make/Makefile index b28fd19d15e..e24279b5da6 100644 --- a/tensorflow/lite/micro/tools/make/Makefile +++ b/tensorflow/lite/micro/tools/make/Makefile @@ -51,7 +51,6 @@ ALL_TAGS := $(TAGS) $(TARGET) # include directories from one source. INCLUDES := \ -I. \ --I$(MAKEFILE_DIR)/downloads/ \ -I$(MAKEFILE_DIR)/downloads/gemmlowp \ -I$(MAKEFILE_DIR)/downloads/flatbuffers/include \ -I$(MAKEFILE_DIR)/downloads/ruy @@ -232,7 +231,6 @@ tensorflow/lite/core/api/op_resolver.cc \ tensorflow/lite/core/api/tensor_utils.cc \ tensorflow/lite/kernels/internal/quantization_util.cc \ tensorflow/lite/kernels/kernel_util.cc \ -tensorflow/lite/schema/schema_conversion_utils.cc \ tensorflow/lite/schema/schema_utils.cc MICROLITE_CC_SRCS := $(filter-out $(MICROLITE_TEST_SRCS), $(MICROLITE_CC_BASE_SRCS)) @@ -309,7 +307,6 @@ tensorflow/lite/kernels/op_macros.h \ tensorflow/lite/kernels/padding.h \ tensorflow/lite/portable_type_to_tflitetype.h \ tensorflow/lite/schema/schema_generated.h \ -tensorflow/lite/schema/schema_conversion_utils.h \ tensorflow/lite/schema/schema_utils.h \ tensorflow/lite/version.h diff --git a/tensorflow/lite/micro/tools/make/ext_libs/cmsis_nn.inc b/tensorflow/lite/micro/tools/make/ext_libs/cmsis_nn.inc index 053d4584300..b4caadf9252 100644 --- a/tensorflow/lite/micro/tools/make/ext_libs/cmsis_nn.inc +++ b/tensorflow/lite/micro/tools/make/ext_libs/cmsis_nn.inc @@ -118,11 +118,15 @@ ifneq ($(filter cmsis-nn,$(ALL_TAGS)),) $(CMSIS_PATH)CMSIS/NN/Include/arm_nnfunctions.h \ $(CMSIS_PATH)CMSIS/NN/Include/arm_nnsupportfunctions.h - # Need to add the CMSIS Core includes path. - # All other CMSIS header files are included with their relative path - # in the CMSIS-NN micro kernel source files in - # tensorflow/lite/micro/kernels/cmsis-nn + # We add -I$(CMSIS_PATH) to enable the code in the TFLM repo (mostly in the + # tensorflow/lite/micro/kernels/cmsis-nn) to use include paths relative to + # the CMSIS code-base. + # + # The CMSIS code itself uses includes such as #include "arm_math.h" and so + # we add $(CMSIS_PATH)/CMSIS/Core/Include etc. to be able to build the CMSIS + # code without any modifications. INCLUDES += \ + -I$(CMSIS_PATH) \ -I$(CMSIS_PATH)/CMSIS/Core/Include \ -I$(CMSIS_PATH)/CMSIS/DSP/Include \ -I$(CMSIS_PATH)/CMSIS/NN/Include diff --git a/tensorflow/lite/schema/schema.fbs b/tensorflow/lite/schema/schema.fbs index 62045344755..ef1592193f7 100644 --- a/tensorflow/lite/schema/schema.fbs +++ b/tensorflow/lite/schema/schema.fbs @@ -352,7 +352,8 @@ enum BuiltinOperator : int32 { SEGMENT_SUM = 125, BATCH_MATMUL = 126, PLACEHOLDER_FOR_GREATER_OP_CODES = 127, - CUMSUM = 128 + CUMSUM = 128, + CALL_ONCE = 129 } @@ -460,6 +461,7 @@ union BuiltinOptions { SegmentSumOptions, BatchMatMulOptions, CumsumOptions, + CallOnceOptions } enum Padding : byte { SAME, VALID } @@ -955,6 +957,10 @@ table IfOptions { else_subgraph_index:int; } +table CallOnceOptions { + init_subgraph_index:int; +} + table WhileOptions { cond_subgraph_index:int; body_subgraph_index:int; diff --git a/tensorflow/lite/schema/schema_generated.h b/tensorflow/lite/schema/schema_generated.h index e7d91a93a99..dd9b655c6e6 100755 --- a/tensorflow/lite/schema/schema_generated.h +++ b/tensorflow/lite/schema/schema_generated.h @@ -325,6 +325,9 @@ struct MatrixSetDiagOptionsT; struct IfOptions; struct IfOptionsT; +struct CallOnceOptions; +struct CallOnceOptionsT; + struct WhileOptions; struct WhileOptionsT; @@ -792,11 +795,12 @@ enum BuiltinOperator { BuiltinOperator_BATCH_MATMUL = 126, BuiltinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES = 127, BuiltinOperator_CUMSUM = 128, + BuiltinOperator_CALL_ONCE = 129, BuiltinOperator_MIN = BuiltinOperator_ADD, - BuiltinOperator_MAX = BuiltinOperator_CUMSUM + BuiltinOperator_MAX = BuiltinOperator_CALL_ONCE }; -inline const BuiltinOperator (&EnumValuesBuiltinOperator())[129] { +inline const BuiltinOperator (&EnumValuesBuiltinOperator())[130] { static const BuiltinOperator values[] = { BuiltinOperator_ADD, BuiltinOperator_AVERAGE_POOL_2D, @@ -926,13 +930,14 @@ inline const BuiltinOperator (&EnumValuesBuiltinOperator())[129] { BuiltinOperator_SEGMENT_SUM, BuiltinOperator_BATCH_MATMUL, BuiltinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES, - BuiltinOperator_CUMSUM + BuiltinOperator_CUMSUM, + BuiltinOperator_CALL_ONCE }; return values; } inline const char * const *EnumNamesBuiltinOperator() { - static const char * const names[130] = { + static const char * const names[131] = { "ADD", "AVERAGE_POOL_2D", "CONCATENATION", @@ -1062,13 +1067,14 @@ inline const char * const *EnumNamesBuiltinOperator() { "BATCH_MATMUL", "PLACEHOLDER_FOR_GREATER_OP_CODES", "CUMSUM", + "CALL_ONCE", nullptr }; return names; } inline const char *EnumNameBuiltinOperator(BuiltinOperator e) { - if (flatbuffers::IsOutRange(e, BuiltinOperator_ADD, BuiltinOperator_CUMSUM)) return ""; + if (flatbuffers::IsOutRange(e, BuiltinOperator_ADD, BuiltinOperator_CALL_ONCE)) return ""; const size_t index = static_cast<size_t>(e); return EnumNamesBuiltinOperator()[index]; } @@ -1177,11 +1183,12 @@ enum BuiltinOptions { BuiltinOptions_SegmentSumOptions = 100, BuiltinOptions_BatchMatMulOptions = 101, BuiltinOptions_CumsumOptions = 102, + BuiltinOptions_CallOnceOptions = 103, BuiltinOptions_MIN = BuiltinOptions_NONE, - BuiltinOptions_MAX = BuiltinOptions_CumsumOptions + BuiltinOptions_MAX = BuiltinOptions_CallOnceOptions }; -inline const BuiltinOptions (&EnumValuesBuiltinOptions())[103] { +inline const BuiltinOptions (&EnumValuesBuiltinOptions())[104] { static const BuiltinOptions values[] = { BuiltinOptions_NONE, BuiltinOptions_Conv2DOptions, @@ -1285,13 +1292,14 @@ inline const BuiltinOptions (&EnumValuesBuiltinOptions())[103] { BuiltinOptions_DensifyOptions, BuiltinOptions_SegmentSumOptions, BuiltinOptions_BatchMatMulOptions, - BuiltinOptions_CumsumOptions + BuiltinOptions_CumsumOptions, + BuiltinOptions_CallOnceOptions }; return values; } inline const char * const *EnumNamesBuiltinOptions() { - static const char * const names[104] = { + static const char * const names[105] = { "NONE", "Conv2DOptions", "DepthwiseConv2DOptions", @@ -1395,13 +1403,14 @@ inline const char * const *EnumNamesBuiltinOptions() { "SegmentSumOptions", "BatchMatMulOptions", "CumsumOptions", + "CallOnceOptions", nullptr }; return names; } inline const char *EnumNameBuiltinOptions(BuiltinOptions e) { - if (flatbuffers::IsOutRange(e, BuiltinOptions_NONE, BuiltinOptions_CumsumOptions)) return ""; + if (flatbuffers::IsOutRange(e, BuiltinOptions_NONE, BuiltinOptions_CallOnceOptions)) return ""; const size_t index = static_cast<size_t>(e); return EnumNamesBuiltinOptions()[index]; } @@ -1818,6 +1827,10 @@ template<> struct BuiltinOptionsTraits<tflite::CumsumOptions> { static const BuiltinOptions enum_value = BuiltinOptions_CumsumOptions; }; +template<> struct BuiltinOptionsTraits<tflite::CallOnceOptions> { + static const BuiltinOptions enum_value = BuiltinOptions_CallOnceOptions; +}; + struct BuiltinOptionsUnion { BuiltinOptions type; void *value; @@ -2666,6 +2679,14 @@ struct BuiltinOptionsUnion { return type == BuiltinOptions_CumsumOptions ? reinterpret_cast<const tflite::CumsumOptionsT *>(value) : nullptr; } + tflite::CallOnceOptionsT *AsCallOnceOptions() { + return type == BuiltinOptions_CallOnceOptions ? + reinterpret_cast<tflite::CallOnceOptionsT *>(value) : nullptr; + } + const tflite::CallOnceOptionsT *AsCallOnceOptions() const { + return type == BuiltinOptions_CallOnceOptions ? + reinterpret_cast<const tflite::CallOnceOptionsT *>(value) : nullptr; + } }; bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type); @@ -8992,6 +9013,60 @@ inline flatbuffers::Offset<IfOptions> CreateIfOptions( flatbuffers::Offset<IfOptions> CreateIfOptions(flatbuffers::FlatBufferBuilder &_fbb, const IfOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +struct CallOnceOptionsT : public flatbuffers::NativeTable { + typedef CallOnceOptions TableType; + int32_t init_subgraph_index; + CallOnceOptionsT() + : init_subgraph_index(0) { + } +}; + +struct CallOnceOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef CallOnceOptionsT NativeTableType; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_INIT_SUBGRAPH_INDEX = 4 + }; + int32_t init_subgraph_index() const { + return GetField<int32_t>(VT_INIT_SUBGRAPH_INDEX, 0); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField<int32_t>(verifier, VT_INIT_SUBGRAPH_INDEX) && + verifier.EndTable(); + } + CallOnceOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(CallOnceOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset<CallOnceOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const CallOnceOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct CallOnceOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_init_subgraph_index(int32_t init_subgraph_index) { + fbb_.AddElement<int32_t>(CallOnceOptions::VT_INIT_SUBGRAPH_INDEX, init_subgraph_index, 0); + } + explicit CallOnceOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + CallOnceOptionsBuilder &operator=(const CallOnceOptionsBuilder &); + flatbuffers::Offset<CallOnceOptions> Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset<CallOnceOptions>(end); + return o; + } +}; + +inline flatbuffers::Offset<CallOnceOptions> CreateCallOnceOptions( + flatbuffers::FlatBufferBuilder &_fbb, + int32_t init_subgraph_index = 0) { + CallOnceOptionsBuilder builder_(_fbb); + builder_.add_init_subgraph_index(init_subgraph_index); + return builder_.Finish(); +} + +flatbuffers::Offset<CallOnceOptions> CreateCallOnceOptions(flatbuffers::FlatBufferBuilder &_fbb, const CallOnceOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + struct WhileOptionsT : public flatbuffers::NativeTable { typedef WhileOptions TableType; int32_t cond_subgraph_index; @@ -9886,6 +9961,9 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const tflite::CumsumOptions *builtin_options_as_CumsumOptions() const { return builtin_options_type() == tflite::BuiltinOptions_CumsumOptions ? static_cast<const tflite::CumsumOptions *>(builtin_options()) : nullptr; } + const tflite::CallOnceOptions *builtin_options_as_CallOnceOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_CallOnceOptions ? static_cast<const tflite::CallOnceOptions *>(builtin_options()) : nullptr; + } const flatbuffers::Vector<uint8_t> *custom_options() const { return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS); } @@ -10330,6 +10408,10 @@ template<> inline const tflite::CumsumOptions *Operator::builtin_options_as<tfli return builtin_options_as_CumsumOptions(); } +template<> inline const tflite::CallOnceOptions *Operator::builtin_options_as<tflite::CallOnceOptions>() const { + return builtin_options_as_CallOnceOptions(); +} + struct OperatorBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; @@ -13810,6 +13892,32 @@ inline flatbuffers::Offset<IfOptions> CreateIfOptions(flatbuffers::FlatBufferBui _else_subgraph_index); } +inline CallOnceOptionsT *CallOnceOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new CallOnceOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void CallOnceOptions::UnPackTo(CallOnceOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = init_subgraph_index(); _o->init_subgraph_index = _e; } +} + +inline flatbuffers::Offset<CallOnceOptions> CallOnceOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const CallOnceOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateCallOnceOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset<CallOnceOptions> CreateCallOnceOptions(flatbuffers::FlatBufferBuilder &_fbb, const CallOnceOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const CallOnceOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _init_subgraph_index = _o->init_subgraph_index; + return tflite::CreateCallOnceOptions( + _fbb, + _init_subgraph_index); +} + inline WhileOptionsT *WhileOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new WhileOptionsT(); UnPackTo(_o, _resolver); @@ -14918,6 +15026,10 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob auto ptr = reinterpret_cast<const tflite::CumsumOptions *>(obj); return verifier.VerifyTable(ptr); } + case BuiltinOptions_CallOnceOptions: { + auto ptr = reinterpret_cast<const tflite::CallOnceOptions *>(obj); + return verifier.VerifyTable(ptr); + } default: return true; } } @@ -15344,6 +15456,10 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c auto ptr = reinterpret_cast<const tflite::CumsumOptions *>(obj); return ptr->UnPack(resolver); } + case BuiltinOptions_CallOnceOptions: { + auto ptr = reinterpret_cast<const tflite::CallOnceOptions *>(obj); + return ptr->UnPack(resolver); + } default: return nullptr; } } @@ -15758,6 +15874,10 @@ inline flatbuffers::Offset<void> BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff auto ptr = reinterpret_cast<const tflite::CumsumOptionsT *>(value); return CreateCumsumOptions(_fbb, ptr, _rehasher).Union(); } + case BuiltinOptions_CallOnceOptions: { + auto ptr = reinterpret_cast<const tflite::CallOnceOptionsT *>(value); + return CreateCallOnceOptions(_fbb, ptr, _rehasher).Union(); + } default: return 0; } } @@ -16172,6 +16292,10 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL value = new tflite::CumsumOptionsT(*reinterpret_cast<tflite::CumsumOptionsT *>(u.value)); break; } + case BuiltinOptions_CallOnceOptions: { + value = new tflite::CallOnceOptionsT(*reinterpret_cast<tflite::CallOnceOptionsT *>(u.value)); + break; + } default: break; } @@ -16689,6 +16813,11 @@ inline void BuiltinOptionsUnion::Reset() { delete ptr; break; } + case BuiltinOptions_CallOnceOptions: { + auto ptr = reinterpret_cast<tflite::CallOnceOptionsT *>(value); + delete ptr; + break; + } default: break; } value = nullptr; diff --git a/tensorflow/lite/special_rules.bzl b/tensorflow/lite/special_rules.bzl index 68143c976f4..2b80b2ef9c1 100644 --- a/tensorflow/lite/special_rules.bzl +++ b/tensorflow/lite/special_rules.bzl @@ -77,3 +77,16 @@ def tflite_schema_utils_friends(): # Its usage should be rare, and is often abused by tools that are doing # Flatbuffer creation/manipulation in unofficially supported ways." return ["//..."] + +def flex_portable_tensorflow_deps(): + """Returns dependencies for building portable tensorflow in Flex delegate.""" + + return [ + "//third_party/fft2d:fft2d_headers", + "//third_party/eigen3", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/strings:str_format", + "@gemmlowp", + "@icu//:common", + "//third_party/icu/data:conversion_data", + ] diff --git a/tensorflow/lite/tools/versioning/runtime_version.cc b/tensorflow/lite/tools/versioning/runtime_version.cc index 8dfd41f9b9d..5366b46ca2b 100644 --- a/tensorflow/lite/tools/versioning/runtime_version.cc +++ b/tensorflow/lite/tools/versioning/runtime_version.cc @@ -320,6 +320,7 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code, {{BuiltinOperator_RANK, 1}, "1.14.0"}, {{BuiltinOperator_WHILE, 1}, "1.15.0"}, {{BuiltinOperator_CUMSUM, 1}, kPendingReleaseVersion}, + {{BuiltinOperator_CALL_ONCE, 1}, kPendingReleaseVersion}, }); std::pair<BuiltinOperator, int> version_key = {op_code, op_version}; diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index 8d4b60fcb90..210bb260269 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -33,7 +33,7 @@ from tensorflow.python.util.tf_export import tf_export # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2020, 10, 29) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2020, 10, 30) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None diff --git a/tensorflow/python/data/experimental/kernel_tests/optimize_dataset_test.py b/tensorflow/python/data/experimental/kernel_tests/optimize_dataset_test.py index 5a6879b3f0e..51bd5fbcbaf 100644 --- a/tensorflow/python/data/experimental/kernel_tests/optimize_dataset_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/optimize_dataset_test.py @@ -208,6 +208,26 @@ class OptimizeDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): dataset = dataset.with_options(options) self.assertDatasetProduces(dataset, expected_output=[[0]]) + @combinations.generate(test_base.default_test_combinations()) + def testOptimizationDoubleOptimizeDatasetNested(self): + def flat_map_fn(_): + dataset = dataset_ops.Dataset.from_tensors(0) + dataset = dataset.apply(testing.assert_next(["MapAndBatch"])) + dataset = dataset.skip(0) + # Should be fused by map and batch fusion + dataset = dataset.map(lambda x: x) + dataset = dataset.batch(1) + return dataset + + dataset = dataset_ops.Dataset.from_tensors(0) + dataset = dataset.flat_map(flat_map_fn) + dataset = dataset_ops._OptimizeDataset(dataset, ["map_and_batch_fusion"], + [], []) + dataset = dataset_ops._OptimizeDataset(dataset, ["noop_elimination"], [], + []) + + self.assertDatasetProduces(dataset, expected_output=[[0]]) + @combinations.generate( combinations.times( test_base.default_test_combinations(), @@ -542,5 +562,6 @@ class OptimizeDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): self.assertEqual(cpu_budget, 1000) self.assertEqual(ram_budget, 999999999) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/dataset_serialization_test_base.py b/tensorflow/python/data/experimental/kernel_tests/serialization/dataset_serialization_test_base.py index aea4934260e..44fe30f6729 100644 --- a/tensorflow/python/data/experimental/kernel_tests/serialization/dataset_serialization_test_base.py +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/dataset_serialization_test_base.py @@ -39,8 +39,7 @@ from tensorflow.python.util import nest def remove_variants(get_next_op): - # TODO(b/72408568): Remove this once session.run can get - # variant tensors. + # TODO(b/72408568): Remove this once session.run can get variant tensors. """Remove variants from a nest structure, so sess.run will execute.""" def _remove_variant(x): @@ -61,7 +60,7 @@ class DatasetSerializationTestBase(test.TestCase): # TODO(b/72657739): Remove sparse_tensor argument, which is to test the # (deprecated) saveable `SparseTensorSliceDataset`, once the API - # `from_sparse_tensor_slices()`and related tests are deleted. + # `from_sparse_tensor_slices()` and related tests are deleted. def run_core_tests(self, ds_fn, num_outputs, sparse_tensors=False): """Runs the core tests. diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py index 71b1303ecf4..4915a5b8fe3 100644 --- a/tensorflow/python/eager/backprop.py +++ b/tensorflow/python/eager/backprop.py @@ -1111,7 +1111,13 @@ class GradientTape(object): Note: Unless you set `persistent=True` a GradientTape can only be used to compute one set of gradients (or jacobians). - See[wikipedia article](http://en.wikipedia.org/wiki/jacobian_matrix_and_determinant) + Note: By default the jacobian implementation uses parallel for (pfor), which + creates a tf.function under the hood for each jacobian call. For better + performance, and to avoid recompilation and vectorization rewrites on each + call, enclose GradientTape code in @tf.function. + + See[wikipedia + article](http://en.wikipedia.org/wiki/jacobian_matrix_and_determinant) for the definition of a Jacobian. Example usage: @@ -1243,6 +1249,12 @@ class GradientTape(object): Note: Unless you set `persistent=True` a GradientTape can only be used to compute one set of gradients (or jacobians). + Note: By default the batch_jacobian implementation uses parallel for (pfor), + which creates a tf.function under the hood for each batch_jacobian call. + For better performance, and to avoid recompilation and vectorization + rewrites on each call, enclose GradientTape code in @tf.function. + + Example usage: ```python diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py index 584fed73158..87fad9b9b1f 100644 --- a/tensorflow/python/eager/backprop_test.py +++ b/tensorflow/python/eager/backprop_test.py @@ -43,6 +43,7 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import custom_gradient from tensorflow.python.ops import embedding_ops from tensorflow.python.ops import functional_ops +from tensorflow.python.ops import gradient_checker_v2 from tensorflow.python.ops import gradients from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn @@ -1713,6 +1714,35 @@ class JacobianTest(test.TestCase): dy_xx_answer = [[[2., 0], [0, 2.]]] * 10 self.assertAllClose(dy_xx_answer, self.evaluate(dy_xx)) + def test_nested_batch_jacobian_foldl(self): + def _grad(f): + def _grad_function(primal): + with backprop.GradientTape() as tape: + tape.watch(primal) + primal_out = f(primal) + return tape.batch_jacobian(primal_out, primal) + return _grad_function + + def _func(x): + return array_ops.reshape( + functional_ops.foldl_v2(lambda a, b: math_ops.cos(a + b), + array_ops.transpose(x)), + [1, 1]) + + f = _func + x = constant_op.constant([[1., 2.]]) + for _ in range(2): + theoretical, numerical = gradient_checker_v2.compute_gradient(f, [x]) + self.assertAllClose(theoretical, numerical, rtol=1e-3) + f = _grad(f) + expected_flat = array_ops.reshape(numerical, [-1]) + self.assertAllClose(expected_flat, + array_ops.reshape(f(x), [-1]), + rtol=1e-3) + self.assertAllClose(expected_flat, + array_ops.reshape(def_function.function(f)(x), [-1]), + rtol=1e-3) + @test_util.run_in_graph_and_eager_modes def test_indexed_slices(self): with backprop.GradientTape(persistent=True) as g: diff --git a/tensorflow/python/eager/backprop_util.py b/tensorflow/python/eager/backprop_util.py index 117b05e0956..e1c719d4a9d 100644 --- a/tensorflow/python/eager/backprop_util.py +++ b/tensorflow/python/eager/backprop_util.py @@ -19,12 +19,35 @@ from __future__ import division from __future__ import print_function from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import handle_data_util + + +def _DTypeFromTensor(tensor): + """Extract either `tensor.dtype` or the unanimous sub-type of a variant.""" + dtype = tensor.dtype + if dtype.base_dtype == dtypes.variant: + # If we know statically that the data a variant points to is non-trainable + # then the variant itself is non-trainable. + if isinstance(tensor, ops.EagerTensor): + handle_data = tensor._handle_data # pylint: disable=protected-access + else: + handle_data = handle_data_util.get_resource_handle_data(tensor) + if (handle_data is not None + and handle_data.is_set + and handle_data.shape_and_type): + first_type = handle_data.shape_and_type[0].dtype + if all(shape_and_type.dtype == first_type + for shape_and_type in handle_data.shape_and_type): + return first_type + return dtype def IsTrainable(tensor_or_dtype): + """Determines whether a tensor or dtype supports infinitesimal changes.""" if tensor_util.is_tensor(tensor_or_dtype): - dtype = tensor_or_dtype.dtype + dtype = _DTypeFromTensor(tensor_or_dtype) else: dtype = tensor_or_dtype dtype = dtypes.as_dtype(dtype) diff --git a/tensorflow/python/eager/benchmarks_test.py b/tensorflow/python/eager/benchmarks_test.py index 37ab60918c2..c567fcb762c 100644 --- a/tensorflow/python/eager/benchmarks_test.py +++ b/tensorflow/python/eager/benchmarks_test.py @@ -1085,13 +1085,12 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase): self._run(func, num_iters=self._num_iters_2_by_2) def _benchmark_tf_dropout_2_by_2(self, + rate=0.5, is_rate_tensor=True, noise_shape=None, device=CPU): if is_rate_tensor: - rate = constant_op.constant(0.5, dtype=dtypes.float32) - else: - rate = 0.5 + rate = constant_op.constant(rate, dtype=dtypes.float32) with context.device(device): def func(): @@ -1112,6 +1111,19 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase): def benchmark_tf_dropout_2_by_2_GPU(self): self._benchmark_tf_dropout_2_by_2(device=GPU) + def benchmark_tf_dropout_scalar_rate_2_by_2_CPU_rate_0(self): + self._benchmark_tf_dropout_2_by_2(rate=0, is_rate_tensor=False) + + def benchmark_tf_dropout_scalar_rate_2_by_2_GPU_rate_0(self): + self._benchmark_tf_dropout_2_by_2(rate=0.0, + is_rate_tensor=False, device=GPU) + + def benchmark_tf_dropout_2_by_2_CPU_rate_0(self): + self._benchmark_tf_dropout_2_by_2(rate=0.0) + + def benchmark_tf_dropout_2_by_2_GPU_rate_0(self): + self._benchmark_tf_dropout_2_by_2(rate=0, device=GPU) + def _benchmark_transpose(self, m, num_iters, diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index 9d9cf0b50c3..d92440e9594 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -1387,6 +1387,7 @@ class _HigherOrderTapeGradientFunctions(_TapeGradientFunctions): gradients with respect to the inputs. """ outputs = [] + iteration_count = 0 # First we need to figure out how many side outputs from the forward pass # will be required. We do this in a temporary graph to avoid actually # running multiple copies of the backward pass (one per _GradientsHelper @@ -1401,15 +1402,42 @@ class _HigherOrderTapeGradientFunctions(_TapeGradientFunctions): # all of the forward op's outputs: symbolic gradients with tf.gradients # instead rely on regenerating backward functions when higher-order # gradients are requested. - while len(outputs) < len(self._func_graph.outputs): + while (len(outputs) < len(self._func_graph.outputs) + # It's possible for gradient generation to add new ops to the forward + # pass. If all of the new outputs are non-trainable, there's no + # reason to continue. + and any(backprop_util.IsTrainable(output) + for output in self._func_graph.outputs[len(outputs):])): + iteration_count += 1 + if iteration_count >= 20 and iteration_count % 5 == 0: + new_op_with_trainable_output = None + num_new_trainable_outputs = 0 + for output in self._func_graph.outputs[len(outputs):]: + if backprop_util.IsTrainable(output): + num_new_trainable_outputs += 1 + new_op_with_trainable_output = output.op + logging.warning( + ("Determining side outputs for the function '{}' is taking longer " + "than expected ({} iterations, typically this converges in 5 or " + "so). This could indicate that a gradient registration is adding " + "new ops to the forward pass every time gradients are generated. " + "{} new trainable output(s) were added this iteration, one from " + "the following op:\n {}\nThis may indicate a TensorFlow bug, or " + "an issue in a tf.custom_gradient.") + .format( + self._func_graph.name, iteration_count, + num_new_trainable_outputs, new_op_with_trainable_output)) outputs = list(self._func_graph.outputs) self._build_functions_for_outputs( outputs, inference_args, input_tangents) + (forward_function, forward_graph, backward_function, output_indices, num_output_tangents) = ( self._build_functions_for_outputs( outputs, inference_args, input_tangents)) - if len(self._func_graph.outputs) != len(outputs): + if (len(self._func_graph.outputs) > len(outputs) + and any(backprop_util.IsTrainable(output) + for output in self._func_graph.outputs[len(outputs):])): raise AssertionError( ("Unexpectedly added new outputs to the forward function when " "building the backward function: {}").format( diff --git a/tensorflow/python/framework/func_graph.py b/tensorflow/python/framework/func_graph.py index f2dc99a5ba0..f70bd75e36b 100644 --- a/tensorflow/python/framework/func_graph.py +++ b/tensorflow/python/framework/func_graph.py @@ -443,6 +443,11 @@ class FuncGraph(ops.Graph): return self._fallback_outer_graph return current + @outer_graph.setter + def outer_graph(self, new_outer_graph): + """Sets `outer_graph` to `new_outer_graph`.""" + self._weak_outer_graph = weakref.ref(new_outer_graph) + @property def output_types(self): return [t.dtype for t in self.outputs] diff --git a/tensorflow/python/grappler/layout_optimizer_test.py b/tensorflow/python/grappler/layout_optimizer_test.py index bee2874b294..ca95908802d 100644 --- a/tensorflow/python/grappler/layout_optimizer_test.py +++ b/tensorflow/python/grappler/layout_optimizer_test.py @@ -782,6 +782,45 @@ class LayoutOptimizerTest(test.TestCase): self.assertIn('concat-2-LayoutOptimizer', nodes) self.assertAllClose(output_val_ref, output_val, atol=1e-3) + @test_util.deprecated_graph_mode_only + def testConcatWithControlDependencyFor5DTensor(self): + if not test.is_gpu_available(cuda_only=True): + self.skipTest('GPU required') + random_seed.set_random_seed(0) + x = random_ops.truncated_normal([2, 2, 14, 14, 1], seed=0) + w = random_ops.truncated_normal([2, 2, 2, 1, 2], seed=0) + strides = [1, 1, 1, 1, 1] + y = gen_nn_ops.conv3d(x, w, strides, 'SAME') + axis = constant_op.constant(4) + var = variables.Variable(3) + assign = state_ops.assign(var, 6) + with ops.control_dependencies([assign]): + concat = array_ops.concat([y, y], axis) + output = array_ops.identity(concat) + + with session.Session(config=_get_config(False)) as sess: + output_val_ref = self.evaluate(output) + + with session.Session(config=_get_config()) as sess: + metadata = config_pb2.RunMetadata() + output_val = sess.run(output, run_metadata=metadata) + + nodes = [] + num_transposes = 0 + for node in metadata.cost_graph.node: + if _is_transpose(node.name): + num_transposes += 1 + nodes.append(node.name) + + # Four transposes were initially added in the Expand phase of + # LayoutOptimizer; two of them are cancelled out in the Collapse phase. + expected_num_transposes = 2 + self.assertEqual(expected_num_transposes, num_transposes) + self._assert_trans_ndhwc_to_ncdhw('Conv3D-0', nodes) + self._assert_trans_ncdhw_to_ndhwc('concat-0-0', nodes) + self._assert_map_ndhwc_to_ncdhw('concat-2', nodes) + self.assertAllClose(output_val_ref, output_val, atol=1e-3) + @test_util.deprecated_graph_mode_only def testFill(self): if test.is_gpu_available(cuda_only=True): @@ -1397,107 +1436,167 @@ class LayoutOptimizerTest(test.TestCase): @test_util.deprecated_graph_mode_only def testConv3D(self): - if test.is_gpu_available(cuda_only=True): - random_seed.set_random_seed(0) - x = random_ops.truncated_normal([1, 784], seed=0) - conv = _two_layer_model(x) - filters = random_ops.truncated_normal([2, 2, 2, 1, 2], seed=0) - strides_val = [1, 1, 1, 1, 1] - x_3d = array_ops.reshape(conv, [-1, 4, 14, 14, 1]) - conv3d = gen_nn_ops.conv3d(x_3d, filters, strides_val, 'VALID') - output = array_ops.identity(conv3d) + if not test.is_gpu_available(cuda_only=True): + self.skipTest('GPU required') + random_seed.set_random_seed(0) + x = random_ops.truncated_normal([2, 2, 14, 14, 1], seed=0) + w = random_ops.truncated_normal([2, 2, 2, 1, 2], seed=0) + strides = [1, 1, 1, 1, 1] + y = gen_nn_ops.conv3d(x, w, strides, 'SAME') + output = array_ops.identity(y) - with session.Session(config=_get_config(False)) as sess: - output_val_ref = sess.run(output) + with session.Session(config=_get_config(False)) as sess: + output_val_ref = sess.run(output) - with session.Session(config=_get_config()) as sess: - metadata = config_pb2.RunMetadata() - output_val = sess.run(output, run_metadata=metadata) + with session.Session(config=_get_config()) as sess: + metadata = config_pb2.RunMetadata() + output_val = sess.run(output, run_metadata=metadata) - nodes = [] - num_transposes = 0 - for node in metadata.cost_graph.node: - if _is_transpose(node.name): - num_transposes += 1 - nodes.append(node.name) + nodes = [] + num_transposes = 0 + for node in metadata.cost_graph.node: + if _is_transpose(node.name): + num_transposes += 1 + nodes.append(node.name) - expected_num_transposes = 2 - self.assertEqual(expected_num_transposes, num_transposes) - self._assert_trans_nhwc_to_nchw('Conv2D-0', nodes) - self._assert_trans_ndhwc_to_ncdhw('Conv3D-0', nodes) - self._assert_trans_ncdhw_to_ndhwc('Conv3D-0-0', nodes) - self.assertAllClose(output_val_ref, output_val, atol=1e-3) + expected_num_transposes = 2 + self.assertEqual(expected_num_transposes, num_transposes) + self._assert_trans_ndhwc_to_ncdhw('Conv3D-0', nodes) + self._assert_trans_ncdhw_to_ndhwc('Conv3D-0-0', nodes) + self.assertAllClose(output_val_ref, output_val, atol=1e-3) @test_util.deprecated_graph_mode_only def testConv3DBackpropInput(self): - if test.is_gpu_available(cuda_only=True): - random_seed.set_random_seed(0) - x = random_ops.truncated_normal([1, 784], seed=0) - conv = _two_layer_model(x) - x_3d = array_ops.reshape(conv, [-1, 4, 14, 14, 1]) - filters = random_ops.truncated_normal([2, 2, 2, 1, 1], seed=0) - strides_val = [1, 1, 1, 1, 1] - shape = array_ops.shape(x_3d) - conv3d_grad = gen_nn_ops.conv3d_backprop_input_v2(shape, filters, x_3d, - strides_val, 'SAME') - output = array_ops.identity(conv3d_grad) + if not test.is_gpu_available(cuda_only=True): + self.skipTest('GPU required') + random_seed.set_random_seed(0) + dy = random_ops.truncated_normal([2, 2, 14, 14, 1], seed=0) + w = random_ops.truncated_normal([2, 2, 2, 1, 1], seed=0) + strides = [1, 1, 1, 1, 1] + x_shape = array_ops.shape(dy) + dx = gen_nn_ops.conv3d_backprop_input_v2(x_shape, w, dy, strides, 'SAME') + output = array_ops.identity(dx) - with session.Session(config=_get_config(False)) as sess: - output_val_ref = sess.run(output) + with session.Session(config=_get_config(False)) as sess: + output_val_ref = sess.run(output) - with session.Session(config=_get_config()) as sess: - metadata = config_pb2.RunMetadata() - output_val = sess.run(output, run_metadata=metadata) + with session.Session(config=_get_config()) as sess: + metadata = config_pb2.RunMetadata() + output_val = sess.run(output, run_metadata=metadata) - nodes = [] - num_transposes = 0 - for node in metadata.cost_graph.node: - if _is_transpose(node.name): - num_transposes += 1 - nodes.append(node.name) + nodes = [] + num_transposes = 0 + for node in metadata.cost_graph.node: + if _is_transpose(node.name): + num_transposes += 1 + nodes.append(node.name) - expected_num_transposes = 2 - self.assertEqual(expected_num_transposes, num_transposes) - self._assert_trans_nhwc_to_nchw('Conv2D-0', nodes) - self._assert_vec_ndhwc_to_ncdhw('Conv3DBackpropInputV2-0', nodes) - self._assert_trans_ndhwc_to_ncdhw('Conv3DBackpropInputV2-2', nodes) - self._assert_trans_ncdhw_to_ndhwc('Conv3DBackpropInputV2-0-0', nodes) - self.assertAllClose(output_val_ref, output_val, atol=1e-3) + expected_num_transposes = 2 + self.assertEqual(expected_num_transposes, num_transposes) + self._assert_vec_ndhwc_to_ncdhw('Conv3DBackpropInputV2-0', nodes) + self._assert_trans_ndhwc_to_ncdhw('Conv3DBackpropInputV2-2', nodes) + self._assert_trans_ncdhw_to_ndhwc('Conv3DBackpropInputV2-0-0', nodes) + self.assertAllClose(output_val_ref, output_val, atol=1e-3) @test_util.deprecated_graph_mode_only def testConv3DBackpropFilter(self): - if test.is_gpu_available(cuda_only=True): - random_seed.set_random_seed(0) - x = random_ops.truncated_normal([1, 784], seed=0) - conv = _two_layer_model(x) - x_3d = array_ops.reshape(conv, [-1, 4, 14, 14, 1]) - filters = random_ops.truncated_normal([2, 2, 2, 1, 1], seed=0) - strides_val = [1, 1, 1, 1, 1] - shape = constant_op.constant([2, 2, 2, 1, 1], shape=[5]) - conv3d_grad = gen_nn_ops.conv3d_backprop_filter_v2( - x_3d, shape, x_3d, strides_val, 'SAME') - output = array_ops.identity(conv3d_grad) + if not test.is_gpu_available(cuda_only=True): + self.skipTest('GPU required') + random_seed.set_random_seed(0) + x = random_ops.truncated_normal([2, 2, 14, 14, 1], seed=0) + dy = random_ops.truncated_normal([2, 2, 14, 14, 1], seed=0) + strides = [1, 1, 1, 1, 1] + w_shape = constant_op.constant([2, 2, 2, 1, 1], shape=[5]) + dw = gen_nn_ops.conv3d_backprop_filter_v2(x, w_shape, dy, strides, 'SAME') + output = array_ops.identity(dw) - with session.Session(config=_get_config(False)) as sess: - output_val_ref = sess.run(output) + with session.Session(config=_get_config(False)) as sess: + output_val_ref = sess.run(output) - with session.Session(config=_get_config()) as sess: - metadata = config_pb2.RunMetadata() - output_val = sess.run(output, run_metadata=metadata) + with session.Session(config=_get_config()) as sess: + metadata = config_pb2.RunMetadata() + output_val = sess.run(output, run_metadata=metadata) - nodes = [] - num_transposes = 0 - for node in metadata.cost_graph.node: - if _is_transpose(node.name): - num_transposes += 1 - nodes.append(node.name) + nodes = [] + num_transposes = 0 + for node in metadata.cost_graph.node: + if _is_transpose(node.name): + num_transposes += 1 + nodes.append(node.name) - expected_num_transposes = 2 - self.assertEqual(expected_num_transposes, num_transposes) - self._assert_trans_nhwc_to_nchw('Conv2D-0', nodes) - self._assert_trans_ndhwc_to_ncdhw('Conv3DBackpropFilterV2-0', nodes) - self._assert_trans_ndhwc_to_ncdhw('Conv3DBackpropFilterV2-2', nodes) - self.assertAllClose(output_val_ref, output_val, atol=1e-3) + expected_num_transposes = 2 + self.assertEqual(expected_num_transposes, num_transposes) + self._assert_trans_ndhwc_to_ncdhw('Conv3DBackpropFilterV2-0', nodes) + self._assert_trans_ndhwc_to_ncdhw('Conv3DBackpropFilterV2-2', nodes) + self.assertAllClose(output_val_ref, output_val, atol=1e-3) + + @test_util.deprecated_graph_mode_only + def testBiasAddFor5DTensor(self): + if not test.is_gpu_available(cuda_only=True): + self.skipTest('GPU required') + random_seed.set_random_seed(0) + x = random_ops.truncated_normal([2, 2, 14, 14, 1], seed=0) + w = random_ops.truncated_normal([2, 2, 2, 1, 2], seed=0) + b = random_ops.truncated_normal([2], seed=0) + strides = [1, 1, 1, 1, 1] + y = gen_nn_ops.conv3d(x, w, strides, 'SAME') + y = gen_nn_ops.bias_add(y, b, 'NHWC') + output = array_ops.identity(y) + + with session.Session(config=_get_config(False)) as sess: + output_val_ref = sess.run(output) + + with session.Session(config=_get_config()) as sess: + metadata = config_pb2.RunMetadata() + output_val = sess.run(output, run_metadata=metadata) + + nodes = [] + num_transposes = 0 + for node in metadata.cost_graph.node: + if _is_transpose(node.name): + num_transposes += 1 + nodes.append(node.name) + + expected_num_transposes = 2 + self.assertEqual(expected_num_transposes, num_transposes) + self._assert_trans_ndhwc_to_ncdhw('Conv3D-0', nodes) + self._assert_trans_ncdhw_to_ndhwc('BiasAdd-0-0', nodes) + self.assertAllClose(output_val_ref, output_val, atol=1e-3) + + @test_util.deprecated_graph_mode_only + def testBiasAddGradFor5DTensor(self): + if not test.is_gpu_available(cuda_only=True): + self.skipTest('GPU required') + random_seed.set_random_seed(0) + dy = random_ops.truncated_normal([2, 2, 14, 14, 1], seed=0) + w = random_ops.truncated_normal([2, 2, 2, 1, 1], seed=0) + strides = [1, 1, 1, 1, 1] + dy_shape = array_ops.shape(dy) + dx = gen_nn_ops.conv3d_backprop_input_v2(dy_shape, w, dy, strides, 'SAME') + db = gen_nn_ops.bias_add_grad(dx, 'NHWC') + output = array_ops.identity(db) + + with session.Session(config=_get_config(False)) as sess: + output_val_ref = sess.run(output) + + with session.Session(config=_get_config()) as sess: + metadata = config_pb2.RunMetadata() + output_val = sess.run(output, run_metadata=metadata) + + nodes = [] + num_transposes = 0 + for node in metadata.cost_graph.node: + if _is_transpose(node.name): + num_transposes += 1 + nodes.append(node.name) + + # The output of Conv3DBackpropInputV2 won't be converted back to NDHWC + # because of the BiasAddGrad. + expected_num_transposes = 1 + self.assertEqual(expected_num_transposes, num_transposes) + self._assert_vec_ndhwc_to_ncdhw('Conv3DBackpropInputV2-0', nodes) + self._assert_trans_ndhwc_to_ncdhw('Conv3DBackpropInputV2-2', nodes) + self.assertAllClose(output_val_ref, output_val, atol=1e-3) @test_util.deprecated_graph_mode_only def testSliceWithNonConstAxis(self): @@ -1536,6 +1635,44 @@ class LayoutOptimizerTest(test.TestCase): self._assert_vec_nhwc_to_nchw('Slice-2', nodes) self.assertAllClose(output_val_ref, output_val, atol=1e-3) + @test_util.deprecated_graph_mode_only + def testSliceWithNonConstAxisFor5DTensor(self): + if not test.is_gpu_available(cuda_only=True): + self.skipTest('GPU required') + random_seed.set_random_seed(0) + x = random_ops.truncated_normal([2, 2, 14, 14, 1], seed=0) + w = random_ops.truncated_normal([2, 2, 2, 1, 2], seed=0) + strides = [1, 1, 1, 1, 1] + y = gen_nn_ops.conv3d(x, w, strides, 'SAME') + size = array_ops.placeholder(dtype='int32') + s = array_ops.slice(y, [0, 0, 0, 0, 0], size) + output = array_ops.identity(s) + + size_val = [1, 1, 2, 2, 1] + with session.Session(config=_get_config(False)) as sess: + output_val_ref = sess.run(output, feed_dict={size: size_val}) + + with session.Session(config=_get_config()) as sess: + metadata = config_pb2.RunMetadata() + output_val = sess.run( + output, run_metadata=metadata, feed_dict={size: size_val}) + + nodes = [] + num_transposes = 0 + for node in metadata.cost_graph.node: + if _is_transpose(node.name): + num_transposes += 1 + nodes.append(node.name) + + # Four transposes were initially added in the Expand phase of + # LayoutOptimizer; two of them are cancelled out in the Collapse phase. + expected_num_transposes = 2 + self.assertEqual(expected_num_transposes, num_transposes) + self._assert_trans_ndhwc_to_ncdhw('Conv3D-0', nodes) + self._assert_trans_ncdhw_to_ndhwc('Slice-0-0', nodes) + self._assert_vec_ndhwc_to_ncdhw('Slice-2', nodes) + self.assertAllClose(output_val_ref, output_val, atol=1e-3) + @test_util.deprecated_graph_mode_only def testStridedSliceWithNonConstAxis(self): if test.is_gpu_available(cuda_only=True): @@ -1722,6 +1859,79 @@ class LayoutOptimizerTest(test.TestCase): self._assert_vec_nchw_to_nhwc('ShapeN-0-0', nodes) self.assertAllEqual(output_val_ref, output_val) + @test_util.deprecated_graph_mode_only + def testShapeNFor5DTensor(self): + if not test.is_gpu_available(cuda_only=True): + self.skipTest('GPU required') + h = array_ops.placeholder(dtype='float32') + x = array_ops.reshape(h, [-1, 2, 14, 14, 1]) + w = random_ops.truncated_normal([2, 2, 2, 1, 2], seed=0) + strides = [1, 1, 1, 1, 1] + y = gen_nn_ops.conv3d(x, w, strides, 'SAME') + shapen = array_ops.shape_n([y, y]) + output = math_ops.add(shapen[0], shapen[1]) + + x_val = [1.7] * 784 + with session.Session(config=_get_config(False)) as sess: + output_val_ref = sess.run(output, feed_dict={h: x_val}) + + with session.Session(config=_get_config()) as sess: + metadata = config_pb2.RunMetadata() + output_val = sess.run(output, run_metadata=metadata, feed_dict={h: x_val}) + + nodes = [] + num_transposes = 0 + for node in metadata.cost_graph.node: + if _is_transpose(node.name): + num_transposes += 1 + nodes.append(node.name) + + expected_num_transposes = 1 + self.assertEqual(expected_num_transposes, num_transposes) + self._assert_trans_ndhwc_to_ncdhw('Conv3D-0', nodes) + self._assert_vec_ncdhw_to_ndhwc('ShapeN-0-0', nodes) + self._assert_vec_ncdhw_to_ndhwc('ShapeN-1-0', nodes) + self.assertAllClose(output_val_ref, output_val, atol=1e-3) + + @test_util.deprecated_graph_mode_only + def testIdentityNFor4DAnd5DTensors(self): + if not test.is_gpu_available(cuda_only=True): + self.skipTest('GPU required') + h = array_ops.placeholder(dtype='float32') + x = array_ops.reshape(h, [-1, 2, 14, 14, 1]) + w = random_ops.truncated_normal([2, 2, 2, 1, 4], seed=0) + strides = [1, 1, 1, 1, 1] + y = gen_nn_ops.conv3d(x, w, strides, 'SAME') + x1 = array_ops.reshape(h, [-1, 784]) + y1 = _two_layer_model(x1) + outputs = array_ops.identity_n([y1, y]) + new_x0 = array_ops.reshape(outputs[0], [-1, 2, 14, 14, 1]) + new_x1 = array_ops.reshape(outputs[1], [-1, 2, 14, 14, 1]) + output = math_ops.add(new_x0, new_x1) + + x_val = [1.7] * 784 + with session.Session(config=_get_config(False)) as sess: + output_val_ref = sess.run(output, feed_dict={h: x_val}) + + with session.Session(config=_get_config()) as sess: + metadata = config_pb2.RunMetadata() + output_val = sess.run(output, run_metadata=metadata, feed_dict={h: x_val}) + + nodes = [] + num_transposes = 0 + for node in metadata.cost_graph.node: + if _is_transpose(node.name): + num_transposes += 1 + nodes.append(node.name) + + expected_num_transposes = 4 + self.assertEqual(expected_num_transposes, num_transposes) + self._assert_trans_ndhwc_to_ncdhw('Conv3D-0', nodes) + self._assert_trans_nhwc_to_nchw('Conv2D-0', nodes) + self._assert_trans_ncdhw_to_ndhwc('IdentityN-1-0', nodes) + self._assert_trans_nchw_to_nhwc('IdentityN-0-0', nodes) + self.assertAllClose(output_val_ref, output_val, atol=1e-3) + @test_util.deprecated_graph_mode_only def testShapeNFollowedByNotConvertibleNodeReshape(self): if test.is_gpu_available(cuda_only=True): diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py index 3705da54716..8206231753f 100644 --- a/tensorflow/python/keras/engine/base_layer.py +++ b/tensorflow/python/keras/engine/base_layer.py @@ -138,10 +138,15 @@ class Layer(module.Module, version_utils.LayerVersionSelector): Attributes: name: The name of the layer (string). - dtype: The dtype of the layer's computations and weights. If mixed - precision is used with a `tf.keras.mixed_precision.Policy`, this is - instead just the dtype of the layer's weights, as the computations are - done in a different dtype. + dtype: The dtype of the layer's weights. + variable_dtype: Alias of `dtype`. + compute_dtype: The dtype of the layer's computations. Layers automatically + cast inputs to this dtype which causes the computations and output to also + be in this dtype. When mixed precision is used with a + `tf.keras.mixed_precision.Policy`, this will be different than + `variable_dtype`. + dtype_policy: The layer's dtype policy. See the + `tf.keras.mixed_precision.Policy` documentation for details. trainable_weights: List of variables to be included in backprop. non_trainable_weights: List of variables that should not be included in backprop. @@ -517,7 +522,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector): Arguments: name: Variable name. shape: Variable shape. Defaults to scalar if unspecified. - dtype: The type of the variable. Defaults to `self.dtype` or `float32`. + dtype: The type of the variable. Defaults to `self.dtype`. initializer: Initializer instance (callable). regularizer: Regularizer instance (callable). trainable: Boolean, whether the variable should be part of the layer's @@ -2373,6 +2378,11 @@ class Layer(module.Module, version_utils.LayerVersionSelector): mixed precision is used, this is the same as `Layer.dtype`, the dtype of the weights. + Layers automatically cast their inputs to the compute dtype, which causes + computations and the output to be in the compute dtype as well. This is done + by the base Layer class in `Layer.__call__`, so you do not have to insert + these casts if implementing your own layer. + Layers often perform certain internal computations in higher precision when `compute_dtype` is float16 or bfloat16 for numeric stability. The output will still typically be float16 or bfloat16 in such cases. diff --git a/tensorflow/python/keras/mixed_precision/loss_scale_optimizer.py b/tensorflow/python/keras/mixed_precision/loss_scale_optimizer.py index f1ca255133e..0802669e471 100644 --- a/tensorflow/python/keras/mixed_precision/loss_scale_optimizer.py +++ b/tensorflow/python/keras/mixed_precision/loss_scale_optimizer.py @@ -395,53 +395,35 @@ _DEFAULT_GROWTH_STEPS = 2000 # pylint: disable=g-classes-have-attributes @keras_export('keras.mixed_precision.LossScaleOptimizer') class LossScaleOptimizer(_DelegatingTrackableMixin, optimizer_v2.OptimizerV2): - """An optimizer that applies loss scaling. + """An optimizer that applies loss scaling to prevent numeric underflow. - Loss scaling is a process that multiplies the loss by a multiplier called the - loss scale, and divides each gradient by the same multiplier. The pseudocode - for this process is: - - ``` - loss = ... - loss *= loss_scale - grads = gradients(loss, vars) - grads /= loss_scale - ``` - - Mathematically, loss scaling has no effect, but can help avoid numerical - underflow in intermediate gradients when float16 tensors are used. By - multiplying the loss, each intermediate gradient will have the same multiplier - applied. - - The loss scale can either be a fixed constant, chosen by the user, or be - dynamically determined. Using a dynamic loss scale is highly recommend and is - the default behavior, as choosing a specific fixed loss scale is difficult. - Every step, the dynamic loss scale is potentially updated to a new value. - Dynamic loss scaling sometimes causes the loss scale to be too high and cause - the gradients to overflow, in which case gradients are not applied to - variables that step. + Loss scaling is a technique to prevent numeric underflow in intermediate + gradients when float16 is used. To prevent underflow, the loss is multiplied + (or "scaled") by a certain factor called the "loss scale", which causes + intermediate gradients to be scaled by the loss scale as well. The final + gradients are divided (or "unscaled") by the loss scale to bring them back to + their original value. `LossScaleOptimizer` wraps another optimizer and applies loss scaling to it. - Loss scaling is applied whenever gradients are computed, either through - `minimize()` or `get_gradients()`. If dynamic, the loss scale is updated - whenever gradients are applied, either through `minimize()` or - `apply_gradients()`. For example: + By default, the loss scale is dynamically updated over time so you do not have + to choose the loss scale. The `minimize` method automatically scales the loss, + unscales the gradients, and updates the loss scale so all you have to do is + wrap your optimizer with a `LossScaleOptimizer` if you use `minimize`. For + example: >>> opt = tf.keras.optimizers.SGD(0.25) >>> opt = tf.keras.mixed_precision.LossScaleOptimizer(opt) >>> var = tf.Variable(1.) >>> loss_fn = lambda: var ** 2 - >>> # 'minimize' applies loss scaling to the loss and updates the loss sale. + >>> # 'minimize' applies loss scaling and updates the loss sale. >>> opt.minimize(loss_fn, var_list=var) >>> var.numpy() 0.5 - If a `tf.GradientTape` is used to compute gradients instead of - `LossScaleOptimizer.minimize` or `LossScaleOptimizer.get_gradients`, the loss - and gradients must be scaled manually. This can be done by calling - `LossScaleOptimizer.get_scaled_loss` before passing the loss to - `tf.GradientTape`, and `LossScaleOptimizer.get_unscaled_gradients` after - computing the gradients with `tf.GradientTape`. For example: + If a `tf.GradientTape` is used to compute gradients instead of `minimize`, you + must scale the loss and gradients manually. This can be done with the + `LossScaleOptimizer.get_scaled_loss` and + `LossScaleOptimizer.get_unscaled_gradients` methods. For example: >>> with tf.GradientTape() as tape: ... loss = loss_fn() @@ -452,8 +434,18 @@ class LossScaleOptimizer(_DelegatingTrackableMixin, optimizer_v2.OptimizerV2): >>> var.numpy() 0.25 + Warning: If you forget to call `get_scaled_loss` or `get_unscaled_gradients` + (or both) when using a `tf.GradientTape`, the model will likely converge to a + worse quality. Please make sure you call each function exactly once. + + When mixed precision with float16 is used, there is typically no risk of + underflow affecting model quality if loss scaling is properly used. See + [the mixed precision guide]( + https://www.tensorflow.org/guide/keras/mixed_precision) for more information + on how to use mixed precision. + Args: - inner_optimizer: The Optimizer instance to wrap. + inner_optimizer: The `tf.keras.optimizers.Optimizer` instance to wrap. dynamic: Bool indicating whether dynamic loss scaling is used. Defaults to True. If True, the loss scale will be dynamically updated over time using an algorithm that keeps the loss scale at approximately its optimal value. @@ -463,11 +455,11 @@ class LossScaleOptimizer(_DelegatingTrackableMixin, optimizer_v2.OptimizerV2): performance overhead to dynamic loss scaling compared to fixed loss scaling. initial_scale: The initial loss scale. If `dynamic` is True, this defaults - to 2 ** 15. If `dynamic` is False, this must be specified and acts as the - sole loss scale, as the loss scale does not change over time. When dynamic - loss scaling is used, is better for this to be a very high number, because - a loss scale that is too high gets lowered far more quickly than a loss - scale that is too low gets raised. + to `2 ** 15`. If `dynamic` is False, this must be specified and acts as + the sole loss scale, as the loss scale does not change over time. When + dynamic loss scaling is used, is better for this to be a very high number, + because a loss scale that is too high gets lowered far more quickly than a + loss scale that is too low gets raised. dynamic_growth_steps: With dynamic loss scaling, every `dynamic_growth_steps` steps with finite gradients, the loss scale is doubled. Defaults to 2000. If a nonfinite gradient is encountered, the @@ -476,27 +468,33 @@ class LossScaleOptimizer(_DelegatingTrackableMixin, optimizer_v2.OptimizerV2): `LossScaleOptimizer.dynamic_counter`. This argument can only be specified if `dynamic` is True. - To use a fixed loss scale instead of dynamic loss scale, pass `dynamic=False` - and pass the loss scale to `initial_scale`. For example: + `LossScaleOptimizer` will occasionally skip applying gradients to the + variables, in which case the trainable variables will not change that step. + This is done because the dynamic loss scale will sometimes be raised too + high, causing overflow in the gradients. Typically, the first 2 to 15 steps of + the model are skipped as the initial loss scale is very high, but afterwards + steps will only be skipped on average 0.05% of the time (the fraction of steps + skipped is `1 / dynamic_growth_steps`). - >>> opt = tf.keras.mixed_precision.LossScaleOptimizer( - ... tf.keras.optimizers.SGD(), dynamic=False, initial_scale=1024) - >>> opt.loss_scale.numpy() - 1024. + `LossScaleOptimizer` delegates all public `Optimizer` methods to the inner + optimizer. Additionally, in methods `minimize` and `get_gradients, it scales + the loss and unscales the gradients. In methods `minimize` and + `apply_gradients`, it additionally updates the loss scale and skips applying + gradients if any gradient has a nonfinite value. + + ### Hyperparameters Hyperparameters can be accessed and set on the LossScaleOptimizer, which will be delegated to the wrapped optimizer. >>> opt = tf.keras.optimizers.Adam(beta_1=0.8, epsilon=1e-5) - >>> lso = tf.keras.mixed_precision.LossScaleOptimizer(opt) - >>> opt.beta_1 + >>> opt = tf.keras.mixed_precision.LossScaleOptimizer(opt) + >>> opt.beta_1 # Equivalent to `opt.inner_optimizer.beta_1` 0.8 - >>> lso.beta_1 # Equivalent to `opt.beta_1` - 0.8 - >>> lso.beta_1 = 0.7 # Equivalent to `opt.beta_1 = 0.7` + >>> opt.beta_1 = 0.7 # Equivalent to `opt.inner_optimizer.beta_1 = 0.7` >>> opt.beta_1 0.7 - >>> lso.beta_1 + >>> opt.inner_optimizer.beta_1 0.7 However, accessing or setting non-hyperparameters is not delegated to the @@ -504,19 +502,19 @@ class LossScaleOptimizer(_DelegatingTrackableMixin, optimizer_v2.OptimizerV2): `epsilon` is not, as the Adam optimizer only calls `Optimizer._set_hyper` on `beta_1`. - >>> opt.epsilon + >>> opt.inner_optimizer.epsilon 1e-5 - >>> lso.epsilon + >>> opt.epsilon Traceback (most recent call last): ... AttributeError: 'LossScaleOptimizer' object has no attribute 'epsilon' - >>> lso.epsilon = 1e-4 - >>> opt.epsilon + >>> opt.epsilon = 1e-4 # This does NOT set epsilon on `opt.inner_optimizer` + >>> opt.inner_optimizer.epsilon >>> 1e-5 In the above example, despite epsilon being set on the LossScaleOptimizer, the old epsilon value will still be used when training as epsilon was not set on - the Adam optimizer. + the inner optimizer. """ _HAS_AGGREGATE_GRAD = True @@ -562,6 +560,7 @@ class LossScaleOptimizer(_DelegatingTrackableMixin, optimizer_v2.OptimizerV2): @property def dynamic(self): + """Bool indicating whether dynamic loss scaling is used.""" return isinstance(self._loss_scale, _DynamicLossScaleState) @property @@ -593,7 +592,8 @@ class LossScaleOptimizer(_DelegatingTrackableMixin, optimizer_v2.OptimizerV2): def initial_scale(self): """The initial loss scale. - This is None if `LossScaleOptimizer.dynamic` is False. + If `LossScaleOptimizer.dynamic` is False, this is the same number as + `LossScaleOptimizer.loss_scale`, as the loss scale never changes. """ if isinstance(self._loss_scale, _DynamicLossScaleState): return self._loss_scale.initial_loss_scale @@ -982,6 +982,24 @@ class LossScaleOptimizerV1(LossScaleOptimizer): ... dynamic_growth_steps=500) >>> assert opt1.get_config() == opt2.get_config() + Make sure to also switch from this class to the non-experimental class in + isinstance checks, if you have any. If you do not do this, your model may run + into hard-to-debug issues, as the experimental `LossScaleOptimizer` subclasses + the non-experimental `LossScaleOptimizer`, but not vice versa. It is safe to + switch isinstance checks to the non-experimental `LossScaleOptimizer` even + before using the non-experimental `LossScaleOptimizer`. + + >>> opt1 = tf.keras.mixed_precision.experimental.LossScaleOptimizer( + ... tf.keras.optimizers.SGD(), loss_scale='dynamic') + >>> # The experimental class subclasses the non-experimental class + >>> isinstance(opt1, tf.keras.mixed_precision.LossScaleOptimizer) + True + >>> opt2 = tf.keras.mixed_precision.LossScaleOptimizer( + ... tf.keras.optimizers.SGD()) + >>> # The non-experimental class does NOT subclass the experimental class. + >>> isinstance(opt2, tf.keras.mixed_precision.experimental.LossScaleOptimizer) + False + Args: optimizer: The Optimizer instance to wrap. loss_scale: The loss scale to scale the loss and gradients. This can diff --git a/tensorflow/python/keras/mixed_precision/policy.py b/tensorflow/python/keras/mixed_precision/policy.py index 6b4e2417c35..e0b75b0a1d1 100644 --- a/tensorflow/python/keras/mixed_precision/policy.py +++ b/tensorflow/python/keras/mixed_precision/policy.py @@ -32,6 +32,7 @@ from tensorflow.python.training.experimental import mixed_precision_global_state from tensorflow.python.util.tf_export import keras_export +# pylint: disable=g-classes-have-attributes @keras_export('keras.mixed_precision.Policy', v1=[]) class Policy(object): """A dtype policy for a Keras layer. @@ -39,106 +40,57 @@ class Policy(object): A dtype policy determines a layer's computation and variable dtypes. Each layer has a policy. Policies can be passed to the `dtype` argument of layer constructors, or a global policy can be set with - `tf.keras.mixed_precision.experimental.set_policy`. A layer will default to - the global policy if no policy is passed to it's constructor. + `tf.keras.mixed_precision.set_global_policy`. - For many models, each layer's policy will have the same compute dtype and - variable dtype, which will typically be float32. In this case, we refer to the - singular dtype as the layer's dtype, which can be queried by the property - `tf.keras.layers.Layer.dtype`. + Args: + name: The policy name, which determines the compute and variable dtypes. Can + be any dtype name, such as `'float32'` or `'float64'`, which causes both + the compute and variable dtypes will be that dtype. Can also be the string + `'mixed_float16'` or `'mixed_bfloat16'`, which causes the compute dtype to + be float16 or bfloat16 and the variable dtype to be float32. - When mixed precision training is used, most layers will instead have a float16 - or bfloat16 compute dtype and a float32 variable dtype, and so the layer does - not have a single dtype. When the variable dtype does not match the compute - dtype, variables will be automatically casted to the compute dtype to avoid - type errors. In this case, `tf.keras.layers.Layer.dtype` refers to the - variable dtype, not the compute dtype. See [the mixed precision guide]( - https://www.tensorflow.org/guide/keras/mixed_precision) for more - information on how to use mixed precision. + Typically you only need to interact with dtype policies when using mixed + precision, which is the use of float16 or bfloat16 for computations and + float32 for variables. This is why the term `mixed_precision` appears in the + API name. Mixed precision can be enabled by passing `'mixed_float16'` or + `'mixed_bfloat16'` to `tf.keras.mixed_precision.set_global_policy`. See [the + mixed precision guide](https://www.tensorflow.org/guide/keras/mixed_precision) + for more information on how to use mixed precision. - Policies are constructed by passing a string to the constructor, e.g. - `tf.keras.mixed_precision.Policy('float32')`. The string determines the - compute and variable dtypes. It can be one of the following: + >>> tf.keras.mixed_precision.set_global_policy('mixed_float16') + >>> layer1 = tf.keras.layers.Dense(10) + >>> layer1.dtype_policy # `layer1` will automatically use mixed precision + <Policy "mixed_float16"> + >>> # Can optionally override layer to use float32 instead of mixed precision. + >>> layer2 = tf.keras.layers.Dense(10, dtype='float32') + >>> layer2.dtype_policy + <Policy "float32"> + >>> # Set policy back to initial float32 for future examples. + >>> tf.keras.mixed_precision.set_global_policy('float32') - * Any dtype name, such as 'float32' or 'float64'. Both the variable and - compute dtypes will be that dtype. - * 'mixed_float16' or 'mixed_bfloat16': The compute dtype is float16 or - bfloat16, while the variable dtype is float32. With 'mixed_float16', - `tf.keras.Model.compile` will wrap the optimizer with a - `tf.keras.mixed_precision.LossScaleOptimizer`. These policies are used for - mixed precision training. + In the example above, passing `dtype='float32'` to the layer is equivalent to + passing `dtype=tf.keras.mixed_precision.Policy('float32')`. In general, + passing a dtype to a layer is equivalent to passing the corresponding policy, + so it is never necessary to explicitly construct a `Policy` object. - ### How to use mixed precision in a Keras model - - To use mixed precision in a Keras model, the `'mixed_float16'` or - `'mixed_bfloat16'` policy can be used. - `tf.keras.mixed_precision.experimental.set_policy` can be used to set the - default policy for layers if no policy is passed to them. For example: - - >>> tf.keras.mixed_precision.experimental.set_policy('mixed_float16') - >>> model = tf.keras.models.Sequential([ - ... tf.keras.layers.Input((100,)), - ... # Dense layers use global policy of 'mixed_float16', which does - ... # computations in float16 while keeping variables in float32. - ... tf.keras.layers.Dense(10), - ... tf.keras.layers.Dense(10), - ... # Softmax should be done in float32 for numeric stability. We pass - ... # dtype='float32' to use float32 instead of the global policy. - ... tf.keras.layers.Activation('softmax', dtype='float32') - ... ]) - - Alternatively, the policy can be passed to individual layers instead of - setting the global policy with `set_policy`: - - >>> policy = tf.keras.mixed_precision.Policy('mixed_float16') - >>> model = tf.keras.models.Sequential([ - ... tf.keras.layers.Input((100,)), - ... tf.keras.layers.Dense(10, dtype=policy), - ... tf.keras.layers.Dense(10, dtype=policy), - ... # Softmax should be done in float32 for numeric stability. - ... tf.keras.layers.Activation('softmax', dtype='float32') - ... ]) - - Note the `'mixed_float16'` policy will apply loss scaling by default in - `Model.fit`, `Model.train_on_batch`, and other training methods. If no such - method is used (e.g., a custom training loop is used) and `'mixed_float16'` is - used, the loss scale must be manually applied. See - `tf.keras.mixed_precision.LossScaleOptimizer` for details. For - `'mixed_bfloat16'`, no loss scaling is done and loss scaling never needs to be - manually applied. - - See [the mixed precision guide]( - https://www.tensorflow.org/guide/keras/mixed_precision) for more - information on using mixed precision - - ### How to use float64 in a Keras model - - Using float64 is similar to mixed precision. Either the global policy can be - set to float64, or `dtype='float64'` can be passed to individual layers. For - example, to set the global policy: - - >>> tf.keras.mixed_precision.experimental.set_policy('float64') - >>> model = tf.keras.models.Sequential([ - ... tf.keras.layers.Input((100,)), - ... # All layers use global policy of 'float64', which does computations - ... # and creates variables in float64. - ... tf.keras.layers.Dense(10), - ... tf.keras.layers.Dense(10), - ... tf.keras.layers.Activation('softmax') - ... ]) - >>> # Optionally set policy back to float32 if any other models use float32 - >>> tf.keras.mixed_precision.experimental.set_policy('float32') + Note: `Model.compile` will automatically wrap an optimizer with a + `tf.keras.mixed_precision.LossScaleOptimizer` if you use the `'mixed_float16'` + policy. If you use a custom training loop instead of calling `Model.compile`, + you should explicitly use a `tf.keras.mixed_precision.LossScaleOptimizer` to + avoid numeric underflow with float16. ### How a layer uses its policy's compute dtype - A layer will cast its inputs to its compute dtype in TensorFlow 2. For - example: + A layer casts its inputs to its compute dtype. This causes the layer's + computations and output to also be in the compute dtype. For example: >>> x = tf.ones((4, 4, 4, 4), dtype='float64') >>> # `layer`'s policy defaults to float32. >>> layer = tf.keras.layers.Conv2D(filters=4, kernel_size=2) - >>> # `layer` casts it's inputs to its compute dtype, which is float32, and - >>> # does computations in float32. + >>> layer.compute_dtype # Equivalent to layer.dtype_policy.compute_dtype + 'float32' + >>> # `layer` casts its inputs to its compute dtype and does computations in + >>> # that dtype. >>> y = layer(x) >>> y.dtype tf.float32 @@ -147,7 +99,8 @@ class Policy(object): subclassing your own layer, you do not have to insert any casts. Currently, only tensors in the first argument to the layer's `call` method are - casted. For example: + casted (although this will likely be changed in a future minor release). For + example: >>> class MyLayer(tf.keras.layers.Layer): ... # Bug! `b` will not be casted. @@ -162,45 +115,13 @@ class Policy(object): >>> y.dtype tf.float32 - If writing your own layer, it is recommended to accept tensors only in the - first argument. This way, all tensors are casted to the layer's compute dtype. - `MyLayer` should therefore be written as: + If writing your own layer with multiple inputs, you should either explicitly + cast other tensors to `self.compute_dtype` in `call` or accept all tensors in + the first argument as a list. - >>> class MyLayer(tf.keras.layers.Layer): - ... # Now, all tensor inputs will be casted. - ... def call(self, inputs): - ... a, b = inputs - ... return a + 1., b + 1. - >>> a = tf.constant(1., dtype="float32") - >>> b = tf.constant(1., dtype="float32") - >>> layer = MyLayer(dtype="float64") - >>> x, y = layer((a, b)) - >>> x.dtype - tf.float64 - >>> y.dtype - tf.float64 - - Other arguments are not automatically casted for technical reasons, but this - may change in a future minor release. - - The casting only occurs in TensorFlow 2, but can be enabled if - `tf.compat.v1.disable_v2_behavior()` has been called with - `tf.compat.v1.keras.layers.enable_v2_dtype_behavior()`. - - A layer subclass can prevent its inputs from being autocasted by passing - `autocast=False` to the layer constructor. For example: - - >>> class NonAutoCastingLayer(tf.keras.layers.Layer): - ... def __init__(self, **kwargs): - ... kwargs['autocast'] = False - ... super(NonAutoCastingLayer, self).__init__(**kwargs) - ... def call(self, inp): - ... return inp - >>> x = tf.ones((4, 4, 4, 4), dtype='float32') - >>> layer = NonAutoCastingLayer(dtype='float64') - >>> y = layer(x) # Will not cast inputs to it's compute dtype of float64 - >>> y.dtype - tf.float32 + The casting only occurs in TensorFlow 2. If + `tf.compat.v1.disable_v2_behavior()` has been called, you can enable the + casting behavior with `tf.compat.v1.keras.layers.enable_v2_dtype_behavior()`. ### How a layer uses its policy's variable dtype @@ -209,30 +130,33 @@ class Policy(object): If a layer's compute and variable dtypes differ, `add_weight` will wrap floating-point variables with a special wrapper called an `AutoCastVariable`. - This wrapper is identical to the original variable except it casts itself to - the layer's compute dtype when used within `Layer.call`. Outside `Layer.call`, - the variable is not casted. + `AutoCastVariable` is identical to the original variable except it casts + itself to the layer's compute dtype when used within `Layer.call`. This means + if you are writing a layer, you do not have to explicitly cast the variables + to the layer's compute dtype. For example: + + >>> class SimpleDense(tf.keras.layers.Layer): + ... + ... def build(self, input_shape): + ... # With mixed precision, self.kernel is a float32 AutoCastVariable + ... self.kernel = self.add_weight('kernel', (input_shape[-1], 10)) + ... + ... def call(self, inputs): + ... # With mixed precision, self.kernel will be casted to float16 + ... return tf.linalg.matmul(inputs, self.kernel) + ... + >>> dtype_policy = tf.keras.mixed_precision.Policy('mixed_float16') + >>> layer = SimpleDense(dtype=dtype_policy) + >>> y = layer(tf.ones((10, 10))) + >>> y.dtype + tf.float16 + >>> layer.kernel.dtype + tf.float32 A layer author can prevent a variable from being wrapped with an - `AutoCastVariable` by passing `experimental_autocast=False` to `add_weight`: - - >>> class MyLayer(tf.keras.layers.Layer): - ... def build(self, input_shape): - ... self.x = self.add_weight('x') - ... self.y = self.add_weight('y', experimental_autocast=False) - >>> policy = tf.keras.mixed_precision.Policy('mixed_float16') - >>> layer = MyLayer(dtype=policy) - >>> layer.build((2, 2)) - >>> layer.x - <AutoCastVariable 'x:0' shape=() dtype=float32 dtype_to_cast_to=float32, - numpy=...> - >>> layer.y - <tf.Variable 'y:0' shape=() dtype=float32, numpy=...> - - Passing `experimental_autocast=False` is useful for layers which may - internally do some math in the variable dtype instead of the compute dtype. - For example, you may wish to compute variable statistics, such as mean and - variance, in the variable dtype. + `AutoCastVariable` by passing `experimental_autocast=False` to `add_weight`, + which is useful if the float32 value of the variable must be accessed within + the layer. ### How to write a layer that supports mixed precision and float64. @@ -241,69 +165,33 @@ class Policy(object): automatically casts inputs, creates variables of the correct type, and in the case of mixed precision, wraps variables with `AutoCastVariables`. - For example, this simple dense layer does not require any additional work to - support mixed precision or float64. Keras automatically casts the inputs and - variable to the appropriate dtype. - - >>> class MyDense(tf.keras.layers.Layer): - ... def build(self, input_shape): - ... self.kernel = self.add_weight('kernel', (input_shape[-1], 10)) - ... def call(self, inputs): - ... return tf.matmul(inputs, self.kernel) - - >>> policy = tf.keras.mixed_precision.Policy('mixed_float16') - >>> layer = MyDense(dtype=policy) - >>> x = np.random.rand(10, 10) - >>> y = layer(x) - >>> y.dtype - tf.float16 - The primary case where you need extra work to support mixed precision or float64 is when you create a new tensor, such as with `tf.ones` or - `tf.constant`. In such cases, you must create the tensor of the correct dtype. - For example, suppose you modify the `MyDense` layer to add a random number to - the output using `tf.random.normal`. You must pass the input dtype to - `tf.random.normal` to ensure the dtypes match. + `tf.random.normal`, In such cases, you must create the tensor of the correct + dtype. For example, if you call `tf.random.normal`, you must pass the compute + dtype, which is the dtype the inputs have been casted to: - >>> class MyDense(tf.keras.layers.Layer): - ... def build(self, input_shape): - ... self.kernel = self.add_weight('kernel', (input_shape[-1], 10)) + >>> class AddRandom(tf.keras.layers.Layer): + ... ... def call(self, inputs): + ... # We must pass `dtype=inputs.dtype`, otherwise a TypeError may + ... # occur when adding `inputs` to `rand`. ... rand = tf.random.normal(shape=inputs.shape, dtype=inputs.dtype) - ... return tf.matmul(inputs, self.kernel) + rand - >>> - >>> layer = MyDense(dtype=policy) + ... return inputs + rand + + >>> dtype_policy = tf.keras.mixed_precision.Policy('mixed_float16') + >>> layer = AddRandom(dtype=dtype_policy) >>> y = layer(x) >>> y.dtype tf.float16 - If you did not pass `dtype=inputs.dtype` to `tf.random.normal`, a `TypeError` - would have occurred. This is because the dtype defaults to `"float32"`, so the - layer would only work if the inputs were float32. + If you did not pass `dtype=inputs.dtype` to `tf.random.normal`, a + `TypeError` would have occurred. This is because the `tf.random.normal`'s + dtype defaults to `"float32"`, but the input dtype is float16. You cannot add + a float32 tensor with a float16 tensor. """ def __init__(self, name): - """Constructs the policy. - - The `name` argument determines the compute and variable dtype. The compute - and variable dtypes can only be specified through `name`, and cannot be - specified directly. - - `name` is also used by `tf.keras.Model.compile`. If `name` is - `"mixed_float16"`, `tf.keras.Model.compile` will automatically wrap the - optimizer with a LossScaleOptimizer if it is not already a - LossScaleOptimizer. - - Args: - name: A string. Can be one of the following values: - * Any dtype name, such as 'float32' or 'float64'. Both the variable and - compute dtypes will be that dtype. - * 'mixed_float16' or 'mixed_bfloat16': The compute dtype is float16 or - bfloat16, while the variable dtype is float32. With 'mixed_float16', - `tf.keras.Model.compile` will wrap the optimizer with a - `tf.keras.mixed_precision.LossScaleOptimizer. These policies are used - for mixed precision training. - """ if isinstance(name, dtypes.DType): raise TypeError("'name' must be a string, not a DType. " "Instead, pass DType.name. Got: %s" % (name.name,)) @@ -372,8 +260,10 @@ class Policy(object): `Policy.compute_dtype`, Layers will cast variables to the compute dtype to avoid type errors. + Variable regularizers are run in the variable dtype, not the compute dtype. + Returns: - The variable dtype of this policy. + The variable dtype of this policy, as a string. """ return self._variable_dtype @@ -381,26 +271,27 @@ class Policy(object): def compute_dtype(self): """The compute dtype of this policy. - This is the dtype layers will do their computations in. + This is the dtype layers will do their computations in. Typically layers + output tensors with the compute dtype as well. Note that even if the compute dtype is float16 or bfloat16, hardware devices may not do individual adds, multiplies, and other fundamental operations in - [b]float16, but instead may do some of them in float32 for numeric + float16 or bfloat16, but instead may do some of them in float32 for numeric stability. The compute dtype is the dtype of the inputs and outputs of the TensorFlow ops that the layer executes. Internally, many TensorFlow ops will - do certain internal calculations in float32, or some other device-internal - intermediate format with higher precision than [b]float16, to increase + do certain internal calculations in float32 or some other device-internal + intermediate format with higher precision than float16/bfloat16, to increase numeric stability. For example, a `tf.keras.layers.Dense` layer, when run on a GPU with a - float16 compute dtype, will pass float16 inputs to tf.matmul. But, tf.matmul - will do use float32 intermediate math. The performance benefit of float16 is - still apparent, due to increased memory bandwidth and the fact modern GPUs - have specialized hardware for computing matmuls on float16 while still - keeping intermediate computations in float32. + float16 compute dtype, will pass float16 inputs to `tf.linalg.matmul`. But, + `tf.linalg.matmul` will do use float32 intermediate math. The performance + benefit of float16 is still apparent, due to increased memory bandwidth and + the fact modern GPUs have specialized hardware for computing matmuls on + float16 inputs while still keeping intermediate computations in float32. Returns: - The compute dtype of this policy. + The compute dtype of this policy, as a string. """ return self._compute_dtype @@ -528,13 +419,18 @@ _global_policy = None @keras_export('keras.mixed_precision.global_policy', 'keras.mixed_precision.experimental.global_policy', v1=[]) def global_policy(): - """Returns the global Policy. + """Returns the global dtype policy. - The global policy is the default policy used for layers, if no policy is - passed to the layer constructor. If no policy has been set with - `keras.mixed_precision.experimental.set_policy`, this will return a policy + The global policy is the default `tf.keras.mixed_precision.Policy` used for + layers, if no policy is passed to the layer constructor. If no policy has been + set with `keras.mixed_precision.set_global_policy`, this will return a policy constructed from `tf.keras.backend.floatx()` (floatx defaults to float32). + >>> tf.keras.mixed_precision.global_policy() + <Policy "float32"> + >>> tf.keras.layers.Dense(10).dtype_policy # Defaults to the global policy + <Policy "float32"> + If TensorFlow 2 behavior has been disabled with `tf.compat.v1.disable_v2_behavior()`, this will instead return a special "_infer" policy which infers the dtype from the dtype of the first input the @@ -573,11 +469,27 @@ def _check_if_mixed_precision_graph_rewrite_is_enabled(policy): @keras_export('keras.mixed_precision.set_global_policy', 'keras.mixed_precision.experimental.set_policy', v1=[]) def set_policy(policy): - """Sets the global Policy. + """Sets the global dtype policy. - The global policy is the default policy used for layers, if no policy is - passed to the layer constructor. If no global policy is set, layers will - instead default to a Policy constructed from `tf.keras.backend.floatx()`. + The global policy is the default `tf.keras.mixed_precision.Policy` used for + layers, if no policy is passed to the layer constructor. + + >>> tf.keras.mixed_precision.set_global_policy('mixed_float16') + >>> tf.keras.mixed_precision.global_policy() + <Policy "mixed_float16"> + >>> tf.keras.layers.Dense(10).dtype_policy + <Policy "mixed_float16"> + >>> # Global policy is not used if a policy is directly passed to constructor + >>> tf.keras.layers.Dense(10, dtype='float64').dtype_policy + <Policy "float64"> + >>> tf.keras.mixed_precision.set_global_policy('float32') + + If no global policy is set, layers will instead default to a Policy + constructed from `tf.keras.backend.floatx()`. + + To use mixed precision, the global policy should be set to `'mixed_float16'` + or `'mixed_bfloat16'`, so that every layer uses a 16-bit compute dtype and + float32 variable dtype by default. Only floating point policies can be set as the global policy, such as `'float32'` and `'mixed_float16'`. Non-floating point policies such as @@ -587,7 +499,9 @@ def set_policy(policy): See `tf.keras.mixed_precision.Policy` for more information. Args: - policy: A Policy, or a string that will be converted to a Policy.. + policy: A Policy, or a string that will be converted to a Policy. Can also + be None, in which case the global policy will be constructed from + `tf.keras.backend.floatx()` """ global _global_policy if not base_layer_utils.v2_dtype_behavior_enabled(): diff --git a/tensorflow/python/keras/saving/BUILD b/tensorflow/python/keras/saving/BUILD index 2e462662e91..14b78a10dfe 100644 --- a/tensorflow/python/keras/saving/BUILD +++ b/tensorflow/python/keras/saving/BUILD @@ -49,6 +49,7 @@ py_library( deps = [ "//tensorflow/python:lib", "//tensorflow/python:math_ops", + "//tensorflow/python:platform", "//tensorflow/python:saver", "//tensorflow/python:tensor_spec", "//tensorflow/python/eager:def_function", diff --git a/tensorflow/python/keras/saving/saved_model/constants.py b/tensorflow/python/keras/saving/saved_model/constants.py index 3f1eca9c500..12265e0a3f3 100644 --- a/tensorflow/python/keras/saving/saved_model/constants.py +++ b/tensorflow/python/keras/saving/saved_model/constants.py @@ -26,3 +26,7 @@ KERAS_ATTR = 'keras_api' # Keys for the serialization cache. # Maps to the keras serialization dict {Layer --> SerializedAttributes object} KERAS_CACHE_KEY = 'keras_serialized_attributes' + + +# Name of Keras metadata file stored in the SavedModel. +SAVED_METADATA_PATH = 'keras_metadata.pb' diff --git a/tensorflow/python/keras/saving/saved_model/load.py b/tensorflow/python/keras/saving/saved_model/load.py index cb6d340ea03..43c1d2bd0d4 100644 --- a/tensorflow/python/keras/saving/saved_model/load.py +++ b/tensorflow/python/keras/saving/saved_model/load.py @@ -17,9 +17,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os import re import types +from google.protobuf import message + from tensorflow.core.framework import versions_pb2 from tensorflow.python.eager import context from tensorflow.python.eager import function as defun @@ -38,6 +41,7 @@ from tensorflow.python.keras.saving.saved_model.serialized_attributes import Com from tensorflow.python.keras.utils import generic_utils from tensorflow.python.keras.utils import metrics_utils from tensorflow.python.keras.utils.generic_utils import LazyLoader +from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging as logging from tensorflow.python.saved_model import load as tf_load from tensorflow.python.saved_model import loader_impl @@ -121,13 +125,26 @@ def load(path, compile=True, options=None): # pylint: disable=redefined-builtin # TODO(kathywu): Add saving/loading of optimizer, compiled losses and metrics. # TODO(kathywu): Add code to load from objects that contain all endpoints - # The Keras metadata file is not yet saved, so create it from the SavedModel. + # Look for metadata file or parse the SavedModel metadata = saved_metadata_pb2.SavedMetadata() meta_graph_def = loader_impl.parse_saved_model(path).meta_graphs[0] object_graph_def = meta_graph_def.object_graph_def - # TODO(kathywu): When the keras metadata file is saved, load it directly - # instead of calling the _read_legacy_metadata function. - _read_legacy_metadata(object_graph_def, metadata) + path_to_metadata_pb = os.path.join(path, constants.SAVED_METADATA_PATH) + if gfile.Exists(path_to_metadata_pb): + try: + with gfile.GFile(path_to_metadata_pb, 'rb') as f: + file_content = f.read() + metadata.ParseFromString(file_content) + except message.DecodeError as e: + raise IOError('Cannot parse keras metadata {}: {}.' + .format(path_to_metadata_pb, str(e))) + else: + logging.warning('SavedModel saved prior to TF 2.4 detected when loading ' + 'Keras model. Please ensure that you are saving the model ' + 'with model.save() or tf.keras.models.save_model(), *NOT* ' + 'tf.saved_model.save(). To confirm, there should be a file ' + 'named "keras_metadata.pb" in the SavedModel directory.') + _read_legacy_metadata(object_graph_def, metadata) if not metadata.nodes: # When there are no Keras objects, return the results from the core loader diff --git a/tensorflow/python/keras/saving/saved_model/save.py b/tensorflow/python/keras/saving/saved_model/save.py index 16984a2221b..2ab7ebb60b1 100644 --- a/tensorflow/python/keras/saving/saved_model/save.py +++ b/tensorflow/python/keras/saving/saved_model/save.py @@ -18,15 +18,21 @@ from __future__ import division from __future__ import print_function import os + +from tensorflow.core.framework import versions_pb2 from tensorflow.python.distribute import distribution_strategy_context from tensorflow.python.keras import backend as K +from tensorflow.python.keras.protobuf import saved_metadata_pb2 from tensorflow.python.keras.saving import saving_utils +from tensorflow.python.keras.saving.saved_model import constants from tensorflow.python.keras.saving.saved_model import save_impl from tensorflow.python.keras.saving.saved_model import utils from tensorflow.python.keras.utils.generic_utils import LazyLoader from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite +from tensorflow.python.platform import gfile from tensorflow.python.saved_model import save as save_lib + # To avoid circular dependencies between keras/engine and keras/saving, # code in keras/saving must delay imports. @@ -86,7 +92,39 @@ def save(model, filepath, overwrite, include_optimizer, signatures=None, # we use the default replica context here. with distribution_strategy_context._get_default_replica_context(): # pylint: disable=protected-access with utils.keras_option_scope(save_traces): - save_lib.save(model, filepath, signatures, options) + saved_nodes, node_paths = save_lib.save_and_return_nodes( + model, filepath, signatures, options) + + # Save all metadata to a separate file in the SavedModel directory. + metadata = generate_keras_metadata(saved_nodes, node_paths) + + with gfile.GFile( + os.path.join(filepath, constants.SAVED_METADATA_PATH), "wb") as w: + w.write(metadata.SerializeToString(deterministic=True)) if not include_optimizer: model.optimizer = orig_optimizer + + +def generate_keras_metadata(saved_nodes, node_paths): + """Constructs a KerasMetadata proto with the metadata of each keras object.""" + metadata = saved_metadata_pb2.SavedMetadata() + + for node_id, node in enumerate(saved_nodes): + if isinstance(node, base_layer.Layer): + path = node_paths[node] + if not path: + node_path = "root" + else: + node_path = "root.{}".format( + ".".join([ref.name for ref in path])) + + metadata.nodes.add( + node_id=node_id, + node_path=node_path, + version=versions_pb2.VersionDef( + producer=1, min_consumer=1, bad_consumers=[]), + identifier=node._object_identifier, # pylint: disable=protected-access + metadata=node._tracking_metadata) # pylint: disable=protected-access + + return metadata diff --git a/tensorflow/python/kernel_tests/collective_ops_test.py b/tensorflow/python/kernel_tests/collective_ops_test.py index fe558bcae64..fdaf3213759 100644 --- a/tensorflow/python/kernel_tests/collective_ops_test.py +++ b/tensorflow/python/kernel_tests/collective_ops_test.py @@ -31,10 +31,12 @@ from tensorflow.python.distribute import test_util from tensorflow.python.eager import context from tensorflow.python.eager import def_function from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import collective_ops as _collective_ops +from tensorflow.python.ops import resource_variable_ops from tensorflow.python.platform import test @@ -68,6 +70,19 @@ device_combination = ( device='GPU', communication=['RING', 'NCCL'], required_gpus=2)) +collective_op_combinations = combinations.times( + combinations.combine( + collective_op=[ + combinations.NamedObject('all_reduce', CollectiveOpsV1.all_reduce), + combinations.NamedObject('all_reduce_v2', + CollectiveOpsV2.all_reduce), + combinations.NamedObject('all_gather', CollectiveOpsV1.all_gather), + combinations.NamedObject('all_gather_v2', + CollectiveOpsV2.all_gather), + ], + mode='eager'), device_combination) + + @combinations.generate( combinations.times( combinations.combine( @@ -283,20 +298,7 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase): run_and_assert(group_size=3, group_key=2) -@combinations.generate( - combinations.times( - combinations.combine( - collective_op=[ - combinations.NamedObject('all_reduce', - CollectiveOpsV1.all_reduce), - combinations.NamedObject('all_reduce_v2', - CollectiveOpsV2.all_reduce), - combinations.NamedObject('all_gather', - CollectiveOpsV1.all_gather), - combinations.NamedObject('all_gather_v2', - CollectiveOpsV2.all_gather), - ], - mode='eager'), device_combination)) +@combinations.generate(collective_op_combinations) class AbortCollectiveOpsTest(test.TestCase, parameterized.TestCase): def setUp(self): @@ -647,20 +649,7 @@ class OpCancellationTest(test.TestCase, parameterized.TestCase): collective_fn() -@combinations.generate( - combinations.times( - combinations.combine( - collective_op=[ - combinations.NamedObject('all_reduce', - CollectiveOpsV1.all_reduce), - combinations.NamedObject('all_reduce_v2', - CollectiveOpsV2.all_reduce), - combinations.NamedObject('all_gather', - CollectiveOpsV1.all_gather), - combinations.NamedObject('all_gather_v2', - CollectiveOpsV2.all_gather), - ], - mode='eager'), device_combination)) +@combinations.generate(collective_op_combinations) class TimeoutTest(test.TestCase, parameterized.TestCase): def setUp(self): @@ -785,6 +774,94 @@ class TimeoutTest(test.TestCase, parameterized.TestCase): communication_hint=communication) +@combinations.generate( + combinations.times( + combinations.combine( + collective_op=[ + combinations.NamedObject('all_reduce_v2', + CollectiveOpsV2.all_reduce), + combinations.NamedObject('all_gather_v2', + CollectiveOpsV2.all_gather), + ], + mode='eager'), device_combination)) +class OrderingTest(test.TestCase, parameterized.TestCase): + + def setUp(self): + _setup_context() + super().setUp() + + def testOrdering(self, collective_op, device, communication): + dev0 = '/device:%s:0' % device + dev1 = '/device:%s:1' % device + group_size = 2 + group_key = 100 + instance_key = 100 + in_tensor = constant_op.constant([1.]) + + with ops.device(dev0): + token0 = resource_variable_ops.ResourceVariable(0.) + with ops.device(dev1): + token1 = resource_variable_ops.ResourceVariable(0.) + + @def_function.function + def f(): + # Launch the first collective with token. + with ops.device(dev0): + collective_op( + in_tensor, + group_size, + group_key, + instance_key, + ordering_token=token0.handle) + with ops.device(dev1): + collective_op( + in_tensor, + group_size, + group_key, + instance_key, + ordering_token=token1.handle) + # Launch the second collective without token. + with ops.device(dev0): + collective_op(in_tensor, group_size, group_key, instance_key) + with ops.device(dev1): + collective_op(in_tensor, group_size, group_key, instance_key) + # Launch the third collective with token. + with ops.device(dev0): + collective_op( + in_tensor, + group_size, + group_key, + instance_key, + ordering_token=token0.handle) + with ops.device(dev1): + collective_op( + in_tensor, + group_size, + group_key, + instance_key, + ordering_token=token1.handle) + + graph = f.get_concrete_function().graph + for device in [dev0, dev1]: + # Try to find the third collective, which should have the first collective + # as a control input. + third = None + for op in graph.get_operations(): + if (op.type.startswith('Collective') and op.device.endswith(device) and + op.control_inputs and + op.control_inputs[0].type.startswith('Collective')): + self.assertIsNone(third) + third = op + self.assertIsNotNone(third) + # Verify it's not the second collective by looking at the inputs. + self.assertTrue(any(v.dtype == dtypes.resource for v in third.inputs)) + first = third.control_inputs[0] + self.assertEqual(third.device, first.device) + # Verify it's not the second collective by looking at the inputs. + self.assertTrue(any(v.dtype == dtypes.resource for v in first.inputs)) + self.assertEmpty(first.control_inputs) + + def _setup_context(): context._reset_context() test_util.set_logical_devices_to_at_least('CPU', 4) diff --git a/tensorflow/python/kernel_tests/matmul_op_test.py b/tensorflow/python/kernel_tests/matmul_op_test.py index 737ca777804..33e84b3ca19 100644 --- a/tensorflow/python/kernel_tests/matmul_op_test.py +++ b/tensorflow/python/kernel_tests/matmul_op_test.py @@ -197,14 +197,16 @@ class MatMulInfixOperatorTest(test_lib.TestCase): def testMismatchedShape(self): with self.assertRaisesRegex( - Exception, "(Shape must be rank 2 but is rank 1|is not a matrix)"): + Exception, (r"(In\[0\] and In\[1\] has different ndims|In\[0\] " + r"ndims must be >= 2|Shape must be rank 2 but is rank 1)")): infix_matmul( ops.convert_to_tensor([10.0, 20.0, 30.0]), ops.convert_to_tensor([[40.0, 50.0], [60.0, 70.0]])) def testMismatchedDimensions(self): with self.assertRaisesRegex( - Exception, "(Dimensions must be equal|Matrix size-incompatible)"): + Exception, + r"(In\[0\] mismatch In\[1\] shape|Dimensions must be equal)"): infix_matmul( ops.convert_to_tensor([[10.0, 20.0, 30.0]]), ops.convert_to_tensor([[40.0, 50.0], [60.0, 70.0]])) @@ -234,9 +236,10 @@ if __name__ == "__main__": # TF2 does not support placeholders under eager so we skip it for use_static_shape in set([True, tf2.enabled()]): for dtype in dtypes_to_test: - if not use_static_shape and (dtype == np.int32 or dtype == np.int64): - # TODO(rmlarsen): Re-enable this test when we have fixed the underlying - # bug in Windows (b/35935459). + if test_util.is_xla_enabled() and (dtype == np.int32 or + dtype == np.int64): + # TODO(b/171924639): Enable this test when XLA DOT supports + # integer types. continue for m in sizes: for n in sizes: diff --git a/tensorflow/python/kernel_tests/tensordot_op_test.py b/tensorflow/python/kernel_tests/tensordot_op_test.py index 368a7f18f8b..268f6891d4e 100644 --- a/tensorflow/python/kernel_tests/tensordot_op_test.py +++ b/tensorflow/python/kernel_tests/tensordot_op_test.py @@ -55,8 +55,8 @@ class TensordotTest(test_lib.TestCase): if context.executing_eagerly(): return with self.cached_session() as sess: - with self.assertRaisesRegex(errors_impl.InvalidArgumentError, - "Matrix size-incompatible"): + with self.assertRaisesOpError( + r"In\[0\] mismatch In\[1\] shape: 2 vs\. 3: \[2,2\] \[3,2\]"): a_ph = array_ops.placeholder(dtypes.float32) b_ph = array_ops.placeholder(dtypes.float32) axes_ph = array_ops.placeholder(dtypes.int32) diff --git a/tensorflow/python/kernel_tests/while_v2_test.py b/tensorflow/python/kernel_tests/while_v2_test.py index 0402e129c19..f3902fb28f3 100644 --- a/tensorflow/python/kernel_tests/while_v2_test.py +++ b/tensorflow/python/kernel_tests/while_v2_test.py @@ -42,6 +42,7 @@ from tensorflow.python.ops import control_flow_util_v2 from tensorflow.python.ops import control_flow_v2_toggles from tensorflow.python.ops import custom_gradient from tensorflow.python.ops import gen_array_ops +from tensorflow.python.ops import gradient_checker_v2 from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import list_ops from tensorflow.python.ops import map_fn @@ -171,6 +172,126 @@ class WhileV2Test(test.TestCase, parameterized.TestCase): self.assertAllEqual(fnWithLoop(), 4.0) + def checkIteratedGradients(self, func): + with context.eager_mode(): + + def _Grad(f): + def _GradFunction(primal): + with backprop.GradientTape() as tape: + tape.watch(primal) + primal_out = f(primal) + return tape.gradient(primal_out, primal) + return _GradFunction + + f = func + one = constant_op.constant(1.) + + for _ in range(3): + theoretical, numerical = gradient_checker_v2.compute_gradient( + def_function.function(f), [one]) + self.assertAllClose(theoretical, numerical, rtol=1e-3) + f = _Grad(f) + self.assertAllClose(array_ops.reshape(numerical, []), + def_function.function(f)(one), + rtol=1e-3) + + def testIteratedGradients(self): + + def _Func(x): + _, z = while_loop_v2( + lambda i, _: i < 2, + lambda i, y: (i + 1, math_ops.cos(y)), + [0, x]) + return z + + self.checkIteratedGradients(_Func) + + def testIteratedGradientsWithList(self): + + def _Func(x): + results = list_ops.empty_tensor_list( + element_shape=[], element_dtype=dtypes.float32) + + def _LoopBody(i, y, handle): + return (i + 1, math_ops.cos(y), + list_ops.tensor_list_push_back(handle, y)) + + _, z, results = while_loop_v2( + lambda i, _, h: i < 2, _LoopBody, [0, x, results]) + return z + math_ops.reduce_sum(list_ops.tensor_list_stack( + results, dtypes.float32)) + + self.checkIteratedGradients(_Func) + + def testGradWhileGradWhileWithVariable(self): + with context.eager_mode(): + v = variables.Variable(1.) + + @def_function.function + def _Func(x): + + def _Inner(a): + with backprop.GradientTape() as tape: + tape.watch(a) + _, b = while_loop_v2( + lambda i, _: i < 2, + lambda i, y: (i + 1, math_ops.cos(v + y)), + [0, a]) + return tape.gradient(b, a) + + _, z = while_loop_v2( + lambda i, _: i < 2, + lambda i, y: (i + 1, _Inner(y)), + [0, x]) + return z + + with backprop.GradientTape(persistent=True) as tape: + x = constant_op.constant(1.) + tape.watch(x) + y = _Func(x) + dx, _ = tape.gradient(y, [x, v]) + theoretical, numerical = gradient_checker_v2.compute_gradient( + _Func, [x]) + self.assertAllClose(numerical, theoretical, rtol=1e-3) + self.assertAllClose(array_ops.reshape(numerical, []), + dx, rtol=1e-3) + + def testThreeNestWithLists(self): + with context.eager_mode(): + def _WrapInWhile(f): + def _Wrapped(x): + results = list_ops.empty_tensor_list( + element_shape=[], element_dtype=dtypes.float32) + + def _LoopBody(i, y, handle): + return (i + 1, f(math_ops.cos(y)), + list_ops.tensor_list_push_back(handle, y)) + + _, z, results = control_flow_ops.while_loop( + lambda i, _, h: i < 2, _LoopBody, [0, x, results]) + return z + math_ops.reduce_sum(list_ops.tensor_list_stack( + results, dtypes.float32)) + return _Wrapped + + f = math_ops.sin + + target_function = _WrapInWhile(_WrapInWhile(_WrapInWhile(f))) + + @def_function.function + def _TapeFromGraphMode(x): + with backprop.GradientTape(persistent=True) as tape: + tape.watch(x) + y = target_function(x) + return tape.gradient(y, x) + + x = constant_op.constant(1.) + dx = _TapeFromGraphMode(x) + theoretical, numerical = gradient_checker_v2.compute_gradient( + target_function, [x]) + self.assertAllClose(numerical, theoretical, rtol=1e-3) + self.assertAllClose(array_ops.reshape(numerical, []), + dx, rtol=1e-3) + def testDeviceLabelsInherited(self): def _LoopBody(i, y): result = math_ops.cos(y) diff --git a/tensorflow/python/ops/collective_ops.py b/tensorflow/python/ops/collective_ops.py index f2378995597..4f33c3aeecc 100644 --- a/tensorflow/python/ops/collective_ops.py +++ b/tensorflow/python/ops/collective_ops.py @@ -78,7 +78,8 @@ def all_reduce_v2(t, merge_op='Add', final_op='Id', communication_hint='auto', - timeout=0): + timeout=0, + ordering_token=None): """Reduces tensors collectively, across devices. Args: @@ -98,10 +99,15 @@ def all_reduce_v2(t, timeout: a float. If set to a non zero, set a completion timeout to detect staleness. If the timer goes off, a DeadlineExceededError is raised. The timeout value in seconds. This feature is experimental. + ordering_token: an optional resource tensor to pass to the op as inputs. + They aren't used by the kernel but allow AutoControlDependency to order + the collectives with control dependencies. Returns: An Op implementing the distributed reduction. """ + if ordering_token is not None: + ordering_token = [ordering_token] return gen_collective_ops.collective_reduce_v2( t, group_size=group_size, @@ -110,7 +116,8 @@ def all_reduce_v2(t, merge_op=merge_op, final_op=final_op, communication_hint=communication_hint.lower(), - timeout_seconds=timeout) + timeout_seconds=timeout, + ordering_token=ordering_token or []) def all_gather(t, @@ -157,7 +164,8 @@ def all_gather_v2(t, group_key, instance_key, communication_hint='auto', - timeout=0): + timeout=0, + ordering_token=None): """Accumulates tensors collectively, across devices, along first dimension. Args: @@ -173,17 +181,23 @@ def all_gather_v2(t, timeout: a float. If set to a non zero, set a completion timeout to detect staleness. If the timer goes off, a DeadlineExceededError is raised. The timeout value in seconds. This feature is experimental. + ordering_token: an optional resource tensor to pass to the op as inputs. + They aren't used by the kernel but allow AutoControlDependency to order + the collectives with control dependencies. Returns: An Op implementing the distributed operation. """ + if ordering_token is not None: + ordering_token = [ordering_token] return gen_collective_ops.collective_gather_v2( t, group_size=group_size, group_key=group_key, instance_key=instance_key, communication_hint=communication_hint.lower(), - timeout_seconds=timeout) + timeout_seconds=timeout, + ordering_token=ordering_token or []) def broadcast_send(t, diff --git a/tensorflow/python/ops/cond_v2.py b/tensorflow/python/ops/cond_v2.py index 02dbf2a594c..059ace7f5ac 100644 --- a/tensorflow/python/ops/cond_v2.py +++ b/tensorflow/python/ops/cond_v2.py @@ -26,7 +26,6 @@ from __future__ import print_function import collections from tensorflow.python.eager import backprop_util -from tensorflow.python.eager import function from tensorflow.python.framework import auto_control_deps from tensorflow.python.framework import auto_control_deps_utils as acd from tensorflow.python.framework import constant_op @@ -193,37 +192,6 @@ def _IfGrad(op, *grads): # pylint: disable=invalid-name return [None] + outputs -def _run_as_function_for_tape_gradients(make_op, cond_inputs): - """Fix higher-order tape gradients by wrapping `make_op` in a function.""" - # GradientTapes created inside a function currently don't work well with - # un-wrapped control flow ops in that same function. Wrapping in an extra - # layer of intermediate function means we run extra logic in the function - # gradient code to record the correct intermediates on the tape. - # - # The function attribute inputs to cond/case ops are not hashable, so we pass - # everything as a capture to bypass defun's caching. - if (gradients_util.PossibleTapeGradientTypes(cond_inputs) - == gradients_util.POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER - # We only need one function between the tape and the cond; if we've - # already wrapped once, we stop wrapping to avoid infinite recursion. - and not (ops.get_default_graph().building_function - and "cond_gradient_wrapper" in ops.get_default_graph().name)): - - op = None - def _run_make_and_extract_op(): - # Post-processing happens on the cond op, not the function call op. - nonlocal op - tensors = make_op() - op, tensors = _get_op_and_outputs(tensors) # pylint: disable=unused-variable - return tensors - - return op, function.defun_with_attributes( - _run_make_and_extract_op, - attributes=dict(func_name="cond_gradient_wrapper"))() - else: - return _get_op_and_outputs(make_op()) - - def _build_cond(pred, true_graph, false_graph, @@ -300,28 +268,35 @@ def _build_cond(pred, else: op_fn = gen_functional_ops.stateless_if - def make_op(): - return op_fn( + def _make_op(inputs): + if_op, tensors = util.get_op_and_outputs(op_fn( pred, - cond_inputs, [t.dtype for t in true_graph.outputs], + inputs, [t.dtype for t in true_graph.outputs], util.create_new_tf_function(true_graph), util.create_new_tf_function(false_graph), output_shapes=_get_output_shapes(true_graph.outputs, false_graph.outputs), - name=name) - if_op, tensors = _run_as_function_for_tape_gradients(make_op, cond_inputs) + name=name)) + _copy_handle_data(tensors, true_graph.outputs, false_graph.outputs) + # `if_op` is None if this is a `StatelessIf` op with no outputs. + if if_op is not None: + # The true and false graphs have already been created, and we need that + # to happen before we know which tensors will be captured and so whether + # to wrap the cond in a tf.function. Post-hoc mutation of the branch + # `outer_graph` properties seems like the only option if we want to + # conditionally wrap in a function. + true_graph.outer_graph = ops.get_default_graph() + false_graph.outer_graph = ops.get_default_graph() + if_op._true_graph = true_graph + if_op._false_graph = false_graph + util.maybe_set_lowering_attr(if_op) + util.maybe_propagate_compile_time_consts_in_xla(if_op) + _set_read_only_resource_inputs_attr(if_op, [true_graph, false_graph]) + # Prevent fetching since the variant outputs can't be fetched directly. + if_op.graph.prevent_fetching(if_op) + return tensors + tensors = util.run_as_function_for_tape_gradients(_make_op, cond_inputs) - # `if_op` is None if this is a `StatelessIf` op with no outputs. - if if_op is not None: - if_op._true_graph = true_graph - if_op._false_graph = false_graph - util.maybe_set_lowering_attr(if_op) - util.maybe_propagate_compile_time_consts_in_xla(if_op) - _set_read_only_resource_inputs_attr(if_op, [true_graph, false_graph]) - # Prevent fetching since the variant outputs can't be fetched directly. - if_op.graph.prevent_fetching(if_op) - - _copy_handle_data(tensors, true_graph.outputs, false_graph.outputs) # Return identities for each output of the If op, rather than the output of # the If op directly. This makes pruning work if the output of cond() is # fetched: the lowering pass converts the If outputs into IdentityN outputs, @@ -718,15 +693,6 @@ def _make_indexed_slices_indices_types_match(op_type, branch_graphs): branch_graph.structured_outputs, branch_graph.outputs) -def _get_op_and_outputs(op_or_outputs): - if isinstance(op_or_outputs, ops.Operation): - return op_or_outputs, [] - elif not op_or_outputs: # Empty list. - return None, [] - else: - return op_or_outputs[0].op, op_or_outputs - - def _pack_sequence_as(structured_outputs, op_outputs): """Packs the outputs of the gradient If/Case op. @@ -1190,24 +1156,23 @@ def _build_case(branch_index, with ops.control_dependencies( sum((list(bg.control_captures) for bg in branch_graphs), [])): - def _make_op(): - return op_fn( + def _make_op(inputs): + case_op, tensors = util.get_op_and_outputs(op_fn( branch_index, - case_inputs, [t.dtype for t in branch_graphs[0].outputs], + inputs, [t.dtype for t in branch_graphs[0].outputs], [util.create_new_tf_function(g) for g in branch_graphs], output_shapes=_get_output_shapes(*[g.outputs for g in branch_graphs]), - name=name) - case_op, tensors = _run_as_function_for_tape_gradients( - _make_op, case_inputs) + name=name)) + _copy_handle_data(tensors, *[g.outputs for g in branch_graphs]) + if case_op is not None: + util.maybe_set_lowering_attr(case_op, lower_using_switch_merge) + util.maybe_propagate_compile_time_consts_in_xla(case_op) + _set_read_only_resource_inputs_attr(case_op, branch_graphs) + # Prevent fetching since the variant outputs can't be fetched directly. + case_op.graph.prevent_fetching(case_op) + return tensors + tensors = util.run_as_function_for_tape_gradients(_make_op, case_inputs) - if case_op is not None: - util.maybe_set_lowering_attr(case_op, lower_using_switch_merge) - util.maybe_propagate_compile_time_consts_in_xla(case_op) - _set_read_only_resource_inputs_attr(case_op, branch_graphs) - # Prevent fetching since the variant outputs can't be fetched directly. - case_op.graph.prevent_fetching(case_op) - - _copy_handle_data(tensors, *[g.outputs for g in branch_graphs]) # Return identities for each output of the Case op, rather than the output of # the Case op directly. This makes pruning work if the output of switch_case() # is fetched: the lowering pass converts the Case outputs into IdentityN diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py index 7ce896c96c4..c75b910058b 100644 --- a/tensorflow/python/ops/control_flow_ops.py +++ b/tensorflow/python/ops/control_flow_ops.py @@ -3616,6 +3616,7 @@ def switch_case(branch_index, return _indexed_case_helper(branch_fns, default, branch_index, name) +@tf_export("__internal__.execute_fn_for_device", v1=[]) def execute_fn_for_device(device_branch_fns, default_fn, name="execute_fn"): """Executes one of the provided callables based on the device placement. diff --git a/tensorflow/python/ops/control_flow_util_v2.py b/tensorflow/python/ops/control_flow_util_v2.py index 6f1bd352e2b..48e221c074b 100644 --- a/tensorflow/python/ops/control_flow_util_v2.py +++ b/tensorflow/python/ops/control_flow_util_v2.py @@ -28,6 +28,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework.func_graph import FuncGraph from tensorflow.python.ops import control_flow_util from tensorflow.python.ops import control_flow_v2_func_graphs +from tensorflow.python.ops import gradients_util from tensorflow.python.util import keras_deps from tensorflow.python.util import tf_contextlib @@ -188,10 +189,10 @@ def resource_input_index(tensor_name, input_names, node_defs, functions): output_idx = int(output_idx) node_def = node_defs[op_name] - if node_def.op == "While": + if node_def.op in ("Identity", "While"): # Captured resources occur at the same index in the lists of inputs and - # outputs of a while op. So we lookup the input of `tensor.op` at the - # same index as the index of `tensor` in the `tensor.op.outputs`. + # outputs of a while or identity op. So we lookup the input of `tensor.op` + # at the same index as the index of `tensor` in the `tensor.op.outputs`. tensor_name = node_def.input[output_idx] elif node_def.op in ("PartitionedCall", "StatefulPartitionedCall"): # Functions output any captured resource tensors used by their @@ -312,3 +313,56 @@ def get_func_graph(op, input_shapes, func_name): func_graph = function_def_to_graph.function_def_to_graph( fdef, input_shapes) return func_graph + + +def get_op_and_outputs(op_or_outputs): + if isinstance(op_or_outputs, ops.Operation): + return op_or_outputs, [] + elif not op_or_outputs: # Empty list. + return None, [] + else: + return op_or_outputs[0].op, op_or_outputs + + +def graph_wrapped_for_higher_order_tape_gradients(graph): + """Check if `graph` is wrapped by `run_as_function_for_tape_gradients`.""" + while graph is not None: + if "cflow_gradient_wrapper" in getattr(graph, "name", ""): + return True + graph = getattr(graph, "outer_graph", None) + return False + + +def run_as_function_for_tape_gradients(make_op, inputs): + """Fix higher-order tape gradients by wrapping `make_op` in a function. + + Args: + make_op: A function that takes a list of inputs and returns a list of output + tensors. This function should set any handle data relevant to its outputs + before returning. + inputs: A list of tensors to check for tape gradients and pass to + `make_op`. These should include all tensors used in `make_op`. + + Returns: + Tensors corresponding to `make_op`'s output. + """ + # GradientTapes created inside a function currently don't work well with + # un-wrapped control flow ops in that same function. Wrapping in an extra + # layer of intermediate function means we run extra logic in the function + # gradient code to record the correct intermediates on the tape. + # + # The function attribute inputs to control flow ops are not hashable, so we + # pass everything as a capture to bypass defun's caching. + if (gradients_util.PossibleTapeGradientTypes(inputs) + == gradients_util.POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER + # We only need one function between the tape and the op; if we've already + # wrapped once, we stop wrapping to avoid infinite recursion. + and not (ops.get_default_graph().building_function + and "cflow_gradient_wrapper" in ops.get_default_graph().name)): + results = function.defun_with_attributes( + make_op, + autograph=False, + attributes=dict(func_name="cflow_gradient_wrapper"))(inputs) + return results + else: + return make_op(inputs) diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py index 2477fa1e920..f19ca68797c 100644 --- a/tensorflow/python/ops/nn_ops.py +++ b/tensorflow/python/ops/nn_ops.py @@ -5143,6 +5143,19 @@ def dropout_v2(x, rate, noise_shape=None, seed=None, name=None): if not x_dtype.is_floating: raise ValueError("x has to be a floating point tensor since it's going " "to be scaled. Got a %s tensor instead." % x_dtype) + if is_rate_number and rate == 0: + # Fast-path: Return the input immediately if rate is non-tensor & is `0`. + # We trigger this after all error checking + # and after `x` has been converted to a tensor, to prevent inconsistent + # tensor conversions/error raising if rate is changed to/from 0. + # + # We also explicitly call `random_seed.get_seed` to make sure + # we don't change the random number generation behavior of + # stateful random ops by entering a fastpath, + # despite not generating a random tensor in the fastpath + random_seed.get_seed(seed) + return x + is_executing_eagerly = context.executing_eagerly() if not tensor_util.is_tensor(rate): if is_rate_number: diff --git a/tensorflow/python/ops/numpy_ops/np_array_ops.py b/tensorflow/python/ops/numpy_ops/np_array_ops.py index 64542656273..3d5c3f93d2e 100644 --- a/tensorflow/python/ops/numpy_ops/np_array_ops.py +++ b/tensorflow/python/ops/numpy_ops/np_array_ops.py @@ -572,7 +572,7 @@ def size(x, axis=None): # pylint: disable=missing-docstring return 1 x = asarray(x).data if x.shape.is_fully_defined(): - return np.prod(x.shape.as_list()) + return np.prod(x.shape.as_list(), dtype=int) else: return np_utils.tensor_to_ndarray(array_ops.size_v2(x)) diff --git a/tensorflow/python/ops/numpy_ops/np_math_ops.py b/tensorflow/python/ops/numpy_ops/np_math_ops.py index 631975c9b8a..85cfdf6c5b8 100644 --- a/tensorflow/python/ops/numpy_ops/np_math_ops.py +++ b/tensorflow/python/ops/numpy_ops/np_math_ops.py @@ -264,10 +264,11 @@ def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None): # pylint: disable=mis def f(a, b): # pylint: disable=missing-docstring # We can't assign to captured variable `axisa`, so make a new variable - axis_a = axisa - axis_b = axisb - axis_c = axisc - if axis is not None: + if axis is None: + axis_a = axisa + axis_b = axisb + axis_c = axisc + else: axis_a = axis axis_b = axis axis_c = axis diff --git a/tensorflow/python/ops/while_v2.py b/tensorflow/python/ops/while_v2.py index 46c20a85a94..8c6d9692a3d 100644 --- a/tensorflow/python/ops/while_v2.py +++ b/tensorflow/python/ops/while_v2.py @@ -276,12 +276,8 @@ def while_loop(cond, body_graph, output_shapes=output_shapes, parallel_iterations=parallel_iterations, - name=scope) - # This is needed so we do not compute derivative wrt these extra outputs. - outputs[0].op._set_attr("_num_original_outputs", - attr_value_pb2.AttrValue(i=num_original_outputs)) - outputs[0].op._cond_graph = cond_graph - outputs[0].op._body_graph = body_graph + name=scope, + num_original_outputs=num_original_outputs) if not ops.get_default_graph().building_function: # In V1 graph mode, return identities for each output of the While op, # rather than the output of the While op directly. This makes pruning work @@ -366,11 +362,18 @@ def _WhileGrad(op, *grads): # pylint: disable=invalid-name cond_graph.name += "_rewritten" body_graph.name += "_rewritten" + # `body_grad_graph.extra_inputs` here is equivalent to skimming off the new + # `body_graph.external_captures` added during `_create_grad_func`. new_inputs = body_grad_graph.extra_inputs new_outputs = body_graph.outputs[orig_num_params:] while_op._set_func_attr("cond", util.create_new_tf_function(cond_graph)) while_op._set_func_attr("body", util.create_new_tf_function(body_graph)) + if len(body_graph.output_types) != len(while_op.inputs) + len(new_inputs): + # Continuing leads to an invalid graph with disconnected inputs. + raise AssertionError( + "Inputs and outputs constructed for the forward op of a While " + "gradient don't match. This doesn't make sense, please file a bug.") while_op._set_type_list_attr("T", body_graph.output_types) while_op._set_shape_list_attr("output_shapes", body_graph.output_shapes) while_op._add_while_inputs(new_inputs) @@ -408,7 +411,8 @@ def _WhileGrad(op, *grads): # pylint: disable=invalid-name body_grad_graph, output_shapes=[t.shape for t in body_grad_graph.outputs], parallel_iterations=parallel_iterations, - name="%s_grad" % while_op.name) + name="%s_grad" % while_op.name, + num_original_outputs=len(body_grad_graph.outputs)) # See comment in while_loop. outputs = [array_ops.identity(t) for t in outputs] @@ -416,7 +420,7 @@ def _WhileGrad(op, *grads): # pylint: disable=invalid-name def _build_while_op(loop_vars, cond_graph, body_graph, output_shapes, - parallel_iterations, name): + parallel_iterations, name, num_original_outputs): """Builds the functional StatelessWhile/While op.""" cond_stateful_ops = [ op for op in cond_graph.get_operations() if op._is_stateful @@ -429,19 +433,30 @@ def _build_while_op(loop_vars, cond_graph, body_graph, output_shapes, else: op_fn = gen_functional_ops.stateless_while - outputs = op_fn( - loop_vars, - util.create_new_tf_function(cond_graph), - util.create_new_tf_function(body_graph), - output_shapes=output_shapes, - parallel_iterations=parallel_iterations, - name=name) - while_op = outputs[0].op - _copy_handle_data(body_graph.outputs, outputs) - util.maybe_set_lowering_attr(while_op) - util.maybe_propagate_compile_time_consts_in_xla(while_op) - _set_read_only_resource_inputs_attr(while_op, [cond_graph, body_graph]) - return outputs + def _make_op(inputs): + while_op, tensors = util.get_op_and_outputs(op_fn( + inputs, + util.create_new_tf_function(cond_graph), + util.create_new_tf_function(body_graph), + output_shapes=output_shapes, + parallel_iterations=parallel_iterations, + name=name)) + _copy_handle_data(body_graph.outputs, tensors) + util.maybe_set_lowering_attr(while_op) + util.maybe_propagate_compile_time_consts_in_xla(while_op) + _set_read_only_resource_inputs_attr(while_op, [cond_graph, body_graph]) + # This is needed so we do not compute derivative wrt these extra outputs. + while_op._set_attr("_num_original_outputs", + attr_value_pb2.AttrValue(i=num_original_outputs)) + # The while op may be created inside a tf.function, in which case ops + # needs to capture "through" it when taking gradients; outer_graph is used + # as a sanity check that capturing only happens from parent to child. + cond_graph.outer_graph = ops.get_default_graph() + body_graph.outer_graph = ops.get_default_graph() + while_op._cond_graph = cond_graph + while_op._body_graph = body_graph + return tensors + return util.run_as_function_for_tape_gradients(_make_op, loop_vars) def _get_intermediates(func_graph): @@ -815,7 +830,7 @@ def _get_accumulator(tensor): # tf.defun adds an Identity for each output, check whether that is the case. identity_op = t.consumers()[0] if (identity_op.type == "Identity" and - identity_op.outputs[0] in tensor.graph.outputs): + any(identity_op.outputs[0] is t for t in tensor.graph.outputs)): return identity_op.outputs[0] return None @@ -938,10 +953,16 @@ class _WhileBodyGradFuncGraph(util.WhileBodyFuncGraph): # and popping from a TensorList removes the constant property of an op and # breaks XLA compilation, which requires certain inputs to be compile-time # constant for certain ops. + # + # This optimization is currently also disabled when under a persistent tape, + # since it leads to an unbounded number of side outputs. With caching it may + # be possible to re-enable it. if (op_type in {"Shape", "Size", "Rank"} and all(input.graph is self._forward_graph for input in inputs) and all(_get_accumulator(input) is None for input in inputs) and - not util_v1.GraphOrParentsInXlaContext(self._forward_graph)): + not util_v1.GraphOrParentsInXlaContext(self._forward_graph) and + not util.graph_wrapped_for_higher_order_tape_gradients( + self._forward_graph)): with self._forward_graph.as_default(): # `name` was built using name_scope stack of gradient graph and may not # be unique in the forward graph. `Graph.create_op` does not uniquify diff --git a/tensorflow/python/saved_model/function_deserialization.py b/tensorflow/python/saved_model/function_deserialization.py index 4d24f7dd009..7e3e3d3b32b 100644 --- a/tensorflow/python/saved_model/function_deserialization.py +++ b/tensorflow/python/saved_model/function_deserialization.py @@ -407,13 +407,14 @@ def _sort_function_defs(library, library_function_names): return [reverse[x] for x in output] -def fix_node_def(node_def, functions, shared_name_suffix, debug_name): +def _check_op_has_custom_gradients(node_def): + """Returns True if op has custom gradients.""" + return ("_gradient_op_type" in node_def.attr and + node_def.op not in ["StatefulPartitionedCall", "PartitionedCall"]) + + +def fix_node_def(node_def, functions, shared_name_suffix): """Replace functions calls and shared names in `node_def`.""" - if ("_gradient_op_type" in node_def.attr and - node_def.op not in ["StatefulPartitionedCall", "PartitionedCall"]): - logging.warning( - "Importing a function (%s) with ops with custom gradients. Will likely " - "fail if a gradient is requested.", debug_name) if node_def.op in functions: node_def.op = functions[node_def.op].name for _, attr_value in node_def.attr.items(): @@ -471,8 +472,16 @@ def _fix_fdef(orig_fdef, functions, shared_name_suffix): """ fdef = function_pb2.FunctionDef() fdef.CopyFrom(orig_fdef) + contains_custom_gradients = False + for node_def in fdef.node_def: - fix_node_def(node_def, functions, shared_name_suffix, fdef.signature.name) + fix_node_def(node_def, functions, shared_name_suffix) + if not contains_custom_gradients: + contains_custom_gradients = _check_op_has_custom_gradients(node_def) + if contains_custom_gradients: + logging.warning( + "Importing a function (%s) with ops with custom gradients. Will likely " + "fail if a gradient is requested.", fdef.signature.name) fdef.signature.name = _clean_function_name(fdef.signature.name) return fdef diff --git a/tensorflow/python/saved_model/load_v1_in_v2.py b/tensorflow/python/saved_model/load_v1_in_v2.py index a8627701bb8..7978e86d093 100644 --- a/tensorflow/python/saved_model/load_v1_in_v2.py +++ b/tensorflow/python/saved_model/load_v1_in_v2.py @@ -216,8 +216,7 @@ class _EagerSavedModelLoader(loader_impl.SavedModelLoader): # the GraphDef itself for consistency. for node_def in meta_graph_def.graph_def.node: function_deserialization.fix_node_def(node_def, functions, - load_shared_name_suffix, - debug_name="MetaGraph import") + load_shared_name_suffix) load_graph_returns = [None] wrapped = wrap_function.wrap_function( diff --git a/tensorflow/python/saved_model/save.py b/tensorflow/python/saved_model/save.py index 87a65724ab9..45d135d2e61 100644 --- a/tensorflow/python/saved_model/save.py +++ b/tensorflow/python/saved_model/save.py @@ -183,8 +183,9 @@ class _SaveableView(object): """ self.options = options self.checkpoint_view = checkpoint_view - trackable_objects, node_ids, slot_variables = ( - self.checkpoint_view.objects_ids_and_slot_variables()) + trackable_objects, path_to_root, node_ids, slot_variables = ( + self.checkpoint_view.objects_ids_and_slot_variables_and_paths()) + self.node_paths = path_to_root self.nodes = trackable_objects self.node_ids = node_ids self.captured_tensor_node_ids = object_identity.ObjectIdentityDictionary() @@ -1029,6 +1030,30 @@ def save(obj, export_dir, signatures=None, options=None): May not be called from within a function body. @end_compatibility """ + save_and_return_nodes(obj, export_dir, signatures, options, + raise_metadata_warning=True) + + +def save_and_return_nodes(obj, export_dir, signatures=None, options=None, + raise_metadata_warning=False): + """Saves a SavedModel while returning all saved nodes and their paths. + + Please see `tf.saved_model.save` for details. + + Args: + obj: A trackable object to export. + export_dir: A directory in which to write the SavedModel. + signatures: A function or dictionary of functions to save in the SavedModel + as signatures. + options: `tf.saved_model.SaveOptions` object for configuring save options. + raise_metadata_warning: Whether to raise the metadata warning. This arg will + be removed in TF 2.5. + + Returns: + A tuple of (a list of saved nodes in the order they are serialized to the + `SavedObjectGraph`, dictionary mapping nodes to one possible path from + the root node to the key node) + """ options = options or save_options.SaveOptions() # TODO(allenl): Factor out some subset of SavedModelBuilder which is 2.x # compatible (no sessions) and share it with this export API rather than @@ -1036,8 +1061,9 @@ def save(obj, export_dir, signatures=None, options=None): saved_model = saved_model_pb2.SavedModel() meta_graph_def = saved_model.meta_graphs.add() - _, exported_graph, object_saver, asset_info = _build_meta_graph( - obj, signatures, options, meta_graph_def) + _, exported_graph, object_saver, asset_info, saved_nodes, node_paths = ( + _build_meta_graph(obj, signatures, options, meta_graph_def, + raise_metadata_warning)) saved_model.saved_model_schema_version = constants.SAVED_MODEL_SCHEMA_VERSION # Write the checkpoint, copy assets into the assets directory, and write out @@ -1077,6 +1103,8 @@ def save(obj, export_dir, signatures=None, options=None): # constants in the saved graph. ops.dismantle_graph(exported_graph) + return saved_nodes, node_paths + def export_meta_graph(obj, filename, signatures=None, options=None): """Exports the MetaGraph proto of the `obj` to a file. @@ -1103,7 +1131,7 @@ def export_meta_graph(obj, filename, signatures=None, options=None): """ options = options or save_options.SaveOptions() export_dir = os.path.dirname(filename) - meta_graph_def, exported_graph, _, _ = _build_meta_graph( + meta_graph_def, exported_graph, _, _, _, _ = _build_meta_graph( obj, signatures, options) file_io.atomic_write_string_to_file( @@ -1122,7 +1150,8 @@ def export_meta_graph(obj, filename, signatures=None, options=None): def _build_meta_graph_impl(obj, signatures, options, - meta_graph_def=None): + meta_graph_def=None, + raise_metadata_warning=True): """Creates a MetaGraph containing the resources and functions of an object.""" if ops.inside_function(): raise AssertionError( @@ -1170,7 +1199,7 @@ def _build_meta_graph_impl(obj, saveable_view, asset_info.asset_index) meta_graph_def.object_graph_def.CopyFrom(object_graph_proto) - if saved_object_metadata: + if saved_object_metadata and raise_metadata_warning: tf_logging.warn( 'FOR KERAS USERS: The object that you are saving contains one or more ' 'Keras models or layers. If you are loading the SavedModel with ' @@ -1186,13 +1215,15 @@ def _build_meta_graph_impl(obj, 'metadta field will be deprecated soon, so please move the metadata to ' 'a different file.') - return (meta_graph_def, exported_graph, object_saver, asset_info) + return (meta_graph_def, exported_graph, object_saver, asset_info, + saveable_view.nodes, saveable_view.node_paths) def _build_meta_graph(obj, signatures, options, - meta_graph_def=None): + meta_graph_def=None, + raise_metadata_warning=True): """Creates a MetaGraph under a save context. Args: @@ -1205,6 +1236,8 @@ def _build_meta_graph(obj, options: `tf.saved_model.SaveOptions` object that specifies options for saving. meta_graph_def: Optional, the MetaGraphDef proto fill. + raise_metadata_warning: Whether to raise a warning when user objects contain + non-empty metadata. Raises: AssertionError: If `export_meta_graph` is executing inside a `tf.function`. @@ -1218,4 +1251,5 @@ def _build_meta_graph(obj, """ with save_context.save_context(options): - return _build_meta_graph_impl(obj, signatures, options, meta_graph_def) + return _build_meta_graph_impl(obj, signatures, options, meta_graph_def, + raise_metadata_warning) diff --git a/tensorflow/python/tools/print_selective_registration_header_test.py b/tensorflow/python/tools/print_selective_registration_header_test.py index b06f6123ddd..637d73c5434 100644 --- a/tensorflow/python/tools/print_selective_registration_header_test.py +++ b/tensorflow/python/tools/print_selective_registration_header_test.py @@ -108,15 +108,14 @@ class PrintOpFilegroupTest(test.TestCase): ops_and_kernels = selective_registration_header_lib.get_ops_and_kernels( 'rawproto', self.WriteGraphFiles(graphs), default_ops) - matmul_prefix = '' + matmul_prefix = 'Batch' self.assertListEqual( [ ('AccumulateNV2', None), # ('BiasAdd', 'BiasOp<CPUDevice, float>'), # - ('MatMul', - matmul_prefix + 'MatMulOp<CPUDevice, double, false >'), # - ('MatMul', matmul_prefix + 'MatMulOp<CPUDevice, float, false >'), # + ('MatMul', matmul_prefix + 'MatMulOp<CPUDevice, double, true>'), # + ('MatMul', matmul_prefix + 'MatMulOp<CPUDevice, float, true>'), # ('NoOp', 'NoOp'), # ('Reshape', 'ReshapeOp'), # ('_Recv', 'RecvOp'), # @@ -132,9 +131,8 @@ class PrintOpFilegroupTest(test.TestCase): [ ('AccumulateNV2', None), # ('BiasAdd', 'BiasOp<CPUDevice, float>'), # - ('MatMul', - matmul_prefix + 'MatMulOp<CPUDevice, double, false >'), # - ('MatMul', matmul_prefix + 'MatMulOp<CPUDevice, float, false >'), # + ('MatMul', matmul_prefix + 'MatMulOp<CPUDevice, double, true>'), # + ('MatMul', matmul_prefix + 'MatMulOp<CPUDevice, float, true>'), # ('NoOp', 'NoOp'), # ('Reshape', 'ReshapeOp'), # ('_Recv', 'RecvOp'), # diff --git a/tensorflow/python/training/tracking/graph_view.py b/tensorflow/python/training/tracking/graph_view.py index f64db993de6..967bb47c6c5 100644 --- a/tensorflow/python/training/tracking/graph_view.py +++ b/tensorflow/python/training/tracking/graph_view.py @@ -430,7 +430,7 @@ class ObjectGraphView(object): name=base.OBJECT_GRAPH_PROTO_KEY)) return named_saveable_objects - def objects_ids_and_slot_variables(self): + def objects_ids_and_slot_variables_and_paths(self): """Traverse the object graph and list all accessible objects. Looks for `Trackable` objects which are dependencies of @@ -439,7 +439,8 @@ class ObjectGraphView(object): (i.e. if they would be saved with a checkpoint). Returns: - A tuple of (trackable objects, object -> node id, slot variables) + A tuple of (trackable objects, paths from root for each object, + object -> node id, slot variables) """ trackable_objects, path_to_root = self._breadth_first_traversal() object_names = object_identity.ObjectIdentityDictionary() @@ -452,6 +453,11 @@ class ObjectGraphView(object): trackable_objects=trackable_objects, node_ids=node_ids, object_names=object_names) + return trackable_objects, path_to_root, node_ids, slot_variables + + def objects_ids_and_slot_variables(self): + trackable_objects, _, node_ids, slot_variables = ( + self.objects_ids_and_slot_variables_and_paths()) return trackable_objects, node_ids, slot_variables def list_objects(self): diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index 7fadb298a2a..72351028fd8 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -1710,7 +1710,7 @@ def tf_custom_op_library_additional_deps(): clean_dep("//tensorflow/core:framework_headers_lib"), ] + if_windows([clean_dep("//tensorflow/python:pywrap_tensorflow_import_lib")]) -# A list of targets that contains the implemenation of +# A list of targets that contains the implementation of # tf_custom_op_library_additional_deps. It's used to generate a DEF file for # exporting symbols from _pywrap_tensorflow.dll on Windows. def tf_custom_op_library_additional_deps_impl(): @@ -2704,7 +2704,7 @@ def if_cuda_or_rocm(if_true, if_false = []): If the same additional dependency is needed for both CUDA and ROCm (for eg. `reduction_ops` dependency for the `bias_op` target above), - then specifying that dependency in both both `if_cuda` and `if_rocm` will + then specifying that dependency in both `if_cuda` and `if_rocm` will result in both those functions returning a select statement, which contains the same dependency, which then leads to a duplicate dependency bazel error. diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt index 96be23b9e50..8c1ff69422e 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt @@ -762,7 +762,7 @@ tf_module { } member_method { name: "CollectiveGatherV2" - argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], " + argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'ordering_token\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], " } member_method { name: "CollectivePermute" @@ -774,7 +774,7 @@ tf_module { } member_method { name: "CollectiveReduceV2" - argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'merge_op\', \'final_op\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], " + argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'ordering_token\', \'merge_op\', \'final_op\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], " } member_method { name: "CombinedNonMaxSuppression" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.__internal__.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.__internal__.pbtxt index c71fd575719..f02cf9001c8 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.__internal__.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.__internal__.pbtxt @@ -24,4 +24,8 @@ tf_module { name: "tracking" mtype: "<type \'module\'>" } + member_method { + name: "execute_fn_for_device" + argspec: "args=[\'device_branch_fns\', \'default_fn\', \'name\'], varargs=None, keywords=None, defaults=[\'execute_fn\'], " + } } diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt index 96be23b9e50..8c1ff69422e 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt @@ -762,7 +762,7 @@ tf_module { } member_method { name: "CollectiveGatherV2" - argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], " + argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'ordering_token\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], " } member_method { name: "CollectivePermute" @@ -774,7 +774,7 @@ tf_module { } member_method { name: "CollectiveReduceV2" - argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'merge_op\', \'final_op\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], " + argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'ordering_token\', \'merge_op\', \'final_op\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], " } member_method { name: "CombinedNonMaxSuppression" diff --git a/tensorflow/tools/compatibility/tf_upgrade_v2_main.py b/tensorflow/tools/compatibility/tf_upgrade_v2_main.py index 7c7461c19da..28c44261a24 100644 --- a/tensorflow/tools/compatibility/tf_upgrade_v2_main.py +++ b/tensorflow/tools/compatibility/tf_upgrade_v2_main.py @@ -107,7 +107,7 @@ Simple usage: "--no_upgrade_compat_v1_import", dest="no_upgrade_compat_v1_import", help=("If specified, don't upgrade explicit imports of " - "`tensorflow.compat.v1 as tf` to the v2 apis. Otherwise, " + "`tensorflow.compat.v1 as tf` to the v2 APIs. Otherwise, " "explicit imports of the form `tensorflow.compat.v1 as tf` will " "be upgraded."), action="store_true") @@ -158,8 +158,7 @@ Simple usage: "--outfile=<output file> argument is required when converting a " "single file.") if args.in_place and args.output_file: - raise ValueError( - "--outfile argument is invalid when when converting in place") + raise ValueError("--outfile argument is invalid when converting in place") output_file = args.input_file if args.in_place else args.output_file files_processed, report_text, errors = process_file( args.input_file, output_file, upgrade) @@ -171,8 +170,7 @@ Simple usage: "--outtree=<output directory> argument is required when converting a " "file tree.") if args.in_place and args.output_tree: - raise ValueError( - "--outtree argument is invalid when when converting in place") + raise ValueError("--outtree argument is invalid when converting in place") output_tree = args.input_tree if args.in_place else args.output_tree files_processed, report_text, errors = upgrade.process_tree( args.input_tree, output_tree, args.copy_other_files) diff --git a/tensorflow/tools/graph_transforms/sparsify_gather.cc b/tensorflow/tools/graph_transforms/sparsify_gather.cc index f990a165f21..0f95ce50aea 100644 --- a/tensorflow/tools/graph_transforms/sparsify_gather.cc +++ b/tensorflow/tools/graph_transforms/sparsify_gather.cc @@ -340,7 +340,7 @@ Status SparsifyGatherInternal( weights_node.name(), ckpt_reader, (*shapes_and_slices)[weights_node.name()], &weight)); } - // Add both both weight and identity node names. + // Add both weight and identity node names. removed_node_names.push_back(weights_node.name()); removed_node_names.push_back(match.inputs[0].node.name()); for (auto input_node : match.inputs[0].node.input()) { diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index c5fd05d1d0e..1df1fcc352d 100755 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -680,8 +680,8 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""): ) # Check out LLVM and MLIR from llvm-project. - LLVM_COMMIT = "4d11daa659a1833a85863ea920174da4d052a8ba" - LLVM_SHA256 = "80a5d618cbf813a4c455d9045d37dcfca2b6cfac596dbc7ef3a4689a67ab7002" + LLVM_COMMIT = "9bb9b737c5573cf3850230bc4db8dac7be0e1e85" + LLVM_SHA256 = "4ca6c8bd7dbb62746bdb28352395593d6c4a052625cd43ffa8bad3e70aa8ce8c" LLVM_URLS = [ "https://storage.googleapis.com/mirror.tensorflow.org/github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT), "https://github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT), diff --git a/third_party/gpus/check_cuda_libs.py b/third_party/gpus/check_cuda_libs.py index 686d36f5c77..afd6380b0ac 100644 --- a/third_party/gpus/check_cuda_libs.py +++ b/third_party/gpus/check_cuda_libs.py @@ -14,7 +14,7 @@ # ============================================================================== """Verifies that a list of libraries is installed on the system. -Takes a a list of arguments with every two subsequent arguments being a logical +Takes a list of arguments with every two subsequent arguments being a logical tuple of (path, check_soname). The path to the library and either True or False to indicate whether to check the soname field on the shared library.