Merge branch 'master'

Merge branch 'master' of https://github.com/tensorflow/tensorflow
into feature-micro-add-op-depth-to-space-pr1
This commit is contained in:
Ryan Kuester 2021-01-05 15:08:34 -06:00
commit c20ac67cb1
127 changed files with 2354 additions and 1370 deletions

View File

@ -114,6 +114,143 @@ This release contains contributions from many people at Google, as well as:
<INSERT>, <NAME>, <HERE>, <USING>, <GITHUB>, <HANDLE>
# Release 2.3.2
## Bug Fixes and Other Changes
* Fixes an access to unitialized memory in Eigen code
([CVE-2020-26266](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26266))
* Fixes a security vulnerability caused by lack of validation in
`tf.raw_ops.DataFormatVecPermute` and `tf.raw_ops.DataFormatDimMap`
([CVE-2020-26267](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26267))
* Fixes a vulnerability caused by attempting to write to immutable memory region in
`tf.raw_ops.ImmutableConst`
([CVE-2020-26268](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26268)
* Fixes a `CHECK`-fail in LSTM with zero-length input
([CVE-2020-26270](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26270))
* Fixes a security vulnerability caused by accessing heap data outside of bounds
when loading a specially crafted `SavedModel`
([CVE-2020-26271](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26271))
* Solves an OOM issue on TPUs when XLA contexts use fused average updates
* Updates `libjpeg-turbo` to `2.0.5` to handle
[CVE-2020-13790](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-13790).
* Updates `junit` to `4.13.1` to handle
[CVE-2020-15250](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15250).
* Updates `PCRE` to `8.44` to handle
[CVE-2019-20838](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-20838)
and
[CVE-2020-14155](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-14155).
* Updates `sqlite3` to `3.44.0` to keep in sync with master branch.
# Release 2.2.2
## Bug Fixes and Other Changes
* Fixes an access to unitialized memory in Eigen code
([CVE-2020-26266](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26266))
* Fixes a security vulnerability caused by lack of validation in
`tf.raw_ops.DataFormatVecPermute` and `tf.raw_ops.DataFormatDimMap`
([CVE-2020-26267](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26267))
* Fixes a vulnerability caused by attempting to write to immutable memory region in
`tf.raw_ops.ImmutableConst`
([CVE-2020-26268](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26268)
* Fixes a `CHECK`-fail in LSTM with zero-length input
([CVE-2020-26270](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26270))
* Fixes a security vulnerability caused by accessing heap data outside of bounds
when loading a specially crafted `SavedModel`
([CVE-2020-26271](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26271))
* Prevents memory leaks in loading `SavedModel`s that import functions
* Updates `libjpeg-turbo` to `2.0.5` to handle
[CVE-2020-13790](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-13790).
* Updates `junit` to `4.13.1` to handle
[CVE-2020-15250](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15250).
* Updates `PCRE` to `8.44` to handle
[CVE-2019-20838](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-20838)
and
[CVE-2020-14155](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-14155).
* Updates `sqlite3` to `3.44.0` to keep in sync with master branch.
# Release 2.1.3
## Bug Fixes and Other Changes
* Fixes an access to unitialized memory in Eigen code
([CVE-2020-26266](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26266))
* Fixes a security vulnerability caused by lack of validation in
`tf.raw_ops.DataFormatVecPermute` and `tf.raw_ops.DataFormatDimMap`
([CVE-2020-26267](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26267))
* Fixes a vulnerability caused by attempting to write to immutable memory region in
`tf.raw_ops.ImmutableConst`
([CVE-2020-26268](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26268)
* Fixes a `CHECK`-fail in LSTM with zero-length input
([CVE-2020-26270](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26270))
* Fixes a security vulnerability caused by accessing heap data outside of bounds
when loading a specially crafted `SavedModel`
([CVE-2020-26271](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26271))
* Updates `libjpeg-turbo` to `2.0.5` to handle
[CVE-2020-13790](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-13790).
* Updates `junit` to `4.13.1` to handle
[CVE-2020-15250](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15250).
* Updates `PCRE` to `8.44` to handle
[CVE-2019-20838](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-20838)
and
[CVE-2020-14155](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-14155).
* Updates `sqlite3` to `3.44.0` to keep in sync with master branch.
* Newer ROCm versions are supported on the 2.1 branch.
# Release 2.0.4
Note that this is the last patch release for the TensorFlow 2.0.x series.
## Bug Fixes and Other Changes
* Fixes an access to unitialized memory in Eigen code
([CVE-2020-26266](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26266))
* Fixes a security vulnerability caused by lack of validation in
`tf.raw_ops.DataFormatVecPermute` and `tf.raw_ops.DataFormatDimMap`
([CVE-2020-26267](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26267))
* Fixes a vulnerability caused by attempting to write to immutable memory region in
`tf.raw_ops.ImmutableConst`
([CVE-2020-26268](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26268)
* Fixes a `CHECK`-fail in LSTM with zero-length input
([CVE-2020-26270](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26270))
* Fixes a security vulnerability caused by accessing heap data outside of bounds
when loading a specially crafted `SavedModel`
([CVE-2020-26271](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26271))
* Updates `libjpeg-turbo` to `2.0.5` to handle
[CVE-2020-13790](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-13790).
* Updates `junit` to `4.13.1` to handle
[CVE-2020-15250](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15250).
* Updates `PCRE` to `8.44` to handle
[CVE-2019-20838](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-20838)
and
[CVE-2020-14155](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-14155).
* Updates `sqlite3` to `3.44.0` to keep in sync with master branch.
# Release 1.15.5
Note that this is the last patch release for the TensorFlow 1.x series.
## Bug Fixes and Other Changes
* Fixes an access to unitialized memory in Eigen code
([CVE-2020-26266](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26266))
* Fixes a security vulnerability caused by lack of validation in
`tf.raw_ops.DataFormatVecPermute` and `tf.raw_ops.DataFormatDimMap`
([CVE-2020-26267](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26267))
* Fixes a vulnerability caused by attempting to write to immutable memory region in
`tf.raw_ops.ImmutableConst`
([CVE-2020-26268](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26268)
* Fixes a `CHECK`-fail in LSTM with zero-length input
([CVE-2020-26270](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26270))
* Fixes a security vulnerability caused by accessing heap data outside of bounds
when loading a specially crafted `SavedModel`
([CVE-2020-26271](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26271))
* Updates `libjpeg-turbo` to `2.0.5` to handle
[CVE-2020-13790](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-13790).
* Updates `junit` to `4.13.1` to handle
[CVE-2020-15250](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15250).
* Updates `PCRE` to `8.44` to handle
[CVE-2019-20838](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-20838)
and
[CVE-2020-14155](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-14155).
* Updates `sqlite3` to `3.44.0` to keep in sync with master branch.
# Release 2.4.0
## Major Features and Improvements
@ -163,7 +300,7 @@ This release contains contributions from many people at Google, as well as:
## Breaking Changes
* TF Core:
* Certain float32 ops run in lower precsion on Ampere based GPUs, including
* Certain float32 ops run in lower precision on Ampere based GPUs, including
matmuls and convolutions, due to the use of [TensorFloat-32]
(https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/).
Specifically, inputs to such ops are rounded from 23 bits of precision to 10

View File

@ -1282,7 +1282,8 @@ class DynamicReshapeOpNotActuallyDynamic
void DynamicReshapeOp::getCanonicalizationPatterns(
OwningRewritePatternList& results, MLIRContext* context) {
results.insert<DynamicReshapeOpNotActuallyDynamic,
RemoveRedundantDynamicReshape, ShapeOfDynamicReshape>(context);
RemoveRedundantDynamicBroadcast, RemoveRedundantDynamicReshape,
ShapeOfDynamicReshape>(context);
}
//===----------------------------------------------------------------------===//

View File

@ -33,3 +33,10 @@ def UnaryEinsumToEinsum : Pat<
def RemoveRedundantDynamicReshape : Pat<
(HLO_DynamicReshapeOp (HLO_DynamicReshapeOp $operand, $shape1), $shape2),
(HLO_DynamicReshapeOp $operand, $shape2)>;
// A dynamic broadcast of a dynamic reshape with the same shape operand
// is a dynamic reshape.
def RemoveRedundantDynamicBroadcast : Pat<
(HLO_DynamicBroadcastInDimOp
(HLO_DynamicReshapeOp $operand, $shape), $shape, $dims),
(HLO_DynamicReshapeOp $operand, $shape)>;

View File

@ -1540,3 +1540,14 @@ func @identity_broadcast_in_dim_reshape(%arg0: tensor<128xf32>) -> tensor<128xf3
return %1 : tensor<128xf32>
// CHECK: return %arg0 : tensor<128xf32>
}
// CHECK-LABEL: @broadcast_of_reshape
func @broadcast_of_reshape(%arg: tensor<?xf32>, %shape: tensor<2xindex>) -> tensor<?x?xf32> {
%0 = "mhlo.dynamic_reshape"(%arg, %shape) : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
%1 = "mhlo.dynamic_broadcast_in_dim"(%0, %shape) {
broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>
} : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}
// CHECK: [[RESHAPE:%.*]] = "mhlo.dynamic_reshape"
// CHECK: return [[RESHAPE]]

View File

@ -133,7 +133,7 @@ class TFL_OperandsHaveSameShapesOrBroadcastableShape<
TFL_RuntimePredOpTrait<"operands do not have the same shape or "
"broadcastable shapes within the rank " # max_bcast_rank,
CPred<"TFL::VerifyOperandsHaveSameShapesOrBroadcastableShape("
"$_op, llvm::ArrayRef<unsigned>({" # StrJoinInt<indices>.result #
"$_op, llvm::ArrayRef<unsigned>({" # !interleave(indices, ", ") #
"}), " # max_bcast_rank # ")">>;
// These additional types/type constraints here are used to decouple the ops
@ -3453,10 +3453,10 @@ def TFL_CastOp : TFL_Op<"cast", [
}];
let arguments = (ins
TFL_TensorOf<[F32, I1, I32, I64, TFL_Quint8, UI8, Complex<F<32>>]>:$input
TFL_TensorOf<[F32, I1, I16, I32, I64, TFL_Quint8, UI8, Complex<F<32>>]>:$input
);
let results = (outs TFL_TensorOf<[F32, I1, I32, I64, TFL_Quint8, UI8, Complex<F<32>>]>:$output);
let results = (outs TFL_TensorOf<[F32, I1, I16, I32, I64, TFL_Quint8, UI8, Complex<F<32>>]>:$output);
// TFLite's cast op does not utilize CastOptions, instead derives types
// from the TfLiteTensors.

View File

@ -34,7 +34,7 @@ class QuantizedType<string n, list<int> params, bit signed>
"Q" # !if (signed, "I", "UI") # !head(params) # " type"> {
string name = n;
string asTraitArgsStr =
StrJoinInt<params>.result # !if(signed, ", true", ", false");
!interleave(params, ", ") # !if(signed, ", true", ", false");
}
// Uniform quantized types. Two integers "smantissa" and "sexp" are used to
@ -134,7 +134,7 @@ class FixedResultScale<QuantizedType qt> : NativeOpTrait<!strconcat(
// needs a scale based on the scales of op1 and op2.
class AccumulatorUniformScale<int bias, int op1, int op2> : NativeOpTrait<
!strconcat("quant::AccumulatorUniformScale<",
StrJoinInt<[bias, op1, op2]>.result,
!interleave([bias, op1, op2], ", "),
">::Impl")>;
// Specify the operand index of the coefficient operand for an affine op
@ -142,7 +142,7 @@ class AccumulatorUniformScale<int bias, int op1, int op2> : NativeOpTrait<
// If the quantization dimension is -1, per-axis quantization isn't supported.
class AffineOpCoefficient<int dim, int index> : NativeOpTrait<
!strconcat("quant::AffineOpCoefficient<",
StrJoinInt<[dim, index]>.result,
!interleave([dim, index], ", "),
">::Impl")>;
// Specify this trait if the op doesn't have quantizable output. We shouldn't

View File

@ -1320,6 +1320,22 @@ func @cast(%arg0: tensor<1x2x2x5xi32>) -> tensor<1x2x2x5xf32> {
// CHECK: "tfl.cast"(%arg0) : (tensor<1x2x2x5xi32>) -> tensor<1x2x2x5xf32>
}
func @castFloat32ToI16(%arg0: tensor<1x2x2x5xf32>) -> tensor<1x2x2x5xi16> {
%0 = "tf.Cast"(%arg0) : (tensor<1x2x2x5xf32>) -> tensor<1x2x2x5xi16>
return %0 : tensor<1x2x2x5xi16>
// CHECK-LABEL: castFloat32ToI16
// CHECK: "tfl.cast"(%arg0) : (tensor<1x2x2x5xf32>) -> tensor<1x2x2x5xi16>
}
func @castI16ToFloat32(%arg0: tensor<1x2x2x5xi16>) -> tensor<1x2x2x5xf32> {
%0 = "tf.Cast"(%arg0) : (tensor<1x2x2x5xi16>) -> tensor<1x2x2x5xf32>
return %0 : tensor<1x2x2x5xf32>
// CHECK-LABEL: castI16ToFloat32
// CHECK: "tfl.cast"(%arg0) : (tensor<1x2x2x5xi16>) -> tensor<1x2x2x5xf32>
}
func @castComplex(%arg0: tensor<1x2x2x5xf32>) -> tensor<1x2x2x5xcomplex<f32>> {
%0 = "tf.Cast"(%arg0) : (tensor<1x2x2x5xf32>) -> tensor<1x2x2x5xcomplex<f32>>
return %0 : tensor<1x2x2x5xcomplex<f32>>

View File

@ -98,11 +98,13 @@ class TF_OpIsBroadcastableToRes<int opId, int resId> : And<[
class TF_AllTypesMatchPred<list<string> values> :
CPred<"TF::AreCastCompatible(llvm::makeArrayRef({"# StrJoin<values>.result #"}))">;
CPred<"TF::AreCastCompatible(llvm::makeArrayRef({" #
!interleave(values, ", ") # "}))">;
class TF_AllTypesMatch<list<string> names> :
PredOpTrait<
"all of {" # StrJoin<names>.result # "} have dynamically equal types ",
"all of {" # !interleave(names, ", ") #
"} have dynamically equal types ",
TF_AllTypesMatchPred<
!foreach(n, names, !subst("$_self", "$" # n, "$_self.getType()"))>>;

View File

@ -147,6 +147,7 @@ cc_library(
"//tensorflow/compiler/mlir/hlo:lhlo",
"//tensorflow/compiler/mlir/hlo:lhlo_gpu",
"//tensorflow/compiler/xla:debug_options_flags",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/service:backend",

View File

@ -384,7 +384,7 @@ ENTRY main {
HloModule BatchNormForwardInference
// CHECK: func @main
// CHECK: lmhlo_gpu.batch_norm_inference"
// CHECK: "lmhlo_gpu.batch_norm_inference"
// CHECK-SAME: epsilon = 1.000000e-03 : f32
// CHECK-SAME: feature_index = 0 : i64
// CHECK-SAME: (memref<2x2x2x2xf32>, memref<2xf32>, memref<2xf32>, memref<2xf32>, memref<2xf32>, memref<2x2x2x2xf32>) -> ()
@ -400,3 +400,15 @@ ENTRY main {
custom-call(f32[2,2,2,2]{3,2,1,0} %input, f32[2]{0} %offset, f32[2]{0} %scale, f32[2]{0} %mean, f32[2]{0} %variance, f32[] %constant, s64[] %constant_1),
custom_call_target="__cudnn$batchNormalizationForwardInference"
}
// -----
HloModule Infeed
// CHECK: func @main
// CHECK: "lmhlo.infeed"
// CHECK-SAME: (memref<3xf32>) -> ()
ENTRY main {
%tok = token[] parameter(0)
ROOT %infeed = (f32[3]{0}, token[]) infeed(token[] %tok)
}

View File

@ -27,6 +27,7 @@ limitations under the License.
#include "mlir/IR/AffineMap.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/IR/Dialect.h" // from @llvm-project
@ -60,6 +61,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/util.h"
@ -67,6 +69,7 @@ using xla::BufferAllocation;
using xla::BufferAssignment;
using xla::HloComputation;
using xla::HloCustomCallInstruction;
using xla::HloInfeedInstruction;
using xla::HloInstruction;
using xla::HloModule;
using xla::HloModuleProto;
@ -199,14 +202,16 @@ class XlaHloToLhloPass
} // namespace
// Creates MLIR operands corresponding to operands and results of the XLA HLO
// instruction. If `num_operands` is not -1, then only the first `num_operands`
// instruction. If `num_operands` is valid, then only the first `num_operands`
// operands of the HLO instruction will be considered.
Status LhloDialectEmitter::CreateOperands(
HloInstruction* instr, llvm::SmallVectorImpl<Value>& operands,
size_t& num_arguments, size_t& num_results,
absl::optional<xla::int64> num_operands) {
HloInstruction* instr, absl::optional<xla::int64> num_operands,
llvm::SmallVectorImpl<Value>& operands, size_t& num_arguments,
size_t& num_results) {
if (num_operands.value_or(0) > instr->operand_count())
return xla::InvalidArgument("num_operands must be <= operand count");
for (xla::int64 i = 0; i < num_operands.value_or(instr->operand_count());
i++) {
++i) {
TF_RETURN_IF_ERROR(GetOrCreateView(instr->operand(i), &operands));
}
num_arguments = operands.size();
@ -215,19 +220,23 @@ Status LhloDialectEmitter::CreateOperands(
return Status::OK();
}
template <typename OpType>
OpType LhloDialectEmitter::CreateOpWithoutAttrs(HloInstruction* instr,
ValueRange operands) {
Location loc = getLocation(instr);
NamedAttribute attrs[] = {{Identifier::get("name", builder_.getContext()),
builder_.getStringAttr(instr->name())}};
return builder_.create<OpType>(loc, llvm::None, operands, attrs);
}
template <typename OpType>
StatusOr<OpType> LhloDialectEmitter::CreateOpWithoutAttrs(
HloInstruction* instr, size_t& num_arguments, size_t& num_results,
absl::optional<xla::int64> num_operands) {
Location loc = getLocation(instr);
std::pair<Identifier, Attribute> attrs[] = {
{Identifier::get("name", builder_.getContext()),
builder_.getStringAttr(instr->name())},
};
llvm::SmallVector<Value, 4> operands;
TF_RETURN_IF_ERROR(CreateOperands(instr, operands, num_arguments, num_results,
num_operands));
return builder_.create<OpType>(loc, llvm::None, operands, attrs);
TF_RETURN_IF_ERROR(CreateOperands(instr, num_operands, operands,
num_arguments, num_results));
return CreateOpWithoutAttrs<OpType>(instr, operands);
}
StatusOr<mlir::Operation*> LhloDialectEmitter::EmitOp(HloInstruction* instr) {
@ -273,6 +282,8 @@ StatusOr<mlir::Operation*> LhloDialectEmitter::EmitOp(HloInstruction* instr) {
return CreateOpWithoutAttrs<lmhlo::FloorOp>(instr);
case HloOpcode::kImag:
return CreateOpWithoutAttrs<lmhlo::ImagOp>(instr);
case HloOpcode::kInfeed:
return EmitInfeedOp(instr);
case HloOpcode::kIsFinite:
return CreateOpWithoutAttrs<lmhlo::IsFiniteOp>(instr);
case HloOpcode::kLog:
@ -387,7 +398,7 @@ StatusOr<Value> LhloDialectEmitter::RewriteFusionOperand(
::xla::ShapeIndex* shape_index, OpBuilder* b, Location loc) {
if (shape.IsTuple()) {
llvm::SmallVector<Value, 4> values;
for (int i = 0; i < shape.tuple_shapes_size(); i++) {
for (int i = 0; i < shape.tuple_shapes_size(); ++i) {
shape_index->push_back(i);
TF_ASSIGN_OR_RETURN(
auto v, RewriteFusionOperand(root, shape.tuple_shapes(i), shape_index,
@ -423,7 +434,7 @@ StatusOr<lmhlo::FusionOp> LhloDialectEmitter::EmitFusionOp(
auto region_builder = OpBuilder::atBlockBegin(&fusion.region().front());
llvm::SmallVector<Value, 8> arguments;
for (int i = 0; i < instr->operands().size(); i++) {
for (int i = 0; i < instr->operands().size(); ++i) {
const HloInstruction* operand = instr->operand(i);
xla::ShapeIndex shape_index;
TF_ASSIGN_OR_RETURN(
@ -982,6 +993,19 @@ StatusOr<lmhlo::AllReduceOp> LhloDialectEmitter::EmitAllReduceOp(
return all_reduce_op;
}
StatusOr<lmhlo::InfeedOp> LhloDialectEmitter::EmitInfeedOp(
HloInstruction* instr) {
HloInfeedInstruction* infeed = ::xla::Cast<HloInfeedInstruction>(instr);
// HLO Infeed instruction has a single operand of token type and a tuple
// with buffers and a token as its output. LMHLO Infeed operation does not
// need the token operand or result, so drop it.
SmallVector<Value, 2> operands;
TF_RETURN_IF_ERROR(GetOrCreateView(instr, &operands, /*result_subset=*/{0}));
auto infeed_op = CreateOpWithoutAttrs<lmhlo::InfeedOp>(instr, operands);
infeed_op.configAttr(builder_.getStringAttr(infeed->infeed_config()));
return infeed_op;
}
StatusOr<Value> LhloDialectEmitter::GetOrCreateArrayView(
const ::xla::HloInstruction* instr, const ::xla::Shape& current_shape,
const ::xla::ShapeIndex& shape_index) {
@ -1055,7 +1079,7 @@ Status LhloDialectEmitter::GetOrCreateViewImpl(
const HloInstruction* instr, const Shape& current_shape,
::xla::ShapeIndex* current_shape_index, SmallVectorImpl<Value>* values) {
if (current_shape.IsTuple()) {
for (int i = 0; i < current_shape.tuple_shapes().size(); i++) {
for (int i = 0; i < current_shape.tuple_shapes().size(); ++i) {
current_shape_index->push_back(i);
TF_RETURN_IF_ERROR(GetOrCreateViewImpl(
instr, current_shape.tuple_shapes(i), current_shape_index, values));
@ -1063,19 +1087,26 @@ Status LhloDialectEmitter::GetOrCreateViewImpl(
}
return Status::OK();
}
TF_ASSIGN_OR_RETURN(
auto v, GetOrCreateArrayView(instr, current_shape, *current_shape_index));
if (current_shape.IsArray()) {
TF_ASSIGN_OR_RETURN(auto v, GetOrCreateArrayView(instr, current_shape,
*current_shape_index));
values->push_back(v);
return Status::OK();
}
return xla::InternalError("Unexpected shape kind for %s and shape index %s",
instr->ToString(), current_shape_index->ToString());
}
// Returns a view for the result of an instruction.
// We first get a view for the slice in the allocation, and then may need to
// create another view to adjust the slice for the shape of the instruction.
Status LhloDialectEmitter::GetOrCreateView(const HloInstruction* instr,
SmallVectorImpl<Value>* values) {
::xla::ShapeIndex shape_index;
return GetOrCreateViewImpl(instr, instr->shape(), &shape_index, values);
Status LhloDialectEmitter::GetOrCreateView(
const HloInstruction* instr, SmallVectorImpl<Value>* values,
const xla::ShapeIndex& result_subset) {
::xla::ShapeIndex shape_index = result_subset;
const Shape& sub_shape =
::xla::ShapeUtil::GetSubshape(instr->shape(), shape_index);
return GetOrCreateViewImpl(instr, sub_shape, &shape_index, values);
}
Status LhloDialectEmitter::Initialize() {

View File

@ -27,6 +27,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h"
namespace mlir {
@ -79,6 +81,7 @@ class LhloDialectEmitter : public ::xla::DfsHloVisitorWithDefault {
::xla::StatusOr<lmhlo::CompareOp> EmitCompareOp(::xla::HloInstruction* instr);
::xla::StatusOr<lmhlo::InfeedOp> EmitInfeedOp(::xla::HloInstruction* instr);
::xla::StatusOr<lmhlo::MapOp> EmitMapOp(::xla::HloInstruction* instr);
::xla::StatusOr<lmhlo::ReducePrecisionOp> EmitReducePrecisionOp(
@ -87,10 +90,16 @@ class LhloDialectEmitter : public ::xla::DfsHloVisitorWithDefault {
::xla::StatusOr<lmhlo::AllReduceOp> EmitAllReduceOp(
::xla::HloInstruction* instr);
::xla::Status CreateOperands(
::xla::HloInstruction* instr, SmallVectorImpl<Value>& operands,
size_t& num_arguments, size_t& num_results,
absl::optional<xla::int64> num_operands = absl::nullopt);
// Create LHLO operation operands given an XLA HLO instruction. By default,
// all XLA HLO operands and results are converted to MLIR and appended to
// `operands`. If `num_operands` is specified, only the first `num_operand`
// operands of the instruction are converted to MLIR. The function returns the
// actual number of operands and results generated for MLIR in `num_arguments`
// and `num_results`.
::xla::Status CreateOperands(::xla::HloInstruction* instr,
absl::optional<xla::int64> num_operands,
SmallVectorImpl<Value>& operands,
size_t& num_arguments, size_t& num_results);
template <typename OpType>
::xla::StatusOr<OpType> CreateOpWithoutAttrs(
@ -105,6 +114,10 @@ class LhloDialectEmitter : public ::xla::DfsHloVisitorWithDefault {
::xla::HloInstruction* instr, size_t& num_arguments, size_t& num_results,
absl::optional<xla::int64> num_operands = absl::nullopt);
template <typename OpType>
OpType CreateOpWithoutAttrs(::xla::HloInstruction* instr,
ValueRange operands);
template <typename T>
DenseIntElementsAttr GetI64DenseElementsAttr(const T& container) {
return builder_.getI64TensorAttr(
@ -140,9 +153,14 @@ class LhloDialectEmitter : public ::xla::DfsHloVisitorWithDefault {
SmallVectorImpl<Value>* values);
// Helper function to create view/tuple of views to a buffer for a given
// instruction result.
// instruction result. `result_subset` can be used to for instructions that
// have a tuple result and MLIR conversion needs to convert only one of the
// tuple elements. Note that if needed, this can be extended to take a list of
// ShapeIndex values in case we need finer control on what elements of the
// output tuple to be converted to MLIR.
tensorflow::Status GetOrCreateView(const ::xla::HloInstruction* instr,
SmallVectorImpl<Value>* values);
SmallVectorImpl<Value>* values,
const xla::ShapeIndex& result_subset = {});
::xla::StatusOr<Value> GetOrCreateArrayView(
const ::xla::HloInstruction* instr, const ::xla::Shape& current_shape,

View File

@ -51,10 +51,10 @@
" \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/xla/tutorials/compile\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n",
" \u003c/td\u003e\n",
" \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/g3doc/tutorials/compile.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
" \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/g3doc/tutorials/jit_compile.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
" \u003c/td\u003e\n",
" \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/g3doc/tutorials/compile.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n",
" \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/g3doc/tutorials/jit_compile.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n",
" \u003c/td\u003e\n",
"\u003c/table\u003e"
]

View File

@ -89,7 +89,7 @@ PyClient::GetDefaultDeviceAssignment1D(int num_replicas) {
}
StatusOr<std::unique_ptr<PjRtBuffer>> PyClient::PjRtBufferFromPyval(
const pybind11::object& argument, PjRtDevice* device, bool force_copy,
pybind11::handle argument, PjRtDevice* device, bool force_copy,
PjRtClient::HostBufferSemantics host_buffer_semantics) {
if (device == nullptr) {
TF_RET_CHECK(!pjrt_client_->local_devices().empty());
@ -123,7 +123,7 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PyClient::PjRtBufferFromPyval(
return buffer;
}
StatusOr<std::unique_ptr<PyBuffer>> PyClient::BufferFromPyval(
const pybind11::object& argument, PjRtDevice* device, bool force_copy,
pybind11::handle argument, PjRtDevice* device, bool force_copy,
PjRtClient::HostBufferSemantics host_buffer_semantics) {
TF_ASSIGN_OR_RETURN(
std::unique_ptr<PjRtBuffer> buffer,

View File

@ -124,10 +124,10 @@ class PyClient : public std::enable_shared_from_this<PyClient> {
}
StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBufferFromPyval(
const pybind11::object& argument, PjRtDevice* device, bool force_copy,
pybind11::handle argument, PjRtDevice* device, bool force_copy,
PjRtClient::HostBufferSemantics host_buffer_semantics);
StatusOr<std::unique_ptr<PyBuffer>> BufferFromPyval(
const pybind11::object& argument, PjRtDevice* device, bool force_copy,
pybind11::handle argument, PjRtDevice* device, bool force_copy,
PjRtClient::HostBufferSemantics host_buffer_semantics);
StatusOr<std::shared_ptr<PyExecutable>> Compile(

View File

@ -71,6 +71,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/gpu_constants.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h"
#include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h"
#include "tensorflow/compiler/xla/service/gpu/infeed_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h"
#include "tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h"
@ -3168,7 +3169,22 @@ Status IrEmitterUnnested::HandleAllToAll(HloInstruction* hlo) {
}
Status IrEmitterUnnested::HandleInfeed(HloInstruction* xla_infeed) {
return ThunkEmitter(this).HandleInfeed(xla_infeed);
TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(xla_infeed));
auto infeed_op = mlir::dyn_cast<mlir::lmhlo::InfeedOp>(input.op);
std::vector<InfeedThunk::ShapedSlice> dest_slices;
dest_slices.reserve(infeed_op.outputs().size());
for (mlir::Value output : infeed_op.outputs()) {
TF_ASSIGN_OR_RETURN(auto slice, GetAllocationSliceForMlir(output));
const Shape& shape = TypeToShape(output.getType());
dest_slices.push_back(InfeedThunk::ShapedSlice{slice, shape});
}
AddThunkToThunkSequence(
absl::make_unique<InfeedThunk>(input.thunk_info, std::move(dest_slices)));
return Status::OK();
}
Status IrEmitterUnnested::HandleOutfeed(HloInstruction* outfeed) {

View File

@ -800,8 +800,16 @@ Status AMDGPUTargetModuleLinker(llvm::Module* module, GpuVersion gpu_version,
std::unique_ptr<llvm::TargetMachine> AMDGPUGetTargetMachine(
llvm::Triple target_triple, int amdgpu_version,
const HloModuleConfig& hlo_module_config) {
string feature_str = "+code-object-v3";
#if TF_ROCM_VERSION >= 30900
// code-object-v3 is default, so no need to expliticitly specify it
// in the feature string. Also, starting with ROCm 4.0, this feature string
// is deprecated, and we get a warning to that effect. So removing that
// feature string
feature_str = "";
#endif
return GetTargetMachine(target_triple, absl::StrCat("gfx", amdgpu_version),
hlo_module_config, "+code-object-v3");
hlo_module_config, feature_str);
}
void AMDGPUBackendInit(const HloModuleConfig& hlo_module_config) {

View File

@ -115,34 +115,6 @@ std::unique_ptr<Thunk> ThunkEmitter::BuildGemmThunk(
/*implements_whole_instruction=*/true);
}
std::unique_ptr<Thunk> ThunkEmitter::BuildInfeedThunk(
const HloInstruction* inst) {
CHECK_EQ(HloOpcode::kInfeed, inst->opcode());
std::vector<ShapeUtil::IndexedShape> leaf_shapes =
ShapeUtil::GetLeafShapes(inst->shape());
// For an infeed HLO, the output is a 2 element tuple where the first element
// of the tuple is all the infeed buffers and the second element is a token.
// The infeed thunk does not need to handle this token output, so just drop
// it.
leaf_shapes.pop_back();
std::vector<InfeedThunk::ShapedSlice> dest_slices;
dest_slices.reserve(leaf_shapes.size());
for (ShapeUtil::IndexedShape& indexed_shape : leaf_shapes) {
BufferAllocation::Slice slice =
GetAllocationSlice(*inst, indexed_shape.index);
const Shape& shape =
ShapeUtil::GetSubshape(inst->shape(), indexed_shape.index);
dest_slices.emplace_back(InfeedThunk::ShapedSlice{slice, shape});
}
return absl::make_unique<InfeedThunk>(context_->GetThunkInfo(inst),
std::move(dest_slices));
}
std::unique_ptr<Thunk> ThunkEmitter::BuildOutfeedThunk(
const HloInstruction* inst) {
CHECK_EQ(HloOpcode::kOutfeed, inst->opcode());
@ -258,11 +230,6 @@ Status ThunkEmitter::HandleTriangularSolve(HloInstruction* hlo) {
return Status::OK();
}
Status ThunkEmitter::HandleInfeed(HloInstruction* infeed) {
AddThunkToThunkSequence(BuildInfeedThunk(infeed));
return Status::OK();
}
Status ThunkEmitter::HandleOutfeed(HloInstruction* outfeed) {
AddThunkToThunkSequence(BuildOutfeedThunk(outfeed));
return Status::OK();

View File

@ -46,7 +46,6 @@ class ThunkEmitter {
Status HandleCustomCall(HloInstruction* custom_call);
Status HandleFft(HloInstruction* fft);
Status HandleTriangularSolve(HloInstruction* hlo);
Status HandleInfeed(HloInstruction* xla_infeed);
Status HandleOutfeed(HloInstruction* outfeed);
private:

View File

@ -367,7 +367,6 @@ xla_test(
"conv_depthwise_test.cc",
],
shard_count = 50,
tags = ["no_rocm"], # ROCm 3.9 regression
deps = [
":conv_depthwise_common",
":test_macros_header",
@ -389,7 +388,6 @@ xla_test(
timeout = "long",
srcs = ["conv_depthwise_backprop_filter_test.cc"],
shard_count = 40,
tags = ["no_rocm"], # ROCm 3.9 regression
deps = [
":test_macros_header",
"//tensorflow/compiler/xla:execution_options_util",
@ -414,7 +412,6 @@ xla_test(
"cpu",
],
shard_count = 50,
tags = ["no_rocm"], # ROCm 3.9 regression
deps = [
":client_library_test_base",
":hlo_test_base",
@ -924,7 +921,6 @@ xla_test(
srcs = ["dot_operation_test.cc"],
shard_count = 20,
tags = [
"no_rocm", # ROCm 3.9 regression
"optonly",
],
deps = [
@ -958,7 +954,6 @@ xla_test(
backends = ["gpu"],
shard_count = 20,
tags = [
"no_rocm", # ROCm 3.9 regression
"optonly",
# TODO(b/151340488): Timed out on 2020-03-12.
"nozapfhahn",
@ -1025,7 +1020,6 @@ xla_test(
},
shard_count = 20,
tags = [
"no_rocm", # ROCm 3.9 regression
"optonly",
],
deps = [
@ -1253,7 +1247,6 @@ xla_test(
"cpu": ["nomsan"],
},
shard_count = 30,
tags = ["no_rocm"], # ROCm 3.9 regression
deps = [
":test_macros_header",
"//tensorflow/compiler/xla:array3d",
@ -1278,7 +1271,6 @@ xla_test(
timeout = "long",
srcs = ["convolution_dimension_numbers_test.cc"],
shard_count = 20,
tags = ["no_rocm"], # ROCm 3.9 regression
deps = [
":test_macros_header",
"//tensorflow/compiler/xla:array4d",
@ -2322,7 +2314,6 @@ xla_test(
name = "multioutput_fusion_test",
srcs = ["multioutput_fusion_test.cc"],
backends = ["gpu"],
tags = ["no_rocm"], # ROCm 3.9 regression
deps = [
":test_macros_header",
"//tensorflow/compiler/xla:literal",

View File

@ -480,13 +480,11 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
{csinfo_.fused_batch_norm_grad_v3,
mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad_v3),
CopyAttrsAll, FusedBatchNormV3Rewrite, GetRewriteCause()});
#ifdef ENABLE_MKLDNN_V1
rinfo_.push_back({csinfo_.fused_batch_norm_ex,
native_fmt ? csinfo_.mkl_native_fused_batch_norm_ex
: csinfo_.mkl_fused_batch_norm_ex,
CopyAttrsAll, FusedBatchNormExRewrite,
GetRewriteCause()});
#endif
rinfo_.push_back({csinfo_.fused_conv2d,
native_fmt ? csinfo_.mkl_native_fused_conv2d
: csinfo_.mkl_fused_conv2d,
@ -672,14 +670,12 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
rinfo_.push_back({csinfo_.requantize,
mkl_op_registry::GetMklOpName(csinfo_.requantize),
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
#ifdef ENABLE_MKLDNN_V1
// Optimized TanhGrad support exists only in DNNL 1.x.
rinfo_.push_back({csinfo_.tanh, mkl_op_registry::GetMklOpName(csinfo_.tanh),
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
rinfo_.push_back({csinfo_.tanh_grad,
mkl_op_registry::GetMklOpName(csinfo_.tanh_grad),
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
#endif // ENABLE_MKLDNN_V1
rinfo_.push_back({csinfo_.reshape,
mkl_op_registry::GetMklOpName(csinfo_.reshape),
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});

View File

@ -53,7 +53,6 @@ static void InitGraph(const string& s, Graph* graph,
GraphDef graph_def;
auto parser = protobuf::TextFormat::Parser();
// parser.AllowRelaxedWhitespace(true);
CHECK(parser.MergeFromString(s, &graph_def)) << s;
GraphConstructorOptions opts;
TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, graph));
@ -66,7 +65,6 @@ static void InitGraph(const string& s, Graph* graph,
class MklLayoutPassTest : public ::testing::Test {
public:
MklLayoutPassTest() : graph_(OpRegistry::Global()) {}
// Ashraf added
Node* FindNode(const string& name) {
for (Node* node : graph_.nodes()) {
if (node->name() == name) return node;
@ -3087,8 +3085,6 @@ REGISTER_TEST_ALL_TYPES(NodeRewrite_LeakyReluGrad_Negative);
REGISTER_TEST_ALL_TYPES(NodeRewrite_LeakyReluLeakyReluGrad_Positive);
#undef REGISTER_TEST
#ifdef ENABLE_MKLDNN_V1
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
DCHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); \
@ -3146,7 +3142,6 @@ REGISTER_TEST_ALL_TYPES(NodeRewrite_TanhGrad_Positive);
}
REGISTER_TEST_ALL_TYPES(NodeRewrite_TanhTanhGrad_Positive);
#undef REGISTER_TEST
#endif // ENABLE_MKLDNN_V1
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
@ -3513,7 +3508,6 @@ REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedBatchNormGradV3_5D_Negative_2);
#undef DATA_FORMAT
#undef REGISTER_TEST
#ifdef ENABLE_MKLDNN_V1
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph("node { name: 'A' op: '" #INPUT "'}" \
@ -3603,7 +3597,6 @@ REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedBatchNormEx_Negative1);
}
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedBatchNormEx_Negative2);
#undef REGISTER_TEST
#endif // ENABLE_MKLDNN_V1
TEST_F(MklLayoutPassTest, NodeRewrite_QuantizedDepthwiseConv2D_Positive) {
InitGraph(
@ -5184,8 +5177,8 @@ static void BM_MklLayoutRewritePass(int iters, int op_nodes) {
bool first = true;
while (iters > 0) {
Graph* graph = new Graph(OpRegistry::Global());
InitGraph(s, graph);
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
InitGraph(s, graph.get());
int N = graph->num_node_ids();
if (first) {
testing::SetLabel(strings::StrCat("Per graph node. Nodes: ", N));
@ -5193,13 +5186,12 @@ static void BM_MklLayoutRewritePass(int iters, int op_nodes) {
}
{
testing::StartTiming();
std::unique_ptr<Graph> ug(graph);
std::unique_ptr<Graph> ug(graph.get());
RunMklLayoutRewritePass(&ug);
testing::StopTiming();
}
iters -= N; // Our benchmark units are individual graph nodes,
// not whole graphs
// delete graph;
}
}
BENCHMARK(BM_MklLayoutRewritePass)->Arg(1000)->Arg(10000);

View File

@ -37,6 +37,7 @@ load(
package(
default_visibility = [
"//tensorflow/core:__subpackages__",
"//tensorflow/security/fuzzing:__subpackages__",
],
licenses = ["notice"], # Apache 2.0
)
@ -622,7 +623,10 @@ cc_library(
name = "bfloat16",
srcs = ["bfloat16.cc"],
hdrs = ["bfloat16.h"],
visibility = ["//tensorflow/core:__subpackages__"],
visibility = [
"//tensorflow/core:__subpackages__",
"//tensorflow/security/fuzzing:__subpackages__",
],
deps = [
":numeric_types",
"//tensorflow/core/platform:byte_order",

View File

@ -269,6 +269,9 @@ Status MetaOptimizer::InitializeOptimizers(
if (cfg_.pin_to_host_optimization() == RewriterConfig::ON) {
optimizers->push_back(MakeUnique<PinToHostOptimizer>());
}
if (cfg_.remapping() != RewriterConfig::OFF) {
optimizers->push_back(MakeUnique<Remapper>(cfg_.remapping()));
}
if (cfg_.arithmetic_optimization() != RewriterConfig::OFF) {
optimizers->push_back(
MakeUnique<ArithmeticOptimizer>(cfg_.arithmetic_optimization()));
@ -278,9 +281,6 @@ Status MetaOptimizer::InitializeOptimizers(
/*optimization level*/ cfg_.layout_optimizer(),
/*CPU layout conversion*/ cfg_.cpu_layout_conversion()));
}
if (cfg_.remapping() != RewriterConfig::OFF) {
optimizers->push_back(MakeUnique<Remapper>(cfg_.remapping()));
}
if (cfg_.loop_optimization() != RewriterConfig::OFF) {
optimizers->push_back(
MakeUnique<LoopOptimizer>(cfg_.loop_optimization(), cpu_device_));

View File

@ -29,8 +29,14 @@ REGISTER5(BinaryOp, CPU, "DivNoNan", functor::div_no_nan, Eigen::half, float,
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
// ROCM TODO: re-enable complex64 / complex128 after compiler fix
#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
!defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED)
REGISTER9(BinaryOp, GPU, "Div", functor::div, float, Eigen::half, double, uint8,
uint16, int16, int64, complex64, complex128);
#else
REGISTER4(BinaryOp, GPU, "Div", functor::div, uint8, uint16, complex64,
complex128);
#endif
REGISTER4(BinaryOp, GPU, "TruncateDiv", functor::div, uint8, uint16, int16,
int64);
REGISTER5(BinaryOp, GPU, "RealDiv", functor::div, float, Eigen::half, double,

View File

@ -30,8 +30,13 @@ REGISTER(BinaryOp, CPU, "Sub", functor::sub, int32);
#endif // __ANDROID_TYPES_SLIM__
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
!defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED)
REGISTER7(BinaryOp, GPU, "Sub", functor::sub, float, Eigen::half, double, int64,
complex64, complex128, uint32);
#else
REGISTER3(BinaryOp, GPU, "Sub", functor::sub, complex64, complex128, uint32);
#endif
// A special GPU kernel for int32.
// TODO(b/25387198): Also enable int32 in device memory. This kernel

View File

@ -85,7 +85,7 @@ void OptimizeDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
// clang-format off
absl::flat_hash_map<string, uint64> live_experiments = {
{"enable_gradient_descent", 0},
{"map_parallelization", 0}
{"map_parallelization", 1}
};
// clang-format on
auto hash_func = [](const string& str) { return Hash64(str); };

View File

@ -43,7 +43,6 @@ limitations under the License.
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/util/mkl_types.h"
#include "tensorflow/core/util/mkl_util.h"
#include "tensorflow/core/util/padding.h"
#include "tensorflow/core/util/tensor_format.h"
@ -65,7 +64,7 @@ struct MklConvFwdParams {
memory::dims dilations;
memory::dims padding_left;
memory::dims padding_right;
MKL_TENSOR_FORMAT tf_fmt;
MklTensorFormat tf_fmt;
bool native_format;
string dtypes = string("");
struct PostOpParam {
@ -80,7 +79,7 @@ struct MklConvFwdParams {
memory::dims bias_dims, memory::dims dst_dims,
memory::dims strides, memory::dims dilations,
memory::dims padding_left, memory::dims padding_right,
MKL_TENSOR_FORMAT tf_fmt, bool native_format)
MklTensorFormat tf_fmt, bool native_format)
: src_dims(src_dims),
filter_dims(filter_dims),
bias_dims(bias_dims),
@ -99,7 +98,7 @@ template <typename Tinput, typename Tfilter, typename Tbias, typename Toutput>
class MklConvFwdPrimitive : public MklPrimitive {
public:
explicit MklConvFwdPrimitive(const MklConvFwdParams& convFwdDims)
: MklPrimitive(engine(ENGINE_CPU, 0)) {
: MklPrimitive(engine(engine::kind::cpu, 0)) {
// Create convolution primitive
if (context_.conv_fwd == nullptr) {
Setup(convFwdDims);
@ -115,8 +114,8 @@ class MklConvFwdPrimitive : public MklPrimitive {
void Execute(const Tinput* src_data, const Tfilter* filter_data,
const Tbias* bias_data, const Toutput* dst_data,
std::shared_ptr<stream> fwd_stream) {
// TODO: Create a common function and avoid the duplicate code
#ifdef ENABLE_MKLDNN_THREADPOOL
// TODO: Create a common function and avoid the duplicate code
context_.src_mem->set_data_handle(
static_cast<void*>(const_cast<Tinput*>(src_data)), *fwd_stream);
context_.filter_mem->set_data_handle(
@ -139,16 +138,13 @@ class MklConvFwdPrimitive : public MklPrimitive {
context_.dst_mem->set_data_handle(
static_cast<void*>(const_cast<Toutput*>(dst_data)));
#endif // ENABLE_MKLDNN_THREADPOOL
#ifdef ENABLE_MKLDNN_V1
DCHECK_EQ(context_.fwd_primitives.size(),
context_.fwd_primitives_args.size());
for (size_t i = 0; i < context_.fwd_primitives.size(); ++i) {
context_.fwd_primitives.at(i).execute(*fwd_stream,
context_.fwd_primitives_args.at(i));
}
#else
fwd_stream->submit(context_.fwd_primitives);
#endif // ENABLE_MKLDNN_V1
// After execution, set data handle back
context_.src_mem->set_data_handle(DummyData);
@ -168,13 +164,6 @@ class MklConvFwdPrimitive : public MklPrimitive {
Execute(src_data, filter_data, nullptr, dst_data, fwd_stream);
}
#ifndef ENABLE_MKLDNN_V1
// In MKL-DNN v1.x, memory format tags only provide a partial description
// of the memory layout. Hence, these functions are disabled for v1.x.
memory::format GetSrcMemoryFormat() const { return context_.src_fmt; }
memory::format GetFilterMemoryFormat() const { return context_.filter_fmt; }
#endif // !ENABLE_MKLDNN_V1
std::shared_ptr<ConvFwdPd> GetPrimitiveDesc() const {
return context_.fwd_pd;
}
@ -182,12 +171,6 @@ class MklConvFwdPrimitive : public MklPrimitive {
private:
// Primitive reuse context for Conv2D Fwd op
struct ConvFwdContext {
#ifndef ENABLE_MKLDNN_V1
// Expected memory format for this primitive instance
memory::format src_fmt;
memory::format filter_fmt;
#endif // !ENABLE_MKLDNN_V1
// MKL-DNN memory
std::shared_ptr<mkldnn::memory> src_mem;
std::shared_ptr<mkldnn::memory> filter_mem;
@ -208,18 +191,10 @@ class MklConvFwdPrimitive : public MklPrimitive {
std::shared_ptr<mkldnn::primitive> conv_fwd;
std::vector<mkldnn::primitive> fwd_primitives;
#ifdef ENABLE_MKLDNN_V1
std::vector<std::unordered_map<int, memory>> fwd_primitives_args;
#endif // ENABLE_MKLDNN_V1
ConvFwdContext()
:
#ifndef ENABLE_MKLDNN_V1
src_fmt(memory::format::any),
filter_fmt(memory::format::any),
#endif // !ENABLE_MKLDNN_V1
src_mem(nullptr),
: src_mem(nullptr),
filter_mem(nullptr),
bias_mem(nullptr),
dst_mem(nullptr),
@ -228,52 +203,45 @@ class MklConvFwdPrimitive : public MklPrimitive {
filter_md(nullptr),
bias_md(nullptr),
fwd_pd(nullptr),
conv_fwd(nullptr) {
}
conv_fwd(nullptr) {}
};
void Setup(const MklConvFwdParams& convFwdDims) {
MEMORY_FORMAT user_data_fmt;
memory::format_tag user_data_fmt;
if (convFwdDims.native_format) {
user_data_fmt = MklTensorFormatToMklDnnDataFormat(convFwdDims.tf_fmt);
} else {
// Create memory descriptors for convolution data w/ no specified format
user_data_fmt = MEMORY_FORMAT::any;
user_data_fmt = memory::format_tag::any;
}
context_.src_md.reset(new memory::desc(
{convFwdDims.src_dims}, MklDnnType<Tinput>(), user_data_fmt));
context_.filter_md.reset(new memory::desc(
{convFwdDims.filter_dims}, MklDnnType<Tfilter>(), MEMORY_FORMAT::any));
context_.filter_md.reset(new memory::desc({convFwdDims.filter_dims},
MklDnnType<Tfilter>(),
memory::format_tag::any));
context_.dst_md.reset(new memory::desc(
{convFwdDims.dst_dims}, MklDnnType<Toutput>(), user_data_fmt));
if (!convFwdDims.bias_dims.empty())
context_.bias_md.reset(new memory::desc(
{convFwdDims.bias_dims}, MklDnnType<Tbias>(), MEMORY_FORMAT::any));
context_.bias_md.reset(new memory::desc({convFwdDims.bias_dims},
MklDnnType<Tbias>(),
memory::format_tag::any));
// Create a convolution descriptor
if (!convFwdDims.bias_dims.empty()) {
context_.fwd_desc.reset(new convolution_forward::desc(
prop_kind::forward, ALGORITHM::convolution_direct, *context_.src_md,
*context_.filter_md, *context_.bias_md, *context_.dst_md,
convFwdDims.strides, convFwdDims.dilations, convFwdDims.padding_left,
#ifndef ENABLE_MKLDNN_V1
convFwdDims.padding_right, padding_kind::zero));
#else
convFwdDims.padding_right));
#endif // !ENABLE_MKLDNN_V1
prop_kind::forward, mkldnn::algorithm::convolution_direct,
*context_.src_md, *context_.filter_md, *context_.bias_md,
*context_.dst_md, convFwdDims.strides, convFwdDims.dilations,
convFwdDims.padding_left, convFwdDims.padding_right));
} else {
context_.fwd_desc.reset(new convolution_forward::desc(
prop_kind::forward, ALGORITHM::convolution_direct, *context_.src_md,
*context_.filter_md, *context_.dst_md, convFwdDims.strides,
convFwdDims.dilations, convFwdDims.padding_left,
#ifndef ENABLE_MKLDNN_V1
convFwdDims.padding_right, padding_kind::zero));
#else
prop_kind::forward, mkldnn::algorithm::convolution_direct,
*context_.src_md, *context_.filter_md, *context_.dst_md,
convFwdDims.strides, convFwdDims.dilations, convFwdDims.padding_left,
convFwdDims.padding_right));
#endif // !ENABLE_MKLDNN_V1
}
context_.fwd_pd.reset(new ConvFwdPd(*context_.fwd_desc, cpu_engine_));
@ -314,54 +282,32 @@ class MklConvFwdPrimitive : public MklPrimitive {
context_.fwd_pd.reset(new ConvFwdPd(*context_.fwd_desc, cpu_engine_));
}
#ifndef ENABLE_MKLDNN_V1
// Store the expected memory format
context_.src_fmt = static_cast<mkldnn::memory::format>(
context_.fwd_pd.get()->src_primitive_desc().desc().data.format);
context_.filter_fmt = static_cast<mkldnn::memory::format>(
context_.fwd_pd.get()->weights_primitive_desc().desc().data.format);
#endif // !ENABLE_MKLDNN_V1
// Create memory primitive based on dummy data
context_.src_mem.reset(new MEMORY_CONSTRUCTOR(
context_.fwd_pd.get()->PRIMITIVE_DESC_SRC, cpu_engine_, DummyData));
context_.filter_mem.reset(new MEMORY_CONSTRUCTOR(
context_.fwd_pd.get()->PRIMITIVE_DESC_WEIGHTS, cpu_engine_, DummyData));
context_.dst_mem.reset(new MEMORY_CONSTRUCTOR(
context_.fwd_pd.get()->PRIMITIVE_DESC_DST, cpu_engine_, DummyData));
context_.src_mem.reset(
new memory(context_.fwd_pd.get()->src_desc(), cpu_engine_, DummyData));
context_.filter_mem.reset(new memory(context_.fwd_pd.get()->weights_desc(),
cpu_engine_, DummyData));
context_.dst_mem.reset(
new memory(context_.fwd_pd.get()->dst_desc(), cpu_engine_, DummyData));
// Create convolution primitive and add it to net
if (!convFwdDims.bias_dims.empty()) {
context_.bias_mem.reset(new MEMORY_CONSTRUCTOR_USING_MEM_PD(
convFwdDims.bias_dims, Tbias, MEMORY_FORMAT::x, cpu_engine_,
DummyData));
#ifdef ENABLE_MKLDNN_V1
context_.bias_mem.reset(new memory(
{{convFwdDims.bias_dims}, MklDnnType<Tbias>(), memory::format_tag::x},
cpu_engine_, DummyData));
context_.conv_fwd.reset(new convolution_forward(*context_.fwd_pd));
context_.fwd_primitives_args.push_back(
{{MKLDNN_ARG_SRC, *context_.src_mem},
{MKLDNN_ARG_WEIGHTS, *context_.filter_mem},
{MKLDNN_ARG_BIAS, *context_.bias_mem},
{ MKLDNN_ARG_DST,
*context_.dst_mem }});
{MKLDNN_ARG_DST, *context_.dst_mem}});
} else {
context_.conv_fwd.reset(new convolution_forward(*context_.fwd_pd));
context_.fwd_primitives_args.push_back(
{{MKLDNN_ARG_SRC, *context_.src_mem},
{MKLDNN_ARG_WEIGHTS, *context_.filter_mem},
{ MKLDNN_ARG_DST,
*context_.dst_mem }});
{MKLDNN_ARG_DST, *context_.dst_mem}});
}
#else
context_.conv_fwd.reset(new convolution_forward(
*context_.fwd_pd, *context_.src_mem, *context_.filter_mem,
*context_.bias_mem, *context_.dst_mem));
} else {
context_.conv_fwd.reset(
new convolution_forward(*context_.fwd_pd, *context_.src_mem,
*context_.filter_mem, *context_.dst_mem));
}
#endif // ENABLE_MKLDNN_V1
context_.fwd_primitives.push_back(*context_.conv_fwd);
}
@ -650,12 +596,10 @@ class MklConvOp : public OpKernel {
auto tf_fmt = is_conv2d ? TFDataFormatToMklDnnDataFormat(data_format_)
: TFDataFormatToMklDnn3DDataFormat(data_format_);
#ifdef ENABLE_MKLDNN_V1
auto mkl_fmt_tag = MklTensorFormatToMklDnnDataFormat(tf_fmt);
// NOTE: `mkl_fmt_tag` will be `format_tag::undef` for ReLU
OP_REQUIRES(context, mkl_fmt_tag != memory::format_tag::undef,
errors::InvalidArgument("Invalid data format"));
#endif // ENABLE_MKLDNN_V1
// If input is in MKL layout, then simply grab the layout; otherwise,
// construct TF layout for input.
@ -667,19 +611,15 @@ class MklConvOp : public OpKernel {
auto src_md =
src_mkl_shape.IsMklTensor()
? src_mkl_shape.GetMklLayout()
#ifdef ENABLE_MKLDNN_V1
: memory::desc(src_dims, MklDnnType<Tinput>(), mkl_fmt_tag);
#else
: memory::desc(src_dims, MklDnnType<Tinput>(), tf_fmt);
#endif // ENABLE_MKLDNN_V1
src.SetUsrMem(src_md, &src_tensor);
// Although filter shape (filter_dims) required is in MKL-DNN order,
// the layout is Tensorflow's layout (HWIO) and (HWIGO) for
// depthwise/group convolutions.
auto filter_format = is_conv2d ? (is_depthwise ? MEMORY_FORMAT::hwigo
: MEMORY_FORMAT::hwio)
: MEMORY_FORMAT::dhwio;
auto filter_format = is_conv2d ? (is_depthwise ? memory::format_tag::hwigo
: memory::format_tag::hwio)
: memory::format_tag::dhwio;
DCHECK(!filter_mkl_shape.IsMklTensor());
auto filter_md =
@ -738,12 +678,9 @@ class MklConvOp : public OpKernel {
// Check whether src and filter need to be reordered.
Tinput* src_data = nullptr;
if (IS_SRC_REORDER_NEEDED(src_md, conv_fwd_pd, conv_fwd)) {
if (src_md != conv_fwd_pd->src_desc()) {
src.SetUsrMem(src_md, &src_tensor);
src.CheckReorderToOpMem(
MEMORY_PD_WITHOUT_DATA(GET_SRC_DESC_FROM_OP_PD(conv_fwd_pd),
cpu_engine_),
context);
src.CheckReorderToOpMem(conv_fwd_pd->src_desc(), cpu_engine_, context);
src_data = static_cast<Tinput*>(src.GetOpMem().get_data_handle());
} else {
src_data = static_cast<Tinput*>(
@ -751,7 +688,7 @@ class MklConvOp : public OpKernel {
}
Tfilter* filter_data = nullptr;
if (IS_FILTER_REORDER_NEEDED(filter_md, conv_fwd_pd, conv_fwd)) {
if (filter_md != conv_fwd_pd->weights_desc()) {
bool is_filter_cached = false;
// If filter is a constant, we can avoid the conversion of filter from
// Tensorflow format to MKL format by caching the filter when it is
@ -761,28 +698,20 @@ class MklConvOp : public OpKernel {
if (IsFilterCacheEmpty(context)) {
// Cache filter if it is not already cached.
CacheFilter(context, conv_fwd_pd, filter_data, filter_tensor,
#ifdef ENABLE_MKLDNN_V1
filter, filter_md, filter_mkl_shape);
#else
filter, filter_md);
#endif // ENABLE_MKLDNN_V1
}
filter_data = GetCachedFilter(
context, GET_WEIGHTS_FORMAT_FROM_OP_PD(conv_fwd_pd, conv_fwd));
filter_data = GetCachedFilter(context, conv_fwd_pd->weights_desc());
is_filter_cached = (filter_data != nullptr);
}
if (!is_filter_cached) {
filter.SetUsrMem(filter_md, &filter_tensor);
if (filter_out_tensor == nullptr) {
filter.CheckReorderToOpMem(
MEMORY_PD_WITHOUT_DATA(GET_WEIGHTS_DESC_FROM_OP_PD(conv_fwd_pd),
cpu_engine_),
filter.CheckReorderToOpMem(conv_fwd_pd->weights_desc(), cpu_engine_,
context);
} else {
filter.CheckReorderToOpMem(
GET_WEIGHTS_DESC_FROM_OP_PD(conv_fwd_pd),
DATA_WITH_ENGINE(filter.GetTensorBuffer(filter_out_tensor),
cpu_engine_),
conv_fwd_pd->weights_desc(),
filter.GetTensorBuffer(filter_out_tensor), cpu_engine_,
context);
}
filter_data =
@ -897,7 +826,8 @@ class MklConvOp : public OpKernel {
// NOTE: Fusion of BiasAdd is handled directly inside MklConvOp by
// checking `fuse_biasadd_` flag.
if (fuse_add_) {
params.post_op_params.push_back({"sum", ALGORITHM_UNDEF, {1.0}, ""});
params.post_op_params.push_back(
{"sum", mkldnn::algorithm::undef, {1.0}, ""});
}
if (fuse_activation_) {
params.post_op_params.push_back(
@ -918,35 +848,27 @@ class MklConvOp : public OpKernel {
virtual void AllocateOutputTensor(OpKernelContext* context,
const ConvFwdPd& conv_prim_desc,
const memory::dims& output_dims_mkl_order,
MKL_TENSOR_FORMAT output_tf_format,
MklTensorFormat output_tf_format,
MklDnnShape* output_mkl_shape,
Tensor** output_tensor) {
DCHECK(output_tensor);
#ifdef ENABLE_MKLDNN_V1
auto dst_md = conv_prim_desc.dst_desc();
#else
auto dst_pd = conv_prim_desc.dst_primitive_desc();
auto dst_md = dst_pd.desc();
#endif // ENABLE_MKLDNN_V1
if (!std::is_same<Ttemp_output, Toutput>::value) {
dst_md.data.data_type =
static_cast<mkldnn_data_type_t>(MklDnnType<Toutput>());
#ifndef ENABLE_MKLDNN_V1
dst_pd = memory::primitive_desc(dst_md, cpu_engine_);
#endif // !ENABLE_MKLDNN_V1
}
// Allocate shape of MKL tensor
output_mkl_shape->SetMklTensor(true);
output_mkl_shape->SetMklLayout(&DST_MD);
output_mkl_shape->SetMklLayout(&dst_md);
output_mkl_shape->SetElemType(MklDnnType<Toutput>());
output_mkl_shape->SetTfLayout(output_dims_mkl_order.size(),
output_dims_mkl_order, output_tf_format);
// Allocate shape of TF tensor
TensorShape output_tf_shape;
output_tf_shape.AddDim((DST_MD.get_size() / sizeof(Toutput)));
output_tf_shape.AddDim((dst_md.get_size() / sizeof(Toutput)));
if (native_format) {
output_tf_shape = output_mkl_shape->GetTfShape();
}
@ -972,23 +894,16 @@ class MklConvOp : public OpKernel {
AllocateOutputSetMklShape(context, kOutputIndex_Dst, output_tensor,
output_tf_shape, *output_mkl_shape,
native_format);
#ifdef ENABLE_MKLDNN_V1
auto output_format_tag = MklTensorFormatToMklDnnDataFormat(
output_mkl_shape->GetTfDataFormat());
OP_REQUIRES(context, output_format_tag != memory::format_tag::undef,
errors::InvalidArgument(
"MklConvOp: AddN fusion: Invalid data format"));
#endif // ENABLE_MKLDNN_V1
auto add_md =
add_mkl_shape.IsMklTensor()
? add_mkl_shape.GetMklLayout()
: memory::desc(output_dims_mkl_order, MklDnnType<Toutput>(),
#ifdef ENABLE_MKLDNN_V1
output_format_tag);
#else
output_mkl_shape->GetTfDataFormat());
auto add_pd = memory::primitive_desc(add_md, this->cpu_engine_);
#endif // ENABLE_MKLDNN_V1
void* add_buf = static_cast<void*>(
const_cast<Toutput*>(add_tensor.flat<Toutput>().data()));
void* dst_buf =
@ -996,16 +911,14 @@ class MklConvOp : public OpKernel {
if (native_format) {
// We are simply deep copying the add_tensor to output_tensor without
// changing memory layout, hence using same memory descriptor.
ADD_MD = DST_MD =
add_md = dst_md =
memory::desc({add_tensor.NumElements()}, MklDnnType<Toutput>(),
mkldnn::memory::format_tag::x);
}
fuse_add_src_.reset(
new MEMORY_CONSTRUCTOR(ADD_MD, this->cpu_engine_, add_buf));
fuse_add_dst_.reset(
new MEMORY_CONSTRUCTOR(DST_MD, this->cpu_engine_, dst_buf));
fuse_add_src_.reset(new memory(add_md, this->cpu_engine_, add_buf));
fuse_add_dst_.reset(new memory(dst_md, this->cpu_engine_, dst_buf));
auto reorder_desc =
REORDER_PD_CONSTRUCTOR(ADD_MD, DST_MD, this->cpu_engine_);
ReorderPd(this->cpu_engine_, add_md, this->cpu_engine_, dst_md);
CreateAndExecuteReorder(reorder_desc, *fuse_add_src_, *fuse_add_dst_,
this->cpu_engine_, context);
@ -1017,7 +930,7 @@ class MklConvOp : public OpKernel {
}
}
engine cpu_engine_ = engine(ENGINE_CPU, 0);
engine cpu_engine_ = engine(engine::kind::cpu, 0);
private:
std::shared_ptr<mkldnn::memory> fuse_add_src_;
@ -1041,7 +954,7 @@ class MklConvOp : public OpKernel {
// This variable is used for alpha in leakyrelu or upper bound in relu6
// depending on the context
float alpha_or_upbound_ = 0.0;
mkldnn::algorithm activation_alg_ = ALGORITHM_UNDEF;
mkldnn::algorithm activation_alg_ = mkldnn::algorithm::undef;
int input_index_pad_ = 2;
@ -1050,15 +963,10 @@ class MklConvOp : public OpKernel {
const int kOutputIndex_Dst = 0, kOutputIndex_Filter = 1;
const int kDilationH = 0, kDilationW = 1;
MKL_TENSOR_FORMAT_IN_C GetFilterTfDataFormat(
const MklDnnShape* filter_mkl_shape,
MklTensorFormat GetFilterTfDataFormat(const MklDnnShape* filter_mkl_shape,
const ConvFwdPd& conv_prim_desc) const {
#ifdef ENABLE_MKLDNN_V1
DCHECK(filter_mkl_shape);
return filter_mkl_shape->GetTfDataFormat();
#else
return conv_prim_desc.weights_primitive_desc().desc().data.format;
#endif // ENABLE_MKLDNN_V1
}
// Allocate persistent tensors for cached filter data and
@ -1070,23 +978,13 @@ class MklConvOp : public OpKernel {
DCHECK(filter_tensor);
TensorShape filter_tf_shape;
filter_tf_shape.AddDim(
(conv_prim_desc.PRIMITIVE_DESC_WEIGHTS.get_size() / sizeof(Tfilter)));
(conv_prim_desc.weights_desc().get_size() / sizeof(Tfilter)));
OP_REQUIRES_OK(context, context->allocate_persistent(
DataTypeToEnum<Tfilter>::value, filter_tf_shape,
&cached_filter_data_ptensor_, filter_tensor));
Tensor* second_tensor = nullptr;
#ifndef ENABLE_MKLDNN_V1
TensorShape filter_mkl_format;
filter_mkl_format.AddDim(
sizeof(GetFilterTfDataFormat(filter_mkl_shape, conv_prim_desc)) /
sizeof(DT_INT32));
OP_REQUIRES_OK(context, context->allocate_persistent(
DT_INT32, filter_mkl_format,
&cached_filter_md_ptensor_, &second_tensor));
second_tensor->scalar<int32>()() = static_cast<int32>(
GetFilterTfDataFormat(filter_mkl_shape, conv_prim_desc));
#else
// There is no tensor format in DNNL 1.x. So we cache the complete filter
// descriptor as flat byte array.
TensorShape cached_filter_md_shape;
@ -1100,7 +998,6 @@ class MklConvOp : public OpKernel {
&cached_filter_md_ptensor_, &second_tensor));
*reinterpret_cast<memory::desc*>(second_tensor->flat<uint8>().data()) =
weights_desc;
#endif // !ENABLE_MKLDNN_V1
}
void AllocatePersistentTensor(OpKernelContext* context,
@ -1114,7 +1011,7 @@ class MklConvOp : public OpKernel {
const memory::dims& filter_dims_tf_order,
Tensor** filter_tensor) {
DCHECK(filter_tensor);
auto filter_md = conv_prim_desc.PRIMITIVE_DESC_WEIGHTS;
auto filter_md = conv_prim_desc.weights_desc();
// Allocate shape of MKL tensor
MklDnnShape filter_mkl_shape;
@ -1127,7 +1024,7 @@ class MklConvOp : public OpKernel {
// is stored in the MKL data.
filter_mkl_shape.SetTfLayout(filter_dims_tf_order.size(),
filter_dims_tf_order,
MKL_TENSOR_FORMAT_BLOCKED);
MklTensorFormat::FORMAT_BLOCKED);
// Allocate the data space for the filter to propagate as TF tensor.
TensorShape filter_tf_shape;
@ -1150,17 +1047,15 @@ class MklConvOp : public OpKernel {
// Create reorders between user layout and MKL layout if it is needed and
// add it to the net before convolution. No need to check for output
// reorder as we propagate output layout to the next layer.
src->CheckReorderToOpMem(
MEMORY_PD_WITHOUT_DATA(conv_prim_desc.PRIMITIVE_DESC_SRC, cpu_engine_));
src->CheckReorderToOpMem(conv_prim_desc.src_desc(), cpu_engine_);
// Rather than re-ordering to a temp buffer, reorder directly to the
// filter output tensor
filter->CheckReorderToOpMem(conv_prim_desc.PRIMITIVE_DESC_WEIGHTS,
filter->CheckReorderToOpMem(conv_prim_desc.weights_desc(),
filter->GetTensorBuffer(filter_out_tensor));
// Create convolution primitive and add it to net.
std::vector<primitive> net;
#ifdef ENABLE_MKLDNN_V1
std::vector<std::unordered_map<int, memory>> net_args;
if (bias) {
DCHECK(fuse_biasadd_);
@ -1168,31 +1063,15 @@ class MklConvOp : public OpKernel {
net_args.push_back({{MKLDNN_ARG_SRC, src->GetOpMem()},
{MKLDNN_ARG_WEIGHTS, filter->GetOpMem()},
{MKLDNN_ARG_BIAS, bias->GetOpMem()},
{ MKLDNN_ARG_DST,
output->GetOpMem() }});
{MKLDNN_ARG_DST, output->GetOpMem()}});
} else {
DCHECK(!fuse_biasadd_);
net.push_back(convolution_forward(conv_prim_desc));
net_args.push_back({{MKLDNN_ARG_SRC, src->GetOpMem()},
{MKLDNN_ARG_WEIGHTS, filter->GetOpMem()},
{ MKLDNN_ARG_DST,
output->GetOpMem() }});
{MKLDNN_ARG_DST, output->GetOpMem()}});
}
ExecutePrimitive(net, &net_args, cpu_engine_);
#else
if (bias) {
DCHECK(fuse_biasadd_);
net.push_back(convolution_forward(conv_prim_desc, src->GetOpMem(),
filter->GetOpMem(), bias->GetOpMem(),
output->GetOpMem()));
} else {
DCHECK(!fuse_biasadd_);
net.push_back(convolution_forward(conv_prim_desc, src->GetOpMem(),
filter->GetOpMem(),
output->GetOpMem()));
}
ExecutePrimitive(net, nullptr, cpu_engine_);
#endif // ENABLE_MKLDNN_V1
}
// TF_LOCKS_EXCLUDED annotation ensures that the lock (mu_) cannot
@ -1208,7 +1087,6 @@ class MklConvOp : public OpKernel {
// Cache the converted filter in a persistent tensor.
// Only one thread can execute this method at any given time.
#ifdef ENABLE_MKLDNN_V1
void CacheFilter(OpKernelContext* context,
const std::shared_ptr<ConvFwdPd>& conv_fwd_pd,
Tfilter* filter_data, const Tensor& filter_tensor,
@ -1254,37 +1132,8 @@ class MklConvOp : public OpKernel {
return true;
}
#else
void CacheFilter(OpKernelContext* context,
const std::shared_ptr<ConvFwdPd>& conv_fwd_pd,
Tfilter* filter_data, const Tensor& filter_tensor,
MklDnnData<Tfilter>& filter, const memory::desc& filter_md)
TF_LOCKS_EXCLUDED(mu_) {
mutex_lock lock(mu_);
const Tensor& cached_filter_data_tensor =
*cached_filter_data_ptensor_.AccessTensor(context);
// If filter is already cached, there's nothing to do.
if (cached_filter_data_tensor.NumElements() > 0) {
return;
}
// Otherwise, cache filter
filter.SetUsrMem(filter_md, &filter_tensor);
filter.CheckReorderToOpMem(conv_fwd_pd.get()->weights_primitive_desc());
filter_data = static_cast<Tfilter*>(filter.GetOpMem().get_data_handle());
Tensor* filter_tensor_ptr = nullptr;
AllocatePersistentTensor(context, *conv_fwd_pd, &filter_tensor_ptr);
void* cached_filter_data = filter.GetTensorBuffer(filter_tensor_ptr);
size_t cached_filter_data_size =
filter.GetOpMem().get_primitive_desc().get_size();
memcpy(cached_filter_data, filter_data, cached_filter_data_size);
}
#endif // ENABLE_MKLDNN_V1
Tfilter* GetCachedFilter(OpKernelContext* context,
const MEMORY_DESC& filter_md)
const memory::desc& filter_md)
TF_LOCKS_EXCLUDED(mu_) {
tf_shared_lock lock(mu_);
const Tensor& cached_filter_data =
@ -1292,15 +1141,10 @@ class MklConvOp : public OpKernel {
const Tensor& cached_filter_md =
*cached_filter_md_ptensor_.AccessTensor(context);
// Check if the memory descriptor of the cached weights is same as
// Check if the memory descriptor of the cached weights is the same as
// filter_md. If so, we can use the cached weights; otherwise
// return nullptr.
#ifdef ENABLE_MKLDNN_V1
if (filter_md == *static_cast<memory::desc*>(cached_filter_md.data())) {
#else
if (cached_filter_md.scalar<int32>().size() &&
cached_filter_md.scalar<int32>()() == filter_md) {
#endif // ENABLE_MKLDNN_V1
return static_cast<Tfilter*>(
const_cast<Tfilter*>(cached_filter_data.flat<Tfilter>().data()));
}
@ -1336,31 +1180,34 @@ class MklFusedConvOp
errors::InvalidArgument(
"Fused Conv2D must have one extra argument: bias."));
} else if (fused_ops == std::vector<string>{"Relu"}) {
this->set_fuse_activation(true, ALGORITHM::eltwise_relu);
this->set_fuse_activation(true, mkldnn::algorithm::eltwise_relu);
} else if (fused_ops == std::vector<string>{"Relu6"}) {
this->set_fuse_activation(true, ALGORITHM::eltwise_bounded_relu, 6.0);
this->set_fuse_activation(true, mkldnn::algorithm::eltwise_bounded_relu,
6.0);
} else if (fused_ops == std::vector<string>{"Elu"}) {
this->set_fuse_activation(true, ALGORITHM::eltwise_elu, 1.0);
this->set_fuse_activation(true, mkldnn::algorithm::eltwise_elu, 1.0);
} else if (fused_ops == std::vector<string>{"LeakyRelu"}) {
float leakyrelu_alpha;
OP_REQUIRES_OK(context,
context->GetAttr("leakyrelu_alpha", &leakyrelu_alpha));
this->set_fuse_activation(true, ALGORITHM::eltwise_relu, leakyrelu_alpha);
this->set_fuse_activation(true, mkldnn::algorithm::eltwise_relu,
leakyrelu_alpha);
} else if (fused_ops == std::vector<string>{"BiasAdd", "Relu"}) {
this->set_fuse_biasadd(true);
this->set_fuse_activation(true, ALGORITHM::eltwise_relu);
this->set_fuse_activation(true, mkldnn::algorithm::eltwise_relu);
OP_REQUIRES(context, num_args == 1,
errors::InvalidArgument(
"Fused Conv2D must have one extra argument: bias."));
} else if (fused_ops == std::vector<string>{"BiasAdd", "Relu6"}) {
this->set_fuse_biasadd(true);
this->set_fuse_activation(true, ALGORITHM::eltwise_bounded_relu, 6.0);
this->set_fuse_activation(true, mkldnn::algorithm::eltwise_bounded_relu,
6.0);
OP_REQUIRES(context, num_args == 1,
errors::InvalidArgument(
"Fused Conv2D must have one extra argument: bias."));
} else if (fused_ops == std::vector<string>{"BiasAdd", "Elu"}) {
this->set_fuse_biasadd(true);
this->set_fuse_activation(true, ALGORITHM::eltwise_elu, 1.0);
this->set_fuse_activation(true, mkldnn::algorithm::eltwise_elu, 1.0);
OP_REQUIRES(context, num_args == 1,
errors::InvalidArgument(
"Fused Conv2D must have one extra argument: bias."));
@ -1369,7 +1216,8 @@ class MklFusedConvOp
float leakyrelu_alpha;
OP_REQUIRES_OK(context,
context->GetAttr("leakyrelu_alpha", &leakyrelu_alpha));
this->set_fuse_activation(true, ALGORITHM::eltwise_relu, leakyrelu_alpha);
this->set_fuse_activation(true, mkldnn::algorithm::eltwise_relu,
leakyrelu_alpha);
OP_REQUIRES(context, num_args == 1,
errors::InvalidArgument(
"Fused Conv2D must have one extra argument: bias."));
@ -1383,7 +1231,7 @@ class MklFusedConvOp
} else if (fused_ops == std::vector<string>{"BiasAdd", "Add", "Relu"}) {
this->set_fuse_biasadd(true);
this->set_fuse_add(true);
this->set_fuse_activation(true, ALGORITHM::eltwise_relu);
this->set_fuse_activation(true, mkldnn::algorithm::eltwise_relu);
OP_REQUIRES(
context, num_args == 2,
errors::InvalidArgument(
@ -1391,7 +1239,8 @@ class MklFusedConvOp
} else if (fused_ops == std::vector<string>{"BiasAdd", "Add", "Relu6"}) {
this->set_fuse_biasadd(true);
this->set_fuse_add(true);
this->set_fuse_activation(true, ALGORITHM::eltwise_bounded_relu, 6.0);
this->set_fuse_activation(true, mkldnn::algorithm::eltwise_bounded_relu,
6.0);
OP_REQUIRES(
context, num_args == 2,
errors::InvalidArgument(
@ -1399,7 +1248,7 @@ class MklFusedConvOp
} else if (fused_ops == std::vector<string>{"BiasAdd", "Add", "Elu"}) {
this->set_fuse_biasadd(true);
this->set_fuse_add(true);
this->set_fuse_activation(true, ALGORITHM::eltwise_elu, 1.0);
this->set_fuse_activation(true, mkldnn::algorithm::eltwise_elu, 1.0);
OP_REQUIRES(
context, num_args == 2,
errors::InvalidArgument(
@ -1411,7 +1260,8 @@ class MklFusedConvOp
float leakyrelu_alpha;
OP_REQUIRES_OK(context,
context->GetAttr("leakyrelu_alpha", &leakyrelu_alpha));
this->set_fuse_activation(true, ALGORITHM::eltwise_relu, leakyrelu_alpha);
this->set_fuse_activation(true, mkldnn::algorithm::eltwise_relu,
leakyrelu_alpha);
OP_REQUIRES(
context, num_args == 2,
errors::InvalidArgument(
@ -1459,13 +1309,14 @@ class MklFusedDepthwiseConvOp
this->set_fuse_biasadd(true);
} else if (fused_ops == std::vector<string>{"BiasAdd", "Relu"}) {
this->set_fuse_biasadd(true);
this->set_fuse_activation(true, ALGORITHM::eltwise_relu);
this->set_fuse_activation(true, mkldnn::algorithm::eltwise_relu);
} else if (fused_ops == std::vector<string>{"BiasAdd", "Relu6"}) {
this->set_fuse_biasadd(true);
this->set_fuse_activation(true, ALGORITHM::eltwise_bounded_relu, 6.0);
this->set_fuse_activation(true, mkldnn::algorithm::eltwise_bounded_relu,
6.0);
} else if (fused_ops == std::vector<string>{"BiasAdd", "Elu"}) {
this->set_fuse_biasadd(true);
this->set_fuse_activation(true, ALGORITHM::eltwise_elu, 1.0);
this->set_fuse_activation(true, mkldnn::algorithm::eltwise_elu, 1.0);
} else {
OP_REQUIRES(context, false,
errors::Unimplemented("Fusion is not implemented: [",
@ -1642,8 +1493,8 @@ class MklQuantizedConv2DOp
param_key.AddAsKey<float>(max_freezed_output);
param_key.AddAsKey<const float*>(min_filter);
param_key.AddAsKey<const float*>(max_filter);
params.post_op_params.push_back(
{"output_scale", ALGORITHM_UNDEF, scales, param_key.GetKey()});
params.post_op_params.push_back({"output_scale", mkldnn::algorithm::undef,
scales, param_key.GetKey()});
}
}
@ -1696,31 +1547,27 @@ class MklQuantizedConv2DOp
bias_attr.set_output_scales(1, scales_);
}
auto bias_md =
MEMORY_PD_CONSTRUCTOR(static_cast<int>(bias_tensor.NumElements()),
Tbias, MEMORY_FORMAT::x, this->cpu_engine_);
auto bias_md = memory::desc({static_cast<int>(bias_tensor.NumElements())},
MklDnnType<Tbias>(), memory::format_tag::x);
void* bias_buf = static_cast<void*>(
const_cast<Tbias*>(bias_tensor.flat<Tbias>().data()));
if (!input_bias_) {
input_bias_ =
new MEMORY_CONSTRUCTOR(bias_md, this->cpu_engine_, bias_buf);
input_bias_ = new memory(bias_md, this->cpu_engine_, bias_buf);
} else {
input_bias_->set_data_handle(bias_buf);
}
if (!scaled_bias_buf_)
AllocTmpBuffer<Tbias>(context, &scaled_bias_tensor_,
GET_BIAS_DESC_FROM_OP_PD(conv_fwd_pd),
&scaled_bias_buf_);
conv_fwd_pd->bias_desc(), &scaled_bias_buf_);
if (!scaled_bias_) {
scaled_bias_ = new MEMORY_CONSTRUCTOR(bias_md, this->cpu_engine_,
scaled_bias_buf_);
scaled_bias_ = new memory(bias_md, this->cpu_engine_, scaled_bias_buf_);
} else {
scaled_bias_->set_data_handle(scaled_bias_buf_);
}
auto reorder_desc = REORDER_PD_CONSTRUCTOR_WITH_ATTR(
input_bias_->GET_DESC, scaled_bias_->GET_DESC, this->cpu_engine_,
bias_attr);
auto reorder_desc =
ReorderPd(this->cpu_engine_, input_bias_->get_desc(),
this->cpu_engine_, scaled_bias_->get_desc(), bias_attr);
CreateAndExecuteReorder(reorder_desc, *input_bias_, *scaled_bias_,
this->cpu_engine_, context);
@ -1754,7 +1601,7 @@ class MklQuantizedConv2DOp
DCHECK(bias_tensor);
TensorShape bias_tf_shape;
bias_tf_shape.AddDim(
(conv_prim_desc.PRIMITIVE_DESC_BIAS.get_size() / sizeof(Tbias)));
(conv_prim_desc.bias_desc().get_size() / sizeof(Tbias)));
OP_REQUIRES_OK(context, context->allocate_persistent(
DataTypeToEnum<Tbias>::value, bias_tf_shape,
&cached_bias_data_ptensor_, bias_tensor));
@ -1787,7 +1634,7 @@ class MklQuantizedConv2DOp
AllocatePersistentTensor(context, *conv_fwd_pd, &bias_tensor_ptr);
void* cached_bias_data = const_cast<void*>(
static_cast<const void*>(bias_tensor_ptr->flat<Tbias>().data()));
size_t cached_bias_data_size = scaled_bias->GET_DESC.get_size();
size_t cached_bias_data_size = scaled_bias->get_desc().get_size();
memcpy(cached_bias_data, bias_data, cached_bias_data_size);
}
@ -1822,7 +1669,7 @@ class MklQuantizedConv2DReluOp
is_depthwise>::ExtendConvFwdParams(context, params);
params.post_op_params.push_back(
{"activation", ALGORITHM::eltwise_relu, {1.0, 0.0, 0.0}, ""});
{"activation", mkldnn::algorithm::eltwise_relu, {1.0, 0.0, 0.0}, ""});
}
};
@ -1868,26 +1715,30 @@ class MklQuantizedConv2DSumReluOp
// if summand_type is also DT_QUINT8 as the scale_output,
// the scaling factor of 255.0f cancels each other and thus is avoided.
// If it is not then it is DT_INT8 and is scaled appropriately.
if (summand_type == DT_QUINT8)
params.post_op_params.push_back(
{"sum", ALGORITHM_UNDEF, {scale_summand / scale_output}, ""});
else
params.post_op_params.push_back(
{"sum",
ALGORITHM_UNDEF,
{255.0f * scale_summand / (scale_output * 127.0f)},
if (summand_type == DT_QUINT8) {
params.post_op_params.push_back({"sum",
mkldnn::algorithm::undef,
{scale_summand / scale_output},
""});
} else {
params.post_op_params.push_back({"sum", ALGORITHM_UNDEF, {1.0}, ""});
params.post_op_params.push_back(
{"sum",
mkldnn::algorithm::undef,
{255.0f * scale_summand / (scale_output * 127.0f)},
""});
}
} else {
params.post_op_params.push_back(
{"sum", mkldnn::algorithm::undef, {1.0}, ""});
}
params.post_op_params.push_back(
{"activation", ALGORITHM::eltwise_relu, {1.0, 0.0, 0.0}, ""});
{"activation", mkldnn::algorithm::eltwise_relu, {1.0, 0.0, 0.0}, ""});
}
void AllocateOutputTensor(OpKernelContext* context,
const ConvFwdPd& conv_prim_desc,
const memory::dims& output_dims_mkl_order,
MKL_TENSOR_FORMAT output_tf_format,
MklTensorFormat output_tf_format,
MklDnnShape* output_mkl_shape,
Tensor** output_tensor) override {
int summand_idx = context->num_inputs() / 2 - 1;
@ -1966,21 +1817,17 @@ class MklQuantizedConv2DSumReluOp
summand_mkl_shape.IsMklTensor()
? summand_mkl_shape.GetMklLayout()
: memory::desc(output_dims_mkl_order, MklDnnType<Tbias>(),
MEMORY_FORMAT::nhwc);
#ifndef ENABLE_MKLDNN_V1
auto summand_pd = memory::primitive_desc(summand_md, this->cpu_engine_);
#endif // !ENABLE_MKLDNN_V1
memory::format_tag::nhwc);
void* summand_buf =
static_cast<void*>(const_cast<Tbias*>(summand.flat<Tbias>().data()));
void* dst_buf =
static_cast<void*>((*output_tensor)->flat<Ttemp_output>().data());
summand_.reset(
new MEMORY_CONSTRUCTOR(SUMMAND_MD, this->cpu_engine_, summand_buf));
dst_.reset(new MEMORY_CONSTRUCTOR(conv_prim_desc.PRIMITIVE_DESC_DST,
this->cpu_engine_, dst_buf));
auto reorder_desc = REORDER_PD_CONSTRUCTOR_WITH_ATTR(
SUMMAND_MD, conv_prim_desc.PRIMITIVE_DESC_DST, this->cpu_engine_,
reorder_attr);
summand_.reset(new memory(summand_md, this->cpu_engine_, summand_buf));
dst_.reset(
new memory(conv_prim_desc.dst_desc(), this->cpu_engine_, dst_buf));
auto reorder_desc =
ReorderPd(this->cpu_engine_, summand_md, this->cpu_engine_,
conv_prim_desc.dst_desc(), reorder_attr);
CreateAndExecuteReorder(reorder_desc, *summand_, *dst_, this->cpu_engine_,
context);
}

View File

@ -42,20 +42,13 @@ limitations under the License.
#include "tensorflow/core/util/padding.h"
#include "tensorflow/core/util/tensor_format.h"
#ifndef ENABLE_MKLDNN_V1
using mkldnn::convolution_direct;
#endif // !ENABLE_MKLDNN_V1
using mkldnn::convolution_forward;
using mkldnn::prop_kind;
using mkldnn::stream;
namespace tensorflow {
#ifdef ENABLE_MKLDNN_V1
#define MKLDNN_SIZE_DTYPE memory::dim
#else
#define MKLDNN_SIZE_DTYPE int
#endif // ENABLE_MKLDNN_V1
using ConvFwdDesc = mkldnn::convolution_forward::desc;
using ConvFwdPd = mkldnn::convolution_forward::primitive_desc;

View File

@ -134,6 +134,7 @@ tf_kernel_library(
"gpu_op_bitwise_and.cc",
"gpu_op_bitwise_or.cc",
"gpu_op_bitwise_xor.cc",
"gpu_op_div.cc",
"gpu_op_equal.cc",
"gpu_op_floor_div.cc",
"gpu_op_greater.cc",
@ -146,6 +147,7 @@ tf_kernel_library(
"gpu_op_mul.cc",
"gpu_op_not_equal.cc",
"gpu_op_right_shift.cc",
"gpu_op_sub.cc",
],
tags = [
"manual",
@ -155,6 +157,7 @@ tf_kernel_library(
":bitwise_and_kernels",
":bitwise_or_kernels",
":bitwise_xor_kernels",
":div_kernels",
":equal_kernels",
":floor_div_kernels",
":gpu_ops_base",
@ -170,6 +173,7 @@ tf_kernel_library(
":mul_kernels",
":not_equal_kernels",
":right_shift_kernels",
":sub_kernels",
"//third_party/eigen3",
],
)
@ -366,8 +370,9 @@ gen_kernel_library(
unroll_factors = "4",
)
[
gen_kernel_library(
name = "add_v2",
name = name,
tile_size = "256,1,1",
types = [
"f16",
@ -378,6 +383,11 @@ gen_kernel_library(
# TODO(b/174543802): Enable once fusion heuristics is better.
# unroll_factors = "4",
)
for name in [
"add_v2",
"sub",
]
]
gen_kernel_library(
name = "complex",
@ -390,6 +400,20 @@ gen_kernel_library(
# unroll_factors = "2",
)
gen_kernel_library(
name = "div",
tile_size = "256,1,1",
types = [
"f16",
"f32",
"f64",
"i16",
"i64",
],
# TODO(b/174543802): Enable once fusion heuristics is better.
# unroll_factors = "4",
)
gen_kernel_library(
name = "mul",
tile_size = "256,1,1",

View File

@ -48,14 +48,17 @@ class GpuBinaryOpTest : public OpsTestBase {
void SetOpKernel(const std::string& op_name, const TensorShape& lhs_shape,
const absl::InlinedVector<T, 10>& lhs_input,
const TensorShape& rhs_shape,
const absl::InlinedVector<T, 10>& rhs_input,
bool use_constraint) {
const absl::InlinedVector<T, 10>& rhs_input, bool add_t,
bool add_tout) {
auto builder = NodeDefBuilder("some_name", op_name)
.Input(FakeInput(DataTypeToEnum<T>::v()))
.Input(FakeInput(DataTypeToEnum<T>::v()));
if (use_constraint) {
if (add_t) {
builder.Attr("T", DataTypeToEnum<T>::v());
}
if (add_tout) {
builder.Attr("Tout", DataTypeToEnum<OutT>::v());
}
TF_ASSERT_OK(builder.Finalize(node_def()));
TF_ASSERT_OK(InitOp());
@ -73,16 +76,20 @@ class GpuBinaryOpTest : public OpsTestBase {
const absl::InlinedVector<T, 10>& rhs_input,
const TensorShape& expected_shape,
const absl::InlinedVector<OutT, 10>& expected_output,
bool use_constraint = true) {
const test::GpuOpsTestConfig& config) {
SetOpKernel<T, OutT>(op_name, lhs_shape, lhs_input, rhs_shape, rhs_input,
use_constraint);
config.add_t, config.add_tout);
TF_ASSERT_OK(RunOpKernel());
// Compare output to expectation.
Tensor expected_tensor(allocator(), DataTypeToEnum<OutT>::value,
expected_shape);
test::FillValues<OutT>(&expected_tensor, expected_output);
if (config.expect_strictly_equal) {
test::ExpectEqual(expected_tensor, *GetOutput(0));
} else {
test::ExpectClose(expected_tensor, *GetOutput(0));
}
}
template <typename T, typename OutT>
@ -91,9 +98,9 @@ class GpuBinaryOpTest : public OpsTestBase {
const absl::InlinedVector<T, 10>& lhs_input,
const TensorShape& rhs_shape,
const absl::InlinedVector<T, 10>& rhs_input,
bool use_constraint = true) {
const test::GpuOpsTestConfig& config) {
SetOpKernel<T, OutT>(op_name, lhs_shape, lhs_input, rhs_shape, rhs_input,
use_constraint);
config.add_t, config.add_tout);
auto status = RunOpKernel();
EXPECT_FALSE(status.ok());
EXPECT_EQ(status.code(), error::INVALID_ARGUMENT);
@ -105,7 +112,7 @@ class GpuBinaryOpTest : public OpsTestBase {
void TestIncompatibleShapes(const std::string& op_name,
const absl::InlinedVector<T, 10>& lhs_input,
const absl::InlinedVector<T, 10>& rhs_input,
bool use_constraint = true) {
const test::GpuOpsTestConfig& config) {
// Prepare incompatibly shaped inputs.
TensorShape lhs_shape{3};
TensorShape rhs_shape{2};
@ -115,8 +122,7 @@ class GpuBinaryOpTest : public OpsTestBase {
test::RepeatInputToMatchShape(rhs_input, rhs_shape.num_elements());
RunAndExpectInvalidArgument<T, OutT>(op_name, lhs_shape, repeated_lhs_input,
rhs_shape, repeated_rhs_input,
use_constraint);
rhs_shape, repeated_rhs_input, config);
}
template <typename T, typename BaselineT, typename OutT,
@ -125,7 +131,7 @@ class GpuBinaryOpTest : public OpsTestBase {
const absl::InlinedVector<T, 10>& lhs_input,
const absl::InlinedVector<T, 10>& rhs_input,
BaselineOutT (*baseline_callback)(BaselineT, BaselineT),
bool use_constraint = true) {
const test::GpuOpsTestConfig& config) {
// Prepare inputs.
int input_size = shape.num_elements();
auto repeated_lhs_input =
@ -147,7 +153,7 @@ class GpuBinaryOpTest : public OpsTestBase {
RunAndExpectResult<T, OutT>(op_name, shape, repeated_lhs_input, shape,
repeated_rhs_input, shape, expected_output,
use_constraint);
config);
}
template <typename T, typename BaselineT, typename OutT,
@ -156,7 +162,7 @@ class GpuBinaryOpTest : public OpsTestBase {
const TensorShape& other_shape,
const absl::InlinedVector<T, 10>& other_input,
BaselineOutT (*baseline_callback)(BaselineT, BaselineT),
bool use_constraint = true) {
const test::GpuOpsTestConfig& config) {
// Prepare inputs.
TensorShape scalar_shape{};
auto repeated_other_input =
@ -177,7 +183,7 @@ class GpuBinaryOpTest : public OpsTestBase {
RunAndExpectResult<T, OutT>(op_name, scalar_shape, scalar_input_vector,
other_shape, repeated_other_input,
/*expected_shape=*/other_shape, expected_output,
use_constraint);
config);
}
template <typename T, typename BaselineT, typename OutT,
@ -187,7 +193,7 @@ class GpuBinaryOpTest : public OpsTestBase {
const absl::InlinedVector<T, 10>& rhs_input,
BaselineOutT (*baseline_callback)(BaselineT,
BaselineT),
bool use_constraint = true) {
const test::GpuOpsTestConfig& config) {
// Prepare inputs.
TensorShape lhs_shape{1};
TensorShape rhs_shape{6};
@ -206,7 +212,7 @@ class GpuBinaryOpTest : public OpsTestBase {
RunAndExpectResult<T, OutT>(
op_name, lhs_shape, repeated_lhs_input, rhs_shape, repeated_rhs_input,
/*expected_shape=*/rhs_shape, expected_output, use_constraint);
/*expected_shape=*/rhs_shape, expected_output, config);
}
template <typename T, typename BaselineT, typename OutT,
@ -216,7 +222,7 @@ class GpuBinaryOpTest : public OpsTestBase {
const absl::InlinedVector<T, 10>& rhs_input,
BaselineOutT (*baseline_callback)(BaselineT,
BaselineT),
bool use_constraint = true) {
const test::GpuOpsTestConfig& config) {
// Prepare inputs.
TensorShape lhs_shape{3};
TensorShape rhs_shape{2, 3};
@ -235,7 +241,7 @@ class GpuBinaryOpTest : public OpsTestBase {
RunAndExpectResult<T, OutT>(
op_name, lhs_shape, repeated_lhs_input, rhs_shape, repeated_rhs_input,
/*expected_shape=*/rhs_shape, expected_output, use_constraint);
/*expected_shape=*/rhs_shape, expected_output, config);
}
template <typename T, typename BaselineT, typename OutT,
@ -244,7 +250,7 @@ class GpuBinaryOpTest : public OpsTestBase {
const absl::InlinedVector<T, 10>& lhs_input,
const absl::InlinedVector<T, 10>& rhs_input,
BaselineOutT (*baseline_callback)(BaselineT, BaselineT),
bool use_constraint = true) {
const test::GpuOpsTestConfig& config) {
// Prepare inputs.
TensorShape lhs_shape{2, 1};
TensorShape rhs_shape{3};
@ -264,7 +270,7 @@ class GpuBinaryOpTest : public OpsTestBase {
RunAndExpectResult<T, OutT>(op_name, lhs_shape, repeated_lhs_input,
rhs_shape, repeated_rhs_input, expected_shape,
expected_output, use_constraint);
expected_output, config);
}
template <typename T, typename BaselineT, typename OutT,
@ -272,7 +278,7 @@ class GpuBinaryOpTest : public OpsTestBase {
void TestEmptyShapeBroadcasting(const std::string& op_name,
const absl::InlinedVector<T, 10>& lhs_input,
const absl::InlinedVector<T, 10>& rhs_input,
bool use_constraint = true) {
const test::GpuOpsTestConfig& config) {
// Prepare inputs.
TensorShape lhs_shape{2, 0, 1};
TensorShape rhs_shape{2, 0, 5};
@ -284,7 +290,7 @@ class GpuBinaryOpTest : public OpsTestBase {
RunAndExpectResult<T, OutT>(op_name, lhs_shape, empty_input, rhs_shape,
empty_input, expected_shape, expected_output,
use_constraint);
config);
}
private:
@ -309,60 +315,60 @@ class GpuBinaryOpTest : public OpsTestBase {
// define your own test fixtures.
#define GENERATE_DEFAULT_TESTS_2(op_name, test_name, T, BaselineT, OutT, \
BaselineOutT, baseline_callback, \
use_constraint) \
BaselineOutT, baseline_callback, config) \
TEST_F(GpuBinaryOpTest, op_name##EqShapes##test_name) { \
TestEqualShapes<T, BaselineT, OutT, BaselineOutT>( \
#op_name, /*shape=*/test::DefaultInputShape(), \
/*lhs_input=*/test::DefaultInput<T>(#op_name), \
/*rhs_input=*/test::DefaultInput<T>(#op_name), baseline_callback, \
use_constraint); \
config); \
} \
\
TEST_F(GpuBinaryOpTest, op_name##OneScalar##test_name) { \
TestOneScalar<T, BaselineT, OutT, BaselineOutT>( \
#op_name, /*scalar_input=*/test::DefaultScalarInput<T>(), \
#op_name, /*scalar_input=*/test::DefaultInput<T>(#op_name).front(), \
/*other_shape=*/test::DefaultInputShape(), \
/*other_input=*/test::DefaultInput<T>(#op_name), baseline_callback, \
use_constraint); \
config); \
} \
\
TEST_F(GpuBinaryOpTest, op_name##IncompatibleShapes##test_name) { \
TestIncompatibleShapes<T, OutT>( \
#op_name, /*lhs_input=*/test::DefaultInput<T>(#op_name), \
/*rhs_input=*/test::DefaultInput<T>(#op_name), use_constraint); \
/*rhs_input=*/test::DefaultInput<T>(#op_name), config); \
} \
\
TEST_F(GpuBinaryOpTest, op_name##BroadcastingExpand##test_name) { \
TestBroadcastingExpand<T, BaselineT, OutT, BaselineOutT>( \
#op_name, /*lhs_input=*/test::DefaultInput<T>(#op_name), \
/*rhs_input=*/test::DefaultInput<T>(#op_name), baseline_callback, \
use_constraint); \
config); \
} \
\
TEST_F(GpuBinaryOpTest, op_name##BroadcastingInDim##test_name) { \
TestBroadcastingInDim<T, BaselineT, OutT, BaselineOutT>( \
#op_name, /*lhs_input=*/test::DefaultInput<T>(#op_name), \
/*rhs_input=*/test::DefaultInput<T>(#op_name), baseline_callback, \
use_constraint); \
config); \
} \
\
TEST_F(GpuBinaryOpTest, op_name##Broadcasting##test_name) { \
TestBroadcasting<T, BaselineT, OutT, BaselineOutT>( \
#op_name, /*lhs_input=*/test::DefaultInput<T>(#op_name), \
/*rhs_input=*/test::DefaultInput<T>(#op_name), baseline_callback, \
use_constraint); \
config); \
} \
\
TEST_F(GpuBinaryOpTest, op_name##EmptyShapeBroadcasting##test_name) { \
TestEmptyShapeBroadcasting<T, BaselineT, OutT, BaselineOutT>( \
#op_name, /*lhs_input=*/test::DefaultInput<T>(#op_name), \
/*rhs_input=*/test::DefaultInput<T>(#op_name), use_constraint); \
/*rhs_input=*/test::DefaultInput<T>(#op_name), config); \
}
#define GENERATE_DEFAULT_TESTS(op_name, test_name, T, OutT, baseline_callback) \
GENERATE_DEFAULT_TESTS_2(op_name, test_name, T, T, OutT, OutT, \
baseline_callback, /*use_constraint=*/false)
baseline_callback, \
test::GpuOpsTestConfig().ExpectStrictlyEqual())
#define GENERATE_DEFAULT_TESTS_SAME_INPUT_AND_OUTPUT_TYPE( \
op_name, test_name, T, baseline_callback) \
@ -433,37 +439,23 @@ GENERATE_DEFAULT_TESTS(BitwiseXor,
GENERATE_DEFAULT_TESTS(BitwiseXor,
/*test_name=*/Int64, int64, int64, baseline_bitwise_xor)
/// Test `tf.LeftShift`.
/// Test `tf.Div`.
template <typename T>
T baseline_left_shift(T lhs, T rhs) {
return lhs << rhs;
T baseline_div(T lhs, T rhs) {
return lhs / rhs;
}
GENERATE_DEFAULT_TESTS(LeftShift, /*test_name=*/Int8, int8, int8,
baseline_left_shift)
GENERATE_DEFAULT_TESTS(LeftShift, /*test_name=*/Int16, int16, int16,
baseline_left_shift)
GENERATE_DEFAULT_TESTS(LeftShift, /*test_name=*/Int32, int32, int32,
baseline_left_shift)
GENERATE_DEFAULT_TESTS(LeftShift, /*test_name=*/Int64, int64, int64,
baseline_left_shift)
/// Test `tf.RightShift`.
template <typename T>
T baseline_right_shift(T lhs, T rhs) {
return lhs >> rhs;
}
GENERATE_DEFAULT_TESTS(RightShift,
/*test_name=*/Int8, int8, int8, baseline_right_shift)
GENERATE_DEFAULT_TESTS(RightShift,
/*test_name=*/Int16, int16, int16, baseline_right_shift)
GENERATE_DEFAULT_TESTS(RightShift,
/*test_name=*/Int32, int32, int32, baseline_right_shift)
GENERATE_DEFAULT_TESTS(RightShift,
/*test_name=*/Int64, int64, int64, baseline_right_shift)
GENERATE_DEFAULT_TESTS(Div,
/*test_name=*/Half, Eigen::half, Eigen::half,
baseline_div);
GENERATE_DEFAULT_TESTS(Div,
/*test_name=*/Float, float, float, baseline_div);
GENERATE_DEFAULT_TESTS(Div,
/*test_name=*/Double, double, double, baseline_div);
GENERATE_DEFAULT_TESTS(Div,
/*test_name=*/Int16, int16, int16, baseline_div);
GENERATE_DEFAULT_TESTS(Div,
/*test_name=*/Int64, int64, int64, baseline_div);
/// Test `tf.Equal`.
@ -482,27 +474,25 @@ GENERATE_DEFAULT_TESTS(Equal, /*test_name=*/Int8, int8, bool, baseline_equal)
GENERATE_DEFAULT_TESTS(Equal, /*test_name=*/Int16, int16, bool, baseline_equal)
GENERATE_DEFAULT_TESTS(Equal, /*test_name=*/Int64, int64, bool, baseline_equal)
/// Test `tf.NotEqual`.
/// Test `tf.FloorDiv`.
template <typename T>
bool baseline_not_equal(T lhs, T rhs) {
return lhs != rhs;
T baseline_floor_div(T lhs, T rhs) {
return std::floor(lhs / rhs);
}
GENERATE_DEFAULT_TESTS(NotEqual, /*test_name=*/Half, Eigen::half, bool,
baseline_not_equal)
GENERATE_DEFAULT_TESTS(NotEqual, /*test_name=*/Float, float, bool,
baseline_not_equal)
GENERATE_DEFAULT_TESTS(NotEqual, /*test_name=*/Double, double, bool,
baseline_not_equal)
GENERATE_DEFAULT_TESTS(NotEqual, /*test_name=*/Bool, bool, bool,
baseline_not_equal)
GENERATE_DEFAULT_TESTS(NotEqual, /*test_name=*/Int8, int8, bool,
baseline_not_equal)
GENERATE_DEFAULT_TESTS(NotEqual, /*test_name=*/Int16, int16, bool,
baseline_not_equal)
GENERATE_DEFAULT_TESTS(NotEqual, /*test_name=*/Int64, int64, bool,
baseline_not_equal)
template <>
Eigen::half baseline_floor_div(Eigen::half lhs, Eigen::half rhs) {
return static_cast<Eigen::half>(std::floor(static_cast<float>(lhs / rhs)));
}
GENERATE_DEFAULT_TESTS(FloorDiv,
/*test_name=*/Half, Eigen::half, Eigen::half,
baseline_floor_div)
GENERATE_DEFAULT_TESTS(FloorDiv,
/*test_name=*/Float, float, float, baseline_floor_div)
GENERATE_DEFAULT_TESTS(FloorDiv,
/*test_name=*/Double, double, double, baseline_floor_div)
/// Test `tf.Greater`.
@ -544,6 +534,22 @@ GENERATE_DEFAULT_TESTS(GreaterEqual, /*test_name=*/Int16, int16, bool,
GENERATE_DEFAULT_TESTS(GreaterEqual, /*test_name=*/Int64, int64, bool,
baseline_greater_equal)
/// Test `tf.LeftShift`.
template <typename T>
T baseline_left_shift(T lhs, T rhs) {
return lhs << rhs;
}
GENERATE_DEFAULT_TESTS(LeftShift, /*test_name=*/Int8, int8, int8,
baseline_left_shift)
GENERATE_DEFAULT_TESTS(LeftShift, /*test_name=*/Int16, int16, int16,
baseline_left_shift)
GENERATE_DEFAULT_TESTS(LeftShift, /*test_name=*/Int32, int32, int32,
baseline_left_shift)
GENERATE_DEFAULT_TESTS(LeftShift, /*test_name=*/Int64, int64, int64,
baseline_left_shift)
/// Test `tf.Less`.
template <typename T>
@ -586,7 +592,7 @@ bool baseline_logical_and(bool lhs, bool rhs) { return lhs && rhs; }
GENERATE_DEFAULT_TESTS_2(LogicalAnd, /*test_name=*/Bool, /*T=*/bool,
/*BaselineT=*/bool, /*OutT=*/bool,
/*BaselineOutT=*/bool, baseline_logical_and,
/*use_constraint=*/false)
test::GpuOpsTestConfig().ExpectStrictlyEqual().NoT())
/// Test `tf.LogicalOr`.
@ -595,27 +601,77 @@ bool baseline_logical_or(bool lhs, bool rhs) { return lhs || rhs; }
GENERATE_DEFAULT_TESTS_2(LogicalOr, /*test_name=*/Bool, /*T=*/bool,
/*BaselineT=*/bool, /*OutT=*/bool,
/*BaselineOutT=*/bool, baseline_logical_or,
/*use_constraint=*/false)
test::GpuOpsTestConfig().ExpectStrictlyEqual().NoT())
/// Test `tf.Mul`.
/// Test `tf.FloorDiv`.
template <typename T>
T baseline_floor_div(T lhs, T rhs) {
return std::floor(lhs / rhs);
T baseline_mul(T lhs, T rhs) {
return lhs * rhs;
}
template <>
Eigen::half baseline_floor_div(Eigen::half lhs, Eigen::half rhs) {
return static_cast<Eigen::half>(std::floor(static_cast<float>(lhs / rhs)));
GENERATE_DEFAULT_TESTS(Mul, /*test_name=*/Half, Eigen::half, Eigen::half,
baseline_mul)
GENERATE_DEFAULT_TESTS(Mul, /*test_name=*/Float, float, float, baseline_mul)
GENERATE_DEFAULT_TESTS(Mul, /*test_name=*/Double, double, double, baseline_mul)
GENERATE_DEFAULT_TESTS(Mul, /*test_name=*/Int8, int8, int8, baseline_mul)
GENERATE_DEFAULT_TESTS(Mul, /*test_name=*/Int16, int16, int16, baseline_mul)
GENERATE_DEFAULT_TESTS(Mul, /*test_name=*/Int64, int64, int64, baseline_mul)
/// Test `tf.NotEqual`.
template <typename T>
bool baseline_not_equal(T lhs, T rhs) {
return lhs != rhs;
}
GENERATE_DEFAULT_TESTS(FloorDiv,
GENERATE_DEFAULT_TESTS(NotEqual, /*test_name=*/Half, Eigen::half, bool,
baseline_not_equal)
GENERATE_DEFAULT_TESTS(NotEqual, /*test_name=*/Float, float, bool,
baseline_not_equal)
GENERATE_DEFAULT_TESTS(NotEqual, /*test_name=*/Double, double, bool,
baseline_not_equal)
GENERATE_DEFAULT_TESTS(NotEqual, /*test_name=*/Bool, bool, bool,
baseline_not_equal)
GENERATE_DEFAULT_TESTS(NotEqual, /*test_name=*/Int8, int8, bool,
baseline_not_equal)
GENERATE_DEFAULT_TESTS(NotEqual, /*test_name=*/Int16, int16, bool,
baseline_not_equal)
GENERATE_DEFAULT_TESTS(NotEqual, /*test_name=*/Int64, int64, bool,
baseline_not_equal)
/// Test `tf.RightShift`.
template <typename T>
T baseline_right_shift(T lhs, T rhs) {
return lhs >> rhs;
}
GENERATE_DEFAULT_TESTS(RightShift,
/*test_name=*/Int8, int8, int8, baseline_right_shift)
GENERATE_DEFAULT_TESTS(RightShift,
/*test_name=*/Int16, int16, int16, baseline_right_shift)
GENERATE_DEFAULT_TESTS(RightShift,
/*test_name=*/Int32, int32, int32, baseline_right_shift)
GENERATE_DEFAULT_TESTS(RightShift,
/*test_name=*/Int64, int64, int64, baseline_right_shift)
/// Test `tf.Sub`.
template <typename T>
T baseline_sub(T lhs, T rhs) {
return lhs - rhs;
}
GENERATE_DEFAULT_TESTS(Sub,
/*test_name=*/Half, Eigen::half, Eigen::half,
baseline_floor_div);
GENERATE_DEFAULT_TESTS(FloorDiv,
/*test_name=*/Float, float, float, baseline_floor_div);
GENERATE_DEFAULT_TESTS(FloorDiv,
/*test_name=*/Double, double, double,
baseline_floor_div);
baseline_sub)
GENERATE_DEFAULT_TESTS(Sub,
/*test_name=*/Float, float, float, baseline_sub)
GENERATE_DEFAULT_TESTS(Sub,
/*test_name=*/Double, double, double, baseline_sub)
GENERATE_DEFAULT_TESTS(Sub,
/*test_name=*/Int64, int64, int64, baseline_sub)
} // namespace
} // end namespace tensorflow

View File

@ -0,0 +1,27 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
namespace tensorflow {
GENERATE_AND_REGISTER_BINARY_KERNEL(Div, f16, DT_HALF, Eigen::half);
GENERATE_AND_REGISTER_BINARY_KERNEL(Div, f32, DT_FLOAT, float);
GENERATE_AND_REGISTER_BINARY_KERNEL(Div, f64, DT_DOUBLE, double);
GENERATE_AND_REGISTER_BINARY_KERNEL(Div, i16, DT_INT16, int16);
GENERATE_AND_REGISTER_BINARY_KERNEL(Div, i64, DT_INT64, int64);
} // namespace tensorflow

View File

@ -0,0 +1,26 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
namespace tensorflow {
GENERATE_AND_REGISTER_BINARY_KERNEL(Sub, f16, DT_HALF, Eigen::half);
GENERATE_AND_REGISTER_BINARY_KERNEL(Sub, f32, DT_FLOAT, float);
GENERATE_AND_REGISTER_BINARY_KERNEL(Sub, f64, DT_DOUBLE, double);
GENERATE_AND_REGISTER_BINARY_KERNEL(Sub, i64, DT_INT64, int64);
} // namespace tensorflow

View File

@ -57,6 +57,7 @@ TensorShape DefaultInputShape();
struct GpuOpsTestConfig {
bool add_t = true;
bool add_tout = false;
// Only used for gpu_unary_ops_test.
bool expect_buffer_reuse = true;
bool expect_strictly_equal = false;
GpuOpsTestConfig ExpectStrictlyEqual() {
@ -119,33 +120,10 @@ absl::InlinedVector<T, 10> DefaultInputGreaterOrEqualToZero() {
/// Helper functions to get default input data.
template <typename T,
std::enable_if_t<llvm::is_one_of<T, int8, int16, int32, int64>::value,
bool> = true>
T DefaultScalarInput() {
return static_cast<T>(3);
}
template <typename T, std::enable_if_t<
llvm::is_one_of<T, Eigen::half, float, double>::value,
bool> = true>
T DefaultScalarInput() {
return static_cast<T>(2.0);
}
template <typename T,
std::enable_if_t<llvm::is_one_of<T, bool>::value, bool> = true>
T DefaultScalarInput() {
return static_cast<T>(true);
}
template <typename T,
std::enable_if_t<llvm::is_one_of<T, int8, int16, int32, int64>::value,
bool> = true>
absl::InlinedVector<T, 10> DefaultInput(absl::string_view op_name) {
if (op_name == "Abs") {
return NearZeroAndExtremeInput<T>();
}
// Only generate values less than the bitwidth of the data type.
if (op_name == "LeftShift" || op_name == "RightShift") {
auto max_shift = sizeof(T) * 8 - 1;
@ -153,6 +131,9 @@ absl::InlinedVector<T, 10> DefaultInput(absl::string_view op_name) {
for (auto i = 0; i < max_shift; ++i) v.push_back(i);
return v;
}
if (op_name == "Div") {
return InputAsVector<T, int>({-18, -9, 9, 18});
}
return InputAsVector<T, int>({-18, -9, -1, 0, 0, 1, 1, 2, 3, 5, 7, 9, 9, 18});
}
@ -160,16 +141,7 @@ template <typename T, std::enable_if_t<
llvm::is_one_of<T, Eigen::half, float, double>::value,
bool> = true>
absl::InlinedVector<T, 10> DefaultInput(absl::string_view op_name) {
if (op_name == "Abs") {
return NearZeroAndExtremeInput<T>();
}
if (op_name == "Log" || op_name == "Rsqrt") {
return DefaultInputGreaterThanZero<T>();
}
if (op_name == "Sqrt") {
return DefaultInputGreaterOrEqualToZero<T>();
}
if (op_name == "FloorDiv") {
if (op_name == "Div" || op_name == "FloorDiv") {
return InputAsVector<T, double>({-18.0, -9.0, -1e-6, -0.1, 0.1, 1e-6, 0.1,
0.2, 0.3, 0.5, 0.7, 0.9, 9.0, 18.0});
}

View File

@ -125,10 +125,22 @@ class GpuUnaryOpTest : public OpsTestBase {
// define your own test fixtures.
#define GENERATE_DEFAULT_TEST(op_name, InT, OutT, baseline_callback, config) \
GENERATE_DEFAULT_TEST2(op_name, InT, InT, OutT, OutT, baseline_callback, \
GENERATE_DEFAULT_TEST_2(op_name, InT, InT, OutT, OutT, baseline_callback, \
config)
#define GENERATE_DEFAULT_TEST2(op_name, InT, BaselineT, OutT, BaselineOutT, \
#define GENERATE_DEFAULT_TEST_2(op_name, InT, BaselineT, OutT, BaselineOutT, \
baseline_callback, config) \
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES_2( \
op_name, InT, BaselineT, OutT, BaselineOutT, \
test::DefaultInput<NativeT>(#op_name), baseline_callback, config)
#define GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES( \
op_name, InT, OutT, input_values, baseline_callback, config) \
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES_2( \
op_name, InT, InT, OutT, OutT, input_values, baseline_callback, config)
#define GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES_2( \
op_name, InT, BaselineT, OutT, BaselineOutT, input_values, \
baseline_callback, config) \
TEST_F(GpuUnaryOpTest, op_name##InT) { \
using NativeT = EnumToDataType<InT>::Type; \
@ -136,25 +148,31 @@ class GpuUnaryOpTest : public OpsTestBase {
using NativeOutT = EnumToDataType<OutT>::Type; \
using NativeBaselineOutT = EnumToDataType<BaselineOutT>::Type; \
Test<NativeT, NativeBaselineT, NativeOutT, NativeBaselineOutT>( \
#op_name, test::DefaultInputShape(), \
test::DefaultInput<NativeT>(#op_name), baseline_callback, config); \
#op_name, test::DefaultInputShape(), input_values, baseline_callback, \
config); \
}
/// Test `tf.Abs`.
GENERATE_DEFAULT_TEST(Abs, DT_FLOAT, DT_FLOAT, std::abs,
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES(
Abs, DT_FLOAT, DT_FLOAT, test::NearZeroAndExtremeInput<float>(), std::abs,
test::GpuOpsTestConfig().ExpectStrictlyEqual())
GENERATE_DEFAULT_TEST(Abs, DT_DOUBLE, DT_DOUBLE, std::abs,
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES(
Abs, DT_DOUBLE, DT_DOUBLE, test::NearZeroAndExtremeInput<double>(),
std::abs, test::GpuOpsTestConfig().ExpectStrictlyEqual())
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES_2(
Abs, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT,
test::NearZeroAndExtremeInput<Eigen::half>(), std::abs,
test::GpuOpsTestConfig().ExpectStrictlyEqual())
GENERATE_DEFAULT_TEST2(Abs, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT, std::abs,
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES(
Abs, DT_INT32, DT_INT32, test::NearZeroAndExtremeInput<int32>(), std::abs,
test::GpuOpsTestConfig().ExpectStrictlyEqual())
GENERATE_DEFAULT_TEST(Abs, DT_INT32, DT_INT32, std::abs,
test::GpuOpsTestConfig().ExpectStrictlyEqual())
GENERATE_DEFAULT_TEST(Abs, DT_INT64, DT_INT64, std::abs,
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES(
Abs, DT_INT64, DT_INT64, test::NearZeroAndExtremeInput<int64>(), std::abs,
test::GpuOpsTestConfig().ExpectStrictlyEqual())
/// Test `tf.Ceil`.
@ -165,7 +183,7 @@ GENERATE_DEFAULT_TEST(Ceil, DT_FLOAT, DT_FLOAT, std::ceil,
GENERATE_DEFAULT_TEST(Ceil, DT_DOUBLE, DT_DOUBLE, std::ceil,
test::GpuOpsTestConfig().ExpectStrictlyEqual())
GENERATE_DEFAULT_TEST2(Ceil, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT, std::ceil,
GENERATE_DEFAULT_TEST_2(Ceil, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT, std::ceil,
test::GpuOpsTestConfig().ExpectStrictlyEqual())
/// Test `tf.Conj`.
@ -189,7 +207,7 @@ GENERATE_DEFAULT_TEST(Cos, DT_FLOAT, DT_FLOAT, std::cos,
GENERATE_DEFAULT_TEST(Cos, DT_DOUBLE, DT_DOUBLE, std::cos,
test::GpuOpsTestConfig())
GENERATE_DEFAULT_TEST2(Cos, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT, std::cos,
GENERATE_DEFAULT_TEST_2(Cos, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT, std::cos,
test::GpuOpsTestConfig())
/// Test `tf.Exp`.
@ -200,7 +218,7 @@ GENERATE_DEFAULT_TEST(Exp, DT_FLOAT, DT_FLOAT, std::exp,
GENERATE_DEFAULT_TEST(Exp, DT_DOUBLE, DT_DOUBLE, std::exp,
test::GpuOpsTestConfig())
GENERATE_DEFAULT_TEST2(Exp, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT, std::exp,
GENERATE_DEFAULT_TEST_2(Exp, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT, std::exp,
test::GpuOpsTestConfig())
/// Test `tf.Floor`.
@ -211,7 +229,7 @@ GENERATE_DEFAULT_TEST(Floor, DT_FLOAT, DT_FLOAT, std::floor,
GENERATE_DEFAULT_TEST(Floor, DT_DOUBLE, DT_DOUBLE, std::floor,
test::GpuOpsTestConfig().ExpectStrictlyEqual())
GENERATE_DEFAULT_TEST2(Floor, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT, std::floor,
GENERATE_DEFAULT_TEST_2(Floor, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT, std::floor,
test::GpuOpsTestConfig().ExpectStrictlyEqual())
/// Test `tf.Imag`.
@ -260,13 +278,17 @@ TEST_F(GpuUnaryOpTest, DISABLED_IsInfHalf) {
/// Test `tf.Log`.
GENERATE_DEFAULT_TEST(Log, DT_FLOAT, DT_FLOAT, std::log,
test::GpuOpsTestConfig())
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES(
Log, DT_FLOAT, DT_FLOAT, test::DefaultInputGreaterThanZero<float>(),
std::log, test::GpuOpsTestConfig())
GENERATE_DEFAULT_TEST(Log, DT_DOUBLE, DT_DOUBLE, std::log,
test::GpuOpsTestConfig())
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES(
Log, DT_DOUBLE, DT_DOUBLE, test::DefaultInputGreaterThanZero<double>(),
std::log, test::GpuOpsTestConfig())
GENERATE_DEFAULT_TEST2(Log, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT, std::log,
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES_2(
Log, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT,
test::DefaultInputGreaterThanZero<Eigen::half>(), std::log,
test::GpuOpsTestConfig())
/// Test `tf.LogicalNot`
@ -290,7 +312,7 @@ GENERATE_DEFAULT_TEST(Neg, DT_FLOAT, DT_FLOAT, baseline_neg,
GENERATE_DEFAULT_TEST(Neg, DT_DOUBLE, DT_DOUBLE, baseline_neg,
test::GpuOpsTestConfig().ExpectStrictlyEqual())
GENERATE_DEFAULT_TEST2(Neg, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT, baseline_neg,
GENERATE_DEFAULT_TEST_2(Neg, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT, baseline_neg,
test::GpuOpsTestConfig())
GENERATE_DEFAULT_TEST(Neg, DT_INT8, DT_INT8, baseline_neg,
@ -323,15 +345,19 @@ T baseline_rsqrt(T x) {
return 1.0 / std::sqrt(x);
}
GENERATE_DEFAULT_TEST(Rsqrt, DT_FLOAT, DT_FLOAT, baseline_rsqrt,
test::GpuOpsTestConfig())
GENERATE_DEFAULT_TEST(Rsqrt, DT_DOUBLE, DT_DOUBLE, baseline_rsqrt,
test::GpuOpsTestConfig())
GENERATE_DEFAULT_TEST2(Rsqrt, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT,
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES(
Rsqrt, DT_FLOAT, DT_FLOAT, test::DefaultInputGreaterThanZero<float>(),
baseline_rsqrt, test::GpuOpsTestConfig())
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES(
Rsqrt, DT_DOUBLE, DT_DOUBLE, test::DefaultInputGreaterThanZero<double>(),
baseline_rsqrt, test::GpuOpsTestConfig())
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES_2(
Rsqrt, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT,
test::DefaultInputGreaterThanZero<Eigen::half>(), baseline_rsqrt,
test::GpuOpsTestConfig())
/// Test `tf.Sign`.
// Reference implementation
@ -350,7 +376,7 @@ GENERATE_DEFAULT_TEST(Sign, DT_DOUBLE, DT_DOUBLE, baseline_sign,
// TODO(b/162577610): We should actually use ExpectStrictlyEqual()
// here. This requires returning 0.0 for input -0.0.
GENERATE_DEFAULT_TEST2(Sign, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT,
GENERATE_DEFAULT_TEST_2(Sign, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT,
baseline_sign, test::GpuOpsTestConfig())
GENERATE_DEFAULT_TEST(Sign, DT_INT64, DT_INT64, baseline_sign,
@ -364,18 +390,23 @@ GENERATE_DEFAULT_TEST(Sin, DT_FLOAT, DT_FLOAT, std::sin,
GENERATE_DEFAULT_TEST(Sin, DT_DOUBLE, DT_DOUBLE, std::sin,
test::GpuOpsTestConfig())
GENERATE_DEFAULT_TEST2(Sin, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT, std::sin,
GENERATE_DEFAULT_TEST_2(Sin, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT, std::sin,
test::GpuOpsTestConfig())
/// Test `tf.Sqrt`.
GENERATE_DEFAULT_TEST(Sqrt, DT_FLOAT, DT_FLOAT, std::sqrt,
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES(
Sqrt, DT_FLOAT, DT_FLOAT, test::DefaultInputGreaterOrEqualToZero<float>(),
std::sqrt, test::GpuOpsTestConfig())
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES(
Sqrt, DT_DOUBLE, DT_DOUBLE,
test::DefaultInputGreaterOrEqualToZero<double>(), std::sqrt,
test::GpuOpsTestConfig())
GENERATE_DEFAULT_TEST(Sqrt, DT_DOUBLE, DT_DOUBLE, std::sqrt,
test::GpuOpsTestConfig())
GENERATE_DEFAULT_TEST2(Sqrt, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT, std::sqrt,
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES_2(
Sqrt, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT,
test::DefaultInputGreaterOrEqualToZero<Eigen::half>(), std::sqrt,
test::GpuOpsTestConfig())
/// Test `tf.Tanh`.
@ -386,7 +417,7 @@ GENERATE_DEFAULT_TEST(Tanh, DT_FLOAT, DT_FLOAT, std::tanh,
GENERATE_DEFAULT_TEST(Tanh, DT_DOUBLE, DT_DOUBLE, std::tanh,
test::GpuOpsTestConfig())
GENERATE_DEFAULT_TEST2(Tanh, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT, std::tanh,
GENERATE_DEFAULT_TEST_2(Tanh, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT, std::tanh,
test::GpuOpsTestConfig())
} // namespace

View File

@ -0,0 +1,6 @@
func @Div_elem_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>)
-> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} {
%0 = "tf.Div"(%arg0, %arg1) {T = elem_type, device = ""}
: (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xelem_type>
return %0 : tensor<*xelem_type>
}

View File

@ -0,0 +1,6 @@
func @Sub_elem_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>)
-> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} {
%0 = "tf.Sub"(%arg0, %arg1) {T = elem_type, device = ""}
: (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xelem_type>
return %0 : tensor<*xelem_type>
}

View File

@ -137,4 +137,14 @@ REGISTER_KERNEL_BUILDER(
Name("QuantizedMaxPool").Device(DEVICE_CPU).TypeConstraint<quint8>("T"),
QuantizedMaxPoolingOp<CPUDevice, quint8>);
#ifdef INTEL_MKL
REGISTER_KERNEL_BUILDER(
Name("QuantizedAvgPool").Device(DEVICE_CPU).TypeConstraint<qint8>("T"),
QuantizedAvgPoolingOp<CPUDevice, qint8>);
REGISTER_KERNEL_BUILDER(
Name("QuantizedMaxPool").Device(DEVICE_CPU).TypeConstraint<qint8>("T"),
QuantizedMaxPoolingOp<CPUDevice, qint8>);
#endif
} // namespace tensorflow

View File

@ -232,11 +232,7 @@ REGISTER_OP("_FusedBatchNormEx")
.Output("reserve_space_1: U")
.Output("reserve_space_2: U")
.Output("reserve_space_3: U")
#ifdef ENABLE_MKLDNN_V1
.Attr("T: {half, float, bfloat16}")
#else
.Attr("T: {half, float}")
#endif
.Attr("U: {float}")
.Attr("epsilon: float = 0.0001")
.Attr("exponential_avg_factor: float = 1.0")

View File

@ -981,6 +981,17 @@ cc_library(
alwayslink = 1,
)
cc_library(
name = "enable_tf2_utils",
srcs = ["enable_tf2_utils.cc"],
hdrs = ["enable_tf2_utils.h"],
copts = tf_copts(),
deps = [
"//tensorflow/core/util:env_var",
],
alwayslink = 1,
)
alias(
name = "profile_utils_cpu_utils",
actual = "//tensorflow/core/platform/profile_utils:profile_utils_cpu_utils",
@ -992,6 +1003,12 @@ filegroup(
compatible_with = get_compatible_with_portable(),
)
filegroup(
name = "enable_tf2_hdr",
srcs = ["enable_tf2_utils.h"],
compatible_with = get_compatible_with_portable(),
)
tf_cc_tests(
name = "low_level_library_tests",
size = "small",
@ -1047,6 +1064,20 @@ tf_cc_test(
],
)
tf_cc_test(
name = "enable_tf2_utils_test",
size = "small",
srcs = [
"enable_tf2_utils_test.cc",
],
deps = [
":enable_tf2_utils",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/util:env_var",
],
)
tf_cc_tests(
name = "stacktrace_handler_test",
size = "small",

View File

@ -149,9 +149,7 @@ class PosixEnv : public Env {
return true;
}
}
#if defined(__ANDROID__) || defined(__EMSCRIPTEN__)
return false;
#else
#if defined(__GLIBC__) || defined(__FreeBSD__)
char buf[100];
#ifdef __FreeBSD__
int res = 0;
@ -164,6 +162,8 @@ class PosixEnv : public Env {
}
*name = buf;
return true;
#else
return false;
#endif
}

View File

@ -0,0 +1,49 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/platform/enable_tf2_utils.h"
#include <atomic>
#include "tensorflow/core/util/env_var.h"
namespace tensorflow {
enum Enablement : uint8 { kFalse = 0, kTrue = 1, undefined = 2 };
// If this flag is set, we will use it as a signal to decide on whether to
// use the MLIR based TF-XLA bridge.
static std::atomic<Enablement> tf2_enabled{undefined};
// Determine whether or not the user has explicitly asked for tf2 execution.
// Will be used to determine whether to use the MLIR based bridge.
void set_tf2_execution(bool enabled) {
tf2_enabled = (enabled) ? Enablement::kTrue : Enablement::kFalse;
}
bool tf2_execution_enabled() {
if (tf2_enabled == Enablement::undefined) {
static bool tf2_behavior_env_enabled = [] {
string tf2_env;
TF_CHECK_OK(ReadStringFromEnvVar("TF2_BEHAVIOR", "0", &tf2_env));
return tf2_env != "0";
}();
tf2_enabled =
(tf2_behavior_env_enabled) ? Enablement::kTrue : Enablement::kFalse;
}
return tf2_enabled;
}
} // namespace tensorflow

View File

@ -0,0 +1,31 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TF_CORE_PLATFORM_TF2_UTILS_H_
#define TF_CORE_PLATFORM_TF2_UTILS_H_
namespace tensorflow {
// Sets the tf2 execution state. This can be used to indicate whether the user
// has explicitly asked for tf2 execution.
void set_tf2_execution(bool enabled);
// Returns true or false depending on whether the user flag for tf2 execution
// has been set. The default is false.
bool tf2_execution_enabled();
} // namespace tensorflow
#endif // TF_CORE_PLATFORM_TF2_UTILS_H_

View File

@ -0,0 +1,35 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Testing TF2 enablement.
#include "tensorflow/core/platform/enable_tf2_utils.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/util/env_var.h"
namespace tensorflow {
TEST(TF2EnabledTest, enabled_behavior) {
string tf2_env;
TF_CHECK_OK(ReadStringFromEnvVar("TF2_BEHAVIOR", "0", &tf2_env));
bool expected = (tf2_env != "0");
EXPECT_EQ(tensorflow::tf2_execution_enabled(), expected);
tensorflow::set_tf2_execution(true);
EXPECT_TRUE(tensorflow::tf2_execution_enabled());
tensorflow::set_tf2_execution(false);
EXPECT_FALSE(tensorflow::tf2_execution_enabled());
}
} // namespace tensorflow

View File

@ -108,7 +108,7 @@ limitations under the License.
#define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0
#define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0
#define TF_GRAPH_DEF_VERSION 634 // Updated: 2021/1/2
#define TF_GRAPH_DEF_VERSION 637 // Updated: 2021/1/5
// Checkpoint compatibility versions (the versions field in SavedSliceMeta).
//

View File

@ -386,6 +386,10 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type,
return ParseSoftmax(op, error_reporter, allocator, builtin_data);
}
case BuiltinOperator_SPACE_TO_DEPTH: {
return ParseSpaceToDepth(op, error_reporter, allocator, builtin_data);
}
case BuiltinOperator_SPLIT: {
return ParseSplit(op, error_reporter, allocator, builtin_data);
}
@ -596,16 +600,6 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type,
*builtin_data = params.release();
return kTfLiteOk;
}
case BuiltinOperator_SPACE_TO_DEPTH: {
auto params = safe_allocator.Allocate<TfLiteSpaceToDepthParams>();
TF_LITE_ENSURE(error_reporter, params != nullptr);
if (const auto* schema_params =
op->builtin_options_as_SpaceToDepthOptions()) {
params->block_size = schema_params->block_size();
}
*builtin_data = params.release();
return kTfLiteOk;
}
case BuiltinOperator_GATHER: {
auto params = safe_allocator.Allocate<TfLiteGatherParams>();
@ -1684,6 +1678,31 @@ TfLiteStatus ParseSoftmax(const Operator* op, ErrorReporter* error_reporter,
return kTfLiteOk;
}
TfLiteStatus ParseSpaceToDepth(const Operator* op,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data) {
CheckParsePointerParams(op, error_reporter, allocator, builtin_data);
SafeBuiltinDataAllocator safe_allocator(allocator);
std::unique_ptr<TfLiteSpaceToDepthParams,
SafeBuiltinDataAllocator::BuiltinDataDeleter>
params = safe_allocator.Allocate<TfLiteSpaceToDepthParams>();
TF_LITE_ENSURE(error_reporter, params != nullptr);
const auto* schema_params = op->builtin_options_as_SpaceToDepthOptions();
if (schema_params != nullptr) {
params->block_size = schema_params->block_size();
} else {
// TODO(b/157480169): We should either return kTfLiteError or fill in some
// reasonable defaults in the params struct. We are not doing so until we
// better undertand the ramifications of changing the legacy behavior.
}
*builtin_data = params.release();
return kTfLiteOk;
}
TfLiteStatus ParseSplit(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data) {
CheckParsePointerParams(op, error_reporter, allocator, builtin_data);

View File

@ -254,6 +254,11 @@ TfLiteStatus ParseSin(const Operator* op, ErrorReporter* error_reporter,
TfLiteStatus ParseSoftmax(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data);
TfLiteStatus ParseSpaceToDepth(const Operator* op,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data);
TfLiteStatus ParseSplit(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data);

View File

@ -111,11 +111,15 @@ class Subgraph {
inline TfLiteStatus SetTensorParametersReadWrite(
int tensor_index, TfLiteType type, const char* name,
const std::vector<int>& dims, TfLiteQuantization quantization,
bool is_variable = false, const size_t rank_dims_signature = 0,
const int* dims_signature = nullptr) {
bool is_variable = false, const std::vector<int>& dims_signature = {}) {
if (dims_signature.empty()) {
return SetTensorParametersReadWrite(tensor_index, type, name, dims.size(),
dims.data(), quantization, is_variable,
rank_dims_signature, dims_signature);
dims.data(), quantization,
is_variable);
}
return SetTensorParametersReadWrite(
tensor_index, type, name, dims.size(), dims.data(), quantization,
is_variable, dims_signature.size(), dims_signature.data());
}
TfLiteStatus SetTensorParametersReadWrite(
int tensor_index, TfLiteType type, const char* name, const size_t rank,

View File

@ -45,6 +45,56 @@ int3 GetWorkGroupsCount(int grid_dimension, const int3& grid_size,
}
return work_groups_count;
}
std::string GetCommonOpenCLDefines(CalculationsPrecision precision) {
std::string result;
result += "#define GLOBAL_ID_0 get_global_id(0)\n";
result += "#define GLOBAL_ID_1 get_global_id(1)\n";
result += "#define GLOBAL_ID_2 get_global_id(2)\n";
result += "#define MAIN_FUNCTION __kernel void main_function\n";
switch (precision) {
case CalculationsPrecision::F32:
result += "#pragma OPENCL EXTENSION cl_khr_3d_image_writes : enable\n";
result += "#define ACCUM_FLT4 float4\n";
result += "#define FLT float\n";
result += "#define FLT2 float2\n";
result += "#define FLT3 float3\n";
result += "#define FLT4 float4\n";
result += "#define TO_FLT4 convert_float4\n";
result += "#define TO_ACCUM_TYPE convert_float4\n";
result += "#define TO_ACCUM_FLT convert_float\n";
result += "#define INIT_FLT4(value) (float4)(value)\n";
break;
case CalculationsPrecision::F16:
result += "#pragma OPENCL EXTENSION cl_khr_3d_image_writes : enable\n";
result += "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n";
result += "#define ACCUM_FLT4 half4\n";
result += "#define FLT half\n";
result += "#define FLT2 half2\n";
result += "#define FLT3 half3\n";
result += "#define FLT4 half4\n";
result += "#define TO_FLT4 convert_half4\n";
result += "#define TO_ACCUM_TYPE convert_half4\n";
result += "#define TO_ACCUM_FLT convert_half\n";
result += "#define INIT_FLT4(value) (half4)(value)\n";
break;
case CalculationsPrecision::F32_F16:
result += "#pragma OPENCL EXTENSION cl_khr_3d_image_writes : enable\n";
result += "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n";
result += "#define ACCUM_FLT4 float4\n";
result += "#define FLT half\n";
result += "#define FLT2 half2\n";
result += "#define FLT3 half3\n";
result += "#define FLT4 half4\n";
result += "#define TO_FLT4 convert_half4\n";
result += "#define TO_ACCUM_TYPE convert_float4\n";
result += "#define TO_ACCUM_FLT convert_float\n";
result += "#define INIT_FLT4(value) (half4)(value)\n";
break;
}
return result;
}
} // namespace
ClOperation::ClOperation(ClOperation&& operation)
@ -94,6 +144,9 @@ absl::Status ClOperation::UpdateParams() {
absl::Status ClOperation::Compile(const CreationContext& creation_context) {
operation_->AssembleCode(creation_context.GetGpuInfo());
operation_->code_ =
GetCommonOpenCLDefines(operation_->definition_.precision) +
operation_->code_;
RETURN_IF_ERROR(cl_args_.Init(
creation_context.GetGpuInfo(),
{{operation_->dst_tensors_names_[0], operation_->elementwise_code_}},

View File

@ -336,6 +336,8 @@ absl::Status Tensor::GetGPUResources(const GPUObjectDescriptor* obj_ptr,
if (!tensor_desc) {
return absl::InvalidArgumentError("Expected TensorDescriptor on input.");
}
resources->ints.push_back(
{"slice_stride", tensor_desc->GetSliceStrideSize(shape_)});
if (descriptor_.HasAxis(Axis::WIDTH)) {
resources->ints.push_back({"width", Width()});
resources->ints.push_back({"width_div2", Width() / 2});

View File

@ -22,61 +22,18 @@ limitations under the License.
namespace tflite {
namespace gpu {
namespace {
std::string GetCommonDefines(CalculationsPrecision precision) {
std::string result;
switch (precision) {
case CalculationsPrecision::F32:
result += "#pragma OPENCL EXTENSION cl_khr_3d_image_writes : enable\n";
result += "#define ACCUM_FLT4 float4\n";
result += "#define FLT float\n";
result += "#define FLT2 float2\n";
result += "#define FLT3 float3\n";
result += "#define FLT4 float4\n";
result += "#define TO_FLT4 convert_float4\n";
result += "#define TO_ACCUM_TYPE convert_float4\n";
result += "#define TO_ACCUM_FLT convert_float\n";
break;
case CalculationsPrecision::F16:
result += "#pragma OPENCL EXTENSION cl_khr_3d_image_writes : enable\n";
result += "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n";
result += "#define ACCUM_FLT4 half4\n";
result += "#define FLT half\n";
result += "#define FLT2 half2\n";
result += "#define FLT3 half3\n";
result += "#define FLT4 half4\n";
result += "#define TO_FLT4 convert_half4\n";
result += "#define TO_ACCUM_TYPE convert_half4\n";
result += "#define TO_ACCUM_FLT convert_half\n";
break;
case CalculationsPrecision::F32_F16:
result += "#pragma OPENCL EXTENSION cl_khr_3d_image_writes : enable\n";
result += "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n";
result += "#define ACCUM_FLT4 float4\n";
result += "#define FLT half\n";
result += "#define FLT2 half2\n";
result += "#define FLT3 half3\n";
result += "#define FLT4 half4\n";
result += "#define TO_FLT4 convert_half4\n";
result += "#define TO_ACCUM_TYPE convert_float4\n";
result += "#define TO_ACCUM_FLT convert_float\n";
break;
}
return result;
}
std::string GetElementWiseCode(const OperationDef& op_def,
bool check_src_slices) {
std::string c;
c += "__kernel void main_function(\n";
c += "MAIN_FUNCTION(\n";
c += "$0) {\n";
c += " int X = get_global_id(0);\n";
c += " int Y = get_global_id(1);\n";
c += " int Z = get_global_id(2);\n";
c += " int X = GLOBAL_ID_0;\n";
c += " int Y = GLOBAL_ID_1;\n";
c += " int Z = GLOBAL_ID_2;\n";
c += " if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() || "
"Z >= args.dst_tensor.Slices()) return; \n";
if (check_src_slices) {
c += " FLT4 src = (FLT4)(0.0f);\n";
c += " FLT4 src = INIT_FLT4(0.0f);\n";
c += " if (Z < args.src_tensor.Slices()) {\n";
c += " src = args.src_tensor.Read(X, Y, Z);\n";
c += " }\n";
@ -240,7 +197,6 @@ void GPUOperation::AssembleCode(const GpuInfo& gpu_info) {
elementwise_code_ = "{\n" + code_ + "\n}\n" + elementwise_code_;
code_ = GetElementWiseCode(definition_, check_src_channels_size_);
}
code_ = GetCommonDefines(definition_.precision) + code_;
}
void GPUOperation::GetPossibleKernelWorkGroups(

View File

@ -94,6 +94,7 @@ TensorDescriptor& TensorDescriptor::operator=(TensorDescriptor&& desc) {
GPUResources TensorDescriptor::GetGPUResources() const {
GPUResources resources;
resources.ints.push_back("slice_stride");
if (HasAxis(Axis::WIDTH)) {
resources.ints.push_back("width");
resources.ints.push_back("width_div2");
@ -175,7 +176,7 @@ absl::Status TensorDescriptor::PerformSelector(
*result = "slices";
return absl::OkStatus();
} else if (selector == "SliceStride") {
*result = GetSliceStride();
*result = "slice_stride";
return absl::OkStatus();
} else if (selector == "Channels") {
*result = "channels";
@ -402,7 +403,7 @@ absl::Status TensorDescriptor::PerformGetPtrWithSliceOffsetSelector(
"GetPtrWithSliceOffset require one argument(slice coordinate), but ",
args.size(), " was passed"));
}
*result = absl::StrCat("buffer + ", args[0], " * ", GetSliceStride());
*result = absl::StrCat("buffer + ", args[0], " * slice_stride");
return absl::OkStatus();
}
@ -644,6 +645,35 @@ bool TensorDescriptor::HasAxis(Axis axis) const {
return false;
}
int TensorDescriptor::GetWidthSize(BHWDC shape) const {
int width = shape.w;
auto it1 = state_vars_.find("ElementsX2");
if (it1 != state_vars_.end() && it1->second == "true") {
width /= 2;
}
auto it2 = state_vars_.find("ElementsX4");
if (it2 != state_vars_.end() && it2->second == "true") {
width /= 4;
}
auto it = state_vars_.find("BatchedWidth");
if (it != state_vars_.end() && it->second == "true") {
width *= shape.b;
}
return width;
}
int TensorDescriptor::GetSliceStrideSize(BHWDC shape) const {
if (IsBatchedWidth()) {
return GetWidthSize(shape) * shape.h;
} else {
if (HasAxis(Axis::BATCH)) {
return GetWidthSize(shape) * shape.h * shape.b;
} else {
return GetWidthSize(shape) * shape.h;
}
}
}
void TensorDescriptor::SetAddressMode(AddressMode mode) {
if (mode == AddressMode::kZero) {
state_vars_["TextureMode"] = "ZERO";
@ -719,18 +749,6 @@ std::string TensorDescriptor::GetWidth() const {
}
}
std::string TensorDescriptor::GetSliceStride() const {
if (IsBatchedWidth()) {
return GetWidth() + " * height";
} else {
if (HasAxis(Axis::BATCH)) {
return GetWidth() + " * height * batch";
} else {
return GetWidth() + " * height";
}
}
}
AddressMode TensorDescriptor::AddressModeFromState() const {
auto it = state_vars_.find("TextureMode");
if (it != state_vars_.end()) {

View File

@ -70,6 +70,8 @@ struct TensorDescriptor : public GPUObjectDescriptor {
bool HasAxis(Axis axis) const;
void SetAddressMode(AddressMode mode);
int GetWidthSize(BHWDC shape) const;
int GetSliceStrideSize(BHWDC shape) const;
absl::Status GetLinkingContextFromWriteSelector(
const std::vector<std::string>& args, std::string* value_name,
@ -136,7 +138,6 @@ struct TensorDescriptor : public GPUObjectDescriptor {
bool IsBatchedWidth() const;
std::string GetWidth() const;
std::string GetSliceStride() const;
AddressMode AddressModeFromState() const;

View File

@ -677,6 +677,7 @@ cc_library(
"//tensorflow/lite/delegates/gpu/common:shape",
"//tensorflow/lite/delegates/gpu/common:types",
"//tensorflow/lite/delegates/gpu/common:util",
"//tensorflow/lite/delegates/gpu/common/task:util",
"//tensorflow/lite/delegates/gpu/metal:compute_task_descriptor",
],
)

View File

@ -32,57 +32,54 @@ namespace gpu {
namespace metal {
std::string GetResizeBilinearCode(const Resize2DAttributes& attr) {
std::string code = R"(
std::string c = R"(
#include <metal_stdlib>
using namespace metal;
$0
kernel void ComputeFunction(
$1
uint3 gid[[thread_position_in_grid]]) {
if (int(gid.x) >= size.z || int(gid.y) >= size.w) {
if (int(gid.x) >= args.dst_tensor.Width() || int(gid.y) >= args.dst_tensor.Height()) {
return;
})";
if (attr.half_pixel_centers) {
code += "const float2 tex_coord = (float2(gid.xy) + 0.5f) * scale - 0.5f;";
} else {
code += "const float2 tex_coord = float2(gid.xy) * scale;";
}
code += R"(
const float2 tex_coord_floor = floor(tex_coord);
const int2 itex_coord_floor = int2(tex_coord_floor);
const int2 borders = size.xy - int2(1, 1);
)";
if (attr.half_pixel_centers) {
c += " float2 tex_coord = (float2(gid.xy) + 0.5f) * scale - 0.5f;";
} else {
c += " float2 tex_coord = float2(gid.xy) * scale;";
}
c += R"(
float2 tex_coord_floor = floor(tex_coord);
int2 itex_coord_floor = int2(tex_coord_floor);
int2 borders = int2(args.src_tensor.Width() - 1, args.src_tensor.Height() - 1);
int4 st;
st.xy = max(itex_coord_floor, int2(0, 0));
st.zw = min(itex_coord_floor + int2(1, 1), borders);
const float2 t = tex_coord - tex_coord_floor; // interpolating factors
const int src_index0 = (gid.z * size.y + st.y) * size.x + st.x;
const int src_index1 = (gid.z * size.y + st.y) * size.x + st.z;
const int src_index2 = (gid.z * size.y + st.w) * size.x + st.x;
const int src_index3 = (gid.z * size.y + st.w) * size.x + st.z;
FLT4 tex11 = src_tensor[src_index0];
FLT4 tex21 = src_tensor[src_index1];
FLT4 tex12 = src_tensor[src_index2];
FLT4 tex22 = src_tensor[src_index3];
float2 t = tex_coord - tex_coord_floor; // interpolating factors
FLT4 tex11 = args.src_tensor.Read(st.x, st.y, gid.z);
FLT4 tex21 = args.src_tensor.Read(st.z, st.y, gid.z);
FLT4 tex12 = args.src_tensor.Read(st.x, st.w, gid.z);
FLT4 tex22 = args.src_tensor.Read(st.z, st.w, gid.z);
// bilinear interpolation
FLT4 value = mix(mix(tex11, tex21, static_cast<FLT>(t.x)),
mix(tex12, tex22, static_cast<FLT>(t.x)), static_cast<FLT>(t.y));
const int linear_index = (gid.z * size.w + gid.y) * size.z + gid.x;
args.dst_tensor.GetAddress(linear_index, gid.x, gid.y, gid.z);
$2
dst_tensor[linear_index] = value;
args.dst_tensor.Write(value, gid.x, gid.y, gid.z);
}
)";
return code;
return c;
}
std::string GetResizeNearestCode(const Resize2DAttributes& attr) {
std::string code = R"(
std::string c = R"(
#include <metal_stdlib>
using namespace metal;
$0
kernel void ComputeFunction(
$1
uint3 gid[[thread_position_in_grid]]) {
if (int(gid.x) >= size.z || int(gid.y) >= size.w) {
if (int(gid.x) >= args.dst_tensor.Width() || int(gid.y) >= args.dst_tensor.Height()) {
return;
}
)";
@ -99,27 +96,27 @@ std::string GetResizeNearestCode(const Resize2DAttributes& attr) {
fxc += " + 0.5f";
fyc += " + 0.5f";
}
code += " int2 coord;\n";
code += " coord.x = static_cast<int>(" + fxc + ");\n";
code += " coord.y = static_cast<int>(" + fyc + ");\n";
code += " coord.x = max(0, coord.x);\n";
code += " coord.y = max(0, coord.y);\n";
code += " coord.x = min(coord.x, size.x - 1);\n";
code += " coord.y = min(coord.y, size.y - 1);\n";
code += R"(
const int src_index = (gid.z * size.y + coord.y) * size.x + coord.x;
FLT4 value = src_tensor[src_index];
const int linear_index = (gid.z * size.w + gid.y) * size.z + gid.x;
c += " int2 coord;\n";
c += " coord.x = static_cast<int>(" + fxc + ");\n";
c += " coord.y = static_cast<int>(" + fyc + ");\n";
c += " coord.x = max(0, coord.x);\n";
c += " coord.y = max(0, coord.y);\n";
c += " coord.x = min(coord.x, args.src_tensor.Width() - 1);\n";
c += " coord.y = min(coord.y, args.src_tensor.Height() - 1);\n";
c += R"(
FLT4 value = args.src_tensor.Read(coord.x, coord.y, gid.z);
args.dst_tensor.GetAddress(linear_index, gid.x, gid.y, gid.z);
$2
dst_tensor[linear_index] = value;
args.dst_tensor.Write(value, gid.x, gid.y, gid.z);
}
)";
return code;
return c;
}
ComputeTaskDescriptor Resize(const OperationDef& definition,
const Resize2DAttributes& attr) {
ComputeTaskDescriptor desc(definition);
desc.tensors_as_args = true;
switch (attr.type) {
case SamplingType::BILINEAR:
desc.shader_source = GetResizeBilinearCode(attr);
@ -136,17 +133,6 @@ ComputeTaskDescriptor Resize(const OperationDef& definition,
desc.AddDstTensor("dst_tensor", definition.dst_tensors[0]);
desc.uniform_buffers = {
{"constant int4& size",
[](const std::vector<BHWC>& src_shapes,
const std::vector<BHWC>& dst_shapes) {
std::vector<int> sizes = {
src_shapes[0].w,
src_shapes[0].h,
dst_shapes[0].w,
dst_shapes[0].h,
};
return GetByteBuffer(sizes);
}},
{"constant float2& scale",
[attr](const std::vector<BHWC>& src_shapes,
const std::vector<BHWC>& dst_shapes) {
@ -160,10 +146,10 @@ ComputeTaskDescriptor Resize(const OperationDef& definition,
desc.resize_function = [](const std::vector<BHWC>& src_shapes,
const std::vector<BHWC>& dst_shapes) {
const uint3 groups_size{16, 16, 1};
const uint3 groups_size{8, 8, 1};
const int dst_layers = DivideRoundUp(dst_shapes[0].c, 4);
int groups_x = DivideRoundUp(dst_shapes[0].w, groups_size.x);
int groups_y = DivideRoundUp(dst_shapes[0].h, groups_size.y);
const int dst_layers = DivideRoundUp(dst_shapes[0].c, 4);
int groups_z = DivideRoundUp(dst_layers, groups_size.z);
return std::make_pair(groups_size, uint3{groups_x, groups_y, groups_z});
};

View File

@ -35,123 +35,146 @@ namespace gpu {
namespace metal {
namespace {
std::string GetSliceCode(const SliceAttributes& attr) {
std::stringstream code;
namespace {
bool Is4Aligned(const SliceAttributes& attr) {
return attr.strides.c == 1 && attr.starts.c % 4 == 0;
}
code << R"(
int4 GetOffset(const SliceAttributes& attr, int src_width, int src_height,
int src_channels, int src_batch) {
int4 offset;
if (attr.strides.w > 0) {
offset.x = attr.starts.w;
} else {
if (attr.ends.w > 0) {
offset.x = attr.ends.w;
} else {
offset.x = src_width + attr.ends.w;
}
}
if (attr.strides.h > 0) {
offset.y = attr.starts.h;
} else {
if (attr.ends.h > 0) {
offset.y = attr.ends.h;
} else {
offset.y = src_height + attr.ends.h;
}
}
if (attr.strides.c > 0) {
offset.z = attr.starts.c;
} else {
if (attr.ends.c > 0) {
offset.z = attr.ends.c;
} else {
offset.z = src_channels + attr.ends.c;
}
}
if (Is4Aligned(attr)) {
offset.z /= 4;
}
if (attr.strides.b > 0) {
offset.w = attr.starts.b;
} else {
if (attr.ends.b > 0) {
offset.w = attr.ends.b;
} else {
offset.w = src_batch + attr.ends.b;
}
}
return offset;
}
} // namespace
std::string GetSliceCode(const OperationDef& op_def, bool alignedx4) {
const std::string batch_id =
op_def.dst_tensors[0].HasAxis(Axis::BATCH) ? "B" : "0";
std::string c = R"(
#include <metal_stdlib>
using namespace metal;
struct uniforms {
int4 src_size;
int4 dst_size;
int4 offset;
int4 stride;
};
constant int4 width = int4($0, $1, $2, 0);
constant int4 height = int4($3, $4, $5, 0);
constant int4 channels = int4($6, $7, $8, 0);
constant FLT4 null_vec = FLT4(0.0f, 0.0f, 0.0f, 0.0f);
$$0
kernel void ComputeFunction(
$$1
$0
kernel void ComputeFunction($1
uint3 gid[[thread_position_in_grid]]) {
if (static_cast<int>(gid.x) >= params.dst_size.x ||
static_cast<int>(gid.y) >= params.dst_size.y) {
return;
}
FLT4 value;
short2 offset;
)";
if (attr.strides.w > 0) {
code << " offset.x = width.x;" << std::endl;
if (op_def.dst_tensors[0].HasAxis(Axis::BATCH)) {
c += " int linear_id = static_cast<int>(gid.x);\n";
c += " int X = linear_id / args.dst_tensor.Batch();\n";
c += " int B = linear_id % args.dst_tensor.Batch();\n";
c += " args.dst_tensor.SetBatchRef(B);\n";
} else {
if (attr.ends.w > 0) {
code << " offset.x = width.z;" << std::endl;
c += " int X = static_cast<int>(gid.x);\n";
}
c += " int Y = static_cast<int>(gid.y);\n";
c += " int Z = static_cast<int>(gid.z);\n";
c += " if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() || "
"Z >= args.dst_tensor.Slices()) { \n";
c += " return; \n";
c += " } \n";
c += " int s_x = X * params.stride.x + params.offset.x;\n";
c += " int s_y = Y * params.stride.y + params.offset.y;\n";
if (op_def.src_tensors[0].HasAxis(Axis::BATCH)) {
c += " int s_b = " + batch_id + " * params.stride.w + params.offset.w;\n";
c += " args.src_tensor.SetBatchRef(s_b);\n";
}
if (alignedx4) {
c += " int s_z = Z + params.offset.z;\n";
c += " FLT4 result = args.src_tensor.Read(s_x, s_y, s_z);\n";
} else {
code << " offset.x = params.src_size.x + width.z;" << std::endl;
c += " FLT4 result;\n";
const std::string postfixes[] = {"x", "y", "z", "w"};
for (int i = 0; i < 4; ++i) {
c += " {\n";
const std::string ch = "(Z * 4 + " + std::to_string(i) + ")";
c += " int s_ch = " + ch + " * params.stride.z + params.offset.z;\n";
c += " int s_z = min(s_ch >> 2, args.src_tensor.Slices() - 1);\n";
c += " int s_z_rem = s_ch & 3;\n";
c += " FLT4 t = args.src_tensor.Read(s_x, s_y, s_z);\n";
c += " result." + postfixes[i] + " = t[s_ch & 3];\n";
c += " }\n";
}
}
if (attr.strides.h > 0) {
code << " offset.y = height.x;" << std::endl;
} else {
if (attr.ends.h > 0) {
code << " offset.y = height.z;" << std::endl;
} else {
code << " offset.y = params.src_size.y + height.z;" << std::endl;
}
}
code << std::endl;
code << " short2 stride = short2(width.y, height.y);" << std::endl;
code << " const short2 s_c = offset + short2(gid.xy) * stride;"
<< std::endl;
code << " bool outside = false;" << std::endl;
code << " int step = gid.z * 4;" << std::endl;
code << " FLT4 tmp;" << std::endl;
code << " int buffer_index = 0;" << std::endl;
code << " int addr = 0;" << std::endl;
code << std::endl;
for (int i = 0; i < 4; i++) {
code << " addr = step * channels.y;" << std::endl;
if (attr.strides.c > 0) {
code << " addr += channels.x;" << std::endl;
} else {
if (attr.ends.c > 0) {
code << " addr += channels.z;" << std::endl;
} else {
code << " addr += params.src_size.z + channels.z;" << std::endl;
}
}
code << " buffer_index = ((addr / 4) * params.src_size.y + s_c.y) * "
"params.src_size.x + "
"s_c.x;"
<< std::endl;
code << " outside = step >= params.dst_size.z;" << std::endl;
code << " tmp = outside ? null_vec : src_tensor[buffer_index];"
<< std::endl;
code << " value[" << i << "] = tmp[addr % 4];" << std::endl;
if (i != 3) {
code << " step++;" << std::endl;
code << std::endl;
}
}
code << R"(
int linear_index = (gid.z * params.dst_size.y + int(gid.y)) *
params.dst_size.x + int(gid.x);
$$2
dst_tensor[linear_index] = value;
})";
return absl::Substitute(
code.str(), attr.starts.w, attr.strides.w, attr.ends.w, attr.starts.h,
attr.strides.h, attr.ends.h, attr.starts.c, attr.strides.c, attr.ends.c);
c += " FLT4 value = result;\n";
c += " args.dst_tensor.GetAddress(linear_index, X, Y, Z);\n";
c += " $2\n";
c += " args.dst_tensor.Write(value, X, Y, Z);\n";
c += "}\n";
return c;
}
} // namespace
ComputeTaskDescriptor Slice(const OperationDef& definition,
const SliceAttributes& attr) {
ComputeTaskDescriptor desc(definition);
desc.shader_source = GetSliceCode(attr);
desc.tensors_as_args = true;
desc.shader_source = GetSliceCode(definition, Is4Aligned(attr));
desc.AddSrcTensor("src_tensor", definition.src_tensors[0]);
desc.AddDstTensor("dst_tensor", definition.dst_tensors[0]);
desc.uniform_buffers = {
{"constant uniforms& params",
[](const std::vector<BHWC>& src_shapes,
[attr](const std::vector<BHWC>& src_shapes,
const std::vector<BHWC>& dst_shapes) {
int4 offset = GetOffset(attr, src_shapes[0].w, src_shapes[0].h,
src_shapes[0].c, src_shapes[0].b);
std::vector<int> uniform_params{
// int4 src_size
src_shapes[0].w,
src_shapes[0].h,
src_shapes[0].c,
DivideRoundUp(src_shapes[0].c, 4),
// int4 dst_size
dst_shapes[0].w,
dst_shapes[0].h,
dst_shapes[0].c,
DivideRoundUp(dst_shapes[0].c, 4),
// int4 offset
offset.x,
offset.y,
offset.z,
offset.w,
// int4 stride
attr.strides.w,
attr.strides.h,
attr.strides.c,
attr.strides.b,
};
return GetByteBuffer(uniform_params);
}},
@ -159,10 +182,10 @@ ComputeTaskDescriptor Slice(const OperationDef& definition,
desc.resize_function = [attr](const std::vector<BHWC>& src_shapes,
const std::vector<BHWC>& dst_shapes) {
const uint3 groups_size{16, 16, 1};
const uint3 groups_size{8, 4, 1};
const int dst_layers = DivideRoundUp(dst_shapes[0].c, 4);
int groups_x = DivideRoundUp(dst_shapes[0].w, groups_size.x);
int groups_y = DivideRoundUp(dst_shapes[0].h, groups_size.y);
const int dst_layers = DivideRoundUp(dst_shapes[0].c, 4);
int groups_z = DivideRoundUp(dst_layers, groups_size.z);
return std::make_pair(groups_size, uint3{groups_x, groups_y, groups_z});
};

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/lite/delegates/gpu/common/gpu_info.h"
#include "tensorflow/lite/delegates/gpu/common/model.h"
#include "tensorflow/lite/delegates/gpu/common/shape.h"
#include "tensorflow/lite/delegates/gpu/common/task/util.h"
#include "tensorflow/lite/delegates/gpu/common/types.h"
#include "tensorflow/lite/delegates/gpu/common/util.h"
#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
@ -40,7 +41,6 @@ std::string GetSoftmax1x1Code(const GpuInfo& gpu_info) {
using namespace metal;
struct uniforms {
int4 size;
float4 mask;
};
@ -51,11 +51,11 @@ kernel void ComputeFunction($1
uint3 ugid[[thread_position_in_grid]])
{
float4 maxx4 = float4(src_tensor[0].x);
for (int s = int(tid); s < params.size.x; s += 32) {
float4 mask_a = s == params.size.x - 1 ? params.mask : float4(1.0f);
float4 maxx4 = float4(args.src_tensor.Read(0, 0, 0).x);
for (int s = int(tid); s < args.src_tensor.Slices(); s += 32) {
float4 mask_a = s == args.src_tensor.Slices() - 1 ? params.mask : float4(1.0f);
float4 mask_b = float4(1.0f) - mask_a;
float4 src = float4(src_tensor[s]);
float4 src = float4(args.src_tensor.Read(0, 0, s));
src = src * mask_a + mask_b * src.x;
maxx4 = max(maxx4, src);
}
@ -89,9 +89,9 @@ kernel void ComputeFunction($1
maximum = tmpx1[0];
float sum = 0.0f;
for (int s = int(tid); s < params.size.x; s += 32) {
float4 mask_temp = s == params.size.x - 1 ? params.mask : float4(1.0f);
float4 src = float4(src_tensor[s]) - float4(maximum);
for (int s = int(tid); s < args.src_tensor.Slices(); s += 32) {
float4 mask_temp = s == args.src_tensor.Slices() - 1 ? params.mask : float4(1.0f);
float4 src = float4(args.src_tensor.Read(0, 0, s)) - float4(maximum);
sum += dot(mask_temp, exp(src));
}
@ -120,13 +120,13 @@ kernel void ComputeFunction($1
sum = tmpx1[0];
int dst_s = int(ugid.x);
if (dst_s < params.size.x) {
int linear_index = dst_s;
float4 src = float4(src_tensor[linear_index]) - float4(maximum);
if (dst_s < args.src_tensor.Slices()) {
float4 src = float4(args.src_tensor.Read(0, 0, dst_s)) - float4(maximum);
FLT4 value = FLT4(exp(src) * sum);
uint3 gid = uint3(0, 0, linear_index);
uint3 gid = uint3(0, 0, dst_s);
args.dst_tensor.GetAddress(linear_index, 0, 0, dst_s);
$2
dst_tensor[linear_index] = value;
args.dst_tensor.Write(value, 0, 0, dst_s);
}
})";
return code;
@ -135,28 +135,27 @@ kernel void ComputeFunction($1
ComputeTaskDescriptor Softmax(const OperationDef& definition) {
ComputeTaskDescriptor desc(definition);
desc.tensors_as_args = true;
desc.shader_source = R"(
#include <metal_stdlib>
using namespace metal;
struct uniforms {
int4 size;
float4 mask;
};
$0
kernel void ComputeFunction(
$1
uint3 gid[[thread_position_in_grid]]) {
if (int(gid.x) >= params.size.x || int(gid.y) >= params.size.y) {
if (int(gid.x) >= args.dst_tensor.Width() || int(gid.y) >= args.dst_tensor.Height()) {
return;
}
float maximum = src_tensor[gid.y * params.size.x + gid.x].x;
for (int d = 0; d < params.size.z; ++d) {
int buffer_index = (d * params.size.y + gid.y) * params.size.x + gid.x;
float4 mask_a = d == params.size.z - 1 ? params.mask : float4(1.0f);
float maximum = args.src_tensor.Read(gid.x, gid.y, 0).x;
for (int d = 0; d < args.dst_tensor.Slices(); ++d) {
float4 mask_a = d == args.dst_tensor.Slices() - 1 ? params.mask : float4(1.0f);
float4 mask_b = float4(1.0f) - mask_a;
float4 src = float4(src_tensor[buffer_index]);
float4 src = float4(args.src_tensor.Read(gid.x, gid.y, d));
src = src * mask_a + mask_b * src.x;
maximum = max(maximum, src.x);
maximum = max(maximum, src.y);
@ -165,19 +164,18 @@ kernel void ComputeFunction(
}
float sum = 0.0f;
for (int d = 0; d < params.size.z; ++d) {
int buffer_index = (d * params.size.y + gid.y) * params.size.x + gid.x;
float4 mask_temp = d == params.size.z - 1 ? params.mask : float4(1.0f);
float4 src = float4(src_tensor[buffer_index]) - float4(maximum);
for (int d = 0; d < args.dst_tensor.Slices(); ++d) {
float4 mask_temp = d == args.dst_tensor.Slices() - 1 ? params.mask : float4(1.0f);
float4 src = float4(args.src_tensor.Read(gid.x, gid.y, d)) - float4(maximum);
sum += dot(mask_temp, exp(src));
}
for (int d = 0; d < params.size.z; ++d) {
const int linear_index = (d * params.size.y + gid.y) * params.size.x + gid.x;
float4 src = float4(src_tensor[linear_index]) - float4(maximum);
for (int d = 0; d < args.dst_tensor.Slices(); ++d) {
float4 src = float4(args.src_tensor.Read(gid.x, gid.y, d)) - float4(maximum);
FLT4 value = FLT4(exp(src) / sum);
args.dst_tensor.GetAddress(linear_index, gid.x, gid.y, d);
$2
dst_tensor[linear_index] = value;
args.dst_tensor.Write(value, gid.x, gid.y, d);
}
}
)";
@ -189,20 +187,9 @@ kernel void ComputeFunction(
{"constant uniforms& params",
[](const std::vector<BHWC>& src_shapes,
const std::vector<BHWC>& dst_shapes) {
const int dst_depth = DivideRoundUp(dst_shapes[0].c, 4);
struct uniforms {
int4 size;
float4 mask;
};
uniforms params;
params.size = {dst_shapes[0].w, dst_shapes[0].h, dst_depth, 1};
params.mask = {0.0f, 0.0f, 0.0f, 0.0f};
int reminder = dst_shapes[0].c % 4 == 0 ? 4 : dst_shapes[0].c % 4;
for (int i = 0; i < reminder; ++i) {
params.mask[i] = 1.0f;
}
const uint8_t* ptr = reinterpret_cast<const uint8_t*>(&params);
return std::vector<uint8_t>(ptr, ptr + sizeof(uniforms));
float4 mask = GetMaskForLastPlane(dst_shapes[0].c);
const uint8_t* ptr = reinterpret_cast<const uint8_t*>(&mask);
return std::vector<uint8_t>(ptr, ptr + sizeof(float4));
}},
};
@ -220,6 +207,7 @@ kernel void ComputeFunction(
ComputeTaskDescriptor Softmax1x1(const OperationDef& definition,
const GpuInfo& gpu_info) {
ComputeTaskDescriptor desc(definition);
desc.tensors_as_args = true;
desc.shader_source = GetSoftmax1x1Code(gpu_info);
desc.AddSrcTensor("src_tensor", definition.src_tensors[0]);
@ -229,20 +217,9 @@ ComputeTaskDescriptor Softmax1x1(const OperationDef& definition,
{"constant uniforms& params",
[](const std::vector<BHWC>& src_shapes,
const std::vector<BHWC>& dst_shapes) {
const int src_depth = DivideRoundUp(dst_shapes[0].c, 4);
struct uniforms {
int4 size;
float4 mask;
};
uniforms params;
params.size = {src_depth, DivideRoundUp(src_depth, 32), 1, 1};
params.mask = {0.0f, 0.0f, 0.0f, 0.0f};
int reminder = dst_shapes[0].c % 4 == 0 ? 4 : dst_shapes[0].c % 4;
for (int i = 0; i < reminder; ++i) {
params.mask[i] = 1.0f;
}
const uint8_t* ptr = reinterpret_cast<const uint8_t*>(&params);
return std::vector<uint8_t>(ptr, ptr + sizeof(uniforms));
float4 mask = GetMaskForLastPlane(dst_shapes[0].c);
const uint8_t* ptr = reinterpret_cast<const uint8_t*>(&mask);
return std::vector<uint8_t>(ptr, ptr + sizeof(float4));
}},
};

View File

@ -124,6 +124,8 @@ absl::Status MetalSpatialTensor::GetGPUResources(
if (!tensor_desc) {
return absl::InvalidArgumentError("Expected TensorDescriptor on input.");
}
resources->ints.push_back(
{"slice_stride", tensor_desc->GetSliceStrideSize(shape_)});
if (descriptor_.HasAxis(Axis::WIDTH)) {
resources->ints.push_back({"width", Width()});
resources->ints.push_back({"width_div2", Width() / 2});

View File

@ -310,7 +310,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Downloading data from https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media\u0026token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8\n",
"Downloading data from https://dl.fbaipublicfiles.com/glue/data/SST-2.zip\n",
"7446528/7439277 [==============================] - 0s 0us/step\n"
]
}
@ -318,7 +318,7 @@
"source": [
"data_dir = tf.keras.utils.get_file(\n",
" fname='SST-2.zip',\n",
" origin='https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media\u0026token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8',\n",
" origin='https://dl.fbaipublicfiles.com/glue/data/SST-2.zip',\n",
" extract=True)\n",
"data_dir = os.path.join(os.path.dirname(data_dir), 'SST-2')"
]

View File

@ -587,11 +587,9 @@ TfLiteStatus InterpreterBuilder::ParseTensors(
status = kTfLiteError;
}
size_t dims_signature_rank = 0;
const int* dims_signature_data = nullptr;
std::vector<int> dims_signature = {};
if (tensor->shape_signature()) {
dims_signature_rank = tensor->shape_signature()->size();
dims_signature_data = tensor->shape_signature()->data();
dims_signature = FlatBufferIntArrayToVector(tensor->shape_signature());
}
bool is_variable = tensor->is_variable();
@ -623,7 +621,7 @@ TfLiteStatus InterpreterBuilder::ParseTensors(
} else {
if (subgraph->SetTensorParametersReadWrite(
i, type, get_name(tensor), dims, quantization, is_variable,
dims_signature_rank, dims_signature_data) != kTfLiteOk) {
dims_signature) != kTfLiteOk) {
error_reporter_->Report("Tensor %d is invalidly specified in schema.\n",
i);
status = kTfLiteError;

View File

@ -80,6 +80,9 @@ TfLiteStatus copyToTensor(TfLiteContext* context, const FromT* in,
case kTfLiteInt32:
copyCast(in, out->data.i32, num_elements);
break;
case kTfLiteInt16:
copyCast(in, out->data.i16, num_elements);
break;
case kTfLiteUInt8:
copyCast(in, out->data.uint8, num_elements);
break;
@ -113,6 +116,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
return copyToTensor(context, input->data.i64, output, num_elements);
case kTfLiteInt32:
return copyToTensor(context, input->data.i32, output, num_elements);
case kTfLiteInt16:
return copyToTensor(context, input->data.i16, output, num_elements);
case kTfLiteUInt8:
return copyToTensor(context, input->data.uint8, output, num_elements);
case kTfLiteFloat32:

View File

@ -46,6 +46,22 @@ class CastOpModel : public SingleOpModel {
int output_;
};
TEST(CastOpModel, CastInt16ToFloat) {
CastOpModel m({TensorType_INT16, {2, 3}}, {TensorType_FLOAT32, {2, 3}});
m.PopulateTensor<int16_t>(m.input(), {100, 200, 300, 400, 500, 600});
m.Invoke();
EXPECT_THAT(m.ExtractVector<float>(m.output()),
ElementsAreArray({100.f, 200.f, 300.f, 400.f, 500.f, 600.f}));
}
TEST(CastOpModel, CastInt16ToInt32) {
CastOpModel m({TensorType_INT16, {2, 3}}, {TensorType_INT32, {2, 3}});
m.PopulateTensor<int16_t>(m.input(), {100, 200, 300, 400, 500, 600});
m.Invoke();
EXPECT_THAT(m.ExtractVector<int32_t>(m.output()),
ElementsAreArray({100, 200, 300, 400, 500, 600}));
}
TEST(CastOpModel, CastInt32ToFloat) {
CastOpModel m({TensorType_INT32, {2, 3}}, {TensorType_FLOAT32, {2, 3}});
m.PopulateTensor<int32_t>(m.input(), {100, 200, 300, 400, 500, 600});
@ -62,6 +78,14 @@ TEST(CastOpModel, CastFloatToInt32) {
ElementsAreArray({100, 20, 3, 0, 0, 1}));
}
TEST(CastOpModel, CastFloatToInt16) {
CastOpModel m({TensorType_FLOAT32, {3, 2}}, {TensorType_INT16, {3, 2}});
m.PopulateTensor<float>(m.input(), {100.f, 20.f, 3.f, 0.4f, 0.999f, 1.1f});
m.Invoke();
EXPECT_THAT(m.ExtractVector<int16_t>(m.output()),
ElementsAreArray({100, 20, 3, 0, 0, 1}));
}
TEST(CastOpModel, CastInt64ToFloat) {
CastOpModel m({TensorType_INT64, {2, 3}}, {TensorType_FLOAT32, {2, 3}});
m.PopulateTensor<int64_t>(m.input(), {100, 200, 300, 400, 500, 600});

View File

@ -121,7 +121,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
// TODO(ahentz): Creating 'all_inputs' below is not very efficient. We should
// allocate and populate these during Prepare().
// TODO(ycling): Activation function parameter is ignored. For now we dont have
// TODO(ycling): Activation function parameter is ignored. For now we don't have
// a model with a Concatenation with fused activation function.
#define TF_LITE_CONCATENATION(scalar) \
{ \

View File

@ -494,6 +494,7 @@ cc_library(
"reference/resize_nearest_neighbor.h",
"reference/round.h",
"reference/softmax.h",
"reference/space_to_depth.h",
"reference/strided_slice.h",
"reference/sub.h",
"reference/svdf.h",
@ -511,13 +512,14 @@ cc_library(
}),
compatible_with = get_compatible_with_portable(),
copts = tflite_copts(),
# We are disabling parse_headers for the tf_lite_static_memory build to
# allow it to be consistent with the OSS bazel build. See b/175817116
# for more details.
features = select({
":tf_lite_static_memory": ["-parse_headers"],
"//conditions:default": [],
}),
# We are disabling parse_headers for this header-only target so that the
# external and internal builds are consistent. The primary issue here is
# that parse_headers is not supported with bazel and the TFLM team would
# really like to have all build errors in shared Micro/Lite code be
# reproducible from the OSS build as well.
#
# See b/175817116 for more details.
features = ["-parse_headers"],
deps = [
":common",
":compatibility",
@ -588,6 +590,7 @@ cc_library(
"reference/resize_nearest_neighbor.h",
"reference/round.h",
"reference/softmax.h",
"reference/space_to_depth.h",
"reference/strided_slice.h",
"reference/string_comparisons.h",
"reference/sub.h",

View File

@ -15,16 +15,13 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_CONV_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_CONV_H_
#include "tensorflow/lite/kernels/internal/types.h"
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/types.h"
namespace tflite {
namespace reference_ops {
inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
const float* input_data, const RuntimeShape& filter_shape,
const float* filter_data, const RuntimeShape& bias_shape,

View File

@ -61,6 +61,7 @@ limitations under the License.
#include "tensorflow/lite/kernels/internal/reference/resize_nearest_neighbor.h"
#include "tensorflow/lite/kernels/internal/reference/round.h"
#include "tensorflow/lite/kernels/internal/reference/softmax.h"
#include "tensorflow/lite/kernels/internal/reference/space_to_depth.h"
#include "tensorflow/lite/kernels/internal/reference/strided_slice.h"
#include "tensorflow/lite/kernels/internal/reference/string_comparisons.h"
#include "tensorflow/lite/kernels/internal/reference/sub.h"
@ -126,58 +127,6 @@ inline void DepthToSpace(const tflite::DepthToSpaceParams& op_params,
}
}
template <typename T>
inline void SpaceToDepth(const tflite::SpaceToDepthParams& op_params,
const RuntimeShape& unextended_input_shape,
const T* input_data,
const RuntimeShape& unextended_output_shape,
T* output_data) {
TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
const RuntimeShape input_shape =
RuntimeShape::ExtendedShape(4, unextended_input_shape);
const RuntimeShape output_shape =
RuntimeShape::ExtendedShape(4, unextended_output_shape);
const int input_depth = input_shape.Dims(3);
const int input_width = input_shape.Dims(2);
const int input_height = input_shape.Dims(1);
const int input_batch = input_shape.Dims(0);
const int output_depth = output_shape.Dims(3);
const int output_width = output_shape.Dims(2);
const int output_height = output_shape.Dims(1);
const int output_batch = output_shape.Dims(0);
const int32 block_size = op_params.block_size;
TFLITE_DCHECK_EQ(input_width, output_width * block_size);
TFLITE_DCHECK_EQ(input_height, output_height * block_size);
TFLITE_DCHECK_EQ(input_depth * block_size * block_size, output_depth);
TFLITE_DCHECK_EQ(input_batch, output_batch);
for (int in_b = 0; in_b < input_batch; ++in_b) {
for (int in_h = 0; in_h < input_height; ++in_h) {
for (int in_w = 0; in_w < input_width; ++in_w) {
for (int in_d = 0; in_d < input_depth; ++in_d) {
const int out_d =
in_d + ((in_h % block_size) * block_size + in_w % block_size) *
input_depth;
const int out_w = in_w / block_size;
const int out_h = in_h / block_size;
const int out_b = in_b;
const int input_index = Offset(input_shape, in_b, in_h, in_w, in_d);
const int output_index =
Offset(output_shape, out_b, out_h, out_w, out_d);
output_data[output_index] = input_data[input_index];
}
}
}
}
}
inline void Elu(const RuntimeShape& input_shape, const float* input_data,
const RuntimeShape& output_shape, float* output_data) {
const int flat_size = MatchingFlatSize(input_shape, output_shape);

View File

@ -0,0 +1,78 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_SPACE_TO_DEPTH_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_SPACE_TO_DEPTH_H_
#include "tensorflow/lite/kernels/internal/types.h"
namespace tflite {
namespace reference_ops {
template <typename T>
inline void SpaceToDepth(const tflite::SpaceToDepthParams& op_params,
const RuntimeShape& unextended_input_shape,
const T* input_data,
const RuntimeShape& unextended_output_shape,
T* output_data) {
TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
const RuntimeShape input_shape =
RuntimeShape::ExtendedShape(4, unextended_input_shape);
const RuntimeShape output_shape =
RuntimeShape::ExtendedShape(4, unextended_output_shape);
const int input_depth = input_shape.Dims(3);
const int input_width = input_shape.Dims(2);
const int input_height = input_shape.Dims(1);
const int input_batch = input_shape.Dims(0);
const int output_depth = output_shape.Dims(3);
const int output_width = output_shape.Dims(2);
const int output_height = output_shape.Dims(1);
const int output_batch = output_shape.Dims(0);
const int32 block_size = op_params.block_size;
TFLITE_DCHECK_EQ(input_width, output_width * block_size);
TFLITE_DCHECK_EQ(input_height, output_height * block_size);
TFLITE_DCHECK_EQ(input_depth * block_size * block_size, output_depth);
TFLITE_DCHECK_EQ(input_batch, output_batch);
for (int in_b = 0; in_b < input_batch; ++in_b) {
for (int in_h = 0; in_h < input_height; ++in_h) {
for (int in_w = 0; in_w < input_width; ++in_w) {
for (int in_d = 0; in_d < input_depth; ++in_d) {
const int out_d =
in_d + ((in_h % block_size) * block_size + in_w % block_size) *
input_depth;
const int out_w = in_w / block_size;
const int out_h = in_h / block_size;
const int out_b = in_b;
const int input_index = Offset(input_shape, in_b, in_h, in_w, in_d);
const int output_index =
Offset(output_shape, out_b, out_h, out_w, out_d);
output_data[output_index] = input_data[input_index];
}
}
}
}
}
} // namespace reference_ops
} // namespace tflite
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_SPACE_TO_DEPTH_H_

View File

@ -170,7 +170,7 @@ TfLiteStatus PopulateQuantizedLstmParams8x8_16(
intermediate_zp.push_back(0);
}
}
// In the absense of projection, hidden becomes otuput and this intermediate
// In the absence of projection, hidden becomes otuput and this intermediate
// is ignored.
TfLiteTensor* hidden;
TF_LITE_ENSURE_OK(context, GetIntermediatesSafe(context, node, 4, &hidden));

View File

@ -204,7 +204,7 @@ to determine if the requested feature aligns with the TFLM roadmap.
1. Run all the tests for x86, and any other platform that you are modifying.
```
tensorflow/lite/micro/tools/make/tools/ci_build/test_x86.sh
tensorflow/lite/micro/tools/ci_build/test_x86.sh
```
Please check the READMEs in the optimized kernel directories for specific

View File

@ -1,2 +1,2 @@
numpy==1.16.2
tensorflow==2.0.0-beta1
tensorflow==2.4.0

View File

@ -33,7 +33,7 @@ set +e
# The pigweed scripts only work from a git repository and the Tensorflow CI
# infrastructure does not always guarantee that. As an ugly workaround, we
# create our own git repo when running on the CI servers.
pushd tensorflow/lite/micro/
pushd tensorflow/lite/
if [[ ${1} == "PRESUBMIT" ]]; then
git init .
git config user.email "tflm@google.com"
@ -43,9 +43,12 @@ if [[ ${1} == "PRESUBMIT" ]]; then
fi
# Check for license with the necessary exclusions.
tools/make/downloads/pigweed/pw_presubmit/py/pw_presubmit/pigweed_presubmit.py \
. \
micro/tools/make/downloads/pigweed/pw_presubmit/py/pw_presubmit/pigweed_presubmit.py \
kernels/internal/reference/ \
micro/ \
-p copyright_notice \
-e kernels/internal/reference/integer_ops/ \
-e kernels/internal/reference/reference_ops.h \
-e tools/make/downloads \
-e tools/make/targets/ecm3531 \
-e BUILD\
@ -66,8 +69,11 @@ LICENSE_CHECK_RESULT=$?
# Python files (with yapf as the formatter) because that needs additional setup.
# We are also ignoring the markdown files to allow for a more gradual rollout of
# this presubmit check.
tools/make/downloads/pigweed/pw_presubmit/py/pw_presubmit/format_code.py \
. \
micro/tools/make/downloads/pigweed/pw_presubmit/py/pw_presubmit/format_code.py \
kernels/internal/reference/ \
micro/ \
-e kernels/internal/reference/integer_ops/ \
-e kernels/internal/reference/reference_ops.h \
-e "\.inc" \
-e "\.md" \
-e "\.py"
@ -76,7 +82,7 @@ CLANG_FORMAT_RESULT=$?
popd
if [[ ${1} == "PRESUBMIT" ]]; then
rm -rf tensorflow/lite/micro/.git
rm -rf tensorflow/lite/.git
fi
# Re-enable exit on error now that we are done with the temporary git repo.

View File

@ -116,7 +116,7 @@ download_and_extract() {
local tempdir=$(mktemp -d)
local tempdir2=$(mktemp -d)
local tempfile=${tempdir}/temp_file
local curl_retries=3
local curl_retries=5
# Destionation already downloaded.
if [ -d ${dir} ]; then
@ -131,24 +131,21 @@ download_and_extract() {
mkdir -p "${dir}"
# We've been seeing occasional 56 errors from valid URLs, so set up a retry
# loop to attempt to recover from them.
for (( i=1; i<=$curl_retries; ++i ))
do
for (( i=1; i<=$curl_retries; ++i )); do
# We have to use this approach because we normally halt the script when
# there's an error, and instead we want to catch errors so we can retry.
set +e
curl -Ls --fail --retry 5 "${url}" > ${tempfile}
set +ex
curl -LsS --fail --retry 5 "${url}" > ${tempfile}
CURL_RESULT=$?
set -e
set -ex
# Was the command successful? If so, continue.
if [[ $CURL_RESULT -eq 0 ]]
then
if [[ $CURL_RESULT -eq 0 ]]; then
break
fi
# Keep trying if we see the '56' error code.
if [[ ( $CURL_RESULT -ne 56 ) || ( $i -eq $curl_retries ) ]]
then
if [[ ( $CURL_RESULT -ne 56 ) || ( $i -eq $curl_retries ) ]]; then
echo "Error $CURL_RESULT downloading '${url}'"
exit 1
fi

View File

@ -110,7 +110,7 @@ class OpsSet(enum.Enum):
"EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8"
def __str__(self):
return self.value
return str(self.value)
@staticmethod
def get_options():
@ -394,22 +394,22 @@ def build_toco_convert_protos(input_tensors,
input_tensors: List of input tensors. Type and shape are computed using
`foo.shape` and `foo.dtype`.
output_tensors: List of output tensors (only .name is used from this).
inference_type: Target data type of real-number arrays in the output file.
Must be `{tf.float32, tf.uint8, tf.int8}`. (default tf.float32)
inference_input_type: Target data type of real-number input arrays. Allows
for a different type for input arrays in the case of quantization. Must be
`{tf.float32, tf.uint8, tf.int8}`. (default `inference_type`)
input_format: Type of data to read Currently must be
`{TENSORFLOW_GRAPHDEF}`. (default TENSORFLOW_GRAPHDEF)
input_shapes: Input array shape. It needs to be a list of the same length as
`input_tensors`, or None. (default None)
output_format: Output file format. Currently must be `{TFLITE,
GRAPHVIZ_DOT}`. (default TFLITE)
quantized_input_stats: List of tuples of floats representing the mean and
standard deviation. Each tuple maps to the corresponding input tensor.
Only need if `inference_input_type` is `QUANTIZED_UINT8` or `INT8`.
real_input_value = (quantized_input_value - mean_value) / std_dev_value.
(default None)
inference_type: Data type of numeric arrays, excluding the input layer.
(default tf.float32, must be in {tf.float32, tf.int8, tf.uint8})
inference_input_type: Data type of the numeric arrays in the input layer. If
`inference_input_type` is in {tf.int8, tf.uint8}, then
`quantized_input_stats` must be provided. (default is the value assigned
to `inference_type`, must be in {tf.float32, tf.int8, tf.uint8})
input_format: Type of data to read.
(default TENSORFLOW_GRAPHDEF, must be in {TENSORFLOW_GRAPHDEF})
input_shapes: Input array shape. (default None, must be None or a list of
the same length as `input_tensors`.)
output_format: Output file format. (default TFLITE, must be in
{TFLITE, GRAPHVIZ_DOT})
quantized_input_stats: Map of input tensor names to a tuple of floats
representing the mean and standard deviation of the training data.
(e.g., {"foo" : (0., 1.)}). Required if `inference_input_type` is tf.int8
or tf.uint8. (default None)
default_ranges_stats: Tuple of integers representing (min, max) range values
for all arrays without a specified range. Intended for experimenting with
quantization via "dummy quantization". (default None)
@ -574,8 +574,10 @@ def toco_convert_graph_def(input_data, input_arrays_with_shape, output_arrays,
if _requires_input_stats(toco_flags):
if (("quantized_input_stats" not in kwargs) or
(not kwargs["quantized_input_stats"])):
raise ValueError("std_dev and mean must be defined when inference_type "
"or inference_input_type is QUANTIZED_UINT8 or INT8.")
raise ValueError(
"The `quantized_input_stats` flag must be defined when either "
"`inference_type` flag or `inference_input_type` flag is set to "
"tf.int8 or tf.uint8.")
input_array.mean_value, input_array.std_value = kwargs[
"quantized_input_stats"][idx]
input_array.name = name
@ -661,7 +663,7 @@ def toco_convert(input_data, input_tensors, output_tensors, *args, **kwargs):
Typically this function is used to convert from TensorFlow GraphDef to TFLite.
Conversion can be customized by providing arguments that are forwarded to
`build_toco_convert_protos` (see documentation for details). This function has
been deprecated. Please use `lite.TFLiteConverter` instead.
been deprecated. Please use `tf.lite.TFLiteConverter` instead.
Args:
input_data: Input data (i.e. often `sess.graph_def`),

View File

@ -137,7 +137,7 @@ class ConvertTest(test_util.TensorFlowTestCase):
self.assertEqual("output", output_details[0]["name"])
self.assertEqual(np.uint8, output_details[0]["dtype"])
self.assertTrue(([1, 16, 16, 3] == output_details[0]["shape"]).all())
self.assertTrue(output_details[0]["quantization"][0] > 0) # scale
self.assertGreater(output_details[0]["quantization"][0], 0) # scale
def testGraphDefQuantizationInvalid(self):
with ops.Graph().as_default():
@ -159,9 +159,9 @@ class ConvertTest(test_util.TensorFlowTestCase):
enable_mlir_converter=False,
inference_type=dtypes.uint8)
self.assertEqual(
"std_dev and mean must be defined when inference_type or "
"inference_input_type is QUANTIZED_UINT8 or INT8.",
str(error.exception))
"The `quantized_input_stats` flag must be defined when either "
"`inference_type` flag or `inference_input_type` flag is set to "
"tf.int8 or tf.uint8.", str(error.exception))
class ConvertTestOpHint(test_util.TensorFlowTestCase):

View File

@ -61,6 +61,7 @@ from tensorflow.lite.python.util import get_debug_info as _get_debug_info
from tensorflow.lite.python.util import get_grappler_config as _get_grappler_config
from tensorflow.lite.python.util import get_tensor_name as _get_tensor_name
from tensorflow.lite.python.util import get_tensors_from_tensor_names as _get_tensors_from_tensor_names
from tensorflow.lite.python.util import get_tf_type_name as _get_tf_type_name
from tensorflow.lite.python.util import is_frozen_graph as _is_frozen_graph
from tensorflow.lite.python.util import model_input_signature as _model_input_signature
from tensorflow.lite.python.util import modify_model_io_type as _modify_model_io_type
@ -89,19 +90,14 @@ from tensorflow.python.util.tf_export import tf_export as _tf_export
@_tf_export("lite.Optimize")
class Optimize(enum.Enum):
"""Enum defining the optimizations to apply when generating tflite graphs.
Some optimizations may come at the cost of accuracy.
"""Enum defining the optimizations to apply when generating a tflite model.
DEFAULT
Default optimization strategy.
Converter will do its best to improve size and latency based on the
information provided.
Enhanced optimizations are gained by providing a representative_dataset.
This is recommended, and is currently equivalent to the modes below.
Currently, weights will be quantized and if representative_dataset is
provided, activations for quantizable operations will also be quantized.
Default optimization strategy that quantizes model weights. Enhanced
optimizations are gained by providing a representative dataset that
quantizes biases and activations as well.
Converter will do its best to reduce size and latency, while minimizing
the loss in accuracy.
OPTIMIZE_FOR_SIZE
Deprecated. Does the same as DEFAULT.
@ -110,14 +106,11 @@ class Optimize(enum.Enum):
Deprecated. Does the same as DEFAULT.
"""
# Default optimization strategy.
#
# Converter will do its best to improve size and latency based on the
# information provided.
# Enhanced optimizations can be gained by providing a representative_dataset.
# This is recommended, and is currently equivalent to the modes below.
# Currently, weights will be quantized and if representative_dataset is
# provided, activations for quantizable operations will also be quantized.
# Default optimization strategy that quantizes model weights. Enhanced
# optimizations are gained by providing a representative dataset that
# quantizes biases and activations as well.
# Converter will do its best to reduce size and latency, while minimizing
# the loss in accuracy.
DEFAULT = "DEFAULT"
# Deprecated. Does the same as DEFAULT.
@ -132,48 +125,47 @@ class Optimize(enum.Enum):
@_tf_export("lite.RepresentativeDataset")
class RepresentativeDataset(object):
"""Representative dataset to evaluate optimizations.
"""Representative dataset used to optimize the model.
A representative dataset that can be used to evaluate optimizations by the
converter. E.g. converter can use these examples to estimate (min, max) ranges
by calibrating the model on inputs. This can allow converter to quantize a
converted floating point model.
This is a generator function that provides a small dataset to calibrate or
estimate the range, i.e, (min, max) of all floating-point arrays in the model
(such as model input, activation outputs of intermediate layers, and model
output) for quantization. Usually, this is a small subset of a few hundred
samples randomly chosen, in no particular order, from the training or
evaluation dataset.
"""
def __init__(self, input_gen):
"""Creates a representative dataset.
Args:
input_gen: an input generator that can be used to generate input samples
for the model. This must be a callable object that returns an object
that supports the `iter()` protocol (e.g. a generator function). The
elements generated must have same type and shape as inputs to the model.
input_gen: A generator function that generates input samples for the
model and has the same order, type and shape as the inputs to the model.
Usually, this is a small subset of a few hundred samples randomly
chosen, in no particular order, from the training or evaluation dataset.
"""
self.input_gen = input_gen
@_tf_export("lite.TargetSpec")
class TargetSpec(object):
"""Specification of target device.
Details about target device. Converter optimizes the generated model for
specific device.
"""Specification of target device used to optimize the model.
Attributes:
supported_ops: Experimental flag, subject to change. Set of OpsSet options
supported by the device. (default set([OpsSet.TFLITE_BUILTINS]))
supported_types: List of types for constant values on the target device.
Frequently, an optimization choice is driven by the most compact
(i.e. smallest) type in this list (default [tf.float32])
supported_ops: Experimental flag, subject to change. Set of `tf.lite.OpsSet`
options, where each option represents a set of operators supported by the
target device. (default {tf.lite.OpsSet.TFLITE_BUILTINS}))
supported_types: Set of `tf.dtypes.DType` data types supported on the target
device. If initialized, optimization might be driven by the smallest type
in this set. (default set())
experimental_select_user_tf_ops: Experimental flag, subject to change. Set
of user's TensorFlow operators' names that are required in the TensorFlow
Lite runtime. These ops will be exported as select TensorFlow ops in the
model (in conjunction with the OpsSet.SELECT_TF_OPS flag). This is an
advanced feature that should only be used if the client is using TF ops
model (in conjunction with the tf.lite.OpsSet.SELECT_TF_OPS flag). This is
an advanced feature that should only be used if the client is using TF ops
that may not be linked in by default with the TF ops that are provided
when using the SELECT_TF_OPS path. The client is responsible for linking
these ops into the target runtime.
"""
def __init__(self,
@ -181,17 +173,17 @@ class TargetSpec(object):
supported_types=None,
experimental_select_user_tf_ops=None):
if supported_ops is None:
supported_ops = set([OpsSet.TFLITE_BUILTINS])
supported_ops = {OpsSet.TFLITE_BUILTINS}
self.supported_ops = supported_ops
if supported_types is None:
supported_types = []
supported_types = set()
self.supported_types = supported_types
if experimental_select_user_tf_ops is None:
self.experimental_select_user_tf_ops = []
self.experimental_select_user_tf_ops = set()
class QuantizationMode(object):
"""QuantizationMode determines the quantized conversion from user options."""
"""QuantizationMode determines the quantization type from user options."""
def __init__(self, optimizations, target_spec, representative_dataset,
graph_def):
@ -205,7 +197,6 @@ class QuantizationMode(object):
# TODO(b/162537905): Refactor the following quantization functions -
# re-organize and refactor for better readability.
def post_training_int8_no_float(self):
"""Post training int8 quantize, disallow float fallback."""
return (self._any_optimization_enabled() and
self._is_int8_target_required() and
not self._is_int16x8_target_required() and
@ -213,19 +204,16 @@ class QuantizationMode(object):
self._representative_dataset is not None)
def post_training_int8_allow_float(self):
"""Post training int8 quantize, allow float fallback."""
return (self._any_optimization_enabled() and
not self._is_int16x8_target_required() and
self._representative_dataset is not None and
self._smallest_supported_type() == _dtypes.int8)
def is_post_training_integer_quantize_8(self):
"""Post training integer 8 quantization."""
return (self.post_training_int8_no_float() or
self.post_training_int8_allow_float())
def is_post_training_integer_quantize_16x8(self):
"""Post training integer 16x8 quantization."""
return (self.post_training_int16x8_no_float() or
self.post_training_int16x8_allow_float())
@ -239,7 +227,6 @@ class QuantizationMode(object):
self.contains_training_quant_op())
def post_training_int16x8_no_float(self):
"""Post training int16x8 quantize, disallow float fallback."""
return (self._any_optimization_enabled() and
not self._is_int8_target_required() and
self._is_int16x8_target_required() and
@ -247,13 +234,11 @@ class QuantizationMode(object):
self._representative_dataset is not None)
def post_training_int16x8_allow_float(self):
"""Post training int16x8 quantize, allow float fallback."""
return (self._any_optimization_enabled() and
self._is_int16x8_target_required() and
self._is_allow_float())
def post_training_dynamic_range_int8(self):
"""Post training int8 const, on-the-fly int8 quantize of dynamic tensors."""
# Post-training dynamic range quantization is only enabled if post-training
# int8 quantization and training time quantization was not done.
return (self._any_optimization_enabled() and
@ -262,7 +247,6 @@ class QuantizationMode(object):
self._smallest_supported_type() == _dtypes.int8)
def post_training_fp16(self):
"""Post training fp16 quantize."""
return (self._any_optimization_enabled() and
self._smallest_supported_type() == _dtypes.float16)
@ -416,21 +400,20 @@ class TFLiteConverterBase(object):
"""Converter subclass to share functionality between V1 and V2 converters."""
def __init__(self):
self.allow_custom_ops = False
self.target_spec = TargetSpec()
self.optimizations = []
self.optimizations = set()
self.representative_dataset = None
self.target_spec = TargetSpec()
self.allow_custom_ops = False
self.experimental_new_converter = True
self._experimental_new_quantizer = False
self._experimental_calibrate_only = False
# The 'GraphDebugInfo' contains the stack traces of all the original nodes
# in the `GraphDef` to the converter.
self._debug_info = None
self._experimental_sparsify_model = False
self._debug_info = None # contains the stack traces of all the original
# nodes in the `GraphDef` to the converter.
self.saved_model_dir = None
self._saved_model_tags = None
self._saved_model_version = 0
self._saved_model_exported_names = []
self._experimental_sparsify_model = False
def _grappler_config(self, optimizers=None):
"""Creates a tf.compat.v1.ConfigProto for configuring Grappler.
@ -684,9 +667,9 @@ class TFLiteSavedModelConverterV2(TFLiteConverterBaseV2):
saved_model_dir: Directory of the SavedModel.
saved_model_tags: Set of tags identifying the MetaGraphDef within the
SavedModel to analyze. All tags in the tag set must be present. (default
set(SERVING)).
saved_model_exported_names: Names to be exported (default: export all)
when the saved model import path is on.
{tf.saved_model.SERVING}).
saved_model_exported_names: Names to be exported when the saved model
import path is on.
trackable_obj: tf.AutoTrackable object associated with `funcs`. A
reference to this object needs to be maintained so that Variables do not
get garbage collected since functions have a weak reference to
@ -972,20 +955,21 @@ class TFLiteConverterV2(TFLiteFrozenGraphConverterV2):
"""Converts a TensorFlow model into TensorFlow Lite model.
Attributes:
allow_custom_ops: Boolean indicating whether to allow custom operations.
When False, any unknown operation is an error. When True, custom ops are
created for any op that is unknown. The developer needs to provide these
to the TensorFlow Lite runtime with a custom resolver. (default False)
optimizations: Experimental flag, subject to change. A list of optimizations
to apply when converting the model. E.g. `[Optimize.DEFAULT]`
representative_dataset: A representative dataset that can be used to
generate input and output samples for the model. The converter can use the
dataset to evaluate different optimizations. Note that this is an optional
attribute but it is necessary if INT8 is the only support builtin ops in
target ops.
optimizations: Experimental flag, subject to change. Set of optimizations
to apply. e.g {tf.lite.Optimize.DEFAULT}. (default None, must be None or a
set of values of type `tf.lite.Optimize`)
representative_dataset: A generator function used for integer quantization
where each generated sample has the same order, type and shape as the
inputs to the model. Usually, this is a small subset of a few hundred
samples randomly chosen, in no particular order, from the training or
evaluation dataset. This is an optional attribute, but required for full
integer quantization, i.e, if `tf.int8` is the only supported type in
`target_spec.supported_types`. Refer to `tf.lite.RepresentativeDataset`.
(default None)
target_spec: Experimental flag, subject to change. Specifications of target
device, including supported ops set, supported types and a set of user's
defined TensorFlow operators required in the TensorFlow Lite runtime.
Refer to `tf.lite.TargetSpec`.
inference_input_type: Data type of the input layer. Note that integer types
(tf.int8 and tf.uint8) are currently only supported for post training
integer quantization and quantization aware training. (default tf.float32,
@ -994,8 +978,13 @@ class TFLiteConverterV2(TFLiteFrozenGraphConverterV2):
types (tf.int8 and tf.uint8) are currently only supported for post
training integer quantization and quantization aware training. (default
tf.float32, must be in {tf.float32, tf.int8, tf.uint8})
allow_custom_ops: Boolean indicating whether to allow custom operations.
When False, any unknown operation is an error. When True, custom ops are
created for any op that is unknown. The developer needs to provide these
to the TensorFlow Lite runtime with a custom resolver. (default False)
experimental_new_converter: Experimental flag, subject to change. Enables
MLIR-based conversion instead of TOCO conversion. (default True)
Example usage:
```python
@ -1063,7 +1052,8 @@ class TFLiteConverterV2(TFLiteFrozenGraphConverterV2):
`signatures` attribute of the MetaGraphdef is used. (default
saved_model.signatures)
tags: Set of tags identifying the MetaGraphDef within the SavedModel to
analyze. All tags in the tag set must be present. (default set(SERVING))
analyze. All tags in the tag set must be present. (default
{tf.saved_model.SERVING} or {'serve'})
Returns:
TFLiteConverter object.
@ -1209,9 +1199,13 @@ class TFLiteConverterBaseV1(TFLiteConverterBase):
if (requires_quantized_input_stats and
not converter_kwargs["quantized_input_stats"]):
raise ValueError("The `quantized_input_stats` flag must be defined when "
"either `inference_type` flag or `inference_input_type` "
"flag is set to tf.uint8 or tf.int8.")
raise ValueError(
"The `quantized_input_stats` flag must be defined when either "
"`inference_type` flag or `inference_input_type` flag is set to "
"tf.int8 or tf.uint8. Currently, `inference_type={}` and "
"`inference_input_type={}`.".format(
_get_tf_type_name(converter_kwargs["inference_type"]),
_get_tf_type_name(converter_kwargs["inference_input_type"])))
def convert(self):
"""Converts a TensorFlow GraphDef based on instance variables.
@ -1424,9 +1418,9 @@ class TFLiteSavedModelConverter(TFLiteConverterBaseV1):
saved_model_dir: Directory of the SavedModel.
saved_model_tags: Set of tags identifying the MetaGraphDef within the
SavedModel to analyze. All tags in the tag set must be present. (default
set(SERVING)).
saved_model_exported_names: Names to be exported (default: export all)
when the saved model import path is on.
{tf.saved_model.SERVING}).
saved_model_exported_names: Names to be exported when the saved model
import path is on.
experimental_debug_info_func: An experimental function to retrieve the
graph debug info for a set of nodes from the `graph_def`.
@ -1645,33 +1639,42 @@ class TFLiteConverter(TFLiteFrozenGraphConverter):
model into either a TFLite FlatBuffer or graph visualization.
Attributes:
inference_type: Target data type of real-number arrays in the output file.
Must be `{tf.float32, tf.uint8}`. If `optimzations` are provided, this
parameter is ignored. (default tf.float32)
inference_input_type: Target data type of real-number input arrays. Allows
for a different type for input arrays. If an integer type is provided and
`optimizations` are not used, `quantized_input_stats` must be provided. If
`inference_type` is tf.uint8, signaling conversion to a fully quantized
model from a quantization-aware trained input model, then
`inference_input_type` defaults to tf.uint8. In all other cases,
`inference_input_type` defaults to tf.float32. Must be `{tf.float32,
tf.uint8, tf.int8}`
inference_output_type: Target data type of real-number output arrays. Allows
for a different type for output arrays. If `inference_type` is tf.uint8,
signaling conversion to a fully quantized model from a quantization-aware
trained output model, then `inference_output_type` defaults to tf.uint8.
In all other cases, `inference_output_type` must be tf.float32, an error
will be thrown otherwise. Must be `{tf.float32, tf.uint8, tf.int8}`
output_format: Output file format. Currently must be `{TFLITE,
GRAPHVIZ_DOT}`. (default TFLITE)
quantized_input_stats: Dict of strings representing input tensor names
mapped to tuple of floats representing the mean and standard deviation
of the training data (e.g., {"foo" : (0., 1.)}). Only need if
`inference_input_type` is `QUANTIZED_UINT8`. real_input_value =
(quantized_input_value - mean_value) / std_dev_value. (default {})
default_ranges_stats: Tuple of integers representing (min, max) range values
for all arrays without a specified range. Intended for experimenting with
quantization via "dummy quantization". (default None)
optimizations: Experimental flag, subject to change. Set of optimizations to
apply. e.g {tf.lite.Optimize.DEFAULT}. (default None, must be None or a
set of values of type `tf.lite.Optimize`)
representative_dataset: A generator function used for integer quantization
where each generated sample has the same order, type and shape as the
inputs to the model. Usually, this is a small subset of a few hundred
samples randomly chosen, in no particular order, from the training or
evaluation dataset. This is an optional attribute, but required for full
integer quantization, i.e, if `tf.int8` is the only supported type in
`target_spec.supported_types`. Refer to `tf.lite.RepresentativeDataset`.
(default None)
target_spec: Experimental flag, subject to change. Specifications of target
device, including supported ops set, supported types and a set of user's
defined TensorFlow operators required in the TensorFlow Lite runtime.
Refer to `tf.lite.TargetSpec`.
inference_type: Data type of numeric arrays, excluding the input layer.
(default tf.float32, must be in {tf.float32, tf.int8, tf.uint8})
inference_input_type: Data type of the numeric arrays in the input layer. If
`inference_input_type` is in {tf.int8, tf.uint8}, then
`quantized_input_stats` must be provided. (default is the value assigned
to `inference_type`, must be in {tf.float32, tf.int8, tf.uint8})
inference_output_type: Data type of the numeric arrays in the output layer.
(default is the value assigned to `inference_type`, must be in
{tf.float32, tf.int8, tf.uint8})
quantized_input_stats: Map of input tensor names to a tuple of floats
representing the mean and standard deviation of the training data.
(e.g., {"foo" : (0., 1.)}). Required if `inference_input_type` is tf.int8
or tf.uint8. (default None)
default_ranges_stats: Tuple of integers (min, max) representing range values
for all numeric arrays without a specified range. Intended for
experimenting with quantization via "dummy quantization". (default None)
allow_custom_ops: Boolean indicating whether to allow custom operations.
When False any unknown operation is an error. When True, custom ops are
created for any op that is unknown. The developer will need to provide
these to the TensorFlow Lite runtime with a custom resolver. (default
False)
drop_control_dependency: Boolean indicating whether to drop control
dependencies silently. This is due to TFLite not supporting control
dependencies. (default True)
@ -1683,37 +1686,25 @@ class TFLiteConverter(TFLiteFrozenGraphConverter):
change_concat_input_ranges: Boolean to change behavior of min/max ranges for
inputs and outputs of the concat operator for quantized models. Changes
the ranges of concat operator overlap when true. (default False)
allow_custom_ops: Boolean indicating whether to allow custom operations.
When false any unknown operation is an error. When true, custom ops are
created for any op that is unknown. The developer will need to provide
these to the TensorFlow Lite runtime with a custom resolver. (default
False)
post_training_quantize: Deprecated. Please specify `[Optimize.DEFAULT]` for
`optimizations` instead. Boolean indicating whether to quantize the
weights of the converted float model. Model size will be reduced and
there will be latency improvements (at the cost of accuracy). (default
False)
output_format: Output file format. (default
tf.compat.v1.lite.constants.TFLITE, must be in
{tf.compat.v1.lite.constants.TFLITE,
tf.compat.v1.lite.constants.GRAPHVIZ_DOT})
dump_graphviz_dir: Full filepath of folder to dump the graphs at various
stages of processing GraphViz .dot files. Preferred over
--output_format=GRAPHVIZ_DOT in order to keep the requirements of the
output file. (default None)
dump_graphviz_video: Boolean indicating whether to dump the graph after
every graph transformation. (default False)
conversion_summary_dir: A string indicating the path to the generated
conversion logs.
target_ops: Deprecated. Please specify `target_spec.supported_ops` instead.
Set of OpsSet options indicating which converter to use. (default
set([OpsSet.TFLITE_BUILTINS]))
target_spec: Experimental flag, subject to change. Specifications of target
device, including supported ops set, supported types and a set of user's
defined TensorFlow operators required in the TensorFlow Lite runtime.
optimizations: Experimental flag, subject to change. A list of optimizations
to apply when converting the model. E.g. `[Optimize.DEFAULT]`
representative_dataset: A representative dataset that can be used to
generate input and output samples for the model. The converter can use the
dataset to evaluate different optimizations.
`output_format=tf.compat.v1.lite.constants.GRAPHVIZ_DOT` in order to keep
the requirements of the output file. (default None)
dump_graphviz_video: Boolean indicating whether to dump the GraphViz .dot
files after every graph transformation. Requires the `dump_graphviz_dir`
flag to be specified. (default False)
conversion_summary_dir: Full path of the directory to store conversion logs.
(default None)
target_ops: Deprecated. Please use `target_spec.supported_ops` instead.
post_training_quantize: Deprecated. Please use `optimizations` instead and
set it to `{tf.lite.Optimize.DEFAULT}`. (default False)
experimental_new_converter: Experimental flag, subject to change. Enables
MLIR-based conversion instead of TOCO conversion. (default True)
Example usage:
```python
@ -1911,9 +1902,10 @@ class TFLiteConverter(TFLiteFrozenGraphConverter):
output_arrays: List of output tensors to freeze graph with. Uses output
arrays from SignatureDef when none are provided. (default None)
tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to
analyze. All tags in the tag set must be present. (default set("serve"))
analyze. All tags in the tag set must be present. (default
{tf.saved_model.SERVING})
signature_key: Key identifying SignatureDef containing inputs and outputs.
(default DEFAULT_SERVING_SIGNATURE_DEF_KEY)
(default tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY)
Returns:
TFLiteConverter class.

View File

@ -1239,18 +1239,22 @@ class FromSessionTest(TestModels, parameterized.TestCase):
quantized_converter.inference_type = quantized_type
quantized_converter.convert()
self.assertEqual(
'The `quantized_input_stats` flag must be defined when '
'either `inference_type` flag or `inference_input_type` '
'flag is set to tf.uint8 or tf.int8.', str(error.exception))
'The `quantized_input_stats` flag must be defined when either '
'`inference_type` flag or `inference_input_type` flag is set to '
'tf.int8 or tf.uint8. Currently, `inference_type=tf.{}` and '
'`inference_input_type=None`.'.format(quantized_type.name),
str(error.exception))
with self.assertRaises(ValueError) as error:
quantized_converter.inference_type = dtypes.float32
quantized_converter.inference_input_type = quantized_type
quantized_converter.convert()
self.assertEqual(
'The `quantized_input_stats` flag must be defined when '
'either `inference_type` flag or `inference_input_type` '
'flag is set to tf.uint8 or tf.int8.', str(error.exception))
'The `quantized_input_stats` flag must be defined when either '
'`inference_type` flag or `inference_input_type` flag is set to '
'tf.int8 or tf.uint8. Currently, `inference_type=tf.float32` and '
'`inference_input_type=tf.{}`.'.format(quantized_type.name),
str(error.exception))
quantized_converter.inference_type = quantized_type
quantized_converter.inference_input_type = quantized_type

View File

@ -127,9 +127,9 @@ def _convert_tflite_enum_type_to_tf_type(tflite_enum_type):
return tf_type
def _get_tf_type_name(tf_type):
def get_tf_type_name(tf_type):
"""Converts tf.dtype (eg: tf.float32) to str (eg: "tf.float32")."""
return "tf." + tf_type.name
return "tf." + tf_type.name if tf_type else None
def get_tensor_name(tensor):
@ -674,7 +674,7 @@ def _modify_model_input_type(model, inference_input_type=dtypes.float32):
raise ValueError(
"Initial model input type must be tf.float32. Expected type for "
"tensor with name '{}' is tf.float32, instead type is {}".format(
float_tensor.name, _get_tf_type_name(float_type)))
float_tensor.name, get_tf_type_name(float_type)))
# If found, validate that the operator output is quantized and compatible
# with the final model input type
quant_type = _convert_tflite_enum_type_to_tf_type(quant_tensor.type)
@ -683,17 +683,17 @@ def _modify_model_input_type(model, inference_input_type=dtypes.float32):
"Initial model input is not quantized. Expected type for "
"tensor with name '{}' should be in {}, instead type is {}".format(
quant_tensor.name,
tuple(_get_tf_type_name(t) for t in
tuple(get_tf_type_name(t) for t in
_MAP_QUANT_TO_IO_TYPES.keys()),
_get_tf_type_name(quant_type)))
get_tf_type_name(quant_type)))
else:
inference_io_types = _MAP_QUANT_TO_IO_TYPES[quant_type]
if inference_input_type not in inference_io_types:
raise ValueError(
"Unsupported `inference_input_type` value. Expected to be in "
"{}, instead got {}.".format(
tuple(_get_tf_type_name(t) for t in inference_io_types),
_get_tf_type_name(inference_input_type)))
tuple(get_tf_type_name(t) for t in inference_io_types),
get_tf_type_name(inference_input_type)))
input_quant_ops.append(op)
if len(subgraph.inputs) != len(input_quant_ops):
@ -725,7 +725,7 @@ def _modify_model_input_type(model, inference_input_type=dtypes.float32):
else:
raise ValueError(
"Unsupported `inference_input_type` value {}.".format(
_get_tf_type_name(inference_input_type)))
get_tf_type_name(inference_input_type)))
def _modify_model_output_type(model, inference_output_type=dtypes.float32):
@ -768,7 +768,7 @@ def _modify_model_output_type(model, inference_output_type=dtypes.float32):
raise ValueError(
"Initial model output type must be tf.float32. Expected type for "
"tensor with name '{}' is tf.float32, instead type is {}".format(
float_tensor.name, _get_tf_type_name(float_type)))
float_tensor.name, get_tf_type_name(float_type)))
# If found, validate that the operator input is quantized and compatible
# with the final model output type
quant_type = _convert_tflite_enum_type_to_tf_type(quant_tensor.type)
@ -777,17 +777,17 @@ def _modify_model_output_type(model, inference_output_type=dtypes.float32):
"Initial model output is not dequantized. Expected type for "
"tensor with name '{}' should be in {}, instead type is {}".format(
quant_tensor.name,
tuple(_get_tf_type_name(t) for t in
tuple(get_tf_type_name(t) for t in
_MAP_QUANT_TO_IO_TYPES.keys()),
_get_tf_type_name(quant_type)))
get_tf_type_name(quant_type)))
else:
inference_io_types = _MAP_QUANT_TO_IO_TYPES[quant_type]
if inference_output_type not in inference_io_types:
raise ValueError(
"Unsupported `inference_output_type` value. Expected to be in "
"{}, instead got {}.".format(
tuple(_get_tf_type_name(t) for t in inference_io_types),
_get_tf_type_name(inference_output_type)))
tuple(get_tf_type_name(t) for t in inference_io_types),
get_tf_type_name(inference_output_type)))
output_dequant_ops.append(op)
if len(subgraph.outputs) != len(output_dequant_ops):
@ -834,7 +834,7 @@ def _modify_model_output_type(model, inference_output_type=dtypes.float32):
else:
raise ValueError(
"Unsupported `inference_output_type` value {}.".format(
_get_tf_type_name(inference_output_type)))
get_tf_type_name(inference_output_type)))
def modify_model_io_type(

View File

@ -26,7 +26,26 @@ from tensorflow.lite.testing.zip_test_utils import register_make_test_function
@register_make_test_function()
def make_cast_tests(options):
"""Generate examples for cast."""
test_parameters = [{
if options.use_experimental_converter:
test_parameters = [
{
"input_dtype": [tf.float32],
"output_dtype": [tf.int16],
"input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]],
},
{
"input_dtype": [tf.int16],
"output_dtype": [tf.float32],
"input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]],
},
{
"input_dtype": [tf.int32],
"output_dtype": [tf.float32],
"input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]],
}]
else:
test_parameters = [
{
"input_dtype": [tf.int32],
"output_dtype": [tf.float32],
"input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]],

View File

@ -763,6 +763,7 @@ py_library(
"//tensorflow/python/util:_pywrap_tfprof",
"//tensorflow/python/util:_pywrap_transform_graph",
"//tensorflow/python/util:_pywrap_util_port",
"//tensorflow/python/platform:_pywrap_tf2",
":_pywrap_utils",
":composite_tensor",
":config",
@ -5266,7 +5267,10 @@ pywrap_tensorflow_macro(
tf_additional_plugin_deps() +
tf_additional_profiler_deps()) + if_xla_available([
"//tensorflow/compiler/aot:tfcompile_lib",
]) + if_static(extra_deps = ["//tensorflow/core/platform:tensor_float_32_utils"]),
]) + if_static(extra_deps = [
"//tensorflow/core/platform:tensor_float_32_utils",
"//tensorflow/core/platform:enable_tf2_utils",
]),
)
# ** Targets for Windows build (start) **
@ -6779,6 +6783,8 @@ py_test(
":client_testlib",
":framework_combinations",
":tf2",
"//tensorflow/python/compat:v2_compat",
"//tensorflow/python/data/kernel_tests:test_base",
],
)

View File

@ -33,7 +33,7 @@ from tensorflow.python.util.tf_export import tf_export
# This value changes every day with an automatic CL. It can be modified in code
# via `forward_compatibility_horizon()` or with the environment variable
# TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date.
_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2021, 1, 2)
_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2021, 1, 5)
_FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS"
_FORWARD_COMPATIBILITY_DATE_NUMBER = None

View File

@ -21,6 +21,7 @@ from __future__ import print_function
from tensorflow.python.compat import v2_compat
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.platform import _pywrap_tf2
from tensorflow.python.platform import test
@ -29,9 +30,13 @@ class DisableV2BehaviorTest(test.TestCase):
def test_basic(self):
t = constant_op.constant([1, 2, 3]) # creates a hidden context
self.assertTrue(isinstance(t, ops.EagerTensor))
t = _pywrap_tf2.is_enabled()
self.assertTrue(t)
v2_compat.disable_v2_behavior()
t = constant_op.constant([1, 2, 3])
self.assertFalse(isinstance(t, ops.EagerTensor))
t = _pywrap_tf2.is_enabled()
self.assertFalse(t)
if __name__ == '__main__':

View File

@ -36,7 +36,10 @@ py_library(
py_library(
name = "trt_convert_py",
srcs = ["trt_convert.py"],
srcs = [
"trt_convert.py",
"utils.py",
],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/compiler/tf2tensorrt:_pywrap_py_utils",

View File

@ -35,7 +35,9 @@ from tensorflow.compiler.tf2tensorrt._pywrap_py_utils import get_linked_tensorrt
from tensorflow.compiler.tf2tensorrt._pywrap_py_utils import is_tensorrt_enabled
from tensorflow.core.framework import graph_pb2
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.compiler.tensorrt import trt_convert
from tensorflow.python.compiler.tensorrt import utils as trt_utils
from tensorflow.python.eager import def_function
from tensorflow.python.framework import graph_io
from tensorflow.python.framework import ops
@ -331,17 +333,21 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
"""Get config proto based on specific settings."""
conversion_params = self.GetConversionParams(run_params)
max_batch_size = self.GetMaxBatchSize(run_params)
if graph_state == GraphState.INFERENCE and run_params.convert_online:
rewriter_cfg = trt_convert.get_tensorrt_rewriter_config(
conversion_params,
is_dynamic_op=run_params.dynamic_engine,
max_batch_size=max_batch_size)
graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_cfg)
max_batch_size=max_batch_size,
disable_non_trt_optimizers=self._disable_non_trt_optimizers)
else:
graph_options = config_pb2.GraphOptions()
rewriter_cfg = rewriter_config_pb2.RewriterConfig()
if self._disable_non_trt_optimizers:
trt_utils.disable_non_trt_optimizers_in_rewriter_config(rewriter_cfg)
config = config_pb2.ConfigProto(
gpu_options=self._GetGPUOptions(), graph_options=graph_options)
gpu_options=self._GetGPUOptions(),
graph_options=config_pb2.GraphOptions(rewrite_options=rewriter_cfg))
return config
def _GetFeedNames(self):

View File

@ -30,6 +30,7 @@ from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.client import session
from tensorflow.python.compiler.tensorrt import utils as trt_utils
from tensorflow.python.eager import context
from tensorflow.python.eager import wrap_function
from tensorflow.python.framework import convert_to_constants
@ -271,9 +272,12 @@ def _get_tensorrt_rewriter_config(conversion_params,
# need to run constant folding again.
rewriter_config_with_trt.optimizers.extend(
["constfold", "layout", "constfold"])
rewriter_config_with_trt.meta_optimizer_iterations = (
rewriter_config_pb2.RewriterConfig.ONE)
optimizer = rewriter_config_with_trt.custom_optimizers.add()
if not disable_non_trt_optimizers:
# Add a constfold optimizer to cleanup the unused Const nodes.
rewriter_config_with_trt.custom_optimizers.add().name = "constfold"
@ -295,25 +299,11 @@ def _get_tensorrt_rewriter_config(conversion_params,
optimizer.parameter_map["max_batch_size"].i = max_batch_size
optimizer.parameter_map["use_implicit_batch"].b = use_implicit_batch
# Disabling optimizers should happen after CopyFrom the template
# Disabling optimizers should happen after defining the TF-TRT grappler pass
# otherwise the template can overwrite the disablement.
if disable_non_trt_optimizers:
off = rewriter_config_pb2.RewriterConfig.OFF
rewriter_config_with_trt.layout_optimizer = off
rewriter_config_with_trt.constant_folding = off
rewriter_config_with_trt.shape_optimization = off
rewriter_config_with_trt.remapping = off
rewriter_config_with_trt.arithmetic_optimization = off
rewriter_config_with_trt.dependency_optimization = off
rewriter_config_with_trt.loop_optimization = off
rewriter_config_with_trt.function_optimization = off
rewriter_config_with_trt.debug_stripper = off
rewriter_config_with_trt.disable_model_pruning = True
rewriter_config_with_trt.scoped_allocator_optimization = off
rewriter_config_with_trt.memory_optimization = (
rewriter_config_pb2.RewriterConfig.NO_MEM_OPT)
rewriter_config_with_trt.pin_to_host_optimization = off
rewriter_config_with_trt.auto_parallel.enable = False
trt_utils.disable_non_trt_optimizers_in_rewriter_config(
rewriter_config_with_trt)
return rewriter_config_with_trt
@ -652,10 +642,19 @@ class TrtGraphConverter(object):
return_elements=fetch_names,
name="")
calibrate_rewriter_cfg = rewriter_config_pb2.RewriterConfig()
if self._test_only_disable_non_trt_optimizers:
trt_utils.disable_non_trt_optimizers_in_rewriter_config(
calibrate_rewriter_cfg)
# Set allow_soft_placement=True to run the graph for calibration so that
# OPs supported by TensorRT but don't have a GPU implementation are allowed
# to execute on CPU.
calibrate_config = config_pb2.ConfigProto(allow_soft_placement=True)
calibrate_config = config_pb2.ConfigProto(
allow_soft_placement=True,
graph_options=config_pb2.GraphOptions(
rewrite_options=calibrate_rewriter_cfg))
with session.Session(
graph=self._calibration_graph,
config=calibrate_config) as calibration_sess:

View File

@ -0,0 +1,47 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Exposes the Python wrapper conversion to trt_graph."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.core.protobuf import rewriter_config_pb2
def disable_non_trt_optimizers_in_rewriter_config(rewriter_config):
"""Modifies rewriter_config to disable all non-TRT optimizations."""
off = rewriter_config_pb2.RewriterConfig.OFF
rewriter_config.arithmetic_optimization = off
rewriter_config.auto_mixed_precision = off
rewriter_config.auto_parallel.enable = False
rewriter_config.constant_folding = off
rewriter_config.debug_stripper = off
rewriter_config.dependency_optimization = off
# This one needs to be ON to allow TF-TRT
rewriter_config.disable_meta_optimizer = False
rewriter_config.disable_model_pruning = True
rewriter_config.function_optimization = off
rewriter_config.implementation_selector = off
rewriter_config.layout_optimizer = off
rewriter_config.loop_optimization = off
rewriter_config.memory_optimization = (
rewriter_config_pb2.RewriterConfig.NO_MEM_OPT)
rewriter_config.min_graph_nodes = -1
rewriter_config.pin_to_host_optimization = off
rewriter_config.remapping = off
rewriter_config.scoped_allocator_optimization = off
rewriter_config.shape_optimization = off

View File

@ -442,7 +442,8 @@ class DataServiceOpsTest(data_service_test_base.TestBase,
results = {}
for _ in range(elements_to_read):
val = next(it).numpy()
if val not in results: results[val] = 0
if val not in results:
results[val] = 0
results[val] += 1
for i in range(num_elements):
self.assertGreater(results[i], elements_to_read / num_elements / 2)
@ -527,6 +528,37 @@ class DataServiceOpsTest(data_service_test_base.TestBase,
data_service_ops.distribute(
processing_mode="invalid", service="grpc://localhost:5000"))
@combinations.generate(test_base.eager_only_combinations())
def testZipDifferentProcessingModesDatasets(self):
cluster = self.create_cluster(num_workers=1)
num_elements = 100
ds1 = dataset_ops.Dataset.range(num_elements)
ds1 = self.make_distributed_dataset(
ds1, cluster, processing_mode="distributed_epoch")
ds2 = dataset_ops.Dataset.range(num_elements)
ds2 = self.make_distributed_dataset(
ds2, cluster, processing_mode="parallel_epochs")
ds = dataset_ops.Dataset.zip((ds1, ds2))
self.assertDatasetProduces(
ds,
list(zip(range(num_elements), range(num_elements))),
assert_items_equal=True)
@combinations.generate(test_base.eager_only_combinations())
def testZipDifferentProcessingModesDatasetsSharedJobName(self):
cluster = self.create_cluster(num_workers=1)
num_elements = 100
ds1 = dataset_ops.Dataset.range(num_elements)
ds1 = self.make_distributed_dataset(
ds1, cluster, processing_mode="distributed_epoch", job_name="job_name")
ds2 = dataset_ops.Dataset.range(num_elements)
ds2 = self.make_distributed_dataset(
ds2, cluster, processing_mode="parallel_epochs", job_name="job_name")
ds = dataset_ops.Dataset.zip((ds1, ds2))
with self.assertRaisesRegex(errors.FailedPreconditionError,
"but there is already an existing job"):
self.getDatasetOutput(ds)
@combinations.generate(test_base.eager_only_combinations())
def testFromDatasetId(self):
cluster = self.create_cluster(num_workers=1)

View File

@ -88,7 +88,8 @@ class DispatchServer(object):
>>> dispatcher = tf.data.experimental.service.DispatchServer()
>>> dispatcher_address = dispatcher.target.split("://")[1]
>>> worker = tf.data.experimental.service.WorkerServer(WorkerConfig(
>>> worker = tf.data.experimental.service.WorkerServer(
... tf.data.experimental.service.WorkerConfig(
... dispatcher_address=dispatcher_address))
>>> dataset = tf.data.Dataset.range(10)
>>> dataset = dataset.apply(tf.data.experimental.service.distribute(

View File

@ -169,7 +169,7 @@ class _WorkerContext(object):
def _get_master_target(self):
"""Return the master target for a task."""
# If cluster_spec is None or empty, we use local master.
if not self._cluster_spec:
if not self._cluster_spec or self._task_type == _TaskType.EVALUATOR:
return ""
# If task_type is None, then it is in-graph replicated training. In this
@ -842,7 +842,8 @@ def run_distribute_coordinator(worker_fn,
session_config, cluster_spec,
task_type, task_id)
if not getattr(strategy.extended, "_std_server_started", False):
if (task_type != _TaskType.EVALUATOR and
not getattr(strategy.extended, "_std_server_started", False)):
# Right now, with eager mode, context is configured with a std server at
# the very beginning while with graph mode the std server is started when
# distribute coordinator is called. We should consolidate these two paths.

View File

@ -589,8 +589,7 @@ class DistributeCoordinatorTestStandaloneMode(DistributeCoordinatorTestBase):
# and distributed_mode.
self.assertEqual(self._worker_context["None"][0], (_strip_protocol(
_bytes_to_str(self._workers[0].target)), 3, True, True))
self.assertEqual(self._worker_context[EVALUATOR][0],
("fake_evaluator", 3, True, False))
self.assertEqual(self._worker_context[EVALUATOR][0], ("", 3, True, False))
class DistributeCoordinatorTestIndependentWorkerMode(
@ -755,19 +754,15 @@ class DistributeCoordinatorTestIndependentWorkerMode(
# and distributed_mode.
self.assertEqual(self._worker_context["None"][0],
(_bytes_to_str(cluster_spec[WORKER][0]), 3, True, True))
self.assertEqual(self._worker_context[EVALUATOR][0],
(cluster_spec[EVALUATOR][0], 3, True, False))
self.assertEqual(self._worker_context[EVALUATOR][0], ("", 3, True, False))
# Make sure each worker runs a std server.
self.assertEqual(len(self._std_servers), 2)
self.assertEqual(len(self._std_servers), 1)
self.assertTrue(WORKER in self._std_servers)
self.assertTrue(EVALUATOR in self._std_servers)
self.assertEqual(len(self._std_servers[WORKER]), 3)
self.assertEqual(len(self._std_servers[EVALUATOR]), 1)
self.assertFalse(self._std_servers[WORKER][0].joined)
self.assertTrue(self._std_servers[WORKER][1].joined)
self.assertTrue(self._std_servers[WORKER][2].joined)
self.assertFalse(self._std_servers[EVALUATOR][0].joined)
def testRunStdServerInGoogleEnvironment(self):
cluster_spec = {"worker": ["fake_worker"], "ps": ["localhost:0"]}

View File

@ -1007,6 +1007,7 @@ class GradientTape(object):
Raises:
RuntimeError: If called on a used, non-persistent tape.
RuntimeError: If called inside the context of the tape.
TypeError: If the target is a None object.
ValueError: If the target is a variable or if unconnected gradients is
called with an unknown value.
"""
@ -1028,6 +1029,11 @@ class GradientTape(object):
"gradient in order to compute higher order "
"derivatives.", 1)
if target is None:
raise TypeError("Target should be a list or nested structure"
" of Tensors or Variables to be differentiated,"
" but recieved %r" % (target))
num_ndarrays = 0
flat_targets = []
for t in nest.flatten(target):

View File

@ -1000,38 +1000,6 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase):
self.assertAllEqual(obj2.testDouble.experimental_get_tracing_count(), 3)
self.assertAllEqual(obj1.testDouble.experimental_get_tracing_count(), 2)
def test_experimental_get_tracing_count_function(self):
@def_function.function
def double(a):
return a + a
double(constant_op.constant(1))
double(constant_op.constant(2))
self.assertAllEqual(double.experimental_get_tracing_count(), 1)
double(constant_op.constant('a'))
self.assertAllEqual(double.experimental_get_tracing_count(), 2)
def test_experimental_get_tracing_count_method(self):
class TestClass():
@def_function.function
def testDouble(self, a):
return a + a
obj1 = TestClass()
obj1.testDouble(constant_op.constant(1))
obj1.testDouble(constant_op.constant(2))
obj1.testDouble(constant_op.constant(1.1))
self.assertAllEqual(obj1.testDouble.experimental_get_tracing_count(), 2)
obj2 = TestClass()
obj2.testDouble(constant_op.constant(1))
obj2.testDouble(constant_op.constant(1.1))
obj2.testDouble(constant_op.constant('a'))
self.assertAllEqual(obj2.testDouble.experimental_get_tracing_count(), 3)
self.assertAllEqual(obj1.testDouble.experimental_get_tracing_count(), 2)
if __name__ == '__main__':
ops.enable_eager_execution()

View File

@ -2047,6 +2047,10 @@ class TensorFlowTestCase(googletest.TestCase):
self._tempdir = None
self._cached_session = None
self._test_start_time = None
# This flag provides the ability to control whether the graph mode gets
# initialized for TF1 or not. Initializing for TF1, which is what was
# happening earlier, was preventing enablement of 'eager mode' in the test.
self._set_default_seed = True
def setUp(self):
super(TensorFlowTestCase, self).setUp()
@ -2061,6 +2065,7 @@ class TensorFlowTestCase(googletest.TestCase):
# cleared first.
ops._default_graph_stack.reset() # pylint: disable=protected-access
ops.reset_default_graph()
if self._set_default_seed:
random_seed.set_random_seed(random_seed.DEFAULT_GRAPH_SEED)
# Reset summary writer in case another test used set_as_default() with their
# summary writer.

View File

@ -18,68 +18,73 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from absl.testing import parameterized
from tensorflow.python import tf2
from tensorflow.python.compat import v2_compat
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.framework import combinations
from tensorflow.python.platform import _pywrap_tf2
from tensorflow.python.platform import test
def set_environ():
os.environ['TF2_BEHAVIOR'] = '1'
def unset_environ():
os.environ['TF2_BEHAVIOR'] = '0'
class EnablingTF2Behavior(test.TestCase, parameterized.TestCase):
def setUp(self):
super(EnablingTF2Behavior, self).setUp()
tf2._force_enable = None
if 'TF2_BEHAVIOR' in os.environ:
del os.environ['TF2_BEHAVIOR']
def __init__(self, methodName):
super().__init__(methodName)
self._set_default_seed = False
actions = [tf2.enable, tf2.disable, set_environ, unset_environ]
@combinations.generate(test_base.v1_only_combinations())
def test_tf1_enable_tf2_behaviour(self):
self.assertFalse(tf2.enabled())
self.assertFalse(_pywrap_tf2.is_enabled())
@combinations.generate(
combinations.combine(
action_0=actions, action_1=actions,
action_2=actions, action_3=actions))
def test_scenarios(self, action_0, action_1, action_2, action_3):
v2_compat.enable_v2_behavior()
self.assertTrue(tf2.enabled())
self.assertTrue(_pywrap_tf2.is_enabled())
def state(action, enabled, disabled):
"""Returns bool tuple (tf2_enabled, force_enabled, force_disabled)."""
if action is tf2.enable:
return True, True, False
elif action is tf2.disable:
return False, False, True
elif action is set_environ:
return not disabled, enabled, disabled
elif action is unset_environ:
return enabled, enabled, disabled
else:
raise ValueError('Unexpected action {}. {} are supported'.format(
action, EnablingTF2Behavior.actions))
v2_compat.disable_v2_behavior()
self.assertFalse(tf2.enabled())
self.assertFalse(_pywrap_tf2.is_enabled())
action_0()
expected, enabled, disabled = state(action_0, False, False)
self.assertEqual(tf2.enabled(), expected)
@combinations.generate(test_base.v1_only_combinations())
def test_tf1_disable_tf2_behaviour(self):
self.assertFalse(tf2.enabled())
self.assertFalse(_pywrap_tf2.is_enabled())
action_1()
expected, enabled, disabled = state(action_1, enabled, disabled)
self.assertEqual(tf2.enabled(), expected)
v2_compat.disable_v2_behavior()
self.assertFalse(tf2.enabled())
self.assertFalse(_pywrap_tf2.is_enabled())
action_2()
expected, enabled, disabled = state(action_2, enabled, disabled)
self.assertEqual(tf2.enabled(), expected)
v2_compat.enable_v2_behavior()
self.assertTrue(tf2.enabled())
self.assertTrue(_pywrap_tf2.is_enabled())
action_3()
expected, enabled, disabled = state(action_3, enabled, disabled)
self.assertEqual(tf2.enabled(), expected)
@combinations.generate(test_base.v2_only_combinations())
def test_tf2_enable_tf2_behaviour(self):
self.assertTrue(tf2.enabled())
self.assertTrue(_pywrap_tf2.is_enabled())
v2_compat.enable_v2_behavior()
self.assertTrue(tf2.enabled())
self.assertTrue(_pywrap_tf2.is_enabled())
v2_compat.disable_v2_behavior()
self.assertFalse(tf2.enabled())
self.assertFalse(_pywrap_tf2.is_enabled())
@combinations.generate(test_base.v2_only_combinations())
def test_tf2_disable_tf2_behaviour(self):
self.assertTrue(tf2.enabled())
self.assertTrue(_pywrap_tf2.is_enabled())
v2_compat.disable_v2_behavior()
self.assertFalse(tf2.enabled())
self.assertFalse(_pywrap_tf2.is_enabled())
v2_compat.enable_v2_behavior()
self.assertTrue(tf2.enabled())
self.assertTrue(_pywrap_tf2.is_enabled())
if __name__ == '__main__':

View File

@ -41,6 +41,7 @@ def index_directory(directory,
directory: The target directory (string).
labels: Either "inferred"
(labels are generated from the directory structure),
None (no labels),
or a list/tuple of integer labels of the same size as the number of
valid files found in the directory. Labels should be sorted according
to the alphanumeric order of the image file paths
@ -61,19 +62,24 @@ def index_directory(directory,
labels: list of matching integer labels (same length as file_paths)
class_names: names of the classes corresponding to these labels, in order.
"""
inferred_class_names = []
if labels is None:
# in the no-label case, index from the parent directory down.
subdirs = ['']
class_names = subdirs
else:
subdirs = []
for subdir in sorted(os.listdir(directory)):
if os.path.isdir(os.path.join(directory, subdir)):
inferred_class_names.append(subdir)
subdirs.append(subdir)
if not class_names:
class_names = inferred_class_names
class_names = subdirs
else:
if set(class_names) != set(inferred_class_names):
if set(class_names) != set(subdirs):
raise ValueError(
'The `class_names` passed did not match the '
'names of the subdirectories of the target directory. '
'Expected: %s, but received: %s' %
(inferred_class_names, class_names))
(subdirs, class_names))
class_indices = dict(zip(class_names, range(len(class_names))))
# Build an index of the files
@ -81,7 +87,8 @@ def index_directory(directory,
pool = multiprocessing.pool.ThreadPool()
results = []
filenames = []
for dirpath in (os.path.join(directory, subdir) for subdir in class_names):
for dirpath in (os.path.join(directory, subdir) for subdir in subdirs):
results.append(
pool.apply_async(index_subdirectory,
(dirpath, class_indices, follow_links, formats)))
@ -90,7 +97,7 @@ def index_directory(directory,
partial_filenames, partial_labels = res.get()
labels_list.append(partial_labels)
filenames += partial_filenames
if labels != 'inferred':
if labels not in ('inferred', None):
if len(labels) != len(filenames):
raise ValueError('Expected the lengths of `labels` to match the number '
'of files in the target directory. len(labels) is %s '
@ -103,6 +110,9 @@ def index_directory(directory,
labels[i:i + len(partial_labels)] = partial_labels
i += len(partial_labels)
if labels is None:
print('Found %d files.' % (len(filenames),))
else:
print('Found %d files belonging to %d classes.' %
(len(filenames), len(class_names)))
pool.close()

View File

@ -74,6 +74,7 @@ def image_dataset_from_directory(directory,
Otherwise, the directory structure is ignored.
labels: Either "inferred"
(labels are generated from the directory structure),
None (no labels),
or a list/tuple of integer labels of the same size as the number of
image files found in the directory. Labels should be sorted according
to the alphanumeric order of the image file paths
@ -139,7 +140,7 @@ def image_dataset_from_directory(directory,
- if `color_mode` is `rgba`,
there are 4 channel in the image tensors.
"""
if labels != 'inferred':
if labels not in ('inferred', None):
if not isinstance(labels, (list, tuple)):
raise ValueError(
'`labels` argument should be a list/tuple of integer labels, of '
@ -156,6 +157,9 @@ def image_dataset_from_directory(directory,
raise ValueError(
'`label_mode` argument must be one of "int", "categorical", "binary", '
'or None. Received: %s' % (label_mode,))
if labels is None or label_mode is None:
labels = None
label_mode = None
if color_mode == 'rgb':
num_channels = 3
elif color_mode == 'rgba':
@ -188,6 +192,8 @@ def image_dataset_from_directory(directory,
image_paths, labels = dataset_utils.get_training_or_validation_split(
image_paths, labels, validation_split, subset)
if not image_paths:
raise ValueError('No images found.')
dataset = paths_and_labels_to_dataset(
image_paths=image_paths,

View File

@ -82,7 +82,7 @@ class ImageDatasetFromDirectoryTest(keras_parameterized.TestCase):
# Save images to the paths
i = 0
for img in self._get_images(color_mode=color_mode, count=count):
path = paths[count % len(paths)]
path = paths[i % len(paths)]
if color_mode == 'rgb':
ext = 'jpg'
else:
@ -92,6 +92,32 @@ class ImageDatasetFromDirectoryTest(keras_parameterized.TestCase):
i += 1
return temp_dir
def test_image_dataset_from_directory_standalone(self):
# Test retrieving images without labels from a directory and its subdirs.
if PIL is None:
return # Skip test if PIL is not available.
# Save a few extra images in the parent directory.
directory = self._prepare_directory(count=7, num_classes=2)
for i, img in enumerate(self._get_images(3)):
filename = 'image_%s.jpg' % (i,)
img.save(os.path.join(directory, filename))
dataset = image_dataset.image_dataset_from_directory(
directory, batch_size=5, image_size=(18, 18), labels=None)
batch = next(iter(dataset))
# We return plain images
self.assertEqual(batch.shape, (5, 18, 18, 3))
self.assertEqual(batch.dtype.name, 'float32')
# Count samples
batch_count = 0
sample_count = 0
for batch in dataset:
batch_count += 1
sample_count += batch.shape[0]
self.assertEqual(batch_count, 2)
self.assertEqual(sample_count, 10)
def test_image_dataset_from_directory_binary(self):
if PIL is None:
return # Skip test if PIL is not available.
@ -253,6 +279,11 @@ class ImageDatasetFromDirectoryTest(keras_parameterized.TestCase):
sample_count += batch.shape[0]
self.assertEqual(sample_count, 25)
def test_image_dataset_from_directory_no_images(self):
directory = self._prepare_directory(num_classes=2, count=0)
with self.assertRaisesRegex(ValueError, 'No images found.'):
_ = image_dataset.image_dataset_from_directory(directory)
def test_image_dataset_from_directory_errors(self):
if PIL is None:
return # Skip test if PIL is not available.
@ -261,7 +292,7 @@ class ImageDatasetFromDirectoryTest(keras_parameterized.TestCase):
with self.assertRaisesRegex(ValueError, '`labels` argument should be'):
_ = image_dataset.image_dataset_from_directory(
directory, labels=None)
directory, labels='other')
with self.assertRaisesRegex(ValueError, '`label_mode` argument must be'):
_ = image_dataset.image_dataset_from_directory(

View File

@ -66,6 +66,7 @@ def text_dataset_from_directory(directory,
Otherwise, the directory structure is ignored.
labels: Either "inferred"
(labels are generated from the directory structure),
None (no labels),
or a list/tuple of integer labels of the same size as the number of
text files found in the directory. Labels should be sorted according
to the alphanumeric order of the text file paths
@ -114,7 +115,7 @@ def text_dataset_from_directory(directory,
of shape `(batch_size, num_classes)`, representing a one-hot
encoding of the class index.
"""
if labels != 'inferred':
if labels not in ('inferred', None):
if not isinstance(labels, (list, tuple)):
raise ValueError(
'`labels` argument should be a list/tuple of integer labels, of '
@ -131,6 +132,9 @@ def text_dataset_from_directory(directory,
raise ValueError(
'`label_mode` argument must be one of "int", "categorical", "binary", '
'or None. Received: %s' % (label_mode,))
if labels is None or label_mode is None:
labels = None
label_mode = None
dataset_utils.check_validation_split_arg(
validation_split, subset, shuffle, seed)
@ -152,6 +156,8 @@ def text_dataset_from_directory(directory,
file_paths, labels = dataset_utils.get_training_or_validation_split(
file_paths, labels, validation_split, subset)
if not file_paths:
raise ValueError('No text files found.')
dataset = paths_and_labels_to_dataset(
file_paths=file_paths,

View File

@ -58,7 +58,7 @@ class TextDatasetFromDirectoryTest(keras_parameterized.TestCase):
paths += class_paths
for i in range(count):
path = paths[count % len(paths)]
path = paths[i % len(paths)]
filename = os.path.join(path, 'text_%s.txt' % (i,))
f = open(os.path.join(temp_dir, filename), 'w')
text = ''.join([random.choice(string.printable) for _ in range(length)])
@ -66,6 +66,32 @@ class TextDatasetFromDirectoryTest(keras_parameterized.TestCase):
f.close()
return temp_dir
def test_text_dataset_from_directory_standalone(self):
# Test retrieving txt files without labels from a directory and its subdirs.
# Save a few extra files in the parent directory.
directory = self._prepare_directory(count=7, num_classes=2)
for i in range(3):
filename = 'text_%s.txt' % (i,)
f = open(os.path.join(directory, filename), 'w')
text = ''.join([random.choice(string.printable) for _ in range(20)])
f.write(text)
f.close()
dataset = text_dataset.text_dataset_from_directory(
directory, batch_size=5, label_mode=None, max_length=10)
batch = next(iter(dataset))
# We just return the texts, no labels
self.assertEqual(batch.shape, (5,))
self.assertEqual(batch.dtype.name, 'string')
# Count samples
batch_count = 0
sample_count = 0
for batch in dataset:
batch_count += 1
sample_count += batch.shape[0]
self.assertEqual(batch_count, 2)
self.assertEqual(sample_count, 10)
def test_text_dataset_from_directory_binary(self):
directory = self._prepare_directory(num_classes=2)
dataset = text_dataset.text_dataset_from_directory(
@ -172,12 +198,17 @@ class TextDatasetFromDirectoryTest(keras_parameterized.TestCase):
sample_count += batch.shape[0]
self.assertEqual(sample_count, 25)
def test_text_dataset_from_directory_no_files(self):
directory = self._prepare_directory(num_classes=2, count=0)
with self.assertRaisesRegex(ValueError, 'No text files found.'):
_ = text_dataset.text_dataset_from_directory(directory)
def test_text_dataset_from_directory_errors(self):
directory = self._prepare_directory(num_classes=3, count=5)
with self.assertRaisesRegex(ValueError, '`labels` argument should be'):
_ = text_dataset.text_dataset_from_directory(
directory, labels=None)
directory, labels='other')
with self.assertRaisesRegex(ValueError, '`label_mode` argument must be'):
_ = text_dataset.text_dataset_from_directory(

View File

@ -3246,10 +3246,12 @@ cuda_py_test(
srcs = ["extract_volume_patches_grad_test.py"],
shard_count = 50,
tags = [
"no_gpu", # b/171837334
"no_oss", # Test times out on oss-nightly cpu builds
"no_pip",
"nogpu", # http://b/171837334
"nomac", # http://b/139946976
"notap", # http://b/31080670
"nogpu", # b/171837334
"nomac", # b/139946976
"notap", # b/31080670
],
deps = [
"//tensorflow/python:array_ops",

View File

@ -1557,6 +1557,14 @@ class AssertTypeTest(test.TestCase):
with self.assertRaisesRegexp(TypeError, "must be of type.*float32"):
check_ops.assert_type(sparse_float16, dtypes.float32)
def test_raise_when_tf_type_is_not_dtype(self):
# Test case for GitHub issue:
# https://github.com/tensorflow/tensorflow/issues/45975
value = constant_op.constant(0.0)
with self.assertRaisesRegexp(TypeError,
"Cannot convert.*to a TensorFlow DType"):
check_ops.assert_type(value, (dtypes.float32,))
class AssertShapesTest(test.TestCase):

Some files were not shown because too many files have changed in this diff Show More