Merge branch 'master'
Merge branch 'master' of https://github.com/tensorflow/tensorflow into feature-micro-add-op-depth-to-space-pr1
This commit is contained in:
commit
c20ac67cb1
139
RELEASE.md
139
RELEASE.md
@ -114,6 +114,143 @@ This release contains contributions from many people at Google, as well as:
|
||||
|
||||
<INSERT>, <NAME>, <HERE>, <USING>, <GITHUB>, <HANDLE>
|
||||
|
||||
# Release 2.3.2
|
||||
|
||||
## Bug Fixes and Other Changes
|
||||
* Fixes an access to unitialized memory in Eigen code
|
||||
([CVE-2020-26266](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26266))
|
||||
* Fixes a security vulnerability caused by lack of validation in
|
||||
`tf.raw_ops.DataFormatVecPermute` and `tf.raw_ops.DataFormatDimMap`
|
||||
([CVE-2020-26267](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26267))
|
||||
* Fixes a vulnerability caused by attempting to write to immutable memory region in
|
||||
`tf.raw_ops.ImmutableConst`
|
||||
([CVE-2020-26268](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26268)
|
||||
* Fixes a `CHECK`-fail in LSTM with zero-length input
|
||||
([CVE-2020-26270](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26270))
|
||||
* Fixes a security vulnerability caused by accessing heap data outside of bounds
|
||||
when loading a specially crafted `SavedModel`
|
||||
([CVE-2020-26271](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26271))
|
||||
* Solves an OOM issue on TPUs when XLA contexts use fused average updates
|
||||
* Updates `libjpeg-turbo` to `2.0.5` to handle
|
||||
[CVE-2020-13790](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-13790).
|
||||
* Updates `junit` to `4.13.1` to handle
|
||||
[CVE-2020-15250](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15250).
|
||||
* Updates `PCRE` to `8.44` to handle
|
||||
[CVE-2019-20838](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-20838)
|
||||
and
|
||||
[CVE-2020-14155](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-14155).
|
||||
* Updates `sqlite3` to `3.44.0` to keep in sync with master branch.
|
||||
|
||||
# Release 2.2.2
|
||||
|
||||
## Bug Fixes and Other Changes
|
||||
* Fixes an access to unitialized memory in Eigen code
|
||||
([CVE-2020-26266](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26266))
|
||||
* Fixes a security vulnerability caused by lack of validation in
|
||||
`tf.raw_ops.DataFormatVecPermute` and `tf.raw_ops.DataFormatDimMap`
|
||||
([CVE-2020-26267](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26267))
|
||||
* Fixes a vulnerability caused by attempting to write to immutable memory region in
|
||||
`tf.raw_ops.ImmutableConst`
|
||||
([CVE-2020-26268](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26268)
|
||||
* Fixes a `CHECK`-fail in LSTM with zero-length input
|
||||
([CVE-2020-26270](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26270))
|
||||
* Fixes a security vulnerability caused by accessing heap data outside of bounds
|
||||
when loading a specially crafted `SavedModel`
|
||||
([CVE-2020-26271](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26271))
|
||||
* Prevents memory leaks in loading `SavedModel`s that import functions
|
||||
* Updates `libjpeg-turbo` to `2.0.5` to handle
|
||||
[CVE-2020-13790](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-13790).
|
||||
* Updates `junit` to `4.13.1` to handle
|
||||
[CVE-2020-15250](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15250).
|
||||
* Updates `PCRE` to `8.44` to handle
|
||||
[CVE-2019-20838](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-20838)
|
||||
and
|
||||
[CVE-2020-14155](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-14155).
|
||||
* Updates `sqlite3` to `3.44.0` to keep in sync with master branch.
|
||||
|
||||
# Release 2.1.3
|
||||
|
||||
## Bug Fixes and Other Changes
|
||||
* Fixes an access to unitialized memory in Eigen code
|
||||
([CVE-2020-26266](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26266))
|
||||
* Fixes a security vulnerability caused by lack of validation in
|
||||
`tf.raw_ops.DataFormatVecPermute` and `tf.raw_ops.DataFormatDimMap`
|
||||
([CVE-2020-26267](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26267))
|
||||
* Fixes a vulnerability caused by attempting to write to immutable memory region in
|
||||
`tf.raw_ops.ImmutableConst`
|
||||
([CVE-2020-26268](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26268)
|
||||
* Fixes a `CHECK`-fail in LSTM with zero-length input
|
||||
([CVE-2020-26270](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26270))
|
||||
* Fixes a security vulnerability caused by accessing heap data outside of bounds
|
||||
when loading a specially crafted `SavedModel`
|
||||
([CVE-2020-26271](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26271))
|
||||
* Updates `libjpeg-turbo` to `2.0.5` to handle
|
||||
[CVE-2020-13790](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-13790).
|
||||
* Updates `junit` to `4.13.1` to handle
|
||||
[CVE-2020-15250](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15250).
|
||||
* Updates `PCRE` to `8.44` to handle
|
||||
[CVE-2019-20838](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-20838)
|
||||
and
|
||||
[CVE-2020-14155](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-14155).
|
||||
* Updates `sqlite3` to `3.44.0` to keep in sync with master branch.
|
||||
* Newer ROCm versions are supported on the 2.1 branch.
|
||||
|
||||
# Release 2.0.4
|
||||
|
||||
Note that this is the last patch release for the TensorFlow 2.0.x series.
|
||||
|
||||
## Bug Fixes and Other Changes
|
||||
* Fixes an access to unitialized memory in Eigen code
|
||||
([CVE-2020-26266](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26266))
|
||||
* Fixes a security vulnerability caused by lack of validation in
|
||||
`tf.raw_ops.DataFormatVecPermute` and `tf.raw_ops.DataFormatDimMap`
|
||||
([CVE-2020-26267](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26267))
|
||||
* Fixes a vulnerability caused by attempting to write to immutable memory region in
|
||||
`tf.raw_ops.ImmutableConst`
|
||||
([CVE-2020-26268](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26268)
|
||||
* Fixes a `CHECK`-fail in LSTM with zero-length input
|
||||
([CVE-2020-26270](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26270))
|
||||
* Fixes a security vulnerability caused by accessing heap data outside of bounds
|
||||
when loading a specially crafted `SavedModel`
|
||||
([CVE-2020-26271](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26271))
|
||||
* Updates `libjpeg-turbo` to `2.0.5` to handle
|
||||
[CVE-2020-13790](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-13790).
|
||||
* Updates `junit` to `4.13.1` to handle
|
||||
[CVE-2020-15250](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15250).
|
||||
* Updates `PCRE` to `8.44` to handle
|
||||
[CVE-2019-20838](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-20838)
|
||||
and
|
||||
[CVE-2020-14155](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-14155).
|
||||
* Updates `sqlite3` to `3.44.0` to keep in sync with master branch.
|
||||
|
||||
# Release 1.15.5
|
||||
|
||||
Note that this is the last patch release for the TensorFlow 1.x series.
|
||||
|
||||
## Bug Fixes and Other Changes
|
||||
* Fixes an access to unitialized memory in Eigen code
|
||||
([CVE-2020-26266](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26266))
|
||||
* Fixes a security vulnerability caused by lack of validation in
|
||||
`tf.raw_ops.DataFormatVecPermute` and `tf.raw_ops.DataFormatDimMap`
|
||||
([CVE-2020-26267](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26267))
|
||||
* Fixes a vulnerability caused by attempting to write to immutable memory region in
|
||||
`tf.raw_ops.ImmutableConst`
|
||||
([CVE-2020-26268](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26268)
|
||||
* Fixes a `CHECK`-fail in LSTM with zero-length input
|
||||
([CVE-2020-26270](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26270))
|
||||
* Fixes a security vulnerability caused by accessing heap data outside of bounds
|
||||
when loading a specially crafted `SavedModel`
|
||||
([CVE-2020-26271](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26271))
|
||||
* Updates `libjpeg-turbo` to `2.0.5` to handle
|
||||
[CVE-2020-13790](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-13790).
|
||||
* Updates `junit` to `4.13.1` to handle
|
||||
[CVE-2020-15250](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15250).
|
||||
* Updates `PCRE` to `8.44` to handle
|
||||
[CVE-2019-20838](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-20838)
|
||||
and
|
||||
[CVE-2020-14155](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-14155).
|
||||
* Updates `sqlite3` to `3.44.0` to keep in sync with master branch.
|
||||
|
||||
# Release 2.4.0
|
||||
|
||||
## Major Features and Improvements
|
||||
@ -163,7 +300,7 @@ This release contains contributions from many people at Google, as well as:
|
||||
## Breaking Changes
|
||||
|
||||
* TF Core:
|
||||
* Certain float32 ops run in lower precsion on Ampere based GPUs, including
|
||||
* Certain float32 ops run in lower precision on Ampere based GPUs, including
|
||||
matmuls and convolutions, due to the use of [TensorFloat-32]
|
||||
(https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/).
|
||||
Specifically, inputs to such ops are rounded from 23 bits of precision to 10
|
||||
|
@ -1282,7 +1282,8 @@ class DynamicReshapeOpNotActuallyDynamic
|
||||
void DynamicReshapeOp::getCanonicalizationPatterns(
|
||||
OwningRewritePatternList& results, MLIRContext* context) {
|
||||
results.insert<DynamicReshapeOpNotActuallyDynamic,
|
||||
RemoveRedundantDynamicReshape, ShapeOfDynamicReshape>(context);
|
||||
RemoveRedundantDynamicBroadcast, RemoveRedundantDynamicReshape,
|
||||
ShapeOfDynamicReshape>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -33,3 +33,10 @@ def UnaryEinsumToEinsum : Pat<
|
||||
def RemoveRedundantDynamicReshape : Pat<
|
||||
(HLO_DynamicReshapeOp (HLO_DynamicReshapeOp $operand, $shape1), $shape2),
|
||||
(HLO_DynamicReshapeOp $operand, $shape2)>;
|
||||
|
||||
// A dynamic broadcast of a dynamic reshape with the same shape operand
|
||||
// is a dynamic reshape.
|
||||
def RemoveRedundantDynamicBroadcast : Pat<
|
||||
(HLO_DynamicBroadcastInDimOp
|
||||
(HLO_DynamicReshapeOp $operand, $shape), $shape, $dims),
|
||||
(HLO_DynamicReshapeOp $operand, $shape)>;
|
||||
|
@ -1540,3 +1540,14 @@ func @identity_broadcast_in_dim_reshape(%arg0: tensor<128xf32>) -> tensor<128xf3
|
||||
return %1 : tensor<128xf32>
|
||||
// CHECK: return %arg0 : tensor<128xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @broadcast_of_reshape
|
||||
func @broadcast_of_reshape(%arg: tensor<?xf32>, %shape: tensor<2xindex>) -> tensor<?x?xf32> {
|
||||
%0 = "mhlo.dynamic_reshape"(%arg, %shape) : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||
%1 = "mhlo.dynamic_broadcast_in_dim"(%0, %shape) {
|
||||
broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>
|
||||
} : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||
return %1 : tensor<?x?xf32>
|
||||
}
|
||||
// CHECK: [[RESHAPE:%.*]] = "mhlo.dynamic_reshape"
|
||||
// CHECK: return [[RESHAPE]]
|
||||
|
@ -133,7 +133,7 @@ class TFL_OperandsHaveSameShapesOrBroadcastableShape<
|
||||
TFL_RuntimePredOpTrait<"operands do not have the same shape or "
|
||||
"broadcastable shapes within the rank " # max_bcast_rank,
|
||||
CPred<"TFL::VerifyOperandsHaveSameShapesOrBroadcastableShape("
|
||||
"$_op, llvm::ArrayRef<unsigned>({" # StrJoinInt<indices>.result #
|
||||
"$_op, llvm::ArrayRef<unsigned>({" # !interleave(indices, ", ") #
|
||||
"}), " # max_bcast_rank # ")">>;
|
||||
|
||||
// These additional types/type constraints here are used to decouple the ops
|
||||
@ -3453,10 +3453,10 @@ def TFL_CastOp : TFL_Op<"cast", [
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TFL_TensorOf<[F32, I1, I32, I64, TFL_Quint8, UI8, Complex<F<32>>]>:$input
|
||||
TFL_TensorOf<[F32, I1, I16, I32, I64, TFL_Quint8, UI8, Complex<F<32>>]>:$input
|
||||
);
|
||||
|
||||
let results = (outs TFL_TensorOf<[F32, I1, I32, I64, TFL_Quint8, UI8, Complex<F<32>>]>:$output);
|
||||
let results = (outs TFL_TensorOf<[F32, I1, I16, I32, I64, TFL_Quint8, UI8, Complex<F<32>>]>:$output);
|
||||
|
||||
// TFLite's cast op does not utilize CastOptions, instead derives types
|
||||
// from the TfLiteTensors.
|
||||
|
@ -34,7 +34,7 @@ class QuantizedType<string n, list<int> params, bit signed>
|
||||
"Q" # !if (signed, "I", "UI") # !head(params) # " type"> {
|
||||
string name = n;
|
||||
string asTraitArgsStr =
|
||||
StrJoinInt<params>.result # !if(signed, ", true", ", false");
|
||||
!interleave(params, ", ") # !if(signed, ", true", ", false");
|
||||
}
|
||||
|
||||
// Uniform quantized types. Two integers "smantissa" and "sexp" are used to
|
||||
@ -134,7 +134,7 @@ class FixedResultScale<QuantizedType qt> : NativeOpTrait<!strconcat(
|
||||
// needs a scale based on the scales of op1 and op2.
|
||||
class AccumulatorUniformScale<int bias, int op1, int op2> : NativeOpTrait<
|
||||
!strconcat("quant::AccumulatorUniformScale<",
|
||||
StrJoinInt<[bias, op1, op2]>.result,
|
||||
!interleave([bias, op1, op2], ", "),
|
||||
">::Impl")>;
|
||||
|
||||
// Specify the operand index of the coefficient operand for an affine op
|
||||
@ -142,7 +142,7 @@ class AccumulatorUniformScale<int bias, int op1, int op2> : NativeOpTrait<
|
||||
// If the quantization dimension is -1, per-axis quantization isn't supported.
|
||||
class AffineOpCoefficient<int dim, int index> : NativeOpTrait<
|
||||
!strconcat("quant::AffineOpCoefficient<",
|
||||
StrJoinInt<[dim, index]>.result,
|
||||
!interleave([dim, index], ", "),
|
||||
">::Impl")>;
|
||||
|
||||
// Specify this trait if the op doesn't have quantizable output. We shouldn't
|
||||
|
@ -1320,6 +1320,22 @@ func @cast(%arg0: tensor<1x2x2x5xi32>) -> tensor<1x2x2x5xf32> {
|
||||
// CHECK: "tfl.cast"(%arg0) : (tensor<1x2x2x5xi32>) -> tensor<1x2x2x5xf32>
|
||||
}
|
||||
|
||||
func @castFloat32ToI16(%arg0: tensor<1x2x2x5xf32>) -> tensor<1x2x2x5xi16> {
|
||||
%0 = "tf.Cast"(%arg0) : (tensor<1x2x2x5xf32>) -> tensor<1x2x2x5xi16>
|
||||
return %0 : tensor<1x2x2x5xi16>
|
||||
|
||||
// CHECK-LABEL: castFloat32ToI16
|
||||
// CHECK: "tfl.cast"(%arg0) : (tensor<1x2x2x5xf32>) -> tensor<1x2x2x5xi16>
|
||||
}
|
||||
|
||||
func @castI16ToFloat32(%arg0: tensor<1x2x2x5xi16>) -> tensor<1x2x2x5xf32> {
|
||||
%0 = "tf.Cast"(%arg0) : (tensor<1x2x2x5xi16>) -> tensor<1x2x2x5xf32>
|
||||
return %0 : tensor<1x2x2x5xf32>
|
||||
|
||||
// CHECK-LABEL: castI16ToFloat32
|
||||
// CHECK: "tfl.cast"(%arg0) : (tensor<1x2x2x5xi16>) -> tensor<1x2x2x5xf32>
|
||||
}
|
||||
|
||||
func @castComplex(%arg0: tensor<1x2x2x5xf32>) -> tensor<1x2x2x5xcomplex<f32>> {
|
||||
%0 = "tf.Cast"(%arg0) : (tensor<1x2x2x5xf32>) -> tensor<1x2x2x5xcomplex<f32>>
|
||||
return %0 : tensor<1x2x2x5xcomplex<f32>>
|
||||
|
@ -98,11 +98,13 @@ class TF_OpIsBroadcastableToRes<int opId, int resId> : And<[
|
||||
|
||||
|
||||
class TF_AllTypesMatchPred<list<string> values> :
|
||||
CPred<"TF::AreCastCompatible(llvm::makeArrayRef({"# StrJoin<values>.result #"}))">;
|
||||
CPred<"TF::AreCastCompatible(llvm::makeArrayRef({" #
|
||||
!interleave(values, ", ") # "}))">;
|
||||
|
||||
class TF_AllTypesMatch<list<string> names> :
|
||||
PredOpTrait<
|
||||
"all of {" # StrJoin<names>.result # "} have dynamically equal types ",
|
||||
"all of {" # !interleave(names, ", ") #
|
||||
"} have dynamically equal types ",
|
||||
TF_AllTypesMatchPred<
|
||||
!foreach(n, names, !subst("$_self", "$" # n, "$_self.getType()"))>>;
|
||||
|
||||
|
@ -147,6 +147,7 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/hlo:lhlo",
|
||||
"//tensorflow/compiler/mlir/hlo:lhlo_gpu",
|
||||
"//tensorflow/compiler/xla:debug_options_flags",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla/service:backend",
|
||||
|
@ -384,7 +384,7 @@ ENTRY main {
|
||||
HloModule BatchNormForwardInference
|
||||
|
||||
// CHECK: func @main
|
||||
// CHECK: lmhlo_gpu.batch_norm_inference"
|
||||
// CHECK: "lmhlo_gpu.batch_norm_inference"
|
||||
// CHECK-SAME: epsilon = 1.000000e-03 : f32
|
||||
// CHECK-SAME: feature_index = 0 : i64
|
||||
// CHECK-SAME: (memref<2x2x2x2xf32>, memref<2xf32>, memref<2xf32>, memref<2xf32>, memref<2xf32>, memref<2x2x2x2xf32>) -> ()
|
||||
@ -400,3 +400,15 @@ ENTRY main {
|
||||
custom-call(f32[2,2,2,2]{3,2,1,0} %input, f32[2]{0} %offset, f32[2]{0} %scale, f32[2]{0} %mean, f32[2]{0} %variance, f32[] %constant, s64[] %constant_1),
|
||||
custom_call_target="__cudnn$batchNormalizationForwardInference"
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
HloModule Infeed
|
||||
|
||||
// CHECK: func @main
|
||||
// CHECK: "lmhlo.infeed"
|
||||
// CHECK-SAME: (memref<3xf32>) -> ()
|
||||
ENTRY main {
|
||||
%tok = token[] parameter(0)
|
||||
ROOT %infeed = (f32[3]{0}, token[]) infeed(token[] %tok)
|
||||
}
|
||||
|
@ -27,6 +27,7 @@ limitations under the License.
|
||||
#include "mlir/IR/AffineMap.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
|
||||
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
|
||||
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
|
||||
#include "mlir/IR/Dialect.h" // from @llvm-project
|
||||
@ -60,6 +61,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_parser.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
|
||||
@ -67,6 +69,7 @@ using xla::BufferAllocation;
|
||||
using xla::BufferAssignment;
|
||||
using xla::HloComputation;
|
||||
using xla::HloCustomCallInstruction;
|
||||
using xla::HloInfeedInstruction;
|
||||
using xla::HloInstruction;
|
||||
using xla::HloModule;
|
||||
using xla::HloModuleProto;
|
||||
@ -199,14 +202,16 @@ class XlaHloToLhloPass
|
||||
} // namespace
|
||||
|
||||
// Creates MLIR operands corresponding to operands and results of the XLA HLO
|
||||
// instruction. If `num_operands` is not -1, then only the first `num_operands`
|
||||
// instruction. If `num_operands` is valid, then only the first `num_operands`
|
||||
// operands of the HLO instruction will be considered.
|
||||
Status LhloDialectEmitter::CreateOperands(
|
||||
HloInstruction* instr, llvm::SmallVectorImpl<Value>& operands,
|
||||
size_t& num_arguments, size_t& num_results,
|
||||
absl::optional<xla::int64> num_operands) {
|
||||
HloInstruction* instr, absl::optional<xla::int64> num_operands,
|
||||
llvm::SmallVectorImpl<Value>& operands, size_t& num_arguments,
|
||||
size_t& num_results) {
|
||||
if (num_operands.value_or(0) > instr->operand_count())
|
||||
return xla::InvalidArgument("num_operands must be <= operand count");
|
||||
for (xla::int64 i = 0; i < num_operands.value_or(instr->operand_count());
|
||||
i++) {
|
||||
++i) {
|
||||
TF_RETURN_IF_ERROR(GetOrCreateView(instr->operand(i), &operands));
|
||||
}
|
||||
num_arguments = operands.size();
|
||||
@ -215,19 +220,23 @@ Status LhloDialectEmitter::CreateOperands(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename OpType>
|
||||
OpType LhloDialectEmitter::CreateOpWithoutAttrs(HloInstruction* instr,
|
||||
ValueRange operands) {
|
||||
Location loc = getLocation(instr);
|
||||
NamedAttribute attrs[] = {{Identifier::get("name", builder_.getContext()),
|
||||
builder_.getStringAttr(instr->name())}};
|
||||
return builder_.create<OpType>(loc, llvm::None, operands, attrs);
|
||||
}
|
||||
|
||||
template <typename OpType>
|
||||
StatusOr<OpType> LhloDialectEmitter::CreateOpWithoutAttrs(
|
||||
HloInstruction* instr, size_t& num_arguments, size_t& num_results,
|
||||
absl::optional<xla::int64> num_operands) {
|
||||
Location loc = getLocation(instr);
|
||||
std::pair<Identifier, Attribute> attrs[] = {
|
||||
{Identifier::get("name", builder_.getContext()),
|
||||
builder_.getStringAttr(instr->name())},
|
||||
};
|
||||
llvm::SmallVector<Value, 4> operands;
|
||||
TF_RETURN_IF_ERROR(CreateOperands(instr, operands, num_arguments, num_results,
|
||||
num_operands));
|
||||
return builder_.create<OpType>(loc, llvm::None, operands, attrs);
|
||||
TF_RETURN_IF_ERROR(CreateOperands(instr, num_operands, operands,
|
||||
num_arguments, num_results));
|
||||
return CreateOpWithoutAttrs<OpType>(instr, operands);
|
||||
}
|
||||
|
||||
StatusOr<mlir::Operation*> LhloDialectEmitter::EmitOp(HloInstruction* instr) {
|
||||
@ -273,6 +282,8 @@ StatusOr<mlir::Operation*> LhloDialectEmitter::EmitOp(HloInstruction* instr) {
|
||||
return CreateOpWithoutAttrs<lmhlo::FloorOp>(instr);
|
||||
case HloOpcode::kImag:
|
||||
return CreateOpWithoutAttrs<lmhlo::ImagOp>(instr);
|
||||
case HloOpcode::kInfeed:
|
||||
return EmitInfeedOp(instr);
|
||||
case HloOpcode::kIsFinite:
|
||||
return CreateOpWithoutAttrs<lmhlo::IsFiniteOp>(instr);
|
||||
case HloOpcode::kLog:
|
||||
@ -387,7 +398,7 @@ StatusOr<Value> LhloDialectEmitter::RewriteFusionOperand(
|
||||
::xla::ShapeIndex* shape_index, OpBuilder* b, Location loc) {
|
||||
if (shape.IsTuple()) {
|
||||
llvm::SmallVector<Value, 4> values;
|
||||
for (int i = 0; i < shape.tuple_shapes_size(); i++) {
|
||||
for (int i = 0; i < shape.tuple_shapes_size(); ++i) {
|
||||
shape_index->push_back(i);
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto v, RewriteFusionOperand(root, shape.tuple_shapes(i), shape_index,
|
||||
@ -423,7 +434,7 @@ StatusOr<lmhlo::FusionOp> LhloDialectEmitter::EmitFusionOp(
|
||||
auto region_builder = OpBuilder::atBlockBegin(&fusion.region().front());
|
||||
|
||||
llvm::SmallVector<Value, 8> arguments;
|
||||
for (int i = 0; i < instr->operands().size(); i++) {
|
||||
for (int i = 0; i < instr->operands().size(); ++i) {
|
||||
const HloInstruction* operand = instr->operand(i);
|
||||
xla::ShapeIndex shape_index;
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
@ -982,6 +993,19 @@ StatusOr<lmhlo::AllReduceOp> LhloDialectEmitter::EmitAllReduceOp(
|
||||
return all_reduce_op;
|
||||
}
|
||||
|
||||
StatusOr<lmhlo::InfeedOp> LhloDialectEmitter::EmitInfeedOp(
|
||||
HloInstruction* instr) {
|
||||
HloInfeedInstruction* infeed = ::xla::Cast<HloInfeedInstruction>(instr);
|
||||
// HLO Infeed instruction has a single operand of token type and a tuple
|
||||
// with buffers and a token as its output. LMHLO Infeed operation does not
|
||||
// need the token operand or result, so drop it.
|
||||
SmallVector<Value, 2> operands;
|
||||
TF_RETURN_IF_ERROR(GetOrCreateView(instr, &operands, /*result_subset=*/{0}));
|
||||
auto infeed_op = CreateOpWithoutAttrs<lmhlo::InfeedOp>(instr, operands);
|
||||
infeed_op.configAttr(builder_.getStringAttr(infeed->infeed_config()));
|
||||
return infeed_op;
|
||||
}
|
||||
|
||||
StatusOr<Value> LhloDialectEmitter::GetOrCreateArrayView(
|
||||
const ::xla::HloInstruction* instr, const ::xla::Shape& current_shape,
|
||||
const ::xla::ShapeIndex& shape_index) {
|
||||
@ -1055,7 +1079,7 @@ Status LhloDialectEmitter::GetOrCreateViewImpl(
|
||||
const HloInstruction* instr, const Shape& current_shape,
|
||||
::xla::ShapeIndex* current_shape_index, SmallVectorImpl<Value>* values) {
|
||||
if (current_shape.IsTuple()) {
|
||||
for (int i = 0; i < current_shape.tuple_shapes().size(); i++) {
|
||||
for (int i = 0; i < current_shape.tuple_shapes().size(); ++i) {
|
||||
current_shape_index->push_back(i);
|
||||
TF_RETURN_IF_ERROR(GetOrCreateViewImpl(
|
||||
instr, current_shape.tuple_shapes(i), current_shape_index, values));
|
||||
@ -1063,19 +1087,26 @@ Status LhloDialectEmitter::GetOrCreateViewImpl(
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto v, GetOrCreateArrayView(instr, current_shape, *current_shape_index));
|
||||
if (current_shape.IsArray()) {
|
||||
TF_ASSIGN_OR_RETURN(auto v, GetOrCreateArrayView(instr, current_shape,
|
||||
*current_shape_index));
|
||||
values->push_back(v);
|
||||
return Status::OK();
|
||||
}
|
||||
return xla::InternalError("Unexpected shape kind for %s and shape index %s",
|
||||
instr->ToString(), current_shape_index->ToString());
|
||||
}
|
||||
|
||||
// Returns a view for the result of an instruction.
|
||||
// We first get a view for the slice in the allocation, and then may need to
|
||||
// create another view to adjust the slice for the shape of the instruction.
|
||||
Status LhloDialectEmitter::GetOrCreateView(const HloInstruction* instr,
|
||||
SmallVectorImpl<Value>* values) {
|
||||
::xla::ShapeIndex shape_index;
|
||||
return GetOrCreateViewImpl(instr, instr->shape(), &shape_index, values);
|
||||
Status LhloDialectEmitter::GetOrCreateView(
|
||||
const HloInstruction* instr, SmallVectorImpl<Value>* values,
|
||||
const xla::ShapeIndex& result_subset) {
|
||||
::xla::ShapeIndex shape_index = result_subset;
|
||||
const Shape& sub_shape =
|
||||
::xla::ShapeUtil::GetSubshape(instr->shape(), shape_index);
|
||||
return GetOrCreateViewImpl(instr, sub_shape, &shape_index, values);
|
||||
}
|
||||
|
||||
Status LhloDialectEmitter::Initialize() {
|
||||
|
@ -27,6 +27,8 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
@ -79,6 +81,7 @@ class LhloDialectEmitter : public ::xla::DfsHloVisitorWithDefault {
|
||||
|
||||
::xla::StatusOr<lmhlo::CompareOp> EmitCompareOp(::xla::HloInstruction* instr);
|
||||
|
||||
::xla::StatusOr<lmhlo::InfeedOp> EmitInfeedOp(::xla::HloInstruction* instr);
|
||||
::xla::StatusOr<lmhlo::MapOp> EmitMapOp(::xla::HloInstruction* instr);
|
||||
|
||||
::xla::StatusOr<lmhlo::ReducePrecisionOp> EmitReducePrecisionOp(
|
||||
@ -87,10 +90,16 @@ class LhloDialectEmitter : public ::xla::DfsHloVisitorWithDefault {
|
||||
::xla::StatusOr<lmhlo::AllReduceOp> EmitAllReduceOp(
|
||||
::xla::HloInstruction* instr);
|
||||
|
||||
::xla::Status CreateOperands(
|
||||
::xla::HloInstruction* instr, SmallVectorImpl<Value>& operands,
|
||||
size_t& num_arguments, size_t& num_results,
|
||||
absl::optional<xla::int64> num_operands = absl::nullopt);
|
||||
// Create LHLO operation operands given an XLA HLO instruction. By default,
|
||||
// all XLA HLO operands and results are converted to MLIR and appended to
|
||||
// `operands`. If `num_operands` is specified, only the first `num_operand`
|
||||
// operands of the instruction are converted to MLIR. The function returns the
|
||||
// actual number of operands and results generated for MLIR in `num_arguments`
|
||||
// and `num_results`.
|
||||
::xla::Status CreateOperands(::xla::HloInstruction* instr,
|
||||
absl::optional<xla::int64> num_operands,
|
||||
SmallVectorImpl<Value>& operands,
|
||||
size_t& num_arguments, size_t& num_results);
|
||||
|
||||
template <typename OpType>
|
||||
::xla::StatusOr<OpType> CreateOpWithoutAttrs(
|
||||
@ -105,6 +114,10 @@ class LhloDialectEmitter : public ::xla::DfsHloVisitorWithDefault {
|
||||
::xla::HloInstruction* instr, size_t& num_arguments, size_t& num_results,
|
||||
absl::optional<xla::int64> num_operands = absl::nullopt);
|
||||
|
||||
template <typename OpType>
|
||||
OpType CreateOpWithoutAttrs(::xla::HloInstruction* instr,
|
||||
ValueRange operands);
|
||||
|
||||
template <typename T>
|
||||
DenseIntElementsAttr GetI64DenseElementsAttr(const T& container) {
|
||||
return builder_.getI64TensorAttr(
|
||||
@ -140,9 +153,14 @@ class LhloDialectEmitter : public ::xla::DfsHloVisitorWithDefault {
|
||||
SmallVectorImpl<Value>* values);
|
||||
|
||||
// Helper function to create view/tuple of views to a buffer for a given
|
||||
// instruction result.
|
||||
// instruction result. `result_subset` can be used to for instructions that
|
||||
// have a tuple result and MLIR conversion needs to convert only one of the
|
||||
// tuple elements. Note that if needed, this can be extended to take a list of
|
||||
// ShapeIndex values in case we need finer control on what elements of the
|
||||
// output tuple to be converted to MLIR.
|
||||
tensorflow::Status GetOrCreateView(const ::xla::HloInstruction* instr,
|
||||
SmallVectorImpl<Value>* values);
|
||||
SmallVectorImpl<Value>* values,
|
||||
const xla::ShapeIndex& result_subset = {});
|
||||
|
||||
::xla::StatusOr<Value> GetOrCreateArrayView(
|
||||
const ::xla::HloInstruction* instr, const ::xla::Shape& current_shape,
|
||||
|
@ -51,10 +51,10 @@
|
||||
" \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/xla/tutorials/compile\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n",
|
||||
" \u003c/td\u003e\n",
|
||||
" \u003ctd\u003e\n",
|
||||
" \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/g3doc/tutorials/compile.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
|
||||
" \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/g3doc/tutorials/jit_compile.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
|
||||
" \u003c/td\u003e\n",
|
||||
" \u003ctd\u003e\n",
|
||||
" \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/g3doc/tutorials/compile.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n",
|
||||
" \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/g3doc/tutorials/jit_compile.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n",
|
||||
" \u003c/td\u003e\n",
|
||||
"\u003c/table\u003e"
|
||||
]
|
||||
|
@ -89,7 +89,7 @@ PyClient::GetDefaultDeviceAssignment1D(int num_replicas) {
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<PjRtBuffer>> PyClient::PjRtBufferFromPyval(
|
||||
const pybind11::object& argument, PjRtDevice* device, bool force_copy,
|
||||
pybind11::handle argument, PjRtDevice* device, bool force_copy,
|
||||
PjRtClient::HostBufferSemantics host_buffer_semantics) {
|
||||
if (device == nullptr) {
|
||||
TF_RET_CHECK(!pjrt_client_->local_devices().empty());
|
||||
@ -123,7 +123,7 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PyClient::PjRtBufferFromPyval(
|
||||
return buffer;
|
||||
}
|
||||
StatusOr<std::unique_ptr<PyBuffer>> PyClient::BufferFromPyval(
|
||||
const pybind11::object& argument, PjRtDevice* device, bool force_copy,
|
||||
pybind11::handle argument, PjRtDevice* device, bool force_copy,
|
||||
PjRtClient::HostBufferSemantics host_buffer_semantics) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::unique_ptr<PjRtBuffer> buffer,
|
||||
|
@ -124,10 +124,10 @@ class PyClient : public std::enable_shared_from_this<PyClient> {
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBufferFromPyval(
|
||||
const pybind11::object& argument, PjRtDevice* device, bool force_copy,
|
||||
pybind11::handle argument, PjRtDevice* device, bool force_copy,
|
||||
PjRtClient::HostBufferSemantics host_buffer_semantics);
|
||||
StatusOr<std::unique_ptr<PyBuffer>> BufferFromPyval(
|
||||
const pybind11::object& argument, PjRtDevice* device, bool force_copy,
|
||||
pybind11::handle argument, PjRtDevice* device, bool force_copy,
|
||||
PjRtClient::HostBufferSemantics host_buffer_semantics);
|
||||
|
||||
StatusOr<std::shared_ptr<PyExecutable>> Compile(
|
||||
|
@ -71,6 +71,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_constants.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/infeed_thunk.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h"
|
||||
@ -3168,7 +3169,22 @@ Status IrEmitterUnnested::HandleAllToAll(HloInstruction* hlo) {
|
||||
}
|
||||
|
||||
Status IrEmitterUnnested::HandleInfeed(HloInstruction* xla_infeed) {
|
||||
return ThunkEmitter(this).HandleInfeed(xla_infeed);
|
||||
TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(xla_infeed));
|
||||
|
||||
auto infeed_op = mlir::dyn_cast<mlir::lmhlo::InfeedOp>(input.op);
|
||||
|
||||
std::vector<InfeedThunk::ShapedSlice> dest_slices;
|
||||
dest_slices.reserve(infeed_op.outputs().size());
|
||||
|
||||
for (mlir::Value output : infeed_op.outputs()) {
|
||||
TF_ASSIGN_OR_RETURN(auto slice, GetAllocationSliceForMlir(output));
|
||||
const Shape& shape = TypeToShape(output.getType());
|
||||
dest_slices.push_back(InfeedThunk::ShapedSlice{slice, shape});
|
||||
}
|
||||
|
||||
AddThunkToThunkSequence(
|
||||
absl::make_unique<InfeedThunk>(input.thunk_info, std::move(dest_slices)));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status IrEmitterUnnested::HandleOutfeed(HloInstruction* outfeed) {
|
||||
|
@ -800,8 +800,16 @@ Status AMDGPUTargetModuleLinker(llvm::Module* module, GpuVersion gpu_version,
|
||||
std::unique_ptr<llvm::TargetMachine> AMDGPUGetTargetMachine(
|
||||
llvm::Triple target_triple, int amdgpu_version,
|
||||
const HloModuleConfig& hlo_module_config) {
|
||||
string feature_str = "+code-object-v3";
|
||||
#if TF_ROCM_VERSION >= 30900
|
||||
// code-object-v3 is default, so no need to expliticitly specify it
|
||||
// in the feature string. Also, starting with ROCm 4.0, this feature string
|
||||
// is deprecated, and we get a warning to that effect. So removing that
|
||||
// feature string
|
||||
feature_str = "";
|
||||
#endif
|
||||
return GetTargetMachine(target_triple, absl::StrCat("gfx", amdgpu_version),
|
||||
hlo_module_config, "+code-object-v3");
|
||||
hlo_module_config, feature_str);
|
||||
}
|
||||
|
||||
void AMDGPUBackendInit(const HloModuleConfig& hlo_module_config) {
|
||||
|
@ -115,34 +115,6 @@ std::unique_ptr<Thunk> ThunkEmitter::BuildGemmThunk(
|
||||
/*implements_whole_instruction=*/true);
|
||||
}
|
||||
|
||||
std::unique_ptr<Thunk> ThunkEmitter::BuildInfeedThunk(
|
||||
const HloInstruction* inst) {
|
||||
CHECK_EQ(HloOpcode::kInfeed, inst->opcode());
|
||||
|
||||
std::vector<ShapeUtil::IndexedShape> leaf_shapes =
|
||||
ShapeUtil::GetLeafShapes(inst->shape());
|
||||
|
||||
// For an infeed HLO, the output is a 2 element tuple where the first element
|
||||
// of the tuple is all the infeed buffers and the second element is a token.
|
||||
// The infeed thunk does not need to handle this token output, so just drop
|
||||
// it.
|
||||
leaf_shapes.pop_back();
|
||||
|
||||
std::vector<InfeedThunk::ShapedSlice> dest_slices;
|
||||
dest_slices.reserve(leaf_shapes.size());
|
||||
|
||||
for (ShapeUtil::IndexedShape& indexed_shape : leaf_shapes) {
|
||||
BufferAllocation::Slice slice =
|
||||
GetAllocationSlice(*inst, indexed_shape.index);
|
||||
const Shape& shape =
|
||||
ShapeUtil::GetSubshape(inst->shape(), indexed_shape.index);
|
||||
dest_slices.emplace_back(InfeedThunk::ShapedSlice{slice, shape});
|
||||
}
|
||||
|
||||
return absl::make_unique<InfeedThunk>(context_->GetThunkInfo(inst),
|
||||
std::move(dest_slices));
|
||||
}
|
||||
|
||||
std::unique_ptr<Thunk> ThunkEmitter::BuildOutfeedThunk(
|
||||
const HloInstruction* inst) {
|
||||
CHECK_EQ(HloOpcode::kOutfeed, inst->opcode());
|
||||
@ -258,11 +230,6 @@ Status ThunkEmitter::HandleTriangularSolve(HloInstruction* hlo) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ThunkEmitter::HandleInfeed(HloInstruction* infeed) {
|
||||
AddThunkToThunkSequence(BuildInfeedThunk(infeed));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ThunkEmitter::HandleOutfeed(HloInstruction* outfeed) {
|
||||
AddThunkToThunkSequence(BuildOutfeedThunk(outfeed));
|
||||
return Status::OK();
|
||||
|
@ -46,7 +46,6 @@ class ThunkEmitter {
|
||||
Status HandleCustomCall(HloInstruction* custom_call);
|
||||
Status HandleFft(HloInstruction* fft);
|
||||
Status HandleTriangularSolve(HloInstruction* hlo);
|
||||
Status HandleInfeed(HloInstruction* xla_infeed);
|
||||
Status HandleOutfeed(HloInstruction* outfeed);
|
||||
|
||||
private:
|
||||
|
@ -367,7 +367,6 @@ xla_test(
|
||||
"conv_depthwise_test.cc",
|
||||
],
|
||||
shard_count = 50,
|
||||
tags = ["no_rocm"], # ROCm 3.9 regression
|
||||
deps = [
|
||||
":conv_depthwise_common",
|
||||
":test_macros_header",
|
||||
@ -389,7 +388,6 @@ xla_test(
|
||||
timeout = "long",
|
||||
srcs = ["conv_depthwise_backprop_filter_test.cc"],
|
||||
shard_count = 40,
|
||||
tags = ["no_rocm"], # ROCm 3.9 regression
|
||||
deps = [
|
||||
":test_macros_header",
|
||||
"//tensorflow/compiler/xla:execution_options_util",
|
||||
@ -414,7 +412,6 @@ xla_test(
|
||||
"cpu",
|
||||
],
|
||||
shard_count = 50,
|
||||
tags = ["no_rocm"], # ROCm 3.9 regression
|
||||
deps = [
|
||||
":client_library_test_base",
|
||||
":hlo_test_base",
|
||||
@ -924,7 +921,6 @@ xla_test(
|
||||
srcs = ["dot_operation_test.cc"],
|
||||
shard_count = 20,
|
||||
tags = [
|
||||
"no_rocm", # ROCm 3.9 regression
|
||||
"optonly",
|
||||
],
|
||||
deps = [
|
||||
@ -958,7 +954,6 @@ xla_test(
|
||||
backends = ["gpu"],
|
||||
shard_count = 20,
|
||||
tags = [
|
||||
"no_rocm", # ROCm 3.9 regression
|
||||
"optonly",
|
||||
# TODO(b/151340488): Timed out on 2020-03-12.
|
||||
"nozapfhahn",
|
||||
@ -1025,7 +1020,6 @@ xla_test(
|
||||
},
|
||||
shard_count = 20,
|
||||
tags = [
|
||||
"no_rocm", # ROCm 3.9 regression
|
||||
"optonly",
|
||||
],
|
||||
deps = [
|
||||
@ -1253,7 +1247,6 @@ xla_test(
|
||||
"cpu": ["nomsan"],
|
||||
},
|
||||
shard_count = 30,
|
||||
tags = ["no_rocm"], # ROCm 3.9 regression
|
||||
deps = [
|
||||
":test_macros_header",
|
||||
"//tensorflow/compiler/xla:array3d",
|
||||
@ -1278,7 +1271,6 @@ xla_test(
|
||||
timeout = "long",
|
||||
srcs = ["convolution_dimension_numbers_test.cc"],
|
||||
shard_count = 20,
|
||||
tags = ["no_rocm"], # ROCm 3.9 regression
|
||||
deps = [
|
||||
":test_macros_header",
|
||||
"//tensorflow/compiler/xla:array4d",
|
||||
@ -2322,7 +2314,6 @@ xla_test(
|
||||
name = "multioutput_fusion_test",
|
||||
srcs = ["multioutput_fusion_test.cc"],
|
||||
backends = ["gpu"],
|
||||
tags = ["no_rocm"], # ROCm 3.9 regression
|
||||
deps = [
|
||||
":test_macros_header",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
|
@ -480,13 +480,11 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
||||
{csinfo_.fused_batch_norm_grad_v3,
|
||||
mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad_v3),
|
||||
CopyAttrsAll, FusedBatchNormV3Rewrite, GetRewriteCause()});
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
rinfo_.push_back({csinfo_.fused_batch_norm_ex,
|
||||
native_fmt ? csinfo_.mkl_native_fused_batch_norm_ex
|
||||
: csinfo_.mkl_fused_batch_norm_ex,
|
||||
CopyAttrsAll, FusedBatchNormExRewrite,
|
||||
GetRewriteCause()});
|
||||
#endif
|
||||
rinfo_.push_back({csinfo_.fused_conv2d,
|
||||
native_fmt ? csinfo_.mkl_native_fused_conv2d
|
||||
: csinfo_.mkl_fused_conv2d,
|
||||
@ -672,14 +670,12 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
||||
rinfo_.push_back({csinfo_.requantize,
|
||||
mkl_op_registry::GetMklOpName(csinfo_.requantize),
|
||||
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
// Optimized TanhGrad support exists only in DNNL 1.x.
|
||||
rinfo_.push_back({csinfo_.tanh, mkl_op_registry::GetMklOpName(csinfo_.tanh),
|
||||
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
|
||||
rinfo_.push_back({csinfo_.tanh_grad,
|
||||
mkl_op_registry::GetMklOpName(csinfo_.tanh_grad),
|
||||
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
rinfo_.push_back({csinfo_.reshape,
|
||||
mkl_op_registry::GetMklOpName(csinfo_.reshape),
|
||||
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
|
||||
|
@ -53,7 +53,6 @@ static void InitGraph(const string& s, Graph* graph,
|
||||
GraphDef graph_def;
|
||||
|
||||
auto parser = protobuf::TextFormat::Parser();
|
||||
// parser.AllowRelaxedWhitespace(true);
|
||||
CHECK(parser.MergeFromString(s, &graph_def)) << s;
|
||||
GraphConstructorOptions opts;
|
||||
TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, graph));
|
||||
@ -66,7 +65,6 @@ static void InitGraph(const string& s, Graph* graph,
|
||||
class MklLayoutPassTest : public ::testing::Test {
|
||||
public:
|
||||
MklLayoutPassTest() : graph_(OpRegistry::Global()) {}
|
||||
// Ashraf added
|
||||
Node* FindNode(const string& name) {
|
||||
for (Node* node : graph_.nodes()) {
|
||||
if (node->name() == name) return node;
|
||||
@ -3087,8 +3085,6 @@ REGISTER_TEST_ALL_TYPES(NodeRewrite_LeakyReluGrad_Negative);
|
||||
REGISTER_TEST_ALL_TYPES(NodeRewrite_LeakyReluLeakyReluGrad_Positive);
|
||||
#undef REGISTER_TEST
|
||||
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
|
||||
#define REGISTER_TEST(NAME, T, INPUT) \
|
||||
TEST_F(MklLayoutPassTest, NAME##_##T) { \
|
||||
DCHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); \
|
||||
@ -3146,7 +3142,6 @@ REGISTER_TEST_ALL_TYPES(NodeRewrite_TanhGrad_Positive);
|
||||
}
|
||||
REGISTER_TEST_ALL_TYPES(NodeRewrite_TanhTanhGrad_Positive);
|
||||
#undef REGISTER_TEST
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
|
||||
#define REGISTER_TEST(NAME, T, INPUT) \
|
||||
TEST_F(MklLayoutPassTest, NAME##_##T) { \
|
||||
@ -3513,7 +3508,6 @@ REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedBatchNormGradV3_5D_Negative_2);
|
||||
#undef DATA_FORMAT
|
||||
#undef REGISTER_TEST
|
||||
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
#define REGISTER_TEST(NAME, T, INPUT) \
|
||||
TEST_F(MklLayoutPassTest, NAME##_##T) { \
|
||||
InitGraph("node { name: 'A' op: '" #INPUT "'}" \
|
||||
@ -3603,7 +3597,6 @@ REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedBatchNormEx_Negative1);
|
||||
}
|
||||
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedBatchNormEx_Negative2);
|
||||
#undef REGISTER_TEST
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
|
||||
TEST_F(MklLayoutPassTest, NodeRewrite_QuantizedDepthwiseConv2D_Positive) {
|
||||
InitGraph(
|
||||
@ -5184,8 +5177,8 @@ static void BM_MklLayoutRewritePass(int iters, int op_nodes) {
|
||||
|
||||
bool first = true;
|
||||
while (iters > 0) {
|
||||
Graph* graph = new Graph(OpRegistry::Global());
|
||||
InitGraph(s, graph);
|
||||
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||
InitGraph(s, graph.get());
|
||||
int N = graph->num_node_ids();
|
||||
if (first) {
|
||||
testing::SetLabel(strings::StrCat("Per graph node. Nodes: ", N));
|
||||
@ -5193,13 +5186,12 @@ static void BM_MklLayoutRewritePass(int iters, int op_nodes) {
|
||||
}
|
||||
{
|
||||
testing::StartTiming();
|
||||
std::unique_ptr<Graph> ug(graph);
|
||||
std::unique_ptr<Graph> ug(graph.get());
|
||||
RunMklLayoutRewritePass(&ug);
|
||||
testing::StopTiming();
|
||||
}
|
||||
iters -= N; // Our benchmark units are individual graph nodes,
|
||||
// not whole graphs
|
||||
// delete graph;
|
||||
}
|
||||
}
|
||||
BENCHMARK(BM_MklLayoutRewritePass)->Arg(1000)->Arg(10000);
|
||||
|
@ -37,6 +37,7 @@ load(
|
||||
package(
|
||||
default_visibility = [
|
||||
"//tensorflow/core:__subpackages__",
|
||||
"//tensorflow/security/fuzzing:__subpackages__",
|
||||
],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
@ -622,7 +623,10 @@ cc_library(
|
||||
name = "bfloat16",
|
||||
srcs = ["bfloat16.cc"],
|
||||
hdrs = ["bfloat16.h"],
|
||||
visibility = ["//tensorflow/core:__subpackages__"],
|
||||
visibility = [
|
||||
"//tensorflow/core:__subpackages__",
|
||||
"//tensorflow/security/fuzzing:__subpackages__",
|
||||
],
|
||||
deps = [
|
||||
":numeric_types",
|
||||
"//tensorflow/core/platform:byte_order",
|
||||
|
@ -269,6 +269,9 @@ Status MetaOptimizer::InitializeOptimizers(
|
||||
if (cfg_.pin_to_host_optimization() == RewriterConfig::ON) {
|
||||
optimizers->push_back(MakeUnique<PinToHostOptimizer>());
|
||||
}
|
||||
if (cfg_.remapping() != RewriterConfig::OFF) {
|
||||
optimizers->push_back(MakeUnique<Remapper>(cfg_.remapping()));
|
||||
}
|
||||
if (cfg_.arithmetic_optimization() != RewriterConfig::OFF) {
|
||||
optimizers->push_back(
|
||||
MakeUnique<ArithmeticOptimizer>(cfg_.arithmetic_optimization()));
|
||||
@ -278,9 +281,6 @@ Status MetaOptimizer::InitializeOptimizers(
|
||||
/*optimization level*/ cfg_.layout_optimizer(),
|
||||
/*CPU layout conversion*/ cfg_.cpu_layout_conversion()));
|
||||
}
|
||||
if (cfg_.remapping() != RewriterConfig::OFF) {
|
||||
optimizers->push_back(MakeUnique<Remapper>(cfg_.remapping()));
|
||||
}
|
||||
if (cfg_.loop_optimization() != RewriterConfig::OFF) {
|
||||
optimizers->push_back(
|
||||
MakeUnique<LoopOptimizer>(cfg_.loop_optimization(), cpu_device_));
|
||||
|
@ -29,8 +29,14 @@ REGISTER5(BinaryOp, CPU, "DivNoNan", functor::div_no_nan, Eigen::half, float,
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
// ROCM TODO: re-enable complex64 / complex128 after compiler fix
|
||||
#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
|
||||
!defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED)
|
||||
REGISTER9(BinaryOp, GPU, "Div", functor::div, float, Eigen::half, double, uint8,
|
||||
uint16, int16, int64, complex64, complex128);
|
||||
#else
|
||||
REGISTER4(BinaryOp, GPU, "Div", functor::div, uint8, uint16, complex64,
|
||||
complex128);
|
||||
#endif
|
||||
REGISTER4(BinaryOp, GPU, "TruncateDiv", functor::div, uint8, uint16, int16,
|
||||
int64);
|
||||
REGISTER5(BinaryOp, GPU, "RealDiv", functor::div, float, Eigen::half, double,
|
||||
|
@ -30,8 +30,13 @@ REGISTER(BinaryOp, CPU, "Sub", functor::sub, int32);
|
||||
#endif // __ANDROID_TYPES_SLIM__
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
|
||||
!defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED)
|
||||
REGISTER7(BinaryOp, GPU, "Sub", functor::sub, float, Eigen::half, double, int64,
|
||||
complex64, complex128, uint32);
|
||||
#else
|
||||
REGISTER3(BinaryOp, GPU, "Sub", functor::sub, complex64, complex128, uint32);
|
||||
#endif
|
||||
|
||||
// A special GPU kernel for int32.
|
||||
// TODO(b/25387198): Also enable int32 in device memory. This kernel
|
||||
|
@ -85,7 +85,7 @@ void OptimizeDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
|
||||
// clang-format off
|
||||
absl::flat_hash_map<string, uint64> live_experiments = {
|
||||
{"enable_gradient_descent", 0},
|
||||
{"map_parallelization", 0}
|
||||
{"map_parallelization", 1}
|
||||
};
|
||||
// clang-format on
|
||||
auto hash_func = [](const string& str) { return Hash64(str); };
|
||||
|
@ -43,7 +43,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/util/mkl_types.h"
|
||||
#include "tensorflow/core/util/mkl_util.h"
|
||||
#include "tensorflow/core/util/padding.h"
|
||||
#include "tensorflow/core/util/tensor_format.h"
|
||||
@ -65,7 +64,7 @@ struct MklConvFwdParams {
|
||||
memory::dims dilations;
|
||||
memory::dims padding_left;
|
||||
memory::dims padding_right;
|
||||
MKL_TENSOR_FORMAT tf_fmt;
|
||||
MklTensorFormat tf_fmt;
|
||||
bool native_format;
|
||||
string dtypes = string("");
|
||||
struct PostOpParam {
|
||||
@ -80,7 +79,7 @@ struct MklConvFwdParams {
|
||||
memory::dims bias_dims, memory::dims dst_dims,
|
||||
memory::dims strides, memory::dims dilations,
|
||||
memory::dims padding_left, memory::dims padding_right,
|
||||
MKL_TENSOR_FORMAT tf_fmt, bool native_format)
|
||||
MklTensorFormat tf_fmt, bool native_format)
|
||||
: src_dims(src_dims),
|
||||
filter_dims(filter_dims),
|
||||
bias_dims(bias_dims),
|
||||
@ -99,7 +98,7 @@ template <typename Tinput, typename Tfilter, typename Tbias, typename Toutput>
|
||||
class MklConvFwdPrimitive : public MklPrimitive {
|
||||
public:
|
||||
explicit MklConvFwdPrimitive(const MklConvFwdParams& convFwdDims)
|
||||
: MklPrimitive(engine(ENGINE_CPU, 0)) {
|
||||
: MklPrimitive(engine(engine::kind::cpu, 0)) {
|
||||
// Create convolution primitive
|
||||
if (context_.conv_fwd == nullptr) {
|
||||
Setup(convFwdDims);
|
||||
@ -115,8 +114,8 @@ class MklConvFwdPrimitive : public MklPrimitive {
|
||||
void Execute(const Tinput* src_data, const Tfilter* filter_data,
|
||||
const Tbias* bias_data, const Toutput* dst_data,
|
||||
std::shared_ptr<stream> fwd_stream) {
|
||||
// TODO: Create a common function and avoid the duplicate code
|
||||
#ifdef ENABLE_MKLDNN_THREADPOOL
|
||||
// TODO: Create a common function and avoid the duplicate code
|
||||
context_.src_mem->set_data_handle(
|
||||
static_cast<void*>(const_cast<Tinput*>(src_data)), *fwd_stream);
|
||||
context_.filter_mem->set_data_handle(
|
||||
@ -139,16 +138,13 @@ class MklConvFwdPrimitive : public MklPrimitive {
|
||||
context_.dst_mem->set_data_handle(
|
||||
static_cast<void*>(const_cast<Toutput*>(dst_data)));
|
||||
#endif // ENABLE_MKLDNN_THREADPOOL
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
|
||||
DCHECK_EQ(context_.fwd_primitives.size(),
|
||||
context_.fwd_primitives_args.size());
|
||||
for (size_t i = 0; i < context_.fwd_primitives.size(); ++i) {
|
||||
context_.fwd_primitives.at(i).execute(*fwd_stream,
|
||||
context_.fwd_primitives_args.at(i));
|
||||
}
|
||||
#else
|
||||
fwd_stream->submit(context_.fwd_primitives);
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
|
||||
// After execution, set data handle back
|
||||
context_.src_mem->set_data_handle(DummyData);
|
||||
@ -168,13 +164,6 @@ class MklConvFwdPrimitive : public MklPrimitive {
|
||||
Execute(src_data, filter_data, nullptr, dst_data, fwd_stream);
|
||||
}
|
||||
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
// In MKL-DNN v1.x, memory format tags only provide a partial description
|
||||
// of the memory layout. Hence, these functions are disabled for v1.x.
|
||||
memory::format GetSrcMemoryFormat() const { return context_.src_fmt; }
|
||||
memory::format GetFilterMemoryFormat() const { return context_.filter_fmt; }
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
|
||||
std::shared_ptr<ConvFwdPd> GetPrimitiveDesc() const {
|
||||
return context_.fwd_pd;
|
||||
}
|
||||
@ -182,12 +171,6 @@ class MklConvFwdPrimitive : public MklPrimitive {
|
||||
private:
|
||||
// Primitive reuse context for Conv2D Fwd op
|
||||
struct ConvFwdContext {
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
// Expected memory format for this primitive instance
|
||||
memory::format src_fmt;
|
||||
memory::format filter_fmt;
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
|
||||
// MKL-DNN memory
|
||||
std::shared_ptr<mkldnn::memory> src_mem;
|
||||
std::shared_ptr<mkldnn::memory> filter_mem;
|
||||
@ -208,18 +191,10 @@ class MklConvFwdPrimitive : public MklPrimitive {
|
||||
std::shared_ptr<mkldnn::primitive> conv_fwd;
|
||||
|
||||
std::vector<mkldnn::primitive> fwd_primitives;
|
||||
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
std::vector<std::unordered_map<int, memory>> fwd_primitives_args;
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
|
||||
ConvFwdContext()
|
||||
:
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
src_fmt(memory::format::any),
|
||||
filter_fmt(memory::format::any),
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
src_mem(nullptr),
|
||||
: src_mem(nullptr),
|
||||
filter_mem(nullptr),
|
||||
bias_mem(nullptr),
|
||||
dst_mem(nullptr),
|
||||
@ -228,52 +203,45 @@ class MklConvFwdPrimitive : public MklPrimitive {
|
||||
filter_md(nullptr),
|
||||
bias_md(nullptr),
|
||||
fwd_pd(nullptr),
|
||||
conv_fwd(nullptr) {
|
||||
}
|
||||
conv_fwd(nullptr) {}
|
||||
};
|
||||
|
||||
void Setup(const MklConvFwdParams& convFwdDims) {
|
||||
MEMORY_FORMAT user_data_fmt;
|
||||
memory::format_tag user_data_fmt;
|
||||
if (convFwdDims.native_format) {
|
||||
user_data_fmt = MklTensorFormatToMklDnnDataFormat(convFwdDims.tf_fmt);
|
||||
} else {
|
||||
// Create memory descriptors for convolution data w/ no specified format
|
||||
user_data_fmt = MEMORY_FORMAT::any;
|
||||
user_data_fmt = memory::format_tag::any;
|
||||
}
|
||||
context_.src_md.reset(new memory::desc(
|
||||
{convFwdDims.src_dims}, MklDnnType<Tinput>(), user_data_fmt));
|
||||
|
||||
context_.filter_md.reset(new memory::desc(
|
||||
{convFwdDims.filter_dims}, MklDnnType<Tfilter>(), MEMORY_FORMAT::any));
|
||||
context_.filter_md.reset(new memory::desc({convFwdDims.filter_dims},
|
||||
MklDnnType<Tfilter>(),
|
||||
memory::format_tag::any));
|
||||
|
||||
context_.dst_md.reset(new memory::desc(
|
||||
{convFwdDims.dst_dims}, MklDnnType<Toutput>(), user_data_fmt));
|
||||
|
||||
if (!convFwdDims.bias_dims.empty())
|
||||
context_.bias_md.reset(new memory::desc(
|
||||
{convFwdDims.bias_dims}, MklDnnType<Tbias>(), MEMORY_FORMAT::any));
|
||||
context_.bias_md.reset(new memory::desc({convFwdDims.bias_dims},
|
||||
MklDnnType<Tbias>(),
|
||||
memory::format_tag::any));
|
||||
|
||||
// Create a convolution descriptor
|
||||
if (!convFwdDims.bias_dims.empty()) {
|
||||
context_.fwd_desc.reset(new convolution_forward::desc(
|
||||
prop_kind::forward, ALGORITHM::convolution_direct, *context_.src_md,
|
||||
*context_.filter_md, *context_.bias_md, *context_.dst_md,
|
||||
convFwdDims.strides, convFwdDims.dilations, convFwdDims.padding_left,
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
convFwdDims.padding_right, padding_kind::zero));
|
||||
#else
|
||||
convFwdDims.padding_right));
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
prop_kind::forward, mkldnn::algorithm::convolution_direct,
|
||||
*context_.src_md, *context_.filter_md, *context_.bias_md,
|
||||
*context_.dst_md, convFwdDims.strides, convFwdDims.dilations,
|
||||
convFwdDims.padding_left, convFwdDims.padding_right));
|
||||
} else {
|
||||
context_.fwd_desc.reset(new convolution_forward::desc(
|
||||
prop_kind::forward, ALGORITHM::convolution_direct, *context_.src_md,
|
||||
*context_.filter_md, *context_.dst_md, convFwdDims.strides,
|
||||
convFwdDims.dilations, convFwdDims.padding_left,
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
convFwdDims.padding_right, padding_kind::zero));
|
||||
#else
|
||||
prop_kind::forward, mkldnn::algorithm::convolution_direct,
|
||||
*context_.src_md, *context_.filter_md, *context_.dst_md,
|
||||
convFwdDims.strides, convFwdDims.dilations, convFwdDims.padding_left,
|
||||
convFwdDims.padding_right));
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
}
|
||||
|
||||
context_.fwd_pd.reset(new ConvFwdPd(*context_.fwd_desc, cpu_engine_));
|
||||
@ -314,54 +282,32 @@ class MklConvFwdPrimitive : public MklPrimitive {
|
||||
context_.fwd_pd.reset(new ConvFwdPd(*context_.fwd_desc, cpu_engine_));
|
||||
}
|
||||
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
// Store the expected memory format
|
||||
context_.src_fmt = static_cast<mkldnn::memory::format>(
|
||||
context_.fwd_pd.get()->src_primitive_desc().desc().data.format);
|
||||
|
||||
context_.filter_fmt = static_cast<mkldnn::memory::format>(
|
||||
context_.fwd_pd.get()->weights_primitive_desc().desc().data.format);
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
|
||||
// Create memory primitive based on dummy data
|
||||
context_.src_mem.reset(new MEMORY_CONSTRUCTOR(
|
||||
context_.fwd_pd.get()->PRIMITIVE_DESC_SRC, cpu_engine_, DummyData));
|
||||
context_.filter_mem.reset(new MEMORY_CONSTRUCTOR(
|
||||
context_.fwd_pd.get()->PRIMITIVE_DESC_WEIGHTS, cpu_engine_, DummyData));
|
||||
context_.dst_mem.reset(new MEMORY_CONSTRUCTOR(
|
||||
context_.fwd_pd.get()->PRIMITIVE_DESC_DST, cpu_engine_, DummyData));
|
||||
context_.src_mem.reset(
|
||||
new memory(context_.fwd_pd.get()->src_desc(), cpu_engine_, DummyData));
|
||||
context_.filter_mem.reset(new memory(context_.fwd_pd.get()->weights_desc(),
|
||||
cpu_engine_, DummyData));
|
||||
context_.dst_mem.reset(
|
||||
new memory(context_.fwd_pd.get()->dst_desc(), cpu_engine_, DummyData));
|
||||
|
||||
// Create convolution primitive and add it to net
|
||||
if (!convFwdDims.bias_dims.empty()) {
|
||||
context_.bias_mem.reset(new MEMORY_CONSTRUCTOR_USING_MEM_PD(
|
||||
convFwdDims.bias_dims, Tbias, MEMORY_FORMAT::x, cpu_engine_,
|
||||
DummyData));
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
context_.bias_mem.reset(new memory(
|
||||
{{convFwdDims.bias_dims}, MklDnnType<Tbias>(), memory::format_tag::x},
|
||||
cpu_engine_, DummyData));
|
||||
context_.conv_fwd.reset(new convolution_forward(*context_.fwd_pd));
|
||||
context_.fwd_primitives_args.push_back(
|
||||
{{MKLDNN_ARG_SRC, *context_.src_mem},
|
||||
{MKLDNN_ARG_WEIGHTS, *context_.filter_mem},
|
||||
{MKLDNN_ARG_BIAS, *context_.bias_mem},
|
||||
{ MKLDNN_ARG_DST,
|
||||
*context_.dst_mem }});
|
||||
{MKLDNN_ARG_DST, *context_.dst_mem}});
|
||||
} else {
|
||||
context_.conv_fwd.reset(new convolution_forward(*context_.fwd_pd));
|
||||
context_.fwd_primitives_args.push_back(
|
||||
{{MKLDNN_ARG_SRC, *context_.src_mem},
|
||||
{MKLDNN_ARG_WEIGHTS, *context_.filter_mem},
|
||||
{ MKLDNN_ARG_DST,
|
||||
*context_.dst_mem }});
|
||||
{MKLDNN_ARG_DST, *context_.dst_mem}});
|
||||
}
|
||||
#else
|
||||
context_.conv_fwd.reset(new convolution_forward(
|
||||
*context_.fwd_pd, *context_.src_mem, *context_.filter_mem,
|
||||
*context_.bias_mem, *context_.dst_mem));
|
||||
} else {
|
||||
context_.conv_fwd.reset(
|
||||
new convolution_forward(*context_.fwd_pd, *context_.src_mem,
|
||||
*context_.filter_mem, *context_.dst_mem));
|
||||
}
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
context_.fwd_primitives.push_back(*context_.conv_fwd);
|
||||
}
|
||||
|
||||
@ -650,12 +596,10 @@ class MklConvOp : public OpKernel {
|
||||
auto tf_fmt = is_conv2d ? TFDataFormatToMklDnnDataFormat(data_format_)
|
||||
: TFDataFormatToMklDnn3DDataFormat(data_format_);
|
||||
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
auto mkl_fmt_tag = MklTensorFormatToMklDnnDataFormat(tf_fmt);
|
||||
// NOTE: `mkl_fmt_tag` will be `format_tag::undef` for ReLU
|
||||
OP_REQUIRES(context, mkl_fmt_tag != memory::format_tag::undef,
|
||||
errors::InvalidArgument("Invalid data format"));
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
|
||||
// If input is in MKL layout, then simply grab the layout; otherwise,
|
||||
// construct TF layout for input.
|
||||
@ -667,19 +611,15 @@ class MklConvOp : public OpKernel {
|
||||
auto src_md =
|
||||
src_mkl_shape.IsMklTensor()
|
||||
? src_mkl_shape.GetMklLayout()
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
: memory::desc(src_dims, MklDnnType<Tinput>(), mkl_fmt_tag);
|
||||
#else
|
||||
: memory::desc(src_dims, MklDnnType<Tinput>(), tf_fmt);
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
src.SetUsrMem(src_md, &src_tensor);
|
||||
|
||||
// Although filter shape (filter_dims) required is in MKL-DNN order,
|
||||
// the layout is Tensorflow's layout (HWIO) and (HWIGO) for
|
||||
// depthwise/group convolutions.
|
||||
auto filter_format = is_conv2d ? (is_depthwise ? MEMORY_FORMAT::hwigo
|
||||
: MEMORY_FORMAT::hwio)
|
||||
: MEMORY_FORMAT::dhwio;
|
||||
auto filter_format = is_conv2d ? (is_depthwise ? memory::format_tag::hwigo
|
||||
: memory::format_tag::hwio)
|
||||
: memory::format_tag::dhwio;
|
||||
|
||||
DCHECK(!filter_mkl_shape.IsMklTensor());
|
||||
auto filter_md =
|
||||
@ -738,12 +678,9 @@ class MklConvOp : public OpKernel {
|
||||
|
||||
// Check whether src and filter need to be reordered.
|
||||
Tinput* src_data = nullptr;
|
||||
if (IS_SRC_REORDER_NEEDED(src_md, conv_fwd_pd, conv_fwd)) {
|
||||
if (src_md != conv_fwd_pd->src_desc()) {
|
||||
src.SetUsrMem(src_md, &src_tensor);
|
||||
src.CheckReorderToOpMem(
|
||||
MEMORY_PD_WITHOUT_DATA(GET_SRC_DESC_FROM_OP_PD(conv_fwd_pd),
|
||||
cpu_engine_),
|
||||
context);
|
||||
src.CheckReorderToOpMem(conv_fwd_pd->src_desc(), cpu_engine_, context);
|
||||
src_data = static_cast<Tinput*>(src.GetOpMem().get_data_handle());
|
||||
} else {
|
||||
src_data = static_cast<Tinput*>(
|
||||
@ -751,7 +688,7 @@ class MklConvOp : public OpKernel {
|
||||
}
|
||||
|
||||
Tfilter* filter_data = nullptr;
|
||||
if (IS_FILTER_REORDER_NEEDED(filter_md, conv_fwd_pd, conv_fwd)) {
|
||||
if (filter_md != conv_fwd_pd->weights_desc()) {
|
||||
bool is_filter_cached = false;
|
||||
// If filter is a constant, we can avoid the conversion of filter from
|
||||
// Tensorflow format to MKL format by caching the filter when it is
|
||||
@ -761,28 +698,20 @@ class MklConvOp : public OpKernel {
|
||||
if (IsFilterCacheEmpty(context)) {
|
||||
// Cache filter if it is not already cached.
|
||||
CacheFilter(context, conv_fwd_pd, filter_data, filter_tensor,
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
filter, filter_md, filter_mkl_shape);
|
||||
#else
|
||||
filter, filter_md);
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
}
|
||||
filter_data = GetCachedFilter(
|
||||
context, GET_WEIGHTS_FORMAT_FROM_OP_PD(conv_fwd_pd, conv_fwd));
|
||||
filter_data = GetCachedFilter(context, conv_fwd_pd->weights_desc());
|
||||
is_filter_cached = (filter_data != nullptr);
|
||||
}
|
||||
if (!is_filter_cached) {
|
||||
filter.SetUsrMem(filter_md, &filter_tensor);
|
||||
if (filter_out_tensor == nullptr) {
|
||||
filter.CheckReorderToOpMem(
|
||||
MEMORY_PD_WITHOUT_DATA(GET_WEIGHTS_DESC_FROM_OP_PD(conv_fwd_pd),
|
||||
cpu_engine_),
|
||||
filter.CheckReorderToOpMem(conv_fwd_pd->weights_desc(), cpu_engine_,
|
||||
context);
|
||||
} else {
|
||||
filter.CheckReorderToOpMem(
|
||||
GET_WEIGHTS_DESC_FROM_OP_PD(conv_fwd_pd),
|
||||
DATA_WITH_ENGINE(filter.GetTensorBuffer(filter_out_tensor),
|
||||
cpu_engine_),
|
||||
conv_fwd_pd->weights_desc(),
|
||||
filter.GetTensorBuffer(filter_out_tensor), cpu_engine_,
|
||||
context);
|
||||
}
|
||||
filter_data =
|
||||
@ -897,7 +826,8 @@ class MklConvOp : public OpKernel {
|
||||
// NOTE: Fusion of BiasAdd is handled directly inside MklConvOp by
|
||||
// checking `fuse_biasadd_` flag.
|
||||
if (fuse_add_) {
|
||||
params.post_op_params.push_back({"sum", ALGORITHM_UNDEF, {1.0}, ""});
|
||||
params.post_op_params.push_back(
|
||||
{"sum", mkldnn::algorithm::undef, {1.0}, ""});
|
||||
}
|
||||
if (fuse_activation_) {
|
||||
params.post_op_params.push_back(
|
||||
@ -918,35 +848,27 @@ class MklConvOp : public OpKernel {
|
||||
virtual void AllocateOutputTensor(OpKernelContext* context,
|
||||
const ConvFwdPd& conv_prim_desc,
|
||||
const memory::dims& output_dims_mkl_order,
|
||||
MKL_TENSOR_FORMAT output_tf_format,
|
||||
MklTensorFormat output_tf_format,
|
||||
MklDnnShape* output_mkl_shape,
|
||||
Tensor** output_tensor) {
|
||||
DCHECK(output_tensor);
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
auto dst_md = conv_prim_desc.dst_desc();
|
||||
#else
|
||||
auto dst_pd = conv_prim_desc.dst_primitive_desc();
|
||||
auto dst_md = dst_pd.desc();
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
|
||||
if (!std::is_same<Ttemp_output, Toutput>::value) {
|
||||
dst_md.data.data_type =
|
||||
static_cast<mkldnn_data_type_t>(MklDnnType<Toutput>());
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
dst_pd = memory::primitive_desc(dst_md, cpu_engine_);
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
}
|
||||
|
||||
// Allocate shape of MKL tensor
|
||||
output_mkl_shape->SetMklTensor(true);
|
||||
output_mkl_shape->SetMklLayout(&DST_MD);
|
||||
output_mkl_shape->SetMklLayout(&dst_md);
|
||||
output_mkl_shape->SetElemType(MklDnnType<Toutput>());
|
||||
output_mkl_shape->SetTfLayout(output_dims_mkl_order.size(),
|
||||
output_dims_mkl_order, output_tf_format);
|
||||
|
||||
// Allocate shape of TF tensor
|
||||
TensorShape output_tf_shape;
|
||||
output_tf_shape.AddDim((DST_MD.get_size() / sizeof(Toutput)));
|
||||
output_tf_shape.AddDim((dst_md.get_size() / sizeof(Toutput)));
|
||||
if (native_format) {
|
||||
output_tf_shape = output_mkl_shape->GetTfShape();
|
||||
}
|
||||
@ -972,23 +894,16 @@ class MklConvOp : public OpKernel {
|
||||
AllocateOutputSetMklShape(context, kOutputIndex_Dst, output_tensor,
|
||||
output_tf_shape, *output_mkl_shape,
|
||||
native_format);
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
auto output_format_tag = MklTensorFormatToMklDnnDataFormat(
|
||||
output_mkl_shape->GetTfDataFormat());
|
||||
OP_REQUIRES(context, output_format_tag != memory::format_tag::undef,
|
||||
errors::InvalidArgument(
|
||||
"MklConvOp: AddN fusion: Invalid data format"));
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
auto add_md =
|
||||
add_mkl_shape.IsMklTensor()
|
||||
? add_mkl_shape.GetMklLayout()
|
||||
: memory::desc(output_dims_mkl_order, MklDnnType<Toutput>(),
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
output_format_tag);
|
||||
#else
|
||||
output_mkl_shape->GetTfDataFormat());
|
||||
auto add_pd = memory::primitive_desc(add_md, this->cpu_engine_);
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
void* add_buf = static_cast<void*>(
|
||||
const_cast<Toutput*>(add_tensor.flat<Toutput>().data()));
|
||||
void* dst_buf =
|
||||
@ -996,16 +911,14 @@ class MklConvOp : public OpKernel {
|
||||
if (native_format) {
|
||||
// We are simply deep copying the add_tensor to output_tensor without
|
||||
// changing memory layout, hence using same memory descriptor.
|
||||
ADD_MD = DST_MD =
|
||||
add_md = dst_md =
|
||||
memory::desc({add_tensor.NumElements()}, MklDnnType<Toutput>(),
|
||||
mkldnn::memory::format_tag::x);
|
||||
}
|
||||
fuse_add_src_.reset(
|
||||
new MEMORY_CONSTRUCTOR(ADD_MD, this->cpu_engine_, add_buf));
|
||||
fuse_add_dst_.reset(
|
||||
new MEMORY_CONSTRUCTOR(DST_MD, this->cpu_engine_, dst_buf));
|
||||
fuse_add_src_.reset(new memory(add_md, this->cpu_engine_, add_buf));
|
||||
fuse_add_dst_.reset(new memory(dst_md, this->cpu_engine_, dst_buf));
|
||||
auto reorder_desc =
|
||||
REORDER_PD_CONSTRUCTOR(ADD_MD, DST_MD, this->cpu_engine_);
|
||||
ReorderPd(this->cpu_engine_, add_md, this->cpu_engine_, dst_md);
|
||||
|
||||
CreateAndExecuteReorder(reorder_desc, *fuse_add_src_, *fuse_add_dst_,
|
||||
this->cpu_engine_, context);
|
||||
@ -1017,7 +930,7 @@ class MklConvOp : public OpKernel {
|
||||
}
|
||||
}
|
||||
|
||||
engine cpu_engine_ = engine(ENGINE_CPU, 0);
|
||||
engine cpu_engine_ = engine(engine::kind::cpu, 0);
|
||||
|
||||
private:
|
||||
std::shared_ptr<mkldnn::memory> fuse_add_src_;
|
||||
@ -1041,7 +954,7 @@ class MklConvOp : public OpKernel {
|
||||
// This variable is used for alpha in leakyrelu or upper bound in relu6
|
||||
// depending on the context
|
||||
float alpha_or_upbound_ = 0.0;
|
||||
mkldnn::algorithm activation_alg_ = ALGORITHM_UNDEF;
|
||||
mkldnn::algorithm activation_alg_ = mkldnn::algorithm::undef;
|
||||
|
||||
int input_index_pad_ = 2;
|
||||
|
||||
@ -1050,15 +963,10 @@ class MklConvOp : public OpKernel {
|
||||
const int kOutputIndex_Dst = 0, kOutputIndex_Filter = 1;
|
||||
const int kDilationH = 0, kDilationW = 1;
|
||||
|
||||
MKL_TENSOR_FORMAT_IN_C GetFilterTfDataFormat(
|
||||
const MklDnnShape* filter_mkl_shape,
|
||||
MklTensorFormat GetFilterTfDataFormat(const MklDnnShape* filter_mkl_shape,
|
||||
const ConvFwdPd& conv_prim_desc) const {
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
DCHECK(filter_mkl_shape);
|
||||
return filter_mkl_shape->GetTfDataFormat();
|
||||
#else
|
||||
return conv_prim_desc.weights_primitive_desc().desc().data.format;
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
}
|
||||
|
||||
// Allocate persistent tensors for cached filter data and
|
||||
@ -1070,23 +978,13 @@ class MklConvOp : public OpKernel {
|
||||
DCHECK(filter_tensor);
|
||||
TensorShape filter_tf_shape;
|
||||
filter_tf_shape.AddDim(
|
||||
(conv_prim_desc.PRIMITIVE_DESC_WEIGHTS.get_size() / sizeof(Tfilter)));
|
||||
(conv_prim_desc.weights_desc().get_size() / sizeof(Tfilter)));
|
||||
OP_REQUIRES_OK(context, context->allocate_persistent(
|
||||
DataTypeToEnum<Tfilter>::value, filter_tf_shape,
|
||||
&cached_filter_data_ptensor_, filter_tensor));
|
||||
|
||||
Tensor* second_tensor = nullptr;
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
TensorShape filter_mkl_format;
|
||||
filter_mkl_format.AddDim(
|
||||
sizeof(GetFilterTfDataFormat(filter_mkl_shape, conv_prim_desc)) /
|
||||
sizeof(DT_INT32));
|
||||
OP_REQUIRES_OK(context, context->allocate_persistent(
|
||||
DT_INT32, filter_mkl_format,
|
||||
&cached_filter_md_ptensor_, &second_tensor));
|
||||
second_tensor->scalar<int32>()() = static_cast<int32>(
|
||||
GetFilterTfDataFormat(filter_mkl_shape, conv_prim_desc));
|
||||
#else
|
||||
|
||||
// There is no tensor format in DNNL 1.x. So we cache the complete filter
|
||||
// descriptor as flat byte array.
|
||||
TensorShape cached_filter_md_shape;
|
||||
@ -1100,7 +998,6 @@ class MklConvOp : public OpKernel {
|
||||
&cached_filter_md_ptensor_, &second_tensor));
|
||||
*reinterpret_cast<memory::desc*>(second_tensor->flat<uint8>().data()) =
|
||||
weights_desc;
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
}
|
||||
|
||||
void AllocatePersistentTensor(OpKernelContext* context,
|
||||
@ -1114,7 +1011,7 @@ class MklConvOp : public OpKernel {
|
||||
const memory::dims& filter_dims_tf_order,
|
||||
Tensor** filter_tensor) {
|
||||
DCHECK(filter_tensor);
|
||||
auto filter_md = conv_prim_desc.PRIMITIVE_DESC_WEIGHTS;
|
||||
auto filter_md = conv_prim_desc.weights_desc();
|
||||
|
||||
// Allocate shape of MKL tensor
|
||||
MklDnnShape filter_mkl_shape;
|
||||
@ -1127,7 +1024,7 @@ class MklConvOp : public OpKernel {
|
||||
// is stored in the MKL data.
|
||||
filter_mkl_shape.SetTfLayout(filter_dims_tf_order.size(),
|
||||
filter_dims_tf_order,
|
||||
MKL_TENSOR_FORMAT_BLOCKED);
|
||||
MklTensorFormat::FORMAT_BLOCKED);
|
||||
|
||||
// Allocate the data space for the filter to propagate as TF tensor.
|
||||
TensorShape filter_tf_shape;
|
||||
@ -1150,17 +1047,15 @@ class MklConvOp : public OpKernel {
|
||||
// Create reorders between user layout and MKL layout if it is needed and
|
||||
// add it to the net before convolution. No need to check for output
|
||||
// reorder as we propagate output layout to the next layer.
|
||||
src->CheckReorderToOpMem(
|
||||
MEMORY_PD_WITHOUT_DATA(conv_prim_desc.PRIMITIVE_DESC_SRC, cpu_engine_));
|
||||
src->CheckReorderToOpMem(conv_prim_desc.src_desc(), cpu_engine_);
|
||||
|
||||
// Rather than re-ordering to a temp buffer, reorder directly to the
|
||||
// filter output tensor
|
||||
filter->CheckReorderToOpMem(conv_prim_desc.PRIMITIVE_DESC_WEIGHTS,
|
||||
filter->CheckReorderToOpMem(conv_prim_desc.weights_desc(),
|
||||
filter->GetTensorBuffer(filter_out_tensor));
|
||||
|
||||
// Create convolution primitive and add it to net.
|
||||
std::vector<primitive> net;
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
std::vector<std::unordered_map<int, memory>> net_args;
|
||||
if (bias) {
|
||||
DCHECK(fuse_biasadd_);
|
||||
@ -1168,31 +1063,15 @@ class MklConvOp : public OpKernel {
|
||||
net_args.push_back({{MKLDNN_ARG_SRC, src->GetOpMem()},
|
||||
{MKLDNN_ARG_WEIGHTS, filter->GetOpMem()},
|
||||
{MKLDNN_ARG_BIAS, bias->GetOpMem()},
|
||||
{ MKLDNN_ARG_DST,
|
||||
output->GetOpMem() }});
|
||||
{MKLDNN_ARG_DST, output->GetOpMem()}});
|
||||
} else {
|
||||
DCHECK(!fuse_biasadd_);
|
||||
net.push_back(convolution_forward(conv_prim_desc));
|
||||
net_args.push_back({{MKLDNN_ARG_SRC, src->GetOpMem()},
|
||||
{MKLDNN_ARG_WEIGHTS, filter->GetOpMem()},
|
||||
{ MKLDNN_ARG_DST,
|
||||
output->GetOpMem() }});
|
||||
{MKLDNN_ARG_DST, output->GetOpMem()}});
|
||||
}
|
||||
ExecutePrimitive(net, &net_args, cpu_engine_);
|
||||
#else
|
||||
if (bias) {
|
||||
DCHECK(fuse_biasadd_);
|
||||
net.push_back(convolution_forward(conv_prim_desc, src->GetOpMem(),
|
||||
filter->GetOpMem(), bias->GetOpMem(),
|
||||
output->GetOpMem()));
|
||||
} else {
|
||||
DCHECK(!fuse_biasadd_);
|
||||
net.push_back(convolution_forward(conv_prim_desc, src->GetOpMem(),
|
||||
filter->GetOpMem(),
|
||||
output->GetOpMem()));
|
||||
}
|
||||
ExecutePrimitive(net, nullptr, cpu_engine_);
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
}
|
||||
|
||||
// TF_LOCKS_EXCLUDED annotation ensures that the lock (mu_) cannot
|
||||
@ -1208,7 +1087,6 @@ class MklConvOp : public OpKernel {
|
||||
|
||||
// Cache the converted filter in a persistent tensor.
|
||||
// Only one thread can execute this method at any given time.
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
void CacheFilter(OpKernelContext* context,
|
||||
const std::shared_ptr<ConvFwdPd>& conv_fwd_pd,
|
||||
Tfilter* filter_data, const Tensor& filter_tensor,
|
||||
@ -1254,37 +1132,8 @@ class MklConvOp : public OpKernel {
|
||||
return true;
|
||||
}
|
||||
|
||||
#else
|
||||
void CacheFilter(OpKernelContext* context,
|
||||
const std::shared_ptr<ConvFwdPd>& conv_fwd_pd,
|
||||
Tfilter* filter_data, const Tensor& filter_tensor,
|
||||
MklDnnData<Tfilter>& filter, const memory::desc& filter_md)
|
||||
TF_LOCKS_EXCLUDED(mu_) {
|
||||
mutex_lock lock(mu_);
|
||||
const Tensor& cached_filter_data_tensor =
|
||||
*cached_filter_data_ptensor_.AccessTensor(context);
|
||||
|
||||
// If filter is already cached, there's nothing to do.
|
||||
if (cached_filter_data_tensor.NumElements() > 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Otherwise, cache filter
|
||||
filter.SetUsrMem(filter_md, &filter_tensor);
|
||||
filter.CheckReorderToOpMem(conv_fwd_pd.get()->weights_primitive_desc());
|
||||
filter_data = static_cast<Tfilter*>(filter.GetOpMem().get_data_handle());
|
||||
|
||||
Tensor* filter_tensor_ptr = nullptr;
|
||||
AllocatePersistentTensor(context, *conv_fwd_pd, &filter_tensor_ptr);
|
||||
void* cached_filter_data = filter.GetTensorBuffer(filter_tensor_ptr);
|
||||
size_t cached_filter_data_size =
|
||||
filter.GetOpMem().get_primitive_desc().get_size();
|
||||
memcpy(cached_filter_data, filter_data, cached_filter_data_size);
|
||||
}
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
|
||||
Tfilter* GetCachedFilter(OpKernelContext* context,
|
||||
const MEMORY_DESC& filter_md)
|
||||
const memory::desc& filter_md)
|
||||
TF_LOCKS_EXCLUDED(mu_) {
|
||||
tf_shared_lock lock(mu_);
|
||||
const Tensor& cached_filter_data =
|
||||
@ -1292,15 +1141,10 @@ class MklConvOp : public OpKernel {
|
||||
const Tensor& cached_filter_md =
|
||||
*cached_filter_md_ptensor_.AccessTensor(context);
|
||||
|
||||
// Check if the memory descriptor of the cached weights is same as
|
||||
// Check if the memory descriptor of the cached weights is the same as
|
||||
// filter_md. If so, we can use the cached weights; otherwise
|
||||
// return nullptr.
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
if (filter_md == *static_cast<memory::desc*>(cached_filter_md.data())) {
|
||||
#else
|
||||
if (cached_filter_md.scalar<int32>().size() &&
|
||||
cached_filter_md.scalar<int32>()() == filter_md) {
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
return static_cast<Tfilter*>(
|
||||
const_cast<Tfilter*>(cached_filter_data.flat<Tfilter>().data()));
|
||||
}
|
||||
@ -1336,31 +1180,34 @@ class MklFusedConvOp
|
||||
errors::InvalidArgument(
|
||||
"Fused Conv2D must have one extra argument: bias."));
|
||||
} else if (fused_ops == std::vector<string>{"Relu"}) {
|
||||
this->set_fuse_activation(true, ALGORITHM::eltwise_relu);
|
||||
this->set_fuse_activation(true, mkldnn::algorithm::eltwise_relu);
|
||||
} else if (fused_ops == std::vector<string>{"Relu6"}) {
|
||||
this->set_fuse_activation(true, ALGORITHM::eltwise_bounded_relu, 6.0);
|
||||
this->set_fuse_activation(true, mkldnn::algorithm::eltwise_bounded_relu,
|
||||
6.0);
|
||||
} else if (fused_ops == std::vector<string>{"Elu"}) {
|
||||
this->set_fuse_activation(true, ALGORITHM::eltwise_elu, 1.0);
|
||||
this->set_fuse_activation(true, mkldnn::algorithm::eltwise_elu, 1.0);
|
||||
} else if (fused_ops == std::vector<string>{"LeakyRelu"}) {
|
||||
float leakyrelu_alpha;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->GetAttr("leakyrelu_alpha", &leakyrelu_alpha));
|
||||
this->set_fuse_activation(true, ALGORITHM::eltwise_relu, leakyrelu_alpha);
|
||||
this->set_fuse_activation(true, mkldnn::algorithm::eltwise_relu,
|
||||
leakyrelu_alpha);
|
||||
} else if (fused_ops == std::vector<string>{"BiasAdd", "Relu"}) {
|
||||
this->set_fuse_biasadd(true);
|
||||
this->set_fuse_activation(true, ALGORITHM::eltwise_relu);
|
||||
this->set_fuse_activation(true, mkldnn::algorithm::eltwise_relu);
|
||||
OP_REQUIRES(context, num_args == 1,
|
||||
errors::InvalidArgument(
|
||||
"Fused Conv2D must have one extra argument: bias."));
|
||||
} else if (fused_ops == std::vector<string>{"BiasAdd", "Relu6"}) {
|
||||
this->set_fuse_biasadd(true);
|
||||
this->set_fuse_activation(true, ALGORITHM::eltwise_bounded_relu, 6.0);
|
||||
this->set_fuse_activation(true, mkldnn::algorithm::eltwise_bounded_relu,
|
||||
6.0);
|
||||
OP_REQUIRES(context, num_args == 1,
|
||||
errors::InvalidArgument(
|
||||
"Fused Conv2D must have one extra argument: bias."));
|
||||
} else if (fused_ops == std::vector<string>{"BiasAdd", "Elu"}) {
|
||||
this->set_fuse_biasadd(true);
|
||||
this->set_fuse_activation(true, ALGORITHM::eltwise_elu, 1.0);
|
||||
this->set_fuse_activation(true, mkldnn::algorithm::eltwise_elu, 1.0);
|
||||
OP_REQUIRES(context, num_args == 1,
|
||||
errors::InvalidArgument(
|
||||
"Fused Conv2D must have one extra argument: bias."));
|
||||
@ -1369,7 +1216,8 @@ class MklFusedConvOp
|
||||
float leakyrelu_alpha;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->GetAttr("leakyrelu_alpha", &leakyrelu_alpha));
|
||||
this->set_fuse_activation(true, ALGORITHM::eltwise_relu, leakyrelu_alpha);
|
||||
this->set_fuse_activation(true, mkldnn::algorithm::eltwise_relu,
|
||||
leakyrelu_alpha);
|
||||
OP_REQUIRES(context, num_args == 1,
|
||||
errors::InvalidArgument(
|
||||
"Fused Conv2D must have one extra argument: bias."));
|
||||
@ -1383,7 +1231,7 @@ class MklFusedConvOp
|
||||
} else if (fused_ops == std::vector<string>{"BiasAdd", "Add", "Relu"}) {
|
||||
this->set_fuse_biasadd(true);
|
||||
this->set_fuse_add(true);
|
||||
this->set_fuse_activation(true, ALGORITHM::eltwise_relu);
|
||||
this->set_fuse_activation(true, mkldnn::algorithm::eltwise_relu);
|
||||
OP_REQUIRES(
|
||||
context, num_args == 2,
|
||||
errors::InvalidArgument(
|
||||
@ -1391,7 +1239,8 @@ class MklFusedConvOp
|
||||
} else if (fused_ops == std::vector<string>{"BiasAdd", "Add", "Relu6"}) {
|
||||
this->set_fuse_biasadd(true);
|
||||
this->set_fuse_add(true);
|
||||
this->set_fuse_activation(true, ALGORITHM::eltwise_bounded_relu, 6.0);
|
||||
this->set_fuse_activation(true, mkldnn::algorithm::eltwise_bounded_relu,
|
||||
6.0);
|
||||
OP_REQUIRES(
|
||||
context, num_args == 2,
|
||||
errors::InvalidArgument(
|
||||
@ -1399,7 +1248,7 @@ class MklFusedConvOp
|
||||
} else if (fused_ops == std::vector<string>{"BiasAdd", "Add", "Elu"}) {
|
||||
this->set_fuse_biasadd(true);
|
||||
this->set_fuse_add(true);
|
||||
this->set_fuse_activation(true, ALGORITHM::eltwise_elu, 1.0);
|
||||
this->set_fuse_activation(true, mkldnn::algorithm::eltwise_elu, 1.0);
|
||||
OP_REQUIRES(
|
||||
context, num_args == 2,
|
||||
errors::InvalidArgument(
|
||||
@ -1411,7 +1260,8 @@ class MklFusedConvOp
|
||||
float leakyrelu_alpha;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->GetAttr("leakyrelu_alpha", &leakyrelu_alpha));
|
||||
this->set_fuse_activation(true, ALGORITHM::eltwise_relu, leakyrelu_alpha);
|
||||
this->set_fuse_activation(true, mkldnn::algorithm::eltwise_relu,
|
||||
leakyrelu_alpha);
|
||||
OP_REQUIRES(
|
||||
context, num_args == 2,
|
||||
errors::InvalidArgument(
|
||||
@ -1459,13 +1309,14 @@ class MklFusedDepthwiseConvOp
|
||||
this->set_fuse_biasadd(true);
|
||||
} else if (fused_ops == std::vector<string>{"BiasAdd", "Relu"}) {
|
||||
this->set_fuse_biasadd(true);
|
||||
this->set_fuse_activation(true, ALGORITHM::eltwise_relu);
|
||||
this->set_fuse_activation(true, mkldnn::algorithm::eltwise_relu);
|
||||
} else if (fused_ops == std::vector<string>{"BiasAdd", "Relu6"}) {
|
||||
this->set_fuse_biasadd(true);
|
||||
this->set_fuse_activation(true, ALGORITHM::eltwise_bounded_relu, 6.0);
|
||||
this->set_fuse_activation(true, mkldnn::algorithm::eltwise_bounded_relu,
|
||||
6.0);
|
||||
} else if (fused_ops == std::vector<string>{"BiasAdd", "Elu"}) {
|
||||
this->set_fuse_biasadd(true);
|
||||
this->set_fuse_activation(true, ALGORITHM::eltwise_elu, 1.0);
|
||||
this->set_fuse_activation(true, mkldnn::algorithm::eltwise_elu, 1.0);
|
||||
} else {
|
||||
OP_REQUIRES(context, false,
|
||||
errors::Unimplemented("Fusion is not implemented: [",
|
||||
@ -1642,8 +1493,8 @@ class MklQuantizedConv2DOp
|
||||
param_key.AddAsKey<float>(max_freezed_output);
|
||||
param_key.AddAsKey<const float*>(min_filter);
|
||||
param_key.AddAsKey<const float*>(max_filter);
|
||||
params.post_op_params.push_back(
|
||||
{"output_scale", ALGORITHM_UNDEF, scales, param_key.GetKey()});
|
||||
params.post_op_params.push_back({"output_scale", mkldnn::algorithm::undef,
|
||||
scales, param_key.GetKey()});
|
||||
}
|
||||
}
|
||||
|
||||
@ -1696,31 +1547,27 @@ class MklQuantizedConv2DOp
|
||||
bias_attr.set_output_scales(1, scales_);
|
||||
}
|
||||
|
||||
auto bias_md =
|
||||
MEMORY_PD_CONSTRUCTOR(static_cast<int>(bias_tensor.NumElements()),
|
||||
Tbias, MEMORY_FORMAT::x, this->cpu_engine_);
|
||||
auto bias_md = memory::desc({static_cast<int>(bias_tensor.NumElements())},
|
||||
MklDnnType<Tbias>(), memory::format_tag::x);
|
||||
void* bias_buf = static_cast<void*>(
|
||||
const_cast<Tbias*>(bias_tensor.flat<Tbias>().data()));
|
||||
if (!input_bias_) {
|
||||
input_bias_ =
|
||||
new MEMORY_CONSTRUCTOR(bias_md, this->cpu_engine_, bias_buf);
|
||||
input_bias_ = new memory(bias_md, this->cpu_engine_, bias_buf);
|
||||
} else {
|
||||
input_bias_->set_data_handle(bias_buf);
|
||||
}
|
||||
|
||||
if (!scaled_bias_buf_)
|
||||
AllocTmpBuffer<Tbias>(context, &scaled_bias_tensor_,
|
||||
GET_BIAS_DESC_FROM_OP_PD(conv_fwd_pd),
|
||||
&scaled_bias_buf_);
|
||||
conv_fwd_pd->bias_desc(), &scaled_bias_buf_);
|
||||
if (!scaled_bias_) {
|
||||
scaled_bias_ = new MEMORY_CONSTRUCTOR(bias_md, this->cpu_engine_,
|
||||
scaled_bias_buf_);
|
||||
scaled_bias_ = new memory(bias_md, this->cpu_engine_, scaled_bias_buf_);
|
||||
} else {
|
||||
scaled_bias_->set_data_handle(scaled_bias_buf_);
|
||||
}
|
||||
auto reorder_desc = REORDER_PD_CONSTRUCTOR_WITH_ATTR(
|
||||
input_bias_->GET_DESC, scaled_bias_->GET_DESC, this->cpu_engine_,
|
||||
bias_attr);
|
||||
auto reorder_desc =
|
||||
ReorderPd(this->cpu_engine_, input_bias_->get_desc(),
|
||||
this->cpu_engine_, scaled_bias_->get_desc(), bias_attr);
|
||||
CreateAndExecuteReorder(reorder_desc, *input_bias_, *scaled_bias_,
|
||||
this->cpu_engine_, context);
|
||||
|
||||
@ -1754,7 +1601,7 @@ class MklQuantizedConv2DOp
|
||||
DCHECK(bias_tensor);
|
||||
TensorShape bias_tf_shape;
|
||||
bias_tf_shape.AddDim(
|
||||
(conv_prim_desc.PRIMITIVE_DESC_BIAS.get_size() / sizeof(Tbias)));
|
||||
(conv_prim_desc.bias_desc().get_size() / sizeof(Tbias)));
|
||||
OP_REQUIRES_OK(context, context->allocate_persistent(
|
||||
DataTypeToEnum<Tbias>::value, bias_tf_shape,
|
||||
&cached_bias_data_ptensor_, bias_tensor));
|
||||
@ -1787,7 +1634,7 @@ class MklQuantizedConv2DOp
|
||||
AllocatePersistentTensor(context, *conv_fwd_pd, &bias_tensor_ptr);
|
||||
void* cached_bias_data = const_cast<void*>(
|
||||
static_cast<const void*>(bias_tensor_ptr->flat<Tbias>().data()));
|
||||
size_t cached_bias_data_size = scaled_bias->GET_DESC.get_size();
|
||||
size_t cached_bias_data_size = scaled_bias->get_desc().get_size();
|
||||
memcpy(cached_bias_data, bias_data, cached_bias_data_size);
|
||||
}
|
||||
|
||||
@ -1822,7 +1669,7 @@ class MklQuantizedConv2DReluOp
|
||||
is_depthwise>::ExtendConvFwdParams(context, params);
|
||||
|
||||
params.post_op_params.push_back(
|
||||
{"activation", ALGORITHM::eltwise_relu, {1.0, 0.0, 0.0}, ""});
|
||||
{"activation", mkldnn::algorithm::eltwise_relu, {1.0, 0.0, 0.0}, ""});
|
||||
}
|
||||
};
|
||||
|
||||
@ -1868,26 +1715,30 @@ class MklQuantizedConv2DSumReluOp
|
||||
// if summand_type is also DT_QUINT8 as the scale_output,
|
||||
// the scaling factor of 255.0f cancels each other and thus is avoided.
|
||||
// If it is not then it is DT_INT8 and is scaled appropriately.
|
||||
if (summand_type == DT_QUINT8)
|
||||
params.post_op_params.push_back(
|
||||
{"sum", ALGORITHM_UNDEF, {scale_summand / scale_output}, ""});
|
||||
else
|
||||
params.post_op_params.push_back(
|
||||
{"sum",
|
||||
ALGORITHM_UNDEF,
|
||||
{255.0f * scale_summand / (scale_output * 127.0f)},
|
||||
if (summand_type == DT_QUINT8) {
|
||||
params.post_op_params.push_back({"sum",
|
||||
mkldnn::algorithm::undef,
|
||||
{scale_summand / scale_output},
|
||||
""});
|
||||
} else {
|
||||
params.post_op_params.push_back({"sum", ALGORITHM_UNDEF, {1.0}, ""});
|
||||
params.post_op_params.push_back(
|
||||
{"sum",
|
||||
mkldnn::algorithm::undef,
|
||||
{255.0f * scale_summand / (scale_output * 127.0f)},
|
||||
""});
|
||||
}
|
||||
} else {
|
||||
params.post_op_params.push_back(
|
||||
{"sum", mkldnn::algorithm::undef, {1.0}, ""});
|
||||
}
|
||||
params.post_op_params.push_back(
|
||||
{"activation", ALGORITHM::eltwise_relu, {1.0, 0.0, 0.0}, ""});
|
||||
{"activation", mkldnn::algorithm::eltwise_relu, {1.0, 0.0, 0.0}, ""});
|
||||
}
|
||||
|
||||
void AllocateOutputTensor(OpKernelContext* context,
|
||||
const ConvFwdPd& conv_prim_desc,
|
||||
const memory::dims& output_dims_mkl_order,
|
||||
MKL_TENSOR_FORMAT output_tf_format,
|
||||
MklTensorFormat output_tf_format,
|
||||
MklDnnShape* output_mkl_shape,
|
||||
Tensor** output_tensor) override {
|
||||
int summand_idx = context->num_inputs() / 2 - 1;
|
||||
@ -1966,21 +1817,17 @@ class MklQuantizedConv2DSumReluOp
|
||||
summand_mkl_shape.IsMklTensor()
|
||||
? summand_mkl_shape.GetMklLayout()
|
||||
: memory::desc(output_dims_mkl_order, MklDnnType<Tbias>(),
|
||||
MEMORY_FORMAT::nhwc);
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
auto summand_pd = memory::primitive_desc(summand_md, this->cpu_engine_);
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
memory::format_tag::nhwc);
|
||||
void* summand_buf =
|
||||
static_cast<void*>(const_cast<Tbias*>(summand.flat<Tbias>().data()));
|
||||
void* dst_buf =
|
||||
static_cast<void*>((*output_tensor)->flat<Ttemp_output>().data());
|
||||
summand_.reset(
|
||||
new MEMORY_CONSTRUCTOR(SUMMAND_MD, this->cpu_engine_, summand_buf));
|
||||
dst_.reset(new MEMORY_CONSTRUCTOR(conv_prim_desc.PRIMITIVE_DESC_DST,
|
||||
this->cpu_engine_, dst_buf));
|
||||
auto reorder_desc = REORDER_PD_CONSTRUCTOR_WITH_ATTR(
|
||||
SUMMAND_MD, conv_prim_desc.PRIMITIVE_DESC_DST, this->cpu_engine_,
|
||||
reorder_attr);
|
||||
summand_.reset(new memory(summand_md, this->cpu_engine_, summand_buf));
|
||||
dst_.reset(
|
||||
new memory(conv_prim_desc.dst_desc(), this->cpu_engine_, dst_buf));
|
||||
auto reorder_desc =
|
||||
ReorderPd(this->cpu_engine_, summand_md, this->cpu_engine_,
|
||||
conv_prim_desc.dst_desc(), reorder_attr);
|
||||
CreateAndExecuteReorder(reorder_desc, *summand_, *dst_, this->cpu_engine_,
|
||||
context);
|
||||
}
|
||||
|
@ -42,20 +42,13 @@ limitations under the License.
|
||||
#include "tensorflow/core/util/padding.h"
|
||||
#include "tensorflow/core/util/tensor_format.h"
|
||||
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
using mkldnn::convolution_direct;
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
using mkldnn::convolution_forward;
|
||||
using mkldnn::prop_kind;
|
||||
using mkldnn::stream;
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
#define MKLDNN_SIZE_DTYPE memory::dim
|
||||
#else
|
||||
#define MKLDNN_SIZE_DTYPE int
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
|
||||
using ConvFwdDesc = mkldnn::convolution_forward::desc;
|
||||
using ConvFwdPd = mkldnn::convolution_forward::primitive_desc;
|
||||
|
@ -134,6 +134,7 @@ tf_kernel_library(
|
||||
"gpu_op_bitwise_and.cc",
|
||||
"gpu_op_bitwise_or.cc",
|
||||
"gpu_op_bitwise_xor.cc",
|
||||
"gpu_op_div.cc",
|
||||
"gpu_op_equal.cc",
|
||||
"gpu_op_floor_div.cc",
|
||||
"gpu_op_greater.cc",
|
||||
@ -146,6 +147,7 @@ tf_kernel_library(
|
||||
"gpu_op_mul.cc",
|
||||
"gpu_op_not_equal.cc",
|
||||
"gpu_op_right_shift.cc",
|
||||
"gpu_op_sub.cc",
|
||||
],
|
||||
tags = [
|
||||
"manual",
|
||||
@ -155,6 +157,7 @@ tf_kernel_library(
|
||||
":bitwise_and_kernels",
|
||||
":bitwise_or_kernels",
|
||||
":bitwise_xor_kernels",
|
||||
":div_kernels",
|
||||
":equal_kernels",
|
||||
":floor_div_kernels",
|
||||
":gpu_ops_base",
|
||||
@ -170,6 +173,7 @@ tf_kernel_library(
|
||||
":mul_kernels",
|
||||
":not_equal_kernels",
|
||||
":right_shift_kernels",
|
||||
":sub_kernels",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
)
|
||||
@ -366,8 +370,9 @@ gen_kernel_library(
|
||||
unroll_factors = "4",
|
||||
)
|
||||
|
||||
[
|
||||
gen_kernel_library(
|
||||
name = "add_v2",
|
||||
name = name,
|
||||
tile_size = "256,1,1",
|
||||
types = [
|
||||
"f16",
|
||||
@ -378,6 +383,11 @@ gen_kernel_library(
|
||||
# TODO(b/174543802): Enable once fusion heuristics is better.
|
||||
# unroll_factors = "4",
|
||||
)
|
||||
for name in [
|
||||
"add_v2",
|
||||
"sub",
|
||||
]
|
||||
]
|
||||
|
||||
gen_kernel_library(
|
||||
name = "complex",
|
||||
@ -390,6 +400,20 @@ gen_kernel_library(
|
||||
# unroll_factors = "2",
|
||||
)
|
||||
|
||||
gen_kernel_library(
|
||||
name = "div",
|
||||
tile_size = "256,1,1",
|
||||
types = [
|
||||
"f16",
|
||||
"f32",
|
||||
"f64",
|
||||
"i16",
|
||||
"i64",
|
||||
],
|
||||
# TODO(b/174543802): Enable once fusion heuristics is better.
|
||||
# unroll_factors = "4",
|
||||
)
|
||||
|
||||
gen_kernel_library(
|
||||
name = "mul",
|
||||
tile_size = "256,1,1",
|
||||
|
@ -48,14 +48,17 @@ class GpuBinaryOpTest : public OpsTestBase {
|
||||
void SetOpKernel(const std::string& op_name, const TensorShape& lhs_shape,
|
||||
const absl::InlinedVector<T, 10>& lhs_input,
|
||||
const TensorShape& rhs_shape,
|
||||
const absl::InlinedVector<T, 10>& rhs_input,
|
||||
bool use_constraint) {
|
||||
const absl::InlinedVector<T, 10>& rhs_input, bool add_t,
|
||||
bool add_tout) {
|
||||
auto builder = NodeDefBuilder("some_name", op_name)
|
||||
.Input(FakeInput(DataTypeToEnum<T>::v()))
|
||||
.Input(FakeInput(DataTypeToEnum<T>::v()));
|
||||
if (use_constraint) {
|
||||
if (add_t) {
|
||||
builder.Attr("T", DataTypeToEnum<T>::v());
|
||||
}
|
||||
if (add_tout) {
|
||||
builder.Attr("Tout", DataTypeToEnum<OutT>::v());
|
||||
}
|
||||
TF_ASSERT_OK(builder.Finalize(node_def()));
|
||||
|
||||
TF_ASSERT_OK(InitOp());
|
||||
@ -73,16 +76,20 @@ class GpuBinaryOpTest : public OpsTestBase {
|
||||
const absl::InlinedVector<T, 10>& rhs_input,
|
||||
const TensorShape& expected_shape,
|
||||
const absl::InlinedVector<OutT, 10>& expected_output,
|
||||
bool use_constraint = true) {
|
||||
const test::GpuOpsTestConfig& config) {
|
||||
SetOpKernel<T, OutT>(op_name, lhs_shape, lhs_input, rhs_shape, rhs_input,
|
||||
use_constraint);
|
||||
config.add_t, config.add_tout);
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
// Compare output to expectation.
|
||||
Tensor expected_tensor(allocator(), DataTypeToEnum<OutT>::value,
|
||||
expected_shape);
|
||||
test::FillValues<OutT>(&expected_tensor, expected_output);
|
||||
if (config.expect_strictly_equal) {
|
||||
test::ExpectEqual(expected_tensor, *GetOutput(0));
|
||||
} else {
|
||||
test::ExpectClose(expected_tensor, *GetOutput(0));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename OutT>
|
||||
@ -91,9 +98,9 @@ class GpuBinaryOpTest : public OpsTestBase {
|
||||
const absl::InlinedVector<T, 10>& lhs_input,
|
||||
const TensorShape& rhs_shape,
|
||||
const absl::InlinedVector<T, 10>& rhs_input,
|
||||
bool use_constraint = true) {
|
||||
const test::GpuOpsTestConfig& config) {
|
||||
SetOpKernel<T, OutT>(op_name, lhs_shape, lhs_input, rhs_shape, rhs_input,
|
||||
use_constraint);
|
||||
config.add_t, config.add_tout);
|
||||
auto status = RunOpKernel();
|
||||
EXPECT_FALSE(status.ok());
|
||||
EXPECT_EQ(status.code(), error::INVALID_ARGUMENT);
|
||||
@ -105,7 +112,7 @@ class GpuBinaryOpTest : public OpsTestBase {
|
||||
void TestIncompatibleShapes(const std::string& op_name,
|
||||
const absl::InlinedVector<T, 10>& lhs_input,
|
||||
const absl::InlinedVector<T, 10>& rhs_input,
|
||||
bool use_constraint = true) {
|
||||
const test::GpuOpsTestConfig& config) {
|
||||
// Prepare incompatibly shaped inputs.
|
||||
TensorShape lhs_shape{3};
|
||||
TensorShape rhs_shape{2};
|
||||
@ -115,8 +122,7 @@ class GpuBinaryOpTest : public OpsTestBase {
|
||||
test::RepeatInputToMatchShape(rhs_input, rhs_shape.num_elements());
|
||||
|
||||
RunAndExpectInvalidArgument<T, OutT>(op_name, lhs_shape, repeated_lhs_input,
|
||||
rhs_shape, repeated_rhs_input,
|
||||
use_constraint);
|
||||
rhs_shape, repeated_rhs_input, config);
|
||||
}
|
||||
|
||||
template <typename T, typename BaselineT, typename OutT,
|
||||
@ -125,7 +131,7 @@ class GpuBinaryOpTest : public OpsTestBase {
|
||||
const absl::InlinedVector<T, 10>& lhs_input,
|
||||
const absl::InlinedVector<T, 10>& rhs_input,
|
||||
BaselineOutT (*baseline_callback)(BaselineT, BaselineT),
|
||||
bool use_constraint = true) {
|
||||
const test::GpuOpsTestConfig& config) {
|
||||
// Prepare inputs.
|
||||
int input_size = shape.num_elements();
|
||||
auto repeated_lhs_input =
|
||||
@ -147,7 +153,7 @@ class GpuBinaryOpTest : public OpsTestBase {
|
||||
|
||||
RunAndExpectResult<T, OutT>(op_name, shape, repeated_lhs_input, shape,
|
||||
repeated_rhs_input, shape, expected_output,
|
||||
use_constraint);
|
||||
config);
|
||||
}
|
||||
|
||||
template <typename T, typename BaselineT, typename OutT,
|
||||
@ -156,7 +162,7 @@ class GpuBinaryOpTest : public OpsTestBase {
|
||||
const TensorShape& other_shape,
|
||||
const absl::InlinedVector<T, 10>& other_input,
|
||||
BaselineOutT (*baseline_callback)(BaselineT, BaselineT),
|
||||
bool use_constraint = true) {
|
||||
const test::GpuOpsTestConfig& config) {
|
||||
// Prepare inputs.
|
||||
TensorShape scalar_shape{};
|
||||
auto repeated_other_input =
|
||||
@ -177,7 +183,7 @@ class GpuBinaryOpTest : public OpsTestBase {
|
||||
RunAndExpectResult<T, OutT>(op_name, scalar_shape, scalar_input_vector,
|
||||
other_shape, repeated_other_input,
|
||||
/*expected_shape=*/other_shape, expected_output,
|
||||
use_constraint);
|
||||
config);
|
||||
}
|
||||
|
||||
template <typename T, typename BaselineT, typename OutT,
|
||||
@ -187,7 +193,7 @@ class GpuBinaryOpTest : public OpsTestBase {
|
||||
const absl::InlinedVector<T, 10>& rhs_input,
|
||||
BaselineOutT (*baseline_callback)(BaselineT,
|
||||
BaselineT),
|
||||
bool use_constraint = true) {
|
||||
const test::GpuOpsTestConfig& config) {
|
||||
// Prepare inputs.
|
||||
TensorShape lhs_shape{1};
|
||||
TensorShape rhs_shape{6};
|
||||
@ -206,7 +212,7 @@ class GpuBinaryOpTest : public OpsTestBase {
|
||||
|
||||
RunAndExpectResult<T, OutT>(
|
||||
op_name, lhs_shape, repeated_lhs_input, rhs_shape, repeated_rhs_input,
|
||||
/*expected_shape=*/rhs_shape, expected_output, use_constraint);
|
||||
/*expected_shape=*/rhs_shape, expected_output, config);
|
||||
}
|
||||
|
||||
template <typename T, typename BaselineT, typename OutT,
|
||||
@ -216,7 +222,7 @@ class GpuBinaryOpTest : public OpsTestBase {
|
||||
const absl::InlinedVector<T, 10>& rhs_input,
|
||||
BaselineOutT (*baseline_callback)(BaselineT,
|
||||
BaselineT),
|
||||
bool use_constraint = true) {
|
||||
const test::GpuOpsTestConfig& config) {
|
||||
// Prepare inputs.
|
||||
TensorShape lhs_shape{3};
|
||||
TensorShape rhs_shape{2, 3};
|
||||
@ -235,7 +241,7 @@ class GpuBinaryOpTest : public OpsTestBase {
|
||||
|
||||
RunAndExpectResult<T, OutT>(
|
||||
op_name, lhs_shape, repeated_lhs_input, rhs_shape, repeated_rhs_input,
|
||||
/*expected_shape=*/rhs_shape, expected_output, use_constraint);
|
||||
/*expected_shape=*/rhs_shape, expected_output, config);
|
||||
}
|
||||
|
||||
template <typename T, typename BaselineT, typename OutT,
|
||||
@ -244,7 +250,7 @@ class GpuBinaryOpTest : public OpsTestBase {
|
||||
const absl::InlinedVector<T, 10>& lhs_input,
|
||||
const absl::InlinedVector<T, 10>& rhs_input,
|
||||
BaselineOutT (*baseline_callback)(BaselineT, BaselineT),
|
||||
bool use_constraint = true) {
|
||||
const test::GpuOpsTestConfig& config) {
|
||||
// Prepare inputs.
|
||||
TensorShape lhs_shape{2, 1};
|
||||
TensorShape rhs_shape{3};
|
||||
@ -264,7 +270,7 @@ class GpuBinaryOpTest : public OpsTestBase {
|
||||
|
||||
RunAndExpectResult<T, OutT>(op_name, lhs_shape, repeated_lhs_input,
|
||||
rhs_shape, repeated_rhs_input, expected_shape,
|
||||
expected_output, use_constraint);
|
||||
expected_output, config);
|
||||
}
|
||||
|
||||
template <typename T, typename BaselineT, typename OutT,
|
||||
@ -272,7 +278,7 @@ class GpuBinaryOpTest : public OpsTestBase {
|
||||
void TestEmptyShapeBroadcasting(const std::string& op_name,
|
||||
const absl::InlinedVector<T, 10>& lhs_input,
|
||||
const absl::InlinedVector<T, 10>& rhs_input,
|
||||
bool use_constraint = true) {
|
||||
const test::GpuOpsTestConfig& config) {
|
||||
// Prepare inputs.
|
||||
TensorShape lhs_shape{2, 0, 1};
|
||||
TensorShape rhs_shape{2, 0, 5};
|
||||
@ -284,7 +290,7 @@ class GpuBinaryOpTest : public OpsTestBase {
|
||||
|
||||
RunAndExpectResult<T, OutT>(op_name, lhs_shape, empty_input, rhs_shape,
|
||||
empty_input, expected_shape, expected_output,
|
||||
use_constraint);
|
||||
config);
|
||||
}
|
||||
|
||||
private:
|
||||
@ -309,60 +315,60 @@ class GpuBinaryOpTest : public OpsTestBase {
|
||||
// define your own test fixtures.
|
||||
|
||||
#define GENERATE_DEFAULT_TESTS_2(op_name, test_name, T, BaselineT, OutT, \
|
||||
BaselineOutT, baseline_callback, \
|
||||
use_constraint) \
|
||||
BaselineOutT, baseline_callback, config) \
|
||||
TEST_F(GpuBinaryOpTest, op_name##EqShapes##test_name) { \
|
||||
TestEqualShapes<T, BaselineT, OutT, BaselineOutT>( \
|
||||
#op_name, /*shape=*/test::DefaultInputShape(), \
|
||||
/*lhs_input=*/test::DefaultInput<T>(#op_name), \
|
||||
/*rhs_input=*/test::DefaultInput<T>(#op_name), baseline_callback, \
|
||||
use_constraint); \
|
||||
config); \
|
||||
} \
|
||||
\
|
||||
TEST_F(GpuBinaryOpTest, op_name##OneScalar##test_name) { \
|
||||
TestOneScalar<T, BaselineT, OutT, BaselineOutT>( \
|
||||
#op_name, /*scalar_input=*/test::DefaultScalarInput<T>(), \
|
||||
#op_name, /*scalar_input=*/test::DefaultInput<T>(#op_name).front(), \
|
||||
/*other_shape=*/test::DefaultInputShape(), \
|
||||
/*other_input=*/test::DefaultInput<T>(#op_name), baseline_callback, \
|
||||
use_constraint); \
|
||||
config); \
|
||||
} \
|
||||
\
|
||||
TEST_F(GpuBinaryOpTest, op_name##IncompatibleShapes##test_name) { \
|
||||
TestIncompatibleShapes<T, OutT>( \
|
||||
#op_name, /*lhs_input=*/test::DefaultInput<T>(#op_name), \
|
||||
/*rhs_input=*/test::DefaultInput<T>(#op_name), use_constraint); \
|
||||
/*rhs_input=*/test::DefaultInput<T>(#op_name), config); \
|
||||
} \
|
||||
\
|
||||
TEST_F(GpuBinaryOpTest, op_name##BroadcastingExpand##test_name) { \
|
||||
TestBroadcastingExpand<T, BaselineT, OutT, BaselineOutT>( \
|
||||
#op_name, /*lhs_input=*/test::DefaultInput<T>(#op_name), \
|
||||
/*rhs_input=*/test::DefaultInput<T>(#op_name), baseline_callback, \
|
||||
use_constraint); \
|
||||
config); \
|
||||
} \
|
||||
\
|
||||
TEST_F(GpuBinaryOpTest, op_name##BroadcastingInDim##test_name) { \
|
||||
TestBroadcastingInDim<T, BaselineT, OutT, BaselineOutT>( \
|
||||
#op_name, /*lhs_input=*/test::DefaultInput<T>(#op_name), \
|
||||
/*rhs_input=*/test::DefaultInput<T>(#op_name), baseline_callback, \
|
||||
use_constraint); \
|
||||
config); \
|
||||
} \
|
||||
\
|
||||
TEST_F(GpuBinaryOpTest, op_name##Broadcasting##test_name) { \
|
||||
TestBroadcasting<T, BaselineT, OutT, BaselineOutT>( \
|
||||
#op_name, /*lhs_input=*/test::DefaultInput<T>(#op_name), \
|
||||
/*rhs_input=*/test::DefaultInput<T>(#op_name), baseline_callback, \
|
||||
use_constraint); \
|
||||
config); \
|
||||
} \
|
||||
\
|
||||
TEST_F(GpuBinaryOpTest, op_name##EmptyShapeBroadcasting##test_name) { \
|
||||
TestEmptyShapeBroadcasting<T, BaselineT, OutT, BaselineOutT>( \
|
||||
#op_name, /*lhs_input=*/test::DefaultInput<T>(#op_name), \
|
||||
/*rhs_input=*/test::DefaultInput<T>(#op_name), use_constraint); \
|
||||
/*rhs_input=*/test::DefaultInput<T>(#op_name), config); \
|
||||
}
|
||||
|
||||
#define GENERATE_DEFAULT_TESTS(op_name, test_name, T, OutT, baseline_callback) \
|
||||
GENERATE_DEFAULT_TESTS_2(op_name, test_name, T, T, OutT, OutT, \
|
||||
baseline_callback, /*use_constraint=*/false)
|
||||
baseline_callback, \
|
||||
test::GpuOpsTestConfig().ExpectStrictlyEqual())
|
||||
|
||||
#define GENERATE_DEFAULT_TESTS_SAME_INPUT_AND_OUTPUT_TYPE( \
|
||||
op_name, test_name, T, baseline_callback) \
|
||||
@ -433,37 +439,23 @@ GENERATE_DEFAULT_TESTS(BitwiseXor,
|
||||
GENERATE_DEFAULT_TESTS(BitwiseXor,
|
||||
/*test_name=*/Int64, int64, int64, baseline_bitwise_xor)
|
||||
|
||||
/// Test `tf.LeftShift`.
|
||||
|
||||
/// Test `tf.Div`.
|
||||
template <typename T>
|
||||
T baseline_left_shift(T lhs, T rhs) {
|
||||
return lhs << rhs;
|
||||
T baseline_div(T lhs, T rhs) {
|
||||
return lhs / rhs;
|
||||
}
|
||||
|
||||
GENERATE_DEFAULT_TESTS(LeftShift, /*test_name=*/Int8, int8, int8,
|
||||
baseline_left_shift)
|
||||
GENERATE_DEFAULT_TESTS(LeftShift, /*test_name=*/Int16, int16, int16,
|
||||
baseline_left_shift)
|
||||
GENERATE_DEFAULT_TESTS(LeftShift, /*test_name=*/Int32, int32, int32,
|
||||
baseline_left_shift)
|
||||
GENERATE_DEFAULT_TESTS(LeftShift, /*test_name=*/Int64, int64, int64,
|
||||
baseline_left_shift)
|
||||
|
||||
/// Test `tf.RightShift`.
|
||||
|
||||
template <typename T>
|
||||
T baseline_right_shift(T lhs, T rhs) {
|
||||
return lhs >> rhs;
|
||||
}
|
||||
|
||||
GENERATE_DEFAULT_TESTS(RightShift,
|
||||
/*test_name=*/Int8, int8, int8, baseline_right_shift)
|
||||
GENERATE_DEFAULT_TESTS(RightShift,
|
||||
/*test_name=*/Int16, int16, int16, baseline_right_shift)
|
||||
GENERATE_DEFAULT_TESTS(RightShift,
|
||||
/*test_name=*/Int32, int32, int32, baseline_right_shift)
|
||||
GENERATE_DEFAULT_TESTS(RightShift,
|
||||
/*test_name=*/Int64, int64, int64, baseline_right_shift)
|
||||
GENERATE_DEFAULT_TESTS(Div,
|
||||
/*test_name=*/Half, Eigen::half, Eigen::half,
|
||||
baseline_div);
|
||||
GENERATE_DEFAULT_TESTS(Div,
|
||||
/*test_name=*/Float, float, float, baseline_div);
|
||||
GENERATE_DEFAULT_TESTS(Div,
|
||||
/*test_name=*/Double, double, double, baseline_div);
|
||||
GENERATE_DEFAULT_TESTS(Div,
|
||||
/*test_name=*/Int16, int16, int16, baseline_div);
|
||||
GENERATE_DEFAULT_TESTS(Div,
|
||||
/*test_name=*/Int64, int64, int64, baseline_div);
|
||||
|
||||
/// Test `tf.Equal`.
|
||||
|
||||
@ -482,27 +474,25 @@ GENERATE_DEFAULT_TESTS(Equal, /*test_name=*/Int8, int8, bool, baseline_equal)
|
||||
GENERATE_DEFAULT_TESTS(Equal, /*test_name=*/Int16, int16, bool, baseline_equal)
|
||||
GENERATE_DEFAULT_TESTS(Equal, /*test_name=*/Int64, int64, bool, baseline_equal)
|
||||
|
||||
/// Test `tf.NotEqual`.
|
||||
/// Test `tf.FloorDiv`.
|
||||
|
||||
template <typename T>
|
||||
bool baseline_not_equal(T lhs, T rhs) {
|
||||
return lhs != rhs;
|
||||
T baseline_floor_div(T lhs, T rhs) {
|
||||
return std::floor(lhs / rhs);
|
||||
}
|
||||
|
||||
GENERATE_DEFAULT_TESTS(NotEqual, /*test_name=*/Half, Eigen::half, bool,
|
||||
baseline_not_equal)
|
||||
GENERATE_DEFAULT_TESTS(NotEqual, /*test_name=*/Float, float, bool,
|
||||
baseline_not_equal)
|
||||
GENERATE_DEFAULT_TESTS(NotEqual, /*test_name=*/Double, double, bool,
|
||||
baseline_not_equal)
|
||||
GENERATE_DEFAULT_TESTS(NotEqual, /*test_name=*/Bool, bool, bool,
|
||||
baseline_not_equal)
|
||||
GENERATE_DEFAULT_TESTS(NotEqual, /*test_name=*/Int8, int8, bool,
|
||||
baseline_not_equal)
|
||||
GENERATE_DEFAULT_TESTS(NotEqual, /*test_name=*/Int16, int16, bool,
|
||||
baseline_not_equal)
|
||||
GENERATE_DEFAULT_TESTS(NotEqual, /*test_name=*/Int64, int64, bool,
|
||||
baseline_not_equal)
|
||||
template <>
|
||||
Eigen::half baseline_floor_div(Eigen::half lhs, Eigen::half rhs) {
|
||||
return static_cast<Eigen::half>(std::floor(static_cast<float>(lhs / rhs)));
|
||||
}
|
||||
|
||||
GENERATE_DEFAULT_TESTS(FloorDiv,
|
||||
/*test_name=*/Half, Eigen::half, Eigen::half,
|
||||
baseline_floor_div)
|
||||
GENERATE_DEFAULT_TESTS(FloorDiv,
|
||||
/*test_name=*/Float, float, float, baseline_floor_div)
|
||||
GENERATE_DEFAULT_TESTS(FloorDiv,
|
||||
/*test_name=*/Double, double, double, baseline_floor_div)
|
||||
|
||||
/// Test `tf.Greater`.
|
||||
|
||||
@ -544,6 +534,22 @@ GENERATE_DEFAULT_TESTS(GreaterEqual, /*test_name=*/Int16, int16, bool,
|
||||
GENERATE_DEFAULT_TESTS(GreaterEqual, /*test_name=*/Int64, int64, bool,
|
||||
baseline_greater_equal)
|
||||
|
||||
/// Test `tf.LeftShift`.
|
||||
|
||||
template <typename T>
|
||||
T baseline_left_shift(T lhs, T rhs) {
|
||||
return lhs << rhs;
|
||||
}
|
||||
|
||||
GENERATE_DEFAULT_TESTS(LeftShift, /*test_name=*/Int8, int8, int8,
|
||||
baseline_left_shift)
|
||||
GENERATE_DEFAULT_TESTS(LeftShift, /*test_name=*/Int16, int16, int16,
|
||||
baseline_left_shift)
|
||||
GENERATE_DEFAULT_TESTS(LeftShift, /*test_name=*/Int32, int32, int32,
|
||||
baseline_left_shift)
|
||||
GENERATE_DEFAULT_TESTS(LeftShift, /*test_name=*/Int64, int64, int64,
|
||||
baseline_left_shift)
|
||||
|
||||
/// Test `tf.Less`.
|
||||
|
||||
template <typename T>
|
||||
@ -586,7 +592,7 @@ bool baseline_logical_and(bool lhs, bool rhs) { return lhs && rhs; }
|
||||
GENERATE_DEFAULT_TESTS_2(LogicalAnd, /*test_name=*/Bool, /*T=*/bool,
|
||||
/*BaselineT=*/bool, /*OutT=*/bool,
|
||||
/*BaselineOutT=*/bool, baseline_logical_and,
|
||||
/*use_constraint=*/false)
|
||||
test::GpuOpsTestConfig().ExpectStrictlyEqual().NoT())
|
||||
|
||||
/// Test `tf.LogicalOr`.
|
||||
|
||||
@ -595,27 +601,77 @@ bool baseline_logical_or(bool lhs, bool rhs) { return lhs || rhs; }
|
||||
GENERATE_DEFAULT_TESTS_2(LogicalOr, /*test_name=*/Bool, /*T=*/bool,
|
||||
/*BaselineT=*/bool, /*OutT=*/bool,
|
||||
/*BaselineOutT=*/bool, baseline_logical_or,
|
||||
/*use_constraint=*/false)
|
||||
test::GpuOpsTestConfig().ExpectStrictlyEqual().NoT())
|
||||
|
||||
/// Test `tf.Mul`.
|
||||
|
||||
/// Test `tf.FloorDiv`.
|
||||
template <typename T>
|
||||
T baseline_floor_div(T lhs, T rhs) {
|
||||
return std::floor(lhs / rhs);
|
||||
T baseline_mul(T lhs, T rhs) {
|
||||
return lhs * rhs;
|
||||
}
|
||||
|
||||
template <>
|
||||
Eigen::half baseline_floor_div(Eigen::half lhs, Eigen::half rhs) {
|
||||
return static_cast<Eigen::half>(std::floor(static_cast<float>(lhs / rhs)));
|
||||
GENERATE_DEFAULT_TESTS(Mul, /*test_name=*/Half, Eigen::half, Eigen::half,
|
||||
baseline_mul)
|
||||
GENERATE_DEFAULT_TESTS(Mul, /*test_name=*/Float, float, float, baseline_mul)
|
||||
GENERATE_DEFAULT_TESTS(Mul, /*test_name=*/Double, double, double, baseline_mul)
|
||||
GENERATE_DEFAULT_TESTS(Mul, /*test_name=*/Int8, int8, int8, baseline_mul)
|
||||
GENERATE_DEFAULT_TESTS(Mul, /*test_name=*/Int16, int16, int16, baseline_mul)
|
||||
GENERATE_DEFAULT_TESTS(Mul, /*test_name=*/Int64, int64, int64, baseline_mul)
|
||||
|
||||
/// Test `tf.NotEqual`.
|
||||
|
||||
template <typename T>
|
||||
bool baseline_not_equal(T lhs, T rhs) {
|
||||
return lhs != rhs;
|
||||
}
|
||||
|
||||
GENERATE_DEFAULT_TESTS(FloorDiv,
|
||||
GENERATE_DEFAULT_TESTS(NotEqual, /*test_name=*/Half, Eigen::half, bool,
|
||||
baseline_not_equal)
|
||||
GENERATE_DEFAULT_TESTS(NotEqual, /*test_name=*/Float, float, bool,
|
||||
baseline_not_equal)
|
||||
GENERATE_DEFAULT_TESTS(NotEqual, /*test_name=*/Double, double, bool,
|
||||
baseline_not_equal)
|
||||
GENERATE_DEFAULT_TESTS(NotEqual, /*test_name=*/Bool, bool, bool,
|
||||
baseline_not_equal)
|
||||
GENERATE_DEFAULT_TESTS(NotEqual, /*test_name=*/Int8, int8, bool,
|
||||
baseline_not_equal)
|
||||
GENERATE_DEFAULT_TESTS(NotEqual, /*test_name=*/Int16, int16, bool,
|
||||
baseline_not_equal)
|
||||
GENERATE_DEFAULT_TESTS(NotEqual, /*test_name=*/Int64, int64, bool,
|
||||
baseline_not_equal)
|
||||
|
||||
/// Test `tf.RightShift`.
|
||||
|
||||
template <typename T>
|
||||
T baseline_right_shift(T lhs, T rhs) {
|
||||
return lhs >> rhs;
|
||||
}
|
||||
|
||||
GENERATE_DEFAULT_TESTS(RightShift,
|
||||
/*test_name=*/Int8, int8, int8, baseline_right_shift)
|
||||
GENERATE_DEFAULT_TESTS(RightShift,
|
||||
/*test_name=*/Int16, int16, int16, baseline_right_shift)
|
||||
GENERATE_DEFAULT_TESTS(RightShift,
|
||||
/*test_name=*/Int32, int32, int32, baseline_right_shift)
|
||||
GENERATE_DEFAULT_TESTS(RightShift,
|
||||
/*test_name=*/Int64, int64, int64, baseline_right_shift)
|
||||
|
||||
/// Test `tf.Sub`.
|
||||
|
||||
template <typename T>
|
||||
T baseline_sub(T lhs, T rhs) {
|
||||
return lhs - rhs;
|
||||
}
|
||||
|
||||
GENERATE_DEFAULT_TESTS(Sub,
|
||||
/*test_name=*/Half, Eigen::half, Eigen::half,
|
||||
baseline_floor_div);
|
||||
GENERATE_DEFAULT_TESTS(FloorDiv,
|
||||
/*test_name=*/Float, float, float, baseline_floor_div);
|
||||
GENERATE_DEFAULT_TESTS(FloorDiv,
|
||||
/*test_name=*/Double, double, double,
|
||||
baseline_floor_div);
|
||||
baseline_sub)
|
||||
GENERATE_DEFAULT_TESTS(Sub,
|
||||
/*test_name=*/Float, float, float, baseline_sub)
|
||||
GENERATE_DEFAULT_TESTS(Sub,
|
||||
/*test_name=*/Double, double, double, baseline_sub)
|
||||
GENERATE_DEFAULT_TESTS(Sub,
|
||||
/*test_name=*/Int64, int64, int64, baseline_sub)
|
||||
|
||||
} // namespace
|
||||
} // end namespace tensorflow
|
||||
|
27
tensorflow/core/kernels/mlir_generated/gpu_op_div.cc
Normal file
27
tensorflow/core/kernels/mlir_generated/gpu_op_div.cc
Normal file
@ -0,0 +1,27 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL(Div, f16, DT_HALF, Eigen::half);
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL(Div, f32, DT_FLOAT, float);
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL(Div, f64, DT_DOUBLE, double);
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL(Div, i16, DT_INT16, int16);
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL(Div, i64, DT_INT64, int64);
|
||||
|
||||
} // namespace tensorflow
|
26
tensorflow/core/kernels/mlir_generated/gpu_op_sub.cc
Normal file
26
tensorflow/core/kernels/mlir_generated/gpu_op_sub.cc
Normal file
@ -0,0 +1,26 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL(Sub, f16, DT_HALF, Eigen::half);
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL(Sub, f32, DT_FLOAT, float);
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL(Sub, f64, DT_DOUBLE, double);
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL(Sub, i64, DT_INT64, int64);
|
||||
|
||||
} // namespace tensorflow
|
@ -57,6 +57,7 @@ TensorShape DefaultInputShape();
|
||||
struct GpuOpsTestConfig {
|
||||
bool add_t = true;
|
||||
bool add_tout = false;
|
||||
// Only used for gpu_unary_ops_test.
|
||||
bool expect_buffer_reuse = true;
|
||||
bool expect_strictly_equal = false;
|
||||
GpuOpsTestConfig ExpectStrictlyEqual() {
|
||||
@ -119,33 +120,10 @@ absl::InlinedVector<T, 10> DefaultInputGreaterOrEqualToZero() {
|
||||
|
||||
/// Helper functions to get default input data.
|
||||
|
||||
template <typename T,
|
||||
std::enable_if_t<llvm::is_one_of<T, int8, int16, int32, int64>::value,
|
||||
bool> = true>
|
||||
T DefaultScalarInput() {
|
||||
return static_cast<T>(3);
|
||||
}
|
||||
|
||||
template <typename T, std::enable_if_t<
|
||||
llvm::is_one_of<T, Eigen::half, float, double>::value,
|
||||
bool> = true>
|
||||
T DefaultScalarInput() {
|
||||
return static_cast<T>(2.0);
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
std::enable_if_t<llvm::is_one_of<T, bool>::value, bool> = true>
|
||||
T DefaultScalarInput() {
|
||||
return static_cast<T>(true);
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
std::enable_if_t<llvm::is_one_of<T, int8, int16, int32, int64>::value,
|
||||
bool> = true>
|
||||
absl::InlinedVector<T, 10> DefaultInput(absl::string_view op_name) {
|
||||
if (op_name == "Abs") {
|
||||
return NearZeroAndExtremeInput<T>();
|
||||
}
|
||||
// Only generate values less than the bitwidth of the data type.
|
||||
if (op_name == "LeftShift" || op_name == "RightShift") {
|
||||
auto max_shift = sizeof(T) * 8 - 1;
|
||||
@ -153,6 +131,9 @@ absl::InlinedVector<T, 10> DefaultInput(absl::string_view op_name) {
|
||||
for (auto i = 0; i < max_shift; ++i) v.push_back(i);
|
||||
return v;
|
||||
}
|
||||
if (op_name == "Div") {
|
||||
return InputAsVector<T, int>({-18, -9, 9, 18});
|
||||
}
|
||||
return InputAsVector<T, int>({-18, -9, -1, 0, 0, 1, 1, 2, 3, 5, 7, 9, 9, 18});
|
||||
}
|
||||
|
||||
@ -160,16 +141,7 @@ template <typename T, std::enable_if_t<
|
||||
llvm::is_one_of<T, Eigen::half, float, double>::value,
|
||||
bool> = true>
|
||||
absl::InlinedVector<T, 10> DefaultInput(absl::string_view op_name) {
|
||||
if (op_name == "Abs") {
|
||||
return NearZeroAndExtremeInput<T>();
|
||||
}
|
||||
if (op_name == "Log" || op_name == "Rsqrt") {
|
||||
return DefaultInputGreaterThanZero<T>();
|
||||
}
|
||||
if (op_name == "Sqrt") {
|
||||
return DefaultInputGreaterOrEqualToZero<T>();
|
||||
}
|
||||
if (op_name == "FloorDiv") {
|
||||
if (op_name == "Div" || op_name == "FloorDiv") {
|
||||
return InputAsVector<T, double>({-18.0, -9.0, -1e-6, -0.1, 0.1, 1e-6, 0.1,
|
||||
0.2, 0.3, 0.5, 0.7, 0.9, 9.0, 18.0});
|
||||
}
|
||||
|
@ -125,10 +125,22 @@ class GpuUnaryOpTest : public OpsTestBase {
|
||||
// define your own test fixtures.
|
||||
|
||||
#define GENERATE_DEFAULT_TEST(op_name, InT, OutT, baseline_callback, config) \
|
||||
GENERATE_DEFAULT_TEST2(op_name, InT, InT, OutT, OutT, baseline_callback, \
|
||||
GENERATE_DEFAULT_TEST_2(op_name, InT, InT, OutT, OutT, baseline_callback, \
|
||||
config)
|
||||
|
||||
#define GENERATE_DEFAULT_TEST2(op_name, InT, BaselineT, OutT, BaselineOutT, \
|
||||
#define GENERATE_DEFAULT_TEST_2(op_name, InT, BaselineT, OutT, BaselineOutT, \
|
||||
baseline_callback, config) \
|
||||
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES_2( \
|
||||
op_name, InT, BaselineT, OutT, BaselineOutT, \
|
||||
test::DefaultInput<NativeT>(#op_name), baseline_callback, config)
|
||||
|
||||
#define GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES( \
|
||||
op_name, InT, OutT, input_values, baseline_callback, config) \
|
||||
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES_2( \
|
||||
op_name, InT, InT, OutT, OutT, input_values, baseline_callback, config)
|
||||
|
||||
#define GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES_2( \
|
||||
op_name, InT, BaselineT, OutT, BaselineOutT, input_values, \
|
||||
baseline_callback, config) \
|
||||
TEST_F(GpuUnaryOpTest, op_name##InT) { \
|
||||
using NativeT = EnumToDataType<InT>::Type; \
|
||||
@ -136,25 +148,31 @@ class GpuUnaryOpTest : public OpsTestBase {
|
||||
using NativeOutT = EnumToDataType<OutT>::Type; \
|
||||
using NativeBaselineOutT = EnumToDataType<BaselineOutT>::Type; \
|
||||
Test<NativeT, NativeBaselineT, NativeOutT, NativeBaselineOutT>( \
|
||||
#op_name, test::DefaultInputShape(), \
|
||||
test::DefaultInput<NativeT>(#op_name), baseline_callback, config); \
|
||||
#op_name, test::DefaultInputShape(), input_values, baseline_callback, \
|
||||
config); \
|
||||
}
|
||||
|
||||
/// Test `tf.Abs`.
|
||||
|
||||
GENERATE_DEFAULT_TEST(Abs, DT_FLOAT, DT_FLOAT, std::abs,
|
||||
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES(
|
||||
Abs, DT_FLOAT, DT_FLOAT, test::NearZeroAndExtremeInput<float>(), std::abs,
|
||||
test::GpuOpsTestConfig().ExpectStrictlyEqual())
|
||||
|
||||
GENERATE_DEFAULT_TEST(Abs, DT_DOUBLE, DT_DOUBLE, std::abs,
|
||||
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES(
|
||||
Abs, DT_DOUBLE, DT_DOUBLE, test::NearZeroAndExtremeInput<double>(),
|
||||
std::abs, test::GpuOpsTestConfig().ExpectStrictlyEqual())
|
||||
|
||||
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES_2(
|
||||
Abs, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT,
|
||||
test::NearZeroAndExtremeInput<Eigen::half>(), std::abs,
|
||||
test::GpuOpsTestConfig().ExpectStrictlyEqual())
|
||||
|
||||
GENERATE_DEFAULT_TEST2(Abs, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT, std::abs,
|
||||
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES(
|
||||
Abs, DT_INT32, DT_INT32, test::NearZeroAndExtremeInput<int32>(), std::abs,
|
||||
test::GpuOpsTestConfig().ExpectStrictlyEqual())
|
||||
|
||||
GENERATE_DEFAULT_TEST(Abs, DT_INT32, DT_INT32, std::abs,
|
||||
test::GpuOpsTestConfig().ExpectStrictlyEqual())
|
||||
|
||||
GENERATE_DEFAULT_TEST(Abs, DT_INT64, DT_INT64, std::abs,
|
||||
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES(
|
||||
Abs, DT_INT64, DT_INT64, test::NearZeroAndExtremeInput<int64>(), std::abs,
|
||||
test::GpuOpsTestConfig().ExpectStrictlyEqual())
|
||||
|
||||
/// Test `tf.Ceil`.
|
||||
@ -165,7 +183,7 @@ GENERATE_DEFAULT_TEST(Ceil, DT_FLOAT, DT_FLOAT, std::ceil,
|
||||
GENERATE_DEFAULT_TEST(Ceil, DT_DOUBLE, DT_DOUBLE, std::ceil,
|
||||
test::GpuOpsTestConfig().ExpectStrictlyEqual())
|
||||
|
||||
GENERATE_DEFAULT_TEST2(Ceil, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT, std::ceil,
|
||||
GENERATE_DEFAULT_TEST_2(Ceil, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT, std::ceil,
|
||||
test::GpuOpsTestConfig().ExpectStrictlyEqual())
|
||||
|
||||
/// Test `tf.Conj`.
|
||||
@ -189,7 +207,7 @@ GENERATE_DEFAULT_TEST(Cos, DT_FLOAT, DT_FLOAT, std::cos,
|
||||
GENERATE_DEFAULT_TEST(Cos, DT_DOUBLE, DT_DOUBLE, std::cos,
|
||||
test::GpuOpsTestConfig())
|
||||
|
||||
GENERATE_DEFAULT_TEST2(Cos, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT, std::cos,
|
||||
GENERATE_DEFAULT_TEST_2(Cos, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT, std::cos,
|
||||
test::GpuOpsTestConfig())
|
||||
|
||||
/// Test `tf.Exp`.
|
||||
@ -200,7 +218,7 @@ GENERATE_DEFAULT_TEST(Exp, DT_FLOAT, DT_FLOAT, std::exp,
|
||||
GENERATE_DEFAULT_TEST(Exp, DT_DOUBLE, DT_DOUBLE, std::exp,
|
||||
test::GpuOpsTestConfig())
|
||||
|
||||
GENERATE_DEFAULT_TEST2(Exp, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT, std::exp,
|
||||
GENERATE_DEFAULT_TEST_2(Exp, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT, std::exp,
|
||||
test::GpuOpsTestConfig())
|
||||
|
||||
/// Test `tf.Floor`.
|
||||
@ -211,7 +229,7 @@ GENERATE_DEFAULT_TEST(Floor, DT_FLOAT, DT_FLOAT, std::floor,
|
||||
GENERATE_DEFAULT_TEST(Floor, DT_DOUBLE, DT_DOUBLE, std::floor,
|
||||
test::GpuOpsTestConfig().ExpectStrictlyEqual())
|
||||
|
||||
GENERATE_DEFAULT_TEST2(Floor, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT, std::floor,
|
||||
GENERATE_DEFAULT_TEST_2(Floor, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT, std::floor,
|
||||
test::GpuOpsTestConfig().ExpectStrictlyEqual())
|
||||
|
||||
/// Test `tf.Imag`.
|
||||
@ -260,13 +278,17 @@ TEST_F(GpuUnaryOpTest, DISABLED_IsInfHalf) {
|
||||
|
||||
/// Test `tf.Log`.
|
||||
|
||||
GENERATE_DEFAULT_TEST(Log, DT_FLOAT, DT_FLOAT, std::log,
|
||||
test::GpuOpsTestConfig())
|
||||
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES(
|
||||
Log, DT_FLOAT, DT_FLOAT, test::DefaultInputGreaterThanZero<float>(),
|
||||
std::log, test::GpuOpsTestConfig())
|
||||
|
||||
GENERATE_DEFAULT_TEST(Log, DT_DOUBLE, DT_DOUBLE, std::log,
|
||||
test::GpuOpsTestConfig())
|
||||
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES(
|
||||
Log, DT_DOUBLE, DT_DOUBLE, test::DefaultInputGreaterThanZero<double>(),
|
||||
std::log, test::GpuOpsTestConfig())
|
||||
|
||||
GENERATE_DEFAULT_TEST2(Log, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT, std::log,
|
||||
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES_2(
|
||||
Log, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT,
|
||||
test::DefaultInputGreaterThanZero<Eigen::half>(), std::log,
|
||||
test::GpuOpsTestConfig())
|
||||
|
||||
/// Test `tf.LogicalNot`
|
||||
@ -290,7 +312,7 @@ GENERATE_DEFAULT_TEST(Neg, DT_FLOAT, DT_FLOAT, baseline_neg,
|
||||
GENERATE_DEFAULT_TEST(Neg, DT_DOUBLE, DT_DOUBLE, baseline_neg,
|
||||
test::GpuOpsTestConfig().ExpectStrictlyEqual())
|
||||
|
||||
GENERATE_DEFAULT_TEST2(Neg, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT, baseline_neg,
|
||||
GENERATE_DEFAULT_TEST_2(Neg, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT, baseline_neg,
|
||||
test::GpuOpsTestConfig())
|
||||
|
||||
GENERATE_DEFAULT_TEST(Neg, DT_INT8, DT_INT8, baseline_neg,
|
||||
@ -323,15 +345,19 @@ T baseline_rsqrt(T x) {
|
||||
return 1.0 / std::sqrt(x);
|
||||
}
|
||||
|
||||
GENERATE_DEFAULT_TEST(Rsqrt, DT_FLOAT, DT_FLOAT, baseline_rsqrt,
|
||||
test::GpuOpsTestConfig())
|
||||
|
||||
GENERATE_DEFAULT_TEST(Rsqrt, DT_DOUBLE, DT_DOUBLE, baseline_rsqrt,
|
||||
test::GpuOpsTestConfig())
|
||||
|
||||
GENERATE_DEFAULT_TEST2(Rsqrt, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT,
|
||||
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES(
|
||||
Rsqrt, DT_FLOAT, DT_FLOAT, test::DefaultInputGreaterThanZero<float>(),
|
||||
baseline_rsqrt, test::GpuOpsTestConfig())
|
||||
|
||||
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES(
|
||||
Rsqrt, DT_DOUBLE, DT_DOUBLE, test::DefaultInputGreaterThanZero<double>(),
|
||||
baseline_rsqrt, test::GpuOpsTestConfig())
|
||||
|
||||
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES_2(
|
||||
Rsqrt, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT,
|
||||
test::DefaultInputGreaterThanZero<Eigen::half>(), baseline_rsqrt,
|
||||
test::GpuOpsTestConfig())
|
||||
|
||||
/// Test `tf.Sign`.
|
||||
|
||||
// Reference implementation
|
||||
@ -350,7 +376,7 @@ GENERATE_DEFAULT_TEST(Sign, DT_DOUBLE, DT_DOUBLE, baseline_sign,
|
||||
|
||||
// TODO(b/162577610): We should actually use ExpectStrictlyEqual()
|
||||
// here. This requires returning 0.0 for input -0.0.
|
||||
GENERATE_DEFAULT_TEST2(Sign, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT,
|
||||
GENERATE_DEFAULT_TEST_2(Sign, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT,
|
||||
baseline_sign, test::GpuOpsTestConfig())
|
||||
|
||||
GENERATE_DEFAULT_TEST(Sign, DT_INT64, DT_INT64, baseline_sign,
|
||||
@ -364,18 +390,23 @@ GENERATE_DEFAULT_TEST(Sin, DT_FLOAT, DT_FLOAT, std::sin,
|
||||
GENERATE_DEFAULT_TEST(Sin, DT_DOUBLE, DT_DOUBLE, std::sin,
|
||||
test::GpuOpsTestConfig())
|
||||
|
||||
GENERATE_DEFAULT_TEST2(Sin, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT, std::sin,
|
||||
GENERATE_DEFAULT_TEST_2(Sin, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT, std::sin,
|
||||
test::GpuOpsTestConfig())
|
||||
|
||||
/// Test `tf.Sqrt`.
|
||||
|
||||
GENERATE_DEFAULT_TEST(Sqrt, DT_FLOAT, DT_FLOAT, std::sqrt,
|
||||
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES(
|
||||
Sqrt, DT_FLOAT, DT_FLOAT, test::DefaultInputGreaterOrEqualToZero<float>(),
|
||||
std::sqrt, test::GpuOpsTestConfig())
|
||||
|
||||
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES(
|
||||
Sqrt, DT_DOUBLE, DT_DOUBLE,
|
||||
test::DefaultInputGreaterOrEqualToZero<double>(), std::sqrt,
|
||||
test::GpuOpsTestConfig())
|
||||
|
||||
GENERATE_DEFAULT_TEST(Sqrt, DT_DOUBLE, DT_DOUBLE, std::sqrt,
|
||||
test::GpuOpsTestConfig())
|
||||
|
||||
GENERATE_DEFAULT_TEST2(Sqrt, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT, std::sqrt,
|
||||
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES_2(
|
||||
Sqrt, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT,
|
||||
test::DefaultInputGreaterOrEqualToZero<Eigen::half>(), std::sqrt,
|
||||
test::GpuOpsTestConfig())
|
||||
|
||||
/// Test `tf.Tanh`.
|
||||
@ -386,7 +417,7 @@ GENERATE_DEFAULT_TEST(Tanh, DT_FLOAT, DT_FLOAT, std::tanh,
|
||||
GENERATE_DEFAULT_TEST(Tanh, DT_DOUBLE, DT_DOUBLE, std::tanh,
|
||||
test::GpuOpsTestConfig())
|
||||
|
||||
GENERATE_DEFAULT_TEST2(Tanh, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT, std::tanh,
|
||||
GENERATE_DEFAULT_TEST_2(Tanh, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT, std::tanh,
|
||||
test::GpuOpsTestConfig())
|
||||
|
||||
} // namespace
|
||||
|
@ -0,0 +1,6 @@
|
||||
func @Div_elem_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>)
|
||||
-> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} {
|
||||
%0 = "tf.Div"(%arg0, %arg1) {T = elem_type, device = ""}
|
||||
: (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xelem_type>
|
||||
return %0 : tensor<*xelem_type>
|
||||
}
|
@ -0,0 +1,6 @@
|
||||
func @Sub_elem_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>)
|
||||
-> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} {
|
||||
%0 = "tf.Sub"(%arg0, %arg1) {T = elem_type, device = ""}
|
||||
: (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xelem_type>
|
||||
return %0 : tensor<*xelem_type>
|
||||
}
|
@ -137,4 +137,14 @@ REGISTER_KERNEL_BUILDER(
|
||||
Name("QuantizedMaxPool").Device(DEVICE_CPU).TypeConstraint<quint8>("T"),
|
||||
QuantizedMaxPoolingOp<CPUDevice, quint8>);
|
||||
|
||||
#ifdef INTEL_MKL
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("QuantizedAvgPool").Device(DEVICE_CPU).TypeConstraint<qint8>("T"),
|
||||
QuantizedAvgPoolingOp<CPUDevice, qint8>);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("QuantizedMaxPool").Device(DEVICE_CPU).TypeConstraint<qint8>("T"),
|
||||
QuantizedMaxPoolingOp<CPUDevice, qint8>);
|
||||
#endif
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -232,11 +232,7 @@ REGISTER_OP("_FusedBatchNormEx")
|
||||
.Output("reserve_space_1: U")
|
||||
.Output("reserve_space_2: U")
|
||||
.Output("reserve_space_3: U")
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
.Attr("T: {half, float, bfloat16}")
|
||||
#else
|
||||
.Attr("T: {half, float}")
|
||||
#endif
|
||||
.Attr("U: {float}")
|
||||
.Attr("epsilon: float = 0.0001")
|
||||
.Attr("exponential_avg_factor: float = 1.0")
|
||||
|
@ -981,6 +981,17 @@ cc_library(
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "enable_tf2_utils",
|
||||
srcs = ["enable_tf2_utils.cc"],
|
||||
hdrs = ["enable_tf2_utils.h"],
|
||||
copts = tf_copts(),
|
||||
deps = [
|
||||
"//tensorflow/core/util:env_var",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
alias(
|
||||
name = "profile_utils_cpu_utils",
|
||||
actual = "//tensorflow/core/platform/profile_utils:profile_utils_cpu_utils",
|
||||
@ -992,6 +1003,12 @@ filegroup(
|
||||
compatible_with = get_compatible_with_portable(),
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "enable_tf2_hdr",
|
||||
srcs = ["enable_tf2_utils.h"],
|
||||
compatible_with = get_compatible_with_portable(),
|
||||
)
|
||||
|
||||
tf_cc_tests(
|
||||
name = "low_level_library_tests",
|
||||
size = "small",
|
||||
@ -1047,6 +1064,20 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "enable_tf2_utils_test",
|
||||
size = "small",
|
||||
srcs = [
|
||||
"enable_tf2_utils_test.cc",
|
||||
],
|
||||
deps = [
|
||||
":enable_tf2_utils",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/util:env_var",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_tests(
|
||||
name = "stacktrace_handler_test",
|
||||
size = "small",
|
||||
|
@ -149,9 +149,7 @@ class PosixEnv : public Env {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
#if defined(__ANDROID__) || defined(__EMSCRIPTEN__)
|
||||
return false;
|
||||
#else
|
||||
#if defined(__GLIBC__) || defined(__FreeBSD__)
|
||||
char buf[100];
|
||||
#ifdef __FreeBSD__
|
||||
int res = 0;
|
||||
@ -164,6 +162,8 @@ class PosixEnv : public Env {
|
||||
}
|
||||
*name = buf;
|
||||
return true;
|
||||
#else
|
||||
return false;
|
||||
#endif
|
||||
}
|
||||
|
||||
|
49
tensorflow/core/platform/enable_tf2_utils.cc
Normal file
49
tensorflow/core/platform/enable_tf2_utils.cc
Normal file
@ -0,0 +1,49 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/platform/enable_tf2_utils.h"
|
||||
|
||||
#include <atomic>
|
||||
|
||||
#include "tensorflow/core/util/env_var.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
enum Enablement : uint8 { kFalse = 0, kTrue = 1, undefined = 2 };
|
||||
|
||||
// If this flag is set, we will use it as a signal to decide on whether to
|
||||
// use the MLIR based TF-XLA bridge.
|
||||
static std::atomic<Enablement> tf2_enabled{undefined};
|
||||
|
||||
// Determine whether or not the user has explicitly asked for tf2 execution.
|
||||
// Will be used to determine whether to use the MLIR based bridge.
|
||||
void set_tf2_execution(bool enabled) {
|
||||
tf2_enabled = (enabled) ? Enablement::kTrue : Enablement::kFalse;
|
||||
}
|
||||
|
||||
bool tf2_execution_enabled() {
|
||||
if (tf2_enabled == Enablement::undefined) {
|
||||
static bool tf2_behavior_env_enabled = [] {
|
||||
string tf2_env;
|
||||
TF_CHECK_OK(ReadStringFromEnvVar("TF2_BEHAVIOR", "0", &tf2_env));
|
||||
return tf2_env != "0";
|
||||
}();
|
||||
tf2_enabled =
|
||||
(tf2_behavior_env_enabled) ? Enablement::kTrue : Enablement::kFalse;
|
||||
}
|
||||
return tf2_enabled;
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
31
tensorflow/core/platform/enable_tf2_utils.h
Normal file
31
tensorflow/core/platform/enable_tf2_utils.h
Normal file
@ -0,0 +1,31 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TF_CORE_PLATFORM_TF2_UTILS_H_
|
||||
#define TF_CORE_PLATFORM_TF2_UTILS_H_
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Sets the tf2 execution state. This can be used to indicate whether the user
|
||||
// has explicitly asked for tf2 execution.
|
||||
void set_tf2_execution(bool enabled);
|
||||
|
||||
// Returns true or false depending on whether the user flag for tf2 execution
|
||||
// has been set. The default is false.
|
||||
bool tf2_execution_enabled();
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TF_CORE_PLATFORM_TF2_UTILS_H_
|
35
tensorflow/core/platform/enable_tf2_utils_test.cc
Normal file
35
tensorflow/core/platform/enable_tf2_utils_test.cc
Normal file
@ -0,0 +1,35 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
// Testing TF2 enablement.
|
||||
|
||||
#include "tensorflow/core/platform/enable_tf2_utils.h"
|
||||
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/util/env_var.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
TEST(TF2EnabledTest, enabled_behavior) {
|
||||
string tf2_env;
|
||||
TF_CHECK_OK(ReadStringFromEnvVar("TF2_BEHAVIOR", "0", &tf2_env));
|
||||
bool expected = (tf2_env != "0");
|
||||
EXPECT_EQ(tensorflow::tf2_execution_enabled(), expected);
|
||||
tensorflow::set_tf2_execution(true);
|
||||
EXPECT_TRUE(tensorflow::tf2_execution_enabled());
|
||||
tensorflow::set_tf2_execution(false);
|
||||
EXPECT_FALSE(tensorflow::tf2_execution_enabled());
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -108,7 +108,7 @@ limitations under the License.
|
||||
|
||||
#define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0
|
||||
#define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0
|
||||
#define TF_GRAPH_DEF_VERSION 634 // Updated: 2021/1/2
|
||||
#define TF_GRAPH_DEF_VERSION 637 // Updated: 2021/1/5
|
||||
|
||||
// Checkpoint compatibility versions (the versions field in SavedSliceMeta).
|
||||
//
|
||||
|
@ -386,6 +386,10 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type,
|
||||
return ParseSoftmax(op, error_reporter, allocator, builtin_data);
|
||||
}
|
||||
|
||||
case BuiltinOperator_SPACE_TO_DEPTH: {
|
||||
return ParseSpaceToDepth(op, error_reporter, allocator, builtin_data);
|
||||
}
|
||||
|
||||
case BuiltinOperator_SPLIT: {
|
||||
return ParseSplit(op, error_reporter, allocator, builtin_data);
|
||||
}
|
||||
@ -596,16 +600,6 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type,
|
||||
*builtin_data = params.release();
|
||||
return kTfLiteOk;
|
||||
}
|
||||
case BuiltinOperator_SPACE_TO_DEPTH: {
|
||||
auto params = safe_allocator.Allocate<TfLiteSpaceToDepthParams>();
|
||||
TF_LITE_ENSURE(error_reporter, params != nullptr);
|
||||
if (const auto* schema_params =
|
||||
op->builtin_options_as_SpaceToDepthOptions()) {
|
||||
params->block_size = schema_params->block_size();
|
||||
}
|
||||
*builtin_data = params.release();
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
case BuiltinOperator_GATHER: {
|
||||
auto params = safe_allocator.Allocate<TfLiteGatherParams>();
|
||||
@ -1684,6 +1678,31 @@ TfLiteStatus ParseSoftmax(const Operator* op, ErrorReporter* error_reporter,
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus ParseSpaceToDepth(const Operator* op,
|
||||
ErrorReporter* error_reporter,
|
||||
BuiltinDataAllocator* allocator,
|
||||
void** builtin_data) {
|
||||
CheckParsePointerParams(op, error_reporter, allocator, builtin_data);
|
||||
|
||||
SafeBuiltinDataAllocator safe_allocator(allocator);
|
||||
std::unique_ptr<TfLiteSpaceToDepthParams,
|
||||
SafeBuiltinDataAllocator::BuiltinDataDeleter>
|
||||
params = safe_allocator.Allocate<TfLiteSpaceToDepthParams>();
|
||||
TF_LITE_ENSURE(error_reporter, params != nullptr);
|
||||
|
||||
const auto* schema_params = op->builtin_options_as_SpaceToDepthOptions();
|
||||
if (schema_params != nullptr) {
|
||||
params->block_size = schema_params->block_size();
|
||||
} else {
|
||||
// TODO(b/157480169): We should either return kTfLiteError or fill in some
|
||||
// reasonable defaults in the params struct. We are not doing so until we
|
||||
// better undertand the ramifications of changing the legacy behavior.
|
||||
}
|
||||
|
||||
*builtin_data = params.release();
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus ParseSplit(const Operator* op, ErrorReporter* error_reporter,
|
||||
BuiltinDataAllocator* allocator, void** builtin_data) {
|
||||
CheckParsePointerParams(op, error_reporter, allocator, builtin_data);
|
||||
|
@ -254,6 +254,11 @@ TfLiteStatus ParseSin(const Operator* op, ErrorReporter* error_reporter,
|
||||
TfLiteStatus ParseSoftmax(const Operator* op, ErrorReporter* error_reporter,
|
||||
BuiltinDataAllocator* allocator, void** builtin_data);
|
||||
|
||||
TfLiteStatus ParseSpaceToDepth(const Operator* op,
|
||||
ErrorReporter* error_reporter,
|
||||
BuiltinDataAllocator* allocator,
|
||||
void** builtin_data);
|
||||
|
||||
TfLiteStatus ParseSplit(const Operator* op, ErrorReporter* error_reporter,
|
||||
BuiltinDataAllocator* allocator, void** builtin_data);
|
||||
|
||||
|
@ -111,11 +111,15 @@ class Subgraph {
|
||||
inline TfLiteStatus SetTensorParametersReadWrite(
|
||||
int tensor_index, TfLiteType type, const char* name,
|
||||
const std::vector<int>& dims, TfLiteQuantization quantization,
|
||||
bool is_variable = false, const size_t rank_dims_signature = 0,
|
||||
const int* dims_signature = nullptr) {
|
||||
bool is_variable = false, const std::vector<int>& dims_signature = {}) {
|
||||
if (dims_signature.empty()) {
|
||||
return SetTensorParametersReadWrite(tensor_index, type, name, dims.size(),
|
||||
dims.data(), quantization, is_variable,
|
||||
rank_dims_signature, dims_signature);
|
||||
dims.data(), quantization,
|
||||
is_variable);
|
||||
}
|
||||
return SetTensorParametersReadWrite(
|
||||
tensor_index, type, name, dims.size(), dims.data(), quantization,
|
||||
is_variable, dims_signature.size(), dims_signature.data());
|
||||
}
|
||||
TfLiteStatus SetTensorParametersReadWrite(
|
||||
int tensor_index, TfLiteType type, const char* name, const size_t rank,
|
||||
|
@ -45,6 +45,56 @@ int3 GetWorkGroupsCount(int grid_dimension, const int3& grid_size,
|
||||
}
|
||||
return work_groups_count;
|
||||
}
|
||||
|
||||
std::string GetCommonOpenCLDefines(CalculationsPrecision precision) {
|
||||
std::string result;
|
||||
|
||||
result += "#define GLOBAL_ID_0 get_global_id(0)\n";
|
||||
result += "#define GLOBAL_ID_1 get_global_id(1)\n";
|
||||
result += "#define GLOBAL_ID_2 get_global_id(2)\n";
|
||||
result += "#define MAIN_FUNCTION __kernel void main_function\n";
|
||||
switch (precision) {
|
||||
case CalculationsPrecision::F32:
|
||||
result += "#pragma OPENCL EXTENSION cl_khr_3d_image_writes : enable\n";
|
||||
result += "#define ACCUM_FLT4 float4\n";
|
||||
result += "#define FLT float\n";
|
||||
result += "#define FLT2 float2\n";
|
||||
result += "#define FLT3 float3\n";
|
||||
result += "#define FLT4 float4\n";
|
||||
result += "#define TO_FLT4 convert_float4\n";
|
||||
result += "#define TO_ACCUM_TYPE convert_float4\n";
|
||||
result += "#define TO_ACCUM_FLT convert_float\n";
|
||||
result += "#define INIT_FLT4(value) (float4)(value)\n";
|
||||
break;
|
||||
case CalculationsPrecision::F16:
|
||||
result += "#pragma OPENCL EXTENSION cl_khr_3d_image_writes : enable\n";
|
||||
result += "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n";
|
||||
result += "#define ACCUM_FLT4 half4\n";
|
||||
result += "#define FLT half\n";
|
||||
result += "#define FLT2 half2\n";
|
||||
result += "#define FLT3 half3\n";
|
||||
result += "#define FLT4 half4\n";
|
||||
result += "#define TO_FLT4 convert_half4\n";
|
||||
result += "#define TO_ACCUM_TYPE convert_half4\n";
|
||||
result += "#define TO_ACCUM_FLT convert_half\n";
|
||||
result += "#define INIT_FLT4(value) (half4)(value)\n";
|
||||
break;
|
||||
case CalculationsPrecision::F32_F16:
|
||||
result += "#pragma OPENCL EXTENSION cl_khr_3d_image_writes : enable\n";
|
||||
result += "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n";
|
||||
result += "#define ACCUM_FLT4 float4\n";
|
||||
result += "#define FLT half\n";
|
||||
result += "#define FLT2 half2\n";
|
||||
result += "#define FLT3 half3\n";
|
||||
result += "#define FLT4 half4\n";
|
||||
result += "#define TO_FLT4 convert_half4\n";
|
||||
result += "#define TO_ACCUM_TYPE convert_float4\n";
|
||||
result += "#define TO_ACCUM_FLT convert_float\n";
|
||||
result += "#define INIT_FLT4(value) (half4)(value)\n";
|
||||
break;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
ClOperation::ClOperation(ClOperation&& operation)
|
||||
@ -94,6 +144,9 @@ absl::Status ClOperation::UpdateParams() {
|
||||
|
||||
absl::Status ClOperation::Compile(const CreationContext& creation_context) {
|
||||
operation_->AssembleCode(creation_context.GetGpuInfo());
|
||||
operation_->code_ =
|
||||
GetCommonOpenCLDefines(operation_->definition_.precision) +
|
||||
operation_->code_;
|
||||
RETURN_IF_ERROR(cl_args_.Init(
|
||||
creation_context.GetGpuInfo(),
|
||||
{{operation_->dst_tensors_names_[0], operation_->elementwise_code_}},
|
||||
|
@ -336,6 +336,8 @@ absl::Status Tensor::GetGPUResources(const GPUObjectDescriptor* obj_ptr,
|
||||
if (!tensor_desc) {
|
||||
return absl::InvalidArgumentError("Expected TensorDescriptor on input.");
|
||||
}
|
||||
resources->ints.push_back(
|
||||
{"slice_stride", tensor_desc->GetSliceStrideSize(shape_)});
|
||||
if (descriptor_.HasAxis(Axis::WIDTH)) {
|
||||
resources->ints.push_back({"width", Width()});
|
||||
resources->ints.push_back({"width_div2", Width() / 2});
|
||||
|
@ -22,61 +22,18 @@ limitations under the License.
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
namespace {
|
||||
std::string GetCommonDefines(CalculationsPrecision precision) {
|
||||
std::string result;
|
||||
|
||||
switch (precision) {
|
||||
case CalculationsPrecision::F32:
|
||||
result += "#pragma OPENCL EXTENSION cl_khr_3d_image_writes : enable\n";
|
||||
result += "#define ACCUM_FLT4 float4\n";
|
||||
result += "#define FLT float\n";
|
||||
result += "#define FLT2 float2\n";
|
||||
result += "#define FLT3 float3\n";
|
||||
result += "#define FLT4 float4\n";
|
||||
result += "#define TO_FLT4 convert_float4\n";
|
||||
result += "#define TO_ACCUM_TYPE convert_float4\n";
|
||||
result += "#define TO_ACCUM_FLT convert_float\n";
|
||||
break;
|
||||
case CalculationsPrecision::F16:
|
||||
result += "#pragma OPENCL EXTENSION cl_khr_3d_image_writes : enable\n";
|
||||
result += "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n";
|
||||
result += "#define ACCUM_FLT4 half4\n";
|
||||
result += "#define FLT half\n";
|
||||
result += "#define FLT2 half2\n";
|
||||
result += "#define FLT3 half3\n";
|
||||
result += "#define FLT4 half4\n";
|
||||
result += "#define TO_FLT4 convert_half4\n";
|
||||
result += "#define TO_ACCUM_TYPE convert_half4\n";
|
||||
result += "#define TO_ACCUM_FLT convert_half\n";
|
||||
break;
|
||||
case CalculationsPrecision::F32_F16:
|
||||
result += "#pragma OPENCL EXTENSION cl_khr_3d_image_writes : enable\n";
|
||||
result += "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n";
|
||||
result += "#define ACCUM_FLT4 float4\n";
|
||||
result += "#define FLT half\n";
|
||||
result += "#define FLT2 half2\n";
|
||||
result += "#define FLT3 half3\n";
|
||||
result += "#define FLT4 half4\n";
|
||||
result += "#define TO_FLT4 convert_half4\n";
|
||||
result += "#define TO_ACCUM_TYPE convert_float4\n";
|
||||
result += "#define TO_ACCUM_FLT convert_float\n";
|
||||
break;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
std::string GetElementWiseCode(const OperationDef& op_def,
|
||||
bool check_src_slices) {
|
||||
std::string c;
|
||||
c += "__kernel void main_function(\n";
|
||||
c += "MAIN_FUNCTION(\n";
|
||||
c += "$0) {\n";
|
||||
c += " int X = get_global_id(0);\n";
|
||||
c += " int Y = get_global_id(1);\n";
|
||||
c += " int Z = get_global_id(2);\n";
|
||||
c += " int X = GLOBAL_ID_0;\n";
|
||||
c += " int Y = GLOBAL_ID_1;\n";
|
||||
c += " int Z = GLOBAL_ID_2;\n";
|
||||
c += " if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() || "
|
||||
"Z >= args.dst_tensor.Slices()) return; \n";
|
||||
if (check_src_slices) {
|
||||
c += " FLT4 src = (FLT4)(0.0f);\n";
|
||||
c += " FLT4 src = INIT_FLT4(0.0f);\n";
|
||||
c += " if (Z < args.src_tensor.Slices()) {\n";
|
||||
c += " src = args.src_tensor.Read(X, Y, Z);\n";
|
||||
c += " }\n";
|
||||
@ -240,7 +197,6 @@ void GPUOperation::AssembleCode(const GpuInfo& gpu_info) {
|
||||
elementwise_code_ = "{\n" + code_ + "\n}\n" + elementwise_code_;
|
||||
code_ = GetElementWiseCode(definition_, check_src_channels_size_);
|
||||
}
|
||||
code_ = GetCommonDefines(definition_.precision) + code_;
|
||||
}
|
||||
|
||||
void GPUOperation::GetPossibleKernelWorkGroups(
|
||||
|
@ -94,6 +94,7 @@ TensorDescriptor& TensorDescriptor::operator=(TensorDescriptor&& desc) {
|
||||
|
||||
GPUResources TensorDescriptor::GetGPUResources() const {
|
||||
GPUResources resources;
|
||||
resources.ints.push_back("slice_stride");
|
||||
if (HasAxis(Axis::WIDTH)) {
|
||||
resources.ints.push_back("width");
|
||||
resources.ints.push_back("width_div2");
|
||||
@ -175,7 +176,7 @@ absl::Status TensorDescriptor::PerformSelector(
|
||||
*result = "slices";
|
||||
return absl::OkStatus();
|
||||
} else if (selector == "SliceStride") {
|
||||
*result = GetSliceStride();
|
||||
*result = "slice_stride";
|
||||
return absl::OkStatus();
|
||||
} else if (selector == "Channels") {
|
||||
*result = "channels";
|
||||
@ -402,7 +403,7 @@ absl::Status TensorDescriptor::PerformGetPtrWithSliceOffsetSelector(
|
||||
"GetPtrWithSliceOffset require one argument(slice coordinate), but ",
|
||||
args.size(), " was passed"));
|
||||
}
|
||||
*result = absl::StrCat("buffer + ", args[0], " * ", GetSliceStride());
|
||||
*result = absl::StrCat("buffer + ", args[0], " * slice_stride");
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
@ -644,6 +645,35 @@ bool TensorDescriptor::HasAxis(Axis axis) const {
|
||||
return false;
|
||||
}
|
||||
|
||||
int TensorDescriptor::GetWidthSize(BHWDC shape) const {
|
||||
int width = shape.w;
|
||||
auto it1 = state_vars_.find("ElementsX2");
|
||||
if (it1 != state_vars_.end() && it1->second == "true") {
|
||||
width /= 2;
|
||||
}
|
||||
auto it2 = state_vars_.find("ElementsX4");
|
||||
if (it2 != state_vars_.end() && it2->second == "true") {
|
||||
width /= 4;
|
||||
}
|
||||
auto it = state_vars_.find("BatchedWidth");
|
||||
if (it != state_vars_.end() && it->second == "true") {
|
||||
width *= shape.b;
|
||||
}
|
||||
return width;
|
||||
}
|
||||
|
||||
int TensorDescriptor::GetSliceStrideSize(BHWDC shape) const {
|
||||
if (IsBatchedWidth()) {
|
||||
return GetWidthSize(shape) * shape.h;
|
||||
} else {
|
||||
if (HasAxis(Axis::BATCH)) {
|
||||
return GetWidthSize(shape) * shape.h * shape.b;
|
||||
} else {
|
||||
return GetWidthSize(shape) * shape.h;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void TensorDescriptor::SetAddressMode(AddressMode mode) {
|
||||
if (mode == AddressMode::kZero) {
|
||||
state_vars_["TextureMode"] = "ZERO";
|
||||
@ -719,18 +749,6 @@ std::string TensorDescriptor::GetWidth() const {
|
||||
}
|
||||
}
|
||||
|
||||
std::string TensorDescriptor::GetSliceStride() const {
|
||||
if (IsBatchedWidth()) {
|
||||
return GetWidth() + " * height";
|
||||
} else {
|
||||
if (HasAxis(Axis::BATCH)) {
|
||||
return GetWidth() + " * height * batch";
|
||||
} else {
|
||||
return GetWidth() + " * height";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
AddressMode TensorDescriptor::AddressModeFromState() const {
|
||||
auto it = state_vars_.find("TextureMode");
|
||||
if (it != state_vars_.end()) {
|
||||
|
@ -70,6 +70,8 @@ struct TensorDescriptor : public GPUObjectDescriptor {
|
||||
|
||||
bool HasAxis(Axis axis) const;
|
||||
void SetAddressMode(AddressMode mode);
|
||||
int GetWidthSize(BHWDC shape) const;
|
||||
int GetSliceStrideSize(BHWDC shape) const;
|
||||
|
||||
absl::Status GetLinkingContextFromWriteSelector(
|
||||
const std::vector<std::string>& args, std::string* value_name,
|
||||
@ -136,7 +138,6 @@ struct TensorDescriptor : public GPUObjectDescriptor {
|
||||
bool IsBatchedWidth() const;
|
||||
|
||||
std::string GetWidth() const;
|
||||
std::string GetSliceStride() const;
|
||||
|
||||
AddressMode AddressModeFromState() const;
|
||||
|
||||
|
@ -677,6 +677,7 @@ cc_library(
|
||||
"//tensorflow/lite/delegates/gpu/common:shape",
|
||||
"//tensorflow/lite/delegates/gpu/common:types",
|
||||
"//tensorflow/lite/delegates/gpu/common:util",
|
||||
"//tensorflow/lite/delegates/gpu/common/task:util",
|
||||
"//tensorflow/lite/delegates/gpu/metal:compute_task_descriptor",
|
||||
],
|
||||
)
|
||||
|
@ -32,57 +32,54 @@ namespace gpu {
|
||||
namespace metal {
|
||||
|
||||
std::string GetResizeBilinearCode(const Resize2DAttributes& attr) {
|
||||
std::string code = R"(
|
||||
std::string c = R"(
|
||||
#include <metal_stdlib>
|
||||
using namespace metal;
|
||||
$0
|
||||
kernel void ComputeFunction(
|
||||
$1
|
||||
uint3 gid[[thread_position_in_grid]]) {
|
||||
if (int(gid.x) >= size.z || int(gid.y) >= size.w) {
|
||||
if (int(gid.x) >= args.dst_tensor.Width() || int(gid.y) >= args.dst_tensor.Height()) {
|
||||
return;
|
||||
})";
|
||||
if (attr.half_pixel_centers) {
|
||||
code += "const float2 tex_coord = (float2(gid.xy) + 0.5f) * scale - 0.5f;";
|
||||
} else {
|
||||
code += "const float2 tex_coord = float2(gid.xy) * scale;";
|
||||
}
|
||||
code += R"(
|
||||
const float2 tex_coord_floor = floor(tex_coord);
|
||||
const int2 itex_coord_floor = int2(tex_coord_floor);
|
||||
const int2 borders = size.xy - int2(1, 1);
|
||||
)";
|
||||
if (attr.half_pixel_centers) {
|
||||
c += " float2 tex_coord = (float2(gid.xy) + 0.5f) * scale - 0.5f;";
|
||||
} else {
|
||||
c += " float2 tex_coord = float2(gid.xy) * scale;";
|
||||
}
|
||||
c += R"(
|
||||
float2 tex_coord_floor = floor(tex_coord);
|
||||
int2 itex_coord_floor = int2(tex_coord_floor);
|
||||
int2 borders = int2(args.src_tensor.Width() - 1, args.src_tensor.Height() - 1);
|
||||
int4 st;
|
||||
st.xy = max(itex_coord_floor, int2(0, 0));
|
||||
st.zw = min(itex_coord_floor + int2(1, 1), borders);
|
||||
const float2 t = tex_coord - tex_coord_floor; // interpolating factors
|
||||
const int src_index0 = (gid.z * size.y + st.y) * size.x + st.x;
|
||||
const int src_index1 = (gid.z * size.y + st.y) * size.x + st.z;
|
||||
const int src_index2 = (gid.z * size.y + st.w) * size.x + st.x;
|
||||
const int src_index3 = (gid.z * size.y + st.w) * size.x + st.z;
|
||||
FLT4 tex11 = src_tensor[src_index0];
|
||||
FLT4 tex21 = src_tensor[src_index1];
|
||||
FLT4 tex12 = src_tensor[src_index2];
|
||||
FLT4 tex22 = src_tensor[src_index3];
|
||||
float2 t = tex_coord - tex_coord_floor; // interpolating factors
|
||||
FLT4 tex11 = args.src_tensor.Read(st.x, st.y, gid.z);
|
||||
FLT4 tex21 = args.src_tensor.Read(st.z, st.y, gid.z);
|
||||
FLT4 tex12 = args.src_tensor.Read(st.x, st.w, gid.z);
|
||||
FLT4 tex22 = args.src_tensor.Read(st.z, st.w, gid.z);
|
||||
// bilinear interpolation
|
||||
FLT4 value = mix(mix(tex11, tex21, static_cast<FLT>(t.x)),
|
||||
mix(tex12, tex22, static_cast<FLT>(t.x)), static_cast<FLT>(t.y));
|
||||
const int linear_index = (gid.z * size.w + gid.y) * size.z + gid.x;
|
||||
args.dst_tensor.GetAddress(linear_index, gid.x, gid.y, gid.z);
|
||||
$2
|
||||
dst_tensor[linear_index] = value;
|
||||
args.dst_tensor.Write(value, gid.x, gid.y, gid.z);
|
||||
}
|
||||
)";
|
||||
return code;
|
||||
return c;
|
||||
}
|
||||
|
||||
std::string GetResizeNearestCode(const Resize2DAttributes& attr) {
|
||||
std::string code = R"(
|
||||
std::string c = R"(
|
||||
#include <metal_stdlib>
|
||||
using namespace metal;
|
||||
$0
|
||||
kernel void ComputeFunction(
|
||||
$1
|
||||
uint3 gid[[thread_position_in_grid]]) {
|
||||
if (int(gid.x) >= size.z || int(gid.y) >= size.w) {
|
||||
if (int(gid.x) >= args.dst_tensor.Width() || int(gid.y) >= args.dst_tensor.Height()) {
|
||||
return;
|
||||
}
|
||||
)";
|
||||
@ -99,27 +96,27 @@ std::string GetResizeNearestCode(const Resize2DAttributes& attr) {
|
||||
fxc += " + 0.5f";
|
||||
fyc += " + 0.5f";
|
||||
}
|
||||
code += " int2 coord;\n";
|
||||
code += " coord.x = static_cast<int>(" + fxc + ");\n";
|
||||
code += " coord.y = static_cast<int>(" + fyc + ");\n";
|
||||
code += " coord.x = max(0, coord.x);\n";
|
||||
code += " coord.y = max(0, coord.y);\n";
|
||||
code += " coord.x = min(coord.x, size.x - 1);\n";
|
||||
code += " coord.y = min(coord.y, size.y - 1);\n";
|
||||
code += R"(
|
||||
const int src_index = (gid.z * size.y + coord.y) * size.x + coord.x;
|
||||
FLT4 value = src_tensor[src_index];
|
||||
const int linear_index = (gid.z * size.w + gid.y) * size.z + gid.x;
|
||||
c += " int2 coord;\n";
|
||||
c += " coord.x = static_cast<int>(" + fxc + ");\n";
|
||||
c += " coord.y = static_cast<int>(" + fyc + ");\n";
|
||||
c += " coord.x = max(0, coord.x);\n";
|
||||
c += " coord.y = max(0, coord.y);\n";
|
||||
c += " coord.x = min(coord.x, args.src_tensor.Width() - 1);\n";
|
||||
c += " coord.y = min(coord.y, args.src_tensor.Height() - 1);\n";
|
||||
c += R"(
|
||||
FLT4 value = args.src_tensor.Read(coord.x, coord.y, gid.z);
|
||||
args.dst_tensor.GetAddress(linear_index, gid.x, gid.y, gid.z);
|
||||
$2
|
||||
dst_tensor[linear_index] = value;
|
||||
args.dst_tensor.Write(value, gid.x, gid.y, gid.z);
|
||||
}
|
||||
)";
|
||||
return code;
|
||||
return c;
|
||||
}
|
||||
|
||||
ComputeTaskDescriptor Resize(const OperationDef& definition,
|
||||
const Resize2DAttributes& attr) {
|
||||
ComputeTaskDescriptor desc(definition);
|
||||
desc.tensors_as_args = true;
|
||||
switch (attr.type) {
|
||||
case SamplingType::BILINEAR:
|
||||
desc.shader_source = GetResizeBilinearCode(attr);
|
||||
@ -136,17 +133,6 @@ ComputeTaskDescriptor Resize(const OperationDef& definition,
|
||||
desc.AddDstTensor("dst_tensor", definition.dst_tensors[0]);
|
||||
|
||||
desc.uniform_buffers = {
|
||||
{"constant int4& size",
|
||||
[](const std::vector<BHWC>& src_shapes,
|
||||
const std::vector<BHWC>& dst_shapes) {
|
||||
std::vector<int> sizes = {
|
||||
src_shapes[0].w,
|
||||
src_shapes[0].h,
|
||||
dst_shapes[0].w,
|
||||
dst_shapes[0].h,
|
||||
};
|
||||
return GetByteBuffer(sizes);
|
||||
}},
|
||||
{"constant float2& scale",
|
||||
[attr](const std::vector<BHWC>& src_shapes,
|
||||
const std::vector<BHWC>& dst_shapes) {
|
||||
@ -160,10 +146,10 @@ ComputeTaskDescriptor Resize(const OperationDef& definition,
|
||||
|
||||
desc.resize_function = [](const std::vector<BHWC>& src_shapes,
|
||||
const std::vector<BHWC>& dst_shapes) {
|
||||
const uint3 groups_size{16, 16, 1};
|
||||
const uint3 groups_size{8, 8, 1};
|
||||
const int dst_layers = DivideRoundUp(dst_shapes[0].c, 4);
|
||||
int groups_x = DivideRoundUp(dst_shapes[0].w, groups_size.x);
|
||||
int groups_y = DivideRoundUp(dst_shapes[0].h, groups_size.y);
|
||||
const int dst_layers = DivideRoundUp(dst_shapes[0].c, 4);
|
||||
int groups_z = DivideRoundUp(dst_layers, groups_size.z);
|
||||
return std::make_pair(groups_size, uint3{groups_x, groups_y, groups_z});
|
||||
};
|
||||
|
@ -35,123 +35,146 @@ namespace gpu {
|
||||
namespace metal {
|
||||
namespace {
|
||||
|
||||
std::string GetSliceCode(const SliceAttributes& attr) {
|
||||
std::stringstream code;
|
||||
namespace {
|
||||
bool Is4Aligned(const SliceAttributes& attr) {
|
||||
return attr.strides.c == 1 && attr.starts.c % 4 == 0;
|
||||
}
|
||||
|
||||
code << R"(
|
||||
int4 GetOffset(const SliceAttributes& attr, int src_width, int src_height,
|
||||
int src_channels, int src_batch) {
|
||||
int4 offset;
|
||||
if (attr.strides.w > 0) {
|
||||
offset.x = attr.starts.w;
|
||||
} else {
|
||||
if (attr.ends.w > 0) {
|
||||
offset.x = attr.ends.w;
|
||||
} else {
|
||||
offset.x = src_width + attr.ends.w;
|
||||
}
|
||||
}
|
||||
if (attr.strides.h > 0) {
|
||||
offset.y = attr.starts.h;
|
||||
} else {
|
||||
if (attr.ends.h > 0) {
|
||||
offset.y = attr.ends.h;
|
||||
} else {
|
||||
offset.y = src_height + attr.ends.h;
|
||||
}
|
||||
}
|
||||
if (attr.strides.c > 0) {
|
||||
offset.z = attr.starts.c;
|
||||
} else {
|
||||
if (attr.ends.c > 0) {
|
||||
offset.z = attr.ends.c;
|
||||
} else {
|
||||
offset.z = src_channels + attr.ends.c;
|
||||
}
|
||||
}
|
||||
if (Is4Aligned(attr)) {
|
||||
offset.z /= 4;
|
||||
}
|
||||
if (attr.strides.b > 0) {
|
||||
offset.w = attr.starts.b;
|
||||
} else {
|
||||
if (attr.ends.b > 0) {
|
||||
offset.w = attr.ends.b;
|
||||
} else {
|
||||
offset.w = src_batch + attr.ends.b;
|
||||
}
|
||||
}
|
||||
return offset;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
std::string GetSliceCode(const OperationDef& op_def, bool alignedx4) {
|
||||
const std::string batch_id =
|
||||
op_def.dst_tensors[0].HasAxis(Axis::BATCH) ? "B" : "0";
|
||||
std::string c = R"(
|
||||
#include <metal_stdlib>
|
||||
using namespace metal;
|
||||
|
||||
struct uniforms {
|
||||
int4 src_size;
|
||||
int4 dst_size;
|
||||
int4 offset;
|
||||
int4 stride;
|
||||
};
|
||||
|
||||
constant int4 width = int4($0, $1, $2, 0);
|
||||
constant int4 height = int4($3, $4, $5, 0);
|
||||
constant int4 channels = int4($6, $7, $8, 0);
|
||||
constant FLT4 null_vec = FLT4(0.0f, 0.0f, 0.0f, 0.0f);
|
||||
|
||||
$$0
|
||||
kernel void ComputeFunction(
|
||||
$$1
|
||||
$0
|
||||
kernel void ComputeFunction($1
|
||||
uint3 gid[[thread_position_in_grid]]) {
|
||||
if (static_cast<int>(gid.x) >= params.dst_size.x ||
|
||||
static_cast<int>(gid.y) >= params.dst_size.y) {
|
||||
return;
|
||||
}
|
||||
|
||||
FLT4 value;
|
||||
short2 offset;
|
||||
)";
|
||||
if (attr.strides.w > 0) {
|
||||
code << " offset.x = width.x;" << std::endl;
|
||||
if (op_def.dst_tensors[0].HasAxis(Axis::BATCH)) {
|
||||
c += " int linear_id = static_cast<int>(gid.x);\n";
|
||||
c += " int X = linear_id / args.dst_tensor.Batch();\n";
|
||||
c += " int B = linear_id % args.dst_tensor.Batch();\n";
|
||||
c += " args.dst_tensor.SetBatchRef(B);\n";
|
||||
} else {
|
||||
if (attr.ends.w > 0) {
|
||||
code << " offset.x = width.z;" << std::endl;
|
||||
c += " int X = static_cast<int>(gid.x);\n";
|
||||
}
|
||||
c += " int Y = static_cast<int>(gid.y);\n";
|
||||
c += " int Z = static_cast<int>(gid.z);\n";
|
||||
c += " if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() || "
|
||||
"Z >= args.dst_tensor.Slices()) { \n";
|
||||
c += " return; \n";
|
||||
c += " } \n";
|
||||
c += " int s_x = X * params.stride.x + params.offset.x;\n";
|
||||
c += " int s_y = Y * params.stride.y + params.offset.y;\n";
|
||||
if (op_def.src_tensors[0].HasAxis(Axis::BATCH)) {
|
||||
c += " int s_b = " + batch_id + " * params.stride.w + params.offset.w;\n";
|
||||
c += " args.src_tensor.SetBatchRef(s_b);\n";
|
||||
}
|
||||
if (alignedx4) {
|
||||
c += " int s_z = Z + params.offset.z;\n";
|
||||
c += " FLT4 result = args.src_tensor.Read(s_x, s_y, s_z);\n";
|
||||
} else {
|
||||
code << " offset.x = params.src_size.x + width.z;" << std::endl;
|
||||
c += " FLT4 result;\n";
|
||||
const std::string postfixes[] = {"x", "y", "z", "w"};
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
c += " {\n";
|
||||
const std::string ch = "(Z * 4 + " + std::to_string(i) + ")";
|
||||
c += " int s_ch = " + ch + " * params.stride.z + params.offset.z;\n";
|
||||
c += " int s_z = min(s_ch >> 2, args.src_tensor.Slices() - 1);\n";
|
||||
c += " int s_z_rem = s_ch & 3;\n";
|
||||
c += " FLT4 t = args.src_tensor.Read(s_x, s_y, s_z);\n";
|
||||
c += " result." + postfixes[i] + " = t[s_ch & 3];\n";
|
||||
c += " }\n";
|
||||
}
|
||||
}
|
||||
if (attr.strides.h > 0) {
|
||||
code << " offset.y = height.x;" << std::endl;
|
||||
} else {
|
||||
if (attr.ends.h > 0) {
|
||||
code << " offset.y = height.z;" << std::endl;
|
||||
} else {
|
||||
code << " offset.y = params.src_size.y + height.z;" << std::endl;
|
||||
}
|
||||
}
|
||||
code << std::endl;
|
||||
code << " short2 stride = short2(width.y, height.y);" << std::endl;
|
||||
|
||||
code << " const short2 s_c = offset + short2(gid.xy) * stride;"
|
||||
<< std::endl;
|
||||
code << " bool outside = false;" << std::endl;
|
||||
code << " int step = gid.z * 4;" << std::endl;
|
||||
code << " FLT4 tmp;" << std::endl;
|
||||
code << " int buffer_index = 0;" << std::endl;
|
||||
code << " int addr = 0;" << std::endl;
|
||||
code << std::endl;
|
||||
for (int i = 0; i < 4; i++) {
|
||||
code << " addr = step * channels.y;" << std::endl;
|
||||
if (attr.strides.c > 0) {
|
||||
code << " addr += channels.x;" << std::endl;
|
||||
} else {
|
||||
if (attr.ends.c > 0) {
|
||||
code << " addr += channels.z;" << std::endl;
|
||||
} else {
|
||||
code << " addr += params.src_size.z + channels.z;" << std::endl;
|
||||
}
|
||||
}
|
||||
code << " buffer_index = ((addr / 4) * params.src_size.y + s_c.y) * "
|
||||
"params.src_size.x + "
|
||||
"s_c.x;"
|
||||
<< std::endl;
|
||||
code << " outside = step >= params.dst_size.z;" << std::endl;
|
||||
code << " tmp = outside ? null_vec : src_tensor[buffer_index];"
|
||||
<< std::endl;
|
||||
code << " value[" << i << "] = tmp[addr % 4];" << std::endl;
|
||||
if (i != 3) {
|
||||
code << " step++;" << std::endl;
|
||||
code << std::endl;
|
||||
}
|
||||
}
|
||||
code << R"(
|
||||
int linear_index = (gid.z * params.dst_size.y + int(gid.y)) *
|
||||
params.dst_size.x + int(gid.x);
|
||||
$$2
|
||||
dst_tensor[linear_index] = value;
|
||||
})";
|
||||
return absl::Substitute(
|
||||
code.str(), attr.starts.w, attr.strides.w, attr.ends.w, attr.starts.h,
|
||||
attr.strides.h, attr.ends.h, attr.starts.c, attr.strides.c, attr.ends.c);
|
||||
c += " FLT4 value = result;\n";
|
||||
c += " args.dst_tensor.GetAddress(linear_index, X, Y, Z);\n";
|
||||
c += " $2\n";
|
||||
c += " args.dst_tensor.Write(value, X, Y, Z);\n";
|
||||
c += "}\n";
|
||||
return c;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
ComputeTaskDescriptor Slice(const OperationDef& definition,
|
||||
const SliceAttributes& attr) {
|
||||
ComputeTaskDescriptor desc(definition);
|
||||
desc.shader_source = GetSliceCode(attr);
|
||||
desc.tensors_as_args = true;
|
||||
desc.shader_source = GetSliceCode(definition, Is4Aligned(attr));
|
||||
|
||||
desc.AddSrcTensor("src_tensor", definition.src_tensors[0]);
|
||||
desc.AddDstTensor("dst_tensor", definition.dst_tensors[0]);
|
||||
|
||||
desc.uniform_buffers = {
|
||||
{"constant uniforms& params",
|
||||
[](const std::vector<BHWC>& src_shapes,
|
||||
[attr](const std::vector<BHWC>& src_shapes,
|
||||
const std::vector<BHWC>& dst_shapes) {
|
||||
int4 offset = GetOffset(attr, src_shapes[0].w, src_shapes[0].h,
|
||||
src_shapes[0].c, src_shapes[0].b);
|
||||
std::vector<int> uniform_params{
|
||||
// int4 src_size
|
||||
src_shapes[0].w,
|
||||
src_shapes[0].h,
|
||||
src_shapes[0].c,
|
||||
DivideRoundUp(src_shapes[0].c, 4),
|
||||
// int4 dst_size
|
||||
dst_shapes[0].w,
|
||||
dst_shapes[0].h,
|
||||
dst_shapes[0].c,
|
||||
DivideRoundUp(dst_shapes[0].c, 4),
|
||||
// int4 offset
|
||||
offset.x,
|
||||
offset.y,
|
||||
offset.z,
|
||||
offset.w,
|
||||
// int4 stride
|
||||
attr.strides.w,
|
||||
attr.strides.h,
|
||||
attr.strides.c,
|
||||
attr.strides.b,
|
||||
};
|
||||
return GetByteBuffer(uniform_params);
|
||||
}},
|
||||
@ -159,10 +182,10 @@ ComputeTaskDescriptor Slice(const OperationDef& definition,
|
||||
|
||||
desc.resize_function = [attr](const std::vector<BHWC>& src_shapes,
|
||||
const std::vector<BHWC>& dst_shapes) {
|
||||
const uint3 groups_size{16, 16, 1};
|
||||
const uint3 groups_size{8, 4, 1};
|
||||
const int dst_layers = DivideRoundUp(dst_shapes[0].c, 4);
|
||||
int groups_x = DivideRoundUp(dst_shapes[0].w, groups_size.x);
|
||||
int groups_y = DivideRoundUp(dst_shapes[0].h, groups_size.y);
|
||||
const int dst_layers = DivideRoundUp(dst_shapes[0].c, 4);
|
||||
int groups_z = DivideRoundUp(dst_layers, groups_size.z);
|
||||
return std::make_pair(groups_size, uint3{groups_x, groups_y, groups_z});
|
||||
};
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/delegates/gpu/common/gpu_info.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/model.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/shape.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/task/util.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/types.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/util.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
|
||||
@ -40,7 +41,6 @@ std::string GetSoftmax1x1Code(const GpuInfo& gpu_info) {
|
||||
using namespace metal;
|
||||
|
||||
struct uniforms {
|
||||
int4 size;
|
||||
float4 mask;
|
||||
};
|
||||
|
||||
@ -51,11 +51,11 @@ kernel void ComputeFunction($1
|
||||
uint3 ugid[[thread_position_in_grid]])
|
||||
{
|
||||
|
||||
float4 maxx4 = float4(src_tensor[0].x);
|
||||
for (int s = int(tid); s < params.size.x; s += 32) {
|
||||
float4 mask_a = s == params.size.x - 1 ? params.mask : float4(1.0f);
|
||||
float4 maxx4 = float4(args.src_tensor.Read(0, 0, 0).x);
|
||||
for (int s = int(tid); s < args.src_tensor.Slices(); s += 32) {
|
||||
float4 mask_a = s == args.src_tensor.Slices() - 1 ? params.mask : float4(1.0f);
|
||||
float4 mask_b = float4(1.0f) - mask_a;
|
||||
float4 src = float4(src_tensor[s]);
|
||||
float4 src = float4(args.src_tensor.Read(0, 0, s));
|
||||
src = src * mask_a + mask_b * src.x;
|
||||
maxx4 = max(maxx4, src);
|
||||
}
|
||||
@ -89,9 +89,9 @@ kernel void ComputeFunction($1
|
||||
maximum = tmpx1[0];
|
||||
|
||||
float sum = 0.0f;
|
||||
for (int s = int(tid); s < params.size.x; s += 32) {
|
||||
float4 mask_temp = s == params.size.x - 1 ? params.mask : float4(1.0f);
|
||||
float4 src = float4(src_tensor[s]) - float4(maximum);
|
||||
for (int s = int(tid); s < args.src_tensor.Slices(); s += 32) {
|
||||
float4 mask_temp = s == args.src_tensor.Slices() - 1 ? params.mask : float4(1.0f);
|
||||
float4 src = float4(args.src_tensor.Read(0, 0, s)) - float4(maximum);
|
||||
sum += dot(mask_temp, exp(src));
|
||||
}
|
||||
|
||||
@ -120,13 +120,13 @@ kernel void ComputeFunction($1
|
||||
sum = tmpx1[0];
|
||||
|
||||
int dst_s = int(ugid.x);
|
||||
if (dst_s < params.size.x) {
|
||||
int linear_index = dst_s;
|
||||
float4 src = float4(src_tensor[linear_index]) - float4(maximum);
|
||||
if (dst_s < args.src_tensor.Slices()) {
|
||||
float4 src = float4(args.src_tensor.Read(0, 0, dst_s)) - float4(maximum);
|
||||
FLT4 value = FLT4(exp(src) * sum);
|
||||
uint3 gid = uint3(0, 0, linear_index);
|
||||
uint3 gid = uint3(0, 0, dst_s);
|
||||
args.dst_tensor.GetAddress(linear_index, 0, 0, dst_s);
|
||||
$2
|
||||
dst_tensor[linear_index] = value;
|
||||
args.dst_tensor.Write(value, 0, 0, dst_s);
|
||||
}
|
||||
})";
|
||||
return code;
|
||||
@ -135,28 +135,27 @@ kernel void ComputeFunction($1
|
||||
|
||||
ComputeTaskDescriptor Softmax(const OperationDef& definition) {
|
||||
ComputeTaskDescriptor desc(definition);
|
||||
desc.tensors_as_args = true;
|
||||
desc.shader_source = R"(
|
||||
#include <metal_stdlib>
|
||||
using namespace metal;
|
||||
|
||||
struct uniforms {
|
||||
int4 size;
|
||||
float4 mask;
|
||||
};
|
||||
$0
|
||||
kernel void ComputeFunction(
|
||||
$1
|
||||
uint3 gid[[thread_position_in_grid]]) {
|
||||
if (int(gid.x) >= params.size.x || int(gid.y) >= params.size.y) {
|
||||
if (int(gid.x) >= args.dst_tensor.Width() || int(gid.y) >= args.dst_tensor.Height()) {
|
||||
return;
|
||||
}
|
||||
|
||||
float maximum = src_tensor[gid.y * params.size.x + gid.x].x;
|
||||
for (int d = 0; d < params.size.z; ++d) {
|
||||
int buffer_index = (d * params.size.y + gid.y) * params.size.x + gid.x;
|
||||
float4 mask_a = d == params.size.z - 1 ? params.mask : float4(1.0f);
|
||||
float maximum = args.src_tensor.Read(gid.x, gid.y, 0).x;
|
||||
for (int d = 0; d < args.dst_tensor.Slices(); ++d) {
|
||||
float4 mask_a = d == args.dst_tensor.Slices() - 1 ? params.mask : float4(1.0f);
|
||||
float4 mask_b = float4(1.0f) - mask_a;
|
||||
float4 src = float4(src_tensor[buffer_index]);
|
||||
float4 src = float4(args.src_tensor.Read(gid.x, gid.y, d));
|
||||
src = src * mask_a + mask_b * src.x;
|
||||
maximum = max(maximum, src.x);
|
||||
maximum = max(maximum, src.y);
|
||||
@ -165,19 +164,18 @@ kernel void ComputeFunction(
|
||||
}
|
||||
|
||||
float sum = 0.0f;
|
||||
for (int d = 0; d < params.size.z; ++d) {
|
||||
int buffer_index = (d * params.size.y + gid.y) * params.size.x + gid.x;
|
||||
float4 mask_temp = d == params.size.z - 1 ? params.mask : float4(1.0f);
|
||||
float4 src = float4(src_tensor[buffer_index]) - float4(maximum);
|
||||
for (int d = 0; d < args.dst_tensor.Slices(); ++d) {
|
||||
float4 mask_temp = d == args.dst_tensor.Slices() - 1 ? params.mask : float4(1.0f);
|
||||
float4 src = float4(args.src_tensor.Read(gid.x, gid.y, d)) - float4(maximum);
|
||||
sum += dot(mask_temp, exp(src));
|
||||
}
|
||||
|
||||
for (int d = 0; d < params.size.z; ++d) {
|
||||
const int linear_index = (d * params.size.y + gid.y) * params.size.x + gid.x;
|
||||
float4 src = float4(src_tensor[linear_index]) - float4(maximum);
|
||||
for (int d = 0; d < args.dst_tensor.Slices(); ++d) {
|
||||
float4 src = float4(args.src_tensor.Read(gid.x, gid.y, d)) - float4(maximum);
|
||||
FLT4 value = FLT4(exp(src) / sum);
|
||||
args.dst_tensor.GetAddress(linear_index, gid.x, gid.y, d);
|
||||
$2
|
||||
dst_tensor[linear_index] = value;
|
||||
args.dst_tensor.Write(value, gid.x, gid.y, d);
|
||||
}
|
||||
}
|
||||
)";
|
||||
@ -189,20 +187,9 @@ kernel void ComputeFunction(
|
||||
{"constant uniforms& params",
|
||||
[](const std::vector<BHWC>& src_shapes,
|
||||
const std::vector<BHWC>& dst_shapes) {
|
||||
const int dst_depth = DivideRoundUp(dst_shapes[0].c, 4);
|
||||
struct uniforms {
|
||||
int4 size;
|
||||
float4 mask;
|
||||
};
|
||||
uniforms params;
|
||||
params.size = {dst_shapes[0].w, dst_shapes[0].h, dst_depth, 1};
|
||||
params.mask = {0.0f, 0.0f, 0.0f, 0.0f};
|
||||
int reminder = dst_shapes[0].c % 4 == 0 ? 4 : dst_shapes[0].c % 4;
|
||||
for (int i = 0; i < reminder; ++i) {
|
||||
params.mask[i] = 1.0f;
|
||||
}
|
||||
const uint8_t* ptr = reinterpret_cast<const uint8_t*>(¶ms);
|
||||
return std::vector<uint8_t>(ptr, ptr + sizeof(uniforms));
|
||||
float4 mask = GetMaskForLastPlane(dst_shapes[0].c);
|
||||
const uint8_t* ptr = reinterpret_cast<const uint8_t*>(&mask);
|
||||
return std::vector<uint8_t>(ptr, ptr + sizeof(float4));
|
||||
}},
|
||||
};
|
||||
|
||||
@ -220,6 +207,7 @@ kernel void ComputeFunction(
|
||||
ComputeTaskDescriptor Softmax1x1(const OperationDef& definition,
|
||||
const GpuInfo& gpu_info) {
|
||||
ComputeTaskDescriptor desc(definition);
|
||||
desc.tensors_as_args = true;
|
||||
desc.shader_source = GetSoftmax1x1Code(gpu_info);
|
||||
|
||||
desc.AddSrcTensor("src_tensor", definition.src_tensors[0]);
|
||||
@ -229,20 +217,9 @@ ComputeTaskDescriptor Softmax1x1(const OperationDef& definition,
|
||||
{"constant uniforms& params",
|
||||
[](const std::vector<BHWC>& src_shapes,
|
||||
const std::vector<BHWC>& dst_shapes) {
|
||||
const int src_depth = DivideRoundUp(dst_shapes[0].c, 4);
|
||||
struct uniforms {
|
||||
int4 size;
|
||||
float4 mask;
|
||||
};
|
||||
uniforms params;
|
||||
params.size = {src_depth, DivideRoundUp(src_depth, 32), 1, 1};
|
||||
params.mask = {0.0f, 0.0f, 0.0f, 0.0f};
|
||||
int reminder = dst_shapes[0].c % 4 == 0 ? 4 : dst_shapes[0].c % 4;
|
||||
for (int i = 0; i < reminder; ++i) {
|
||||
params.mask[i] = 1.0f;
|
||||
}
|
||||
const uint8_t* ptr = reinterpret_cast<const uint8_t*>(¶ms);
|
||||
return std::vector<uint8_t>(ptr, ptr + sizeof(uniforms));
|
||||
float4 mask = GetMaskForLastPlane(dst_shapes[0].c);
|
||||
const uint8_t* ptr = reinterpret_cast<const uint8_t*>(&mask);
|
||||
return std::vector<uint8_t>(ptr, ptr + sizeof(float4));
|
||||
}},
|
||||
};
|
||||
|
||||
|
@ -124,6 +124,8 @@ absl::Status MetalSpatialTensor::GetGPUResources(
|
||||
if (!tensor_desc) {
|
||||
return absl::InvalidArgumentError("Expected TensorDescriptor on input.");
|
||||
}
|
||||
resources->ints.push_back(
|
||||
{"slice_stride", tensor_desc->GetSliceStrideSize(shape_)});
|
||||
if (descriptor_.HasAxis(Axis::WIDTH)) {
|
||||
resources->ints.push_back({"width", Width()});
|
||||
resources->ints.push_back({"width_div2", Width() / 2});
|
||||
|
@ -310,7 +310,7 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Downloading data from https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media\u0026token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8\n",
|
||||
"Downloading data from https://dl.fbaipublicfiles.com/glue/data/SST-2.zip\n",
|
||||
"7446528/7439277 [==============================] - 0s 0us/step\n"
|
||||
]
|
||||
}
|
||||
@ -318,7 +318,7 @@
|
||||
"source": [
|
||||
"data_dir = tf.keras.utils.get_file(\n",
|
||||
" fname='SST-2.zip',\n",
|
||||
" origin='https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media\u0026token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8',\n",
|
||||
" origin='https://dl.fbaipublicfiles.com/glue/data/SST-2.zip',\n",
|
||||
" extract=True)\n",
|
||||
"data_dir = os.path.join(os.path.dirname(data_dir), 'SST-2')"
|
||||
]
|
||||
|
@ -587,11 +587,9 @@ TfLiteStatus InterpreterBuilder::ParseTensors(
|
||||
status = kTfLiteError;
|
||||
}
|
||||
|
||||
size_t dims_signature_rank = 0;
|
||||
const int* dims_signature_data = nullptr;
|
||||
std::vector<int> dims_signature = {};
|
||||
if (tensor->shape_signature()) {
|
||||
dims_signature_rank = tensor->shape_signature()->size();
|
||||
dims_signature_data = tensor->shape_signature()->data();
|
||||
dims_signature = FlatBufferIntArrayToVector(tensor->shape_signature());
|
||||
}
|
||||
|
||||
bool is_variable = tensor->is_variable();
|
||||
@ -623,7 +621,7 @@ TfLiteStatus InterpreterBuilder::ParseTensors(
|
||||
} else {
|
||||
if (subgraph->SetTensorParametersReadWrite(
|
||||
i, type, get_name(tensor), dims, quantization, is_variable,
|
||||
dims_signature_rank, dims_signature_data) != kTfLiteOk) {
|
||||
dims_signature) != kTfLiteOk) {
|
||||
error_reporter_->Report("Tensor %d is invalidly specified in schema.\n",
|
||||
i);
|
||||
status = kTfLiteError;
|
||||
|
@ -80,6 +80,9 @@ TfLiteStatus copyToTensor(TfLiteContext* context, const FromT* in,
|
||||
case kTfLiteInt32:
|
||||
copyCast(in, out->data.i32, num_elements);
|
||||
break;
|
||||
case kTfLiteInt16:
|
||||
copyCast(in, out->data.i16, num_elements);
|
||||
break;
|
||||
case kTfLiteUInt8:
|
||||
copyCast(in, out->data.uint8, num_elements);
|
||||
break;
|
||||
@ -113,6 +116,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
return copyToTensor(context, input->data.i64, output, num_elements);
|
||||
case kTfLiteInt32:
|
||||
return copyToTensor(context, input->data.i32, output, num_elements);
|
||||
case kTfLiteInt16:
|
||||
return copyToTensor(context, input->data.i16, output, num_elements);
|
||||
case kTfLiteUInt8:
|
||||
return copyToTensor(context, input->data.uint8, output, num_elements);
|
||||
case kTfLiteFloat32:
|
||||
|
@ -46,6 +46,22 @@ class CastOpModel : public SingleOpModel {
|
||||
int output_;
|
||||
};
|
||||
|
||||
TEST(CastOpModel, CastInt16ToFloat) {
|
||||
CastOpModel m({TensorType_INT16, {2, 3}}, {TensorType_FLOAT32, {2, 3}});
|
||||
m.PopulateTensor<int16_t>(m.input(), {100, 200, 300, 400, 500, 600});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.ExtractVector<float>(m.output()),
|
||||
ElementsAreArray({100.f, 200.f, 300.f, 400.f, 500.f, 600.f}));
|
||||
}
|
||||
|
||||
TEST(CastOpModel, CastInt16ToInt32) {
|
||||
CastOpModel m({TensorType_INT16, {2, 3}}, {TensorType_INT32, {2, 3}});
|
||||
m.PopulateTensor<int16_t>(m.input(), {100, 200, 300, 400, 500, 600});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.ExtractVector<int32_t>(m.output()),
|
||||
ElementsAreArray({100, 200, 300, 400, 500, 600}));
|
||||
}
|
||||
|
||||
TEST(CastOpModel, CastInt32ToFloat) {
|
||||
CastOpModel m({TensorType_INT32, {2, 3}}, {TensorType_FLOAT32, {2, 3}});
|
||||
m.PopulateTensor<int32_t>(m.input(), {100, 200, 300, 400, 500, 600});
|
||||
@ -62,6 +78,14 @@ TEST(CastOpModel, CastFloatToInt32) {
|
||||
ElementsAreArray({100, 20, 3, 0, 0, 1}));
|
||||
}
|
||||
|
||||
TEST(CastOpModel, CastFloatToInt16) {
|
||||
CastOpModel m({TensorType_FLOAT32, {3, 2}}, {TensorType_INT16, {3, 2}});
|
||||
m.PopulateTensor<float>(m.input(), {100.f, 20.f, 3.f, 0.4f, 0.999f, 1.1f});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.ExtractVector<int16_t>(m.output()),
|
||||
ElementsAreArray({100, 20, 3, 0, 0, 1}));
|
||||
}
|
||||
|
||||
TEST(CastOpModel, CastInt64ToFloat) {
|
||||
CastOpModel m({TensorType_INT64, {2, 3}}, {TensorType_FLOAT32, {2, 3}});
|
||||
m.PopulateTensor<int64_t>(m.input(), {100, 200, 300, 400, 500, 600});
|
||||
|
@ -121,7 +121,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
|
||||
// TODO(ahentz): Creating 'all_inputs' below is not very efficient. We should
|
||||
// allocate and populate these during Prepare().
|
||||
// TODO(ycling): Activation function parameter is ignored. For now we dont have
|
||||
// TODO(ycling): Activation function parameter is ignored. For now we don't have
|
||||
// a model with a Concatenation with fused activation function.
|
||||
#define TF_LITE_CONCATENATION(scalar) \
|
||||
{ \
|
||||
|
@ -494,6 +494,7 @@ cc_library(
|
||||
"reference/resize_nearest_neighbor.h",
|
||||
"reference/round.h",
|
||||
"reference/softmax.h",
|
||||
"reference/space_to_depth.h",
|
||||
"reference/strided_slice.h",
|
||||
"reference/sub.h",
|
||||
"reference/svdf.h",
|
||||
@ -511,13 +512,14 @@ cc_library(
|
||||
}),
|
||||
compatible_with = get_compatible_with_portable(),
|
||||
copts = tflite_copts(),
|
||||
# We are disabling parse_headers for the tf_lite_static_memory build to
|
||||
# allow it to be consistent with the OSS bazel build. See b/175817116
|
||||
# for more details.
|
||||
features = select({
|
||||
":tf_lite_static_memory": ["-parse_headers"],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
# We are disabling parse_headers for this header-only target so that the
|
||||
# external and internal builds are consistent. The primary issue here is
|
||||
# that parse_headers is not supported with bazel and the TFLM team would
|
||||
# really like to have all build errors in shared Micro/Lite code be
|
||||
# reproducible from the OSS build as well.
|
||||
#
|
||||
# See b/175817116 for more details.
|
||||
features = ["-parse_headers"],
|
||||
deps = [
|
||||
":common",
|
||||
":compatibility",
|
||||
@ -588,6 +590,7 @@ cc_library(
|
||||
"reference/resize_nearest_neighbor.h",
|
||||
"reference/round.h",
|
||||
"reference/softmax.h",
|
||||
"reference/space_to_depth.h",
|
||||
"reference/strided_slice.h",
|
||||
"reference/string_comparisons.h",
|
||||
"reference/sub.h",
|
||||
|
@ -15,16 +15,13 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_CONV_H_
|
||||
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_CONV_H_
|
||||
|
||||
#include "tensorflow/lite/kernels/internal/types.h"
|
||||
#include "tensorflow/lite/kernels/internal/common.h"
|
||||
|
||||
|
||||
#include "tensorflow/lite/kernels/internal/types.h"
|
||||
|
||||
namespace tflite {
|
||||
|
||||
namespace reference_ops {
|
||||
|
||||
|
||||
inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
|
||||
const float* input_data, const RuntimeShape& filter_shape,
|
||||
const float* filter_data, const RuntimeShape& bias_shape,
|
||||
|
@ -61,6 +61,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/kernels/internal/reference/resize_nearest_neighbor.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/round.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/softmax.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/space_to_depth.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/strided_slice.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/string_comparisons.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/sub.h"
|
||||
@ -126,58 +127,6 @@ inline void DepthToSpace(const tflite::DepthToSpaceParams& op_params,
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void SpaceToDepth(const tflite::SpaceToDepthParams& op_params,
|
||||
const RuntimeShape& unextended_input_shape,
|
||||
const T* input_data,
|
||||
const RuntimeShape& unextended_output_shape,
|
||||
T* output_data) {
|
||||
TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
|
||||
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
|
||||
const RuntimeShape input_shape =
|
||||
RuntimeShape::ExtendedShape(4, unextended_input_shape);
|
||||
const RuntimeShape output_shape =
|
||||
RuntimeShape::ExtendedShape(4, unextended_output_shape);
|
||||
|
||||
const int input_depth = input_shape.Dims(3);
|
||||
const int input_width = input_shape.Dims(2);
|
||||
const int input_height = input_shape.Dims(1);
|
||||
const int input_batch = input_shape.Dims(0);
|
||||
|
||||
const int output_depth = output_shape.Dims(3);
|
||||
const int output_width = output_shape.Dims(2);
|
||||
const int output_height = output_shape.Dims(1);
|
||||
const int output_batch = output_shape.Dims(0);
|
||||
|
||||
const int32 block_size = op_params.block_size;
|
||||
|
||||
TFLITE_DCHECK_EQ(input_width, output_width * block_size);
|
||||
TFLITE_DCHECK_EQ(input_height, output_height * block_size);
|
||||
TFLITE_DCHECK_EQ(input_depth * block_size * block_size, output_depth);
|
||||
TFLITE_DCHECK_EQ(input_batch, output_batch);
|
||||
|
||||
for (int in_b = 0; in_b < input_batch; ++in_b) {
|
||||
for (int in_h = 0; in_h < input_height; ++in_h) {
|
||||
for (int in_w = 0; in_w < input_width; ++in_w) {
|
||||
for (int in_d = 0; in_d < input_depth; ++in_d) {
|
||||
const int out_d =
|
||||
in_d + ((in_h % block_size) * block_size + in_w % block_size) *
|
||||
input_depth;
|
||||
const int out_w = in_w / block_size;
|
||||
const int out_h = in_h / block_size;
|
||||
const int out_b = in_b;
|
||||
|
||||
const int input_index = Offset(input_shape, in_b, in_h, in_w, in_d);
|
||||
const int output_index =
|
||||
Offset(output_shape, out_b, out_h, out_w, out_d);
|
||||
|
||||
output_data[output_index] = input_data[input_index];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline void Elu(const RuntimeShape& input_shape, const float* input_data,
|
||||
const RuntimeShape& output_shape, float* output_data) {
|
||||
const int flat_size = MatchingFlatSize(input_shape, output_shape);
|
||||
|
78
tensorflow/lite/kernels/internal/reference/space_to_depth.h
Normal file
78
tensorflow/lite/kernels/internal/reference/space_to_depth.h
Normal file
@ -0,0 +1,78 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_SPACE_TO_DEPTH_H_
|
||||
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_SPACE_TO_DEPTH_H_
|
||||
|
||||
#include "tensorflow/lite/kernels/internal/types.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace reference_ops {
|
||||
|
||||
template <typename T>
|
||||
inline void SpaceToDepth(const tflite::SpaceToDepthParams& op_params,
|
||||
const RuntimeShape& unextended_input_shape,
|
||||
const T* input_data,
|
||||
const RuntimeShape& unextended_output_shape,
|
||||
T* output_data) {
|
||||
TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
|
||||
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
|
||||
const RuntimeShape input_shape =
|
||||
RuntimeShape::ExtendedShape(4, unextended_input_shape);
|
||||
const RuntimeShape output_shape =
|
||||
RuntimeShape::ExtendedShape(4, unextended_output_shape);
|
||||
|
||||
const int input_depth = input_shape.Dims(3);
|
||||
const int input_width = input_shape.Dims(2);
|
||||
const int input_height = input_shape.Dims(1);
|
||||
const int input_batch = input_shape.Dims(0);
|
||||
|
||||
const int output_depth = output_shape.Dims(3);
|
||||
const int output_width = output_shape.Dims(2);
|
||||
const int output_height = output_shape.Dims(1);
|
||||
const int output_batch = output_shape.Dims(0);
|
||||
|
||||
const int32 block_size = op_params.block_size;
|
||||
|
||||
TFLITE_DCHECK_EQ(input_width, output_width * block_size);
|
||||
TFLITE_DCHECK_EQ(input_height, output_height * block_size);
|
||||
TFLITE_DCHECK_EQ(input_depth * block_size * block_size, output_depth);
|
||||
TFLITE_DCHECK_EQ(input_batch, output_batch);
|
||||
|
||||
for (int in_b = 0; in_b < input_batch; ++in_b) {
|
||||
for (int in_h = 0; in_h < input_height; ++in_h) {
|
||||
for (int in_w = 0; in_w < input_width; ++in_w) {
|
||||
for (int in_d = 0; in_d < input_depth; ++in_d) {
|
||||
const int out_d =
|
||||
in_d + ((in_h % block_size) * block_size + in_w % block_size) *
|
||||
input_depth;
|
||||
const int out_w = in_w / block_size;
|
||||
const int out_h = in_h / block_size;
|
||||
const int out_b = in_b;
|
||||
|
||||
const int input_index = Offset(input_shape, in_b, in_h, in_w, in_d);
|
||||
const int output_index =
|
||||
Offset(output_shape, out_b, out_h, out_w, out_d);
|
||||
|
||||
output_data[output_index] = input_data[input_index];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace reference_ops
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_SPACE_TO_DEPTH_H_
|
@ -170,7 +170,7 @@ TfLiteStatus PopulateQuantizedLstmParams8x8_16(
|
||||
intermediate_zp.push_back(0);
|
||||
}
|
||||
}
|
||||
// In the absense of projection, hidden becomes otuput and this intermediate
|
||||
// In the absence of projection, hidden becomes otuput and this intermediate
|
||||
// is ignored.
|
||||
TfLiteTensor* hidden;
|
||||
TF_LITE_ENSURE_OK(context, GetIntermediatesSafe(context, node, 4, &hidden));
|
||||
|
@ -204,7 +204,7 @@ to determine if the requested feature aligns with the TFLM roadmap.
|
||||
1. Run all the tests for x86, and any other platform that you are modifying.
|
||||
|
||||
```
|
||||
tensorflow/lite/micro/tools/make/tools/ci_build/test_x86.sh
|
||||
tensorflow/lite/micro/tools/ci_build/test_x86.sh
|
||||
```
|
||||
|
||||
Please check the READMEs in the optimized kernel directories for specific
|
||||
|
@ -1,2 +1,2 @@
|
||||
numpy==1.16.2
|
||||
tensorflow==2.0.0-beta1
|
||||
tensorflow==2.4.0
|
||||
|
@ -33,7 +33,7 @@ set +e
|
||||
# The pigweed scripts only work from a git repository and the Tensorflow CI
|
||||
# infrastructure does not always guarantee that. As an ugly workaround, we
|
||||
# create our own git repo when running on the CI servers.
|
||||
pushd tensorflow/lite/micro/
|
||||
pushd tensorflow/lite/
|
||||
if [[ ${1} == "PRESUBMIT" ]]; then
|
||||
git init .
|
||||
git config user.email "tflm@google.com"
|
||||
@ -43,9 +43,12 @@ if [[ ${1} == "PRESUBMIT" ]]; then
|
||||
fi
|
||||
|
||||
# Check for license with the necessary exclusions.
|
||||
tools/make/downloads/pigweed/pw_presubmit/py/pw_presubmit/pigweed_presubmit.py \
|
||||
. \
|
||||
micro/tools/make/downloads/pigweed/pw_presubmit/py/pw_presubmit/pigweed_presubmit.py \
|
||||
kernels/internal/reference/ \
|
||||
micro/ \
|
||||
-p copyright_notice \
|
||||
-e kernels/internal/reference/integer_ops/ \
|
||||
-e kernels/internal/reference/reference_ops.h \
|
||||
-e tools/make/downloads \
|
||||
-e tools/make/targets/ecm3531 \
|
||||
-e BUILD\
|
||||
@ -66,8 +69,11 @@ LICENSE_CHECK_RESULT=$?
|
||||
# Python files (with yapf as the formatter) because that needs additional setup.
|
||||
# We are also ignoring the markdown files to allow for a more gradual rollout of
|
||||
# this presubmit check.
|
||||
tools/make/downloads/pigweed/pw_presubmit/py/pw_presubmit/format_code.py \
|
||||
. \
|
||||
micro/tools/make/downloads/pigweed/pw_presubmit/py/pw_presubmit/format_code.py \
|
||||
kernels/internal/reference/ \
|
||||
micro/ \
|
||||
-e kernels/internal/reference/integer_ops/ \
|
||||
-e kernels/internal/reference/reference_ops.h \
|
||||
-e "\.inc" \
|
||||
-e "\.md" \
|
||||
-e "\.py"
|
||||
@ -76,7 +82,7 @@ CLANG_FORMAT_RESULT=$?
|
||||
|
||||
popd
|
||||
if [[ ${1} == "PRESUBMIT" ]]; then
|
||||
rm -rf tensorflow/lite/micro/.git
|
||||
rm -rf tensorflow/lite/.git
|
||||
fi
|
||||
|
||||
# Re-enable exit on error now that we are done with the temporary git repo.
|
||||
|
@ -116,7 +116,7 @@ download_and_extract() {
|
||||
local tempdir=$(mktemp -d)
|
||||
local tempdir2=$(mktemp -d)
|
||||
local tempfile=${tempdir}/temp_file
|
||||
local curl_retries=3
|
||||
local curl_retries=5
|
||||
|
||||
# Destionation already downloaded.
|
||||
if [ -d ${dir} ]; then
|
||||
@ -131,24 +131,21 @@ download_and_extract() {
|
||||
mkdir -p "${dir}"
|
||||
# We've been seeing occasional 56 errors from valid URLs, so set up a retry
|
||||
# loop to attempt to recover from them.
|
||||
for (( i=1; i<=$curl_retries; ++i ))
|
||||
do
|
||||
for (( i=1; i<=$curl_retries; ++i )); do
|
||||
# We have to use this approach because we normally halt the script when
|
||||
# there's an error, and instead we want to catch errors so we can retry.
|
||||
set +e
|
||||
curl -Ls --fail --retry 5 "${url}" > ${tempfile}
|
||||
set +ex
|
||||
curl -LsS --fail --retry 5 "${url}" > ${tempfile}
|
||||
CURL_RESULT=$?
|
||||
set -e
|
||||
set -ex
|
||||
|
||||
# Was the command successful? If so, continue.
|
||||
if [[ $CURL_RESULT -eq 0 ]]
|
||||
then
|
||||
if [[ $CURL_RESULT -eq 0 ]]; then
|
||||
break
|
||||
fi
|
||||
|
||||
# Keep trying if we see the '56' error code.
|
||||
if [[ ( $CURL_RESULT -ne 56 ) || ( $i -eq $curl_retries ) ]]
|
||||
then
|
||||
if [[ ( $CURL_RESULT -ne 56 ) || ( $i -eq $curl_retries ) ]]; then
|
||||
echo "Error $CURL_RESULT downloading '${url}'"
|
||||
exit 1
|
||||
fi
|
||||
|
@ -110,7 +110,7 @@ class OpsSet(enum.Enum):
|
||||
"EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8"
|
||||
|
||||
def __str__(self):
|
||||
return self.value
|
||||
return str(self.value)
|
||||
|
||||
@staticmethod
|
||||
def get_options():
|
||||
@ -394,22 +394,22 @@ def build_toco_convert_protos(input_tensors,
|
||||
input_tensors: List of input tensors. Type and shape are computed using
|
||||
`foo.shape` and `foo.dtype`.
|
||||
output_tensors: List of output tensors (only .name is used from this).
|
||||
inference_type: Target data type of real-number arrays in the output file.
|
||||
Must be `{tf.float32, tf.uint8, tf.int8}`. (default tf.float32)
|
||||
inference_input_type: Target data type of real-number input arrays. Allows
|
||||
for a different type for input arrays in the case of quantization. Must be
|
||||
`{tf.float32, tf.uint8, tf.int8}`. (default `inference_type`)
|
||||
input_format: Type of data to read Currently must be
|
||||
`{TENSORFLOW_GRAPHDEF}`. (default TENSORFLOW_GRAPHDEF)
|
||||
input_shapes: Input array shape. It needs to be a list of the same length as
|
||||
`input_tensors`, or None. (default None)
|
||||
output_format: Output file format. Currently must be `{TFLITE,
|
||||
GRAPHVIZ_DOT}`. (default TFLITE)
|
||||
quantized_input_stats: List of tuples of floats representing the mean and
|
||||
standard deviation. Each tuple maps to the corresponding input tensor.
|
||||
Only need if `inference_input_type` is `QUANTIZED_UINT8` or `INT8`.
|
||||
real_input_value = (quantized_input_value - mean_value) / std_dev_value.
|
||||
(default None)
|
||||
inference_type: Data type of numeric arrays, excluding the input layer.
|
||||
(default tf.float32, must be in {tf.float32, tf.int8, tf.uint8})
|
||||
inference_input_type: Data type of the numeric arrays in the input layer. If
|
||||
`inference_input_type` is in {tf.int8, tf.uint8}, then
|
||||
`quantized_input_stats` must be provided. (default is the value assigned
|
||||
to `inference_type`, must be in {tf.float32, tf.int8, tf.uint8})
|
||||
input_format: Type of data to read.
|
||||
(default TENSORFLOW_GRAPHDEF, must be in {TENSORFLOW_GRAPHDEF})
|
||||
input_shapes: Input array shape. (default None, must be None or a list of
|
||||
the same length as `input_tensors`.)
|
||||
output_format: Output file format. (default TFLITE, must be in
|
||||
{TFLITE, GRAPHVIZ_DOT})
|
||||
quantized_input_stats: Map of input tensor names to a tuple of floats
|
||||
representing the mean and standard deviation of the training data.
|
||||
(e.g., {"foo" : (0., 1.)}). Required if `inference_input_type` is tf.int8
|
||||
or tf.uint8. (default None)
|
||||
default_ranges_stats: Tuple of integers representing (min, max) range values
|
||||
for all arrays without a specified range. Intended for experimenting with
|
||||
quantization via "dummy quantization". (default None)
|
||||
@ -574,8 +574,10 @@ def toco_convert_graph_def(input_data, input_arrays_with_shape, output_arrays,
|
||||
if _requires_input_stats(toco_flags):
|
||||
if (("quantized_input_stats" not in kwargs) or
|
||||
(not kwargs["quantized_input_stats"])):
|
||||
raise ValueError("std_dev and mean must be defined when inference_type "
|
||||
"or inference_input_type is QUANTIZED_UINT8 or INT8.")
|
||||
raise ValueError(
|
||||
"The `quantized_input_stats` flag must be defined when either "
|
||||
"`inference_type` flag or `inference_input_type` flag is set to "
|
||||
"tf.int8 or tf.uint8.")
|
||||
input_array.mean_value, input_array.std_value = kwargs[
|
||||
"quantized_input_stats"][idx]
|
||||
input_array.name = name
|
||||
@ -661,7 +663,7 @@ def toco_convert(input_data, input_tensors, output_tensors, *args, **kwargs):
|
||||
Typically this function is used to convert from TensorFlow GraphDef to TFLite.
|
||||
Conversion can be customized by providing arguments that are forwarded to
|
||||
`build_toco_convert_protos` (see documentation for details). This function has
|
||||
been deprecated. Please use `lite.TFLiteConverter` instead.
|
||||
been deprecated. Please use `tf.lite.TFLiteConverter` instead.
|
||||
|
||||
Args:
|
||||
input_data: Input data (i.e. often `sess.graph_def`),
|
||||
|
@ -137,7 +137,7 @@ class ConvertTest(test_util.TensorFlowTestCase):
|
||||
self.assertEqual("output", output_details[0]["name"])
|
||||
self.assertEqual(np.uint8, output_details[0]["dtype"])
|
||||
self.assertTrue(([1, 16, 16, 3] == output_details[0]["shape"]).all())
|
||||
self.assertTrue(output_details[0]["quantization"][0] > 0) # scale
|
||||
self.assertGreater(output_details[0]["quantization"][0], 0) # scale
|
||||
|
||||
def testGraphDefQuantizationInvalid(self):
|
||||
with ops.Graph().as_default():
|
||||
@ -159,9 +159,9 @@ class ConvertTest(test_util.TensorFlowTestCase):
|
||||
enable_mlir_converter=False,
|
||||
inference_type=dtypes.uint8)
|
||||
self.assertEqual(
|
||||
"std_dev and mean must be defined when inference_type or "
|
||||
"inference_input_type is QUANTIZED_UINT8 or INT8.",
|
||||
str(error.exception))
|
||||
"The `quantized_input_stats` flag must be defined when either "
|
||||
"`inference_type` flag or `inference_input_type` flag is set to "
|
||||
"tf.int8 or tf.uint8.", str(error.exception))
|
||||
|
||||
|
||||
class ConvertTestOpHint(test_util.TensorFlowTestCase):
|
||||
|
@ -61,6 +61,7 @@ from tensorflow.lite.python.util import get_debug_info as _get_debug_info
|
||||
from tensorflow.lite.python.util import get_grappler_config as _get_grappler_config
|
||||
from tensorflow.lite.python.util import get_tensor_name as _get_tensor_name
|
||||
from tensorflow.lite.python.util import get_tensors_from_tensor_names as _get_tensors_from_tensor_names
|
||||
from tensorflow.lite.python.util import get_tf_type_name as _get_tf_type_name
|
||||
from tensorflow.lite.python.util import is_frozen_graph as _is_frozen_graph
|
||||
from tensorflow.lite.python.util import model_input_signature as _model_input_signature
|
||||
from tensorflow.lite.python.util import modify_model_io_type as _modify_model_io_type
|
||||
@ -89,19 +90,14 @@ from tensorflow.python.util.tf_export import tf_export as _tf_export
|
||||
|
||||
@_tf_export("lite.Optimize")
|
||||
class Optimize(enum.Enum):
|
||||
"""Enum defining the optimizations to apply when generating tflite graphs.
|
||||
|
||||
Some optimizations may come at the cost of accuracy.
|
||||
"""Enum defining the optimizations to apply when generating a tflite model.
|
||||
|
||||
DEFAULT
|
||||
Default optimization strategy.
|
||||
|
||||
Converter will do its best to improve size and latency based on the
|
||||
information provided.
|
||||
Enhanced optimizations are gained by providing a representative_dataset.
|
||||
This is recommended, and is currently equivalent to the modes below.
|
||||
Currently, weights will be quantized and if representative_dataset is
|
||||
provided, activations for quantizable operations will also be quantized.
|
||||
Default optimization strategy that quantizes model weights. Enhanced
|
||||
optimizations are gained by providing a representative dataset that
|
||||
quantizes biases and activations as well.
|
||||
Converter will do its best to reduce size and latency, while minimizing
|
||||
the loss in accuracy.
|
||||
|
||||
OPTIMIZE_FOR_SIZE
|
||||
Deprecated. Does the same as DEFAULT.
|
||||
@ -110,14 +106,11 @@ class Optimize(enum.Enum):
|
||||
Deprecated. Does the same as DEFAULT.
|
||||
"""
|
||||
|
||||
# Default optimization strategy.
|
||||
#
|
||||
# Converter will do its best to improve size and latency based on the
|
||||
# information provided.
|
||||
# Enhanced optimizations can be gained by providing a representative_dataset.
|
||||
# This is recommended, and is currently equivalent to the modes below.
|
||||
# Currently, weights will be quantized and if representative_dataset is
|
||||
# provided, activations for quantizable operations will also be quantized.
|
||||
# Default optimization strategy that quantizes model weights. Enhanced
|
||||
# optimizations are gained by providing a representative dataset that
|
||||
# quantizes biases and activations as well.
|
||||
# Converter will do its best to reduce size and latency, while minimizing
|
||||
# the loss in accuracy.
|
||||
DEFAULT = "DEFAULT"
|
||||
|
||||
# Deprecated. Does the same as DEFAULT.
|
||||
@ -132,48 +125,47 @@ class Optimize(enum.Enum):
|
||||
|
||||
@_tf_export("lite.RepresentativeDataset")
|
||||
class RepresentativeDataset(object):
|
||||
"""Representative dataset to evaluate optimizations.
|
||||
"""Representative dataset used to optimize the model.
|
||||
|
||||
A representative dataset that can be used to evaluate optimizations by the
|
||||
converter. E.g. converter can use these examples to estimate (min, max) ranges
|
||||
by calibrating the model on inputs. This can allow converter to quantize a
|
||||
converted floating point model.
|
||||
This is a generator function that provides a small dataset to calibrate or
|
||||
estimate the range, i.e, (min, max) of all floating-point arrays in the model
|
||||
(such as model input, activation outputs of intermediate layers, and model
|
||||
output) for quantization. Usually, this is a small subset of a few hundred
|
||||
samples randomly chosen, in no particular order, from the training or
|
||||
evaluation dataset.
|
||||
"""
|
||||
|
||||
def __init__(self, input_gen):
|
||||
"""Creates a representative dataset.
|
||||
|
||||
Args:
|
||||
input_gen: an input generator that can be used to generate input samples
|
||||
for the model. This must be a callable object that returns an object
|
||||
that supports the `iter()` protocol (e.g. a generator function). The
|
||||
elements generated must have same type and shape as inputs to the model.
|
||||
input_gen: A generator function that generates input samples for the
|
||||
model and has the same order, type and shape as the inputs to the model.
|
||||
Usually, this is a small subset of a few hundred samples randomly
|
||||
chosen, in no particular order, from the training or evaluation dataset.
|
||||
"""
|
||||
self.input_gen = input_gen
|
||||
|
||||
|
||||
@_tf_export("lite.TargetSpec")
|
||||
class TargetSpec(object):
|
||||
"""Specification of target device.
|
||||
|
||||
Details about target device. Converter optimizes the generated model for
|
||||
specific device.
|
||||
"""Specification of target device used to optimize the model.
|
||||
|
||||
Attributes:
|
||||
supported_ops: Experimental flag, subject to change. Set of OpsSet options
|
||||
supported by the device. (default set([OpsSet.TFLITE_BUILTINS]))
|
||||
supported_types: List of types for constant values on the target device.
|
||||
Frequently, an optimization choice is driven by the most compact
|
||||
(i.e. smallest) type in this list (default [tf.float32])
|
||||
supported_ops: Experimental flag, subject to change. Set of `tf.lite.OpsSet`
|
||||
options, where each option represents a set of operators supported by the
|
||||
target device. (default {tf.lite.OpsSet.TFLITE_BUILTINS}))
|
||||
supported_types: Set of `tf.dtypes.DType` data types supported on the target
|
||||
device. If initialized, optimization might be driven by the smallest type
|
||||
in this set. (default set())
|
||||
experimental_select_user_tf_ops: Experimental flag, subject to change. Set
|
||||
of user's TensorFlow operators' names that are required in the TensorFlow
|
||||
Lite runtime. These ops will be exported as select TensorFlow ops in the
|
||||
model (in conjunction with the OpsSet.SELECT_TF_OPS flag). This is an
|
||||
advanced feature that should only be used if the client is using TF ops
|
||||
model (in conjunction with the tf.lite.OpsSet.SELECT_TF_OPS flag). This is
|
||||
an advanced feature that should only be used if the client is using TF ops
|
||||
that may not be linked in by default with the TF ops that are provided
|
||||
when using the SELECT_TF_OPS path. The client is responsible for linking
|
||||
these ops into the target runtime.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
@ -181,17 +173,17 @@ class TargetSpec(object):
|
||||
supported_types=None,
|
||||
experimental_select_user_tf_ops=None):
|
||||
if supported_ops is None:
|
||||
supported_ops = set([OpsSet.TFLITE_BUILTINS])
|
||||
supported_ops = {OpsSet.TFLITE_BUILTINS}
|
||||
self.supported_ops = supported_ops
|
||||
if supported_types is None:
|
||||
supported_types = []
|
||||
supported_types = set()
|
||||
self.supported_types = supported_types
|
||||
if experimental_select_user_tf_ops is None:
|
||||
self.experimental_select_user_tf_ops = []
|
||||
self.experimental_select_user_tf_ops = set()
|
||||
|
||||
|
||||
class QuantizationMode(object):
|
||||
"""QuantizationMode determines the quantized conversion from user options."""
|
||||
"""QuantizationMode determines the quantization type from user options."""
|
||||
|
||||
def __init__(self, optimizations, target_spec, representative_dataset,
|
||||
graph_def):
|
||||
@ -205,7 +197,6 @@ class QuantizationMode(object):
|
||||
# TODO(b/162537905): Refactor the following quantization functions -
|
||||
# re-organize and refactor for better readability.
|
||||
def post_training_int8_no_float(self):
|
||||
"""Post training int8 quantize, disallow float fallback."""
|
||||
return (self._any_optimization_enabled() and
|
||||
self._is_int8_target_required() and
|
||||
not self._is_int16x8_target_required() and
|
||||
@ -213,19 +204,16 @@ class QuantizationMode(object):
|
||||
self._representative_dataset is not None)
|
||||
|
||||
def post_training_int8_allow_float(self):
|
||||
"""Post training int8 quantize, allow float fallback."""
|
||||
return (self._any_optimization_enabled() and
|
||||
not self._is_int16x8_target_required() and
|
||||
self._representative_dataset is not None and
|
||||
self._smallest_supported_type() == _dtypes.int8)
|
||||
|
||||
def is_post_training_integer_quantize_8(self):
|
||||
"""Post training integer 8 quantization."""
|
||||
return (self.post_training_int8_no_float() or
|
||||
self.post_training_int8_allow_float())
|
||||
|
||||
def is_post_training_integer_quantize_16x8(self):
|
||||
"""Post training integer 16x8 quantization."""
|
||||
return (self.post_training_int16x8_no_float() or
|
||||
self.post_training_int16x8_allow_float())
|
||||
|
||||
@ -239,7 +227,6 @@ class QuantizationMode(object):
|
||||
self.contains_training_quant_op())
|
||||
|
||||
def post_training_int16x8_no_float(self):
|
||||
"""Post training int16x8 quantize, disallow float fallback."""
|
||||
return (self._any_optimization_enabled() and
|
||||
not self._is_int8_target_required() and
|
||||
self._is_int16x8_target_required() and
|
||||
@ -247,13 +234,11 @@ class QuantizationMode(object):
|
||||
self._representative_dataset is not None)
|
||||
|
||||
def post_training_int16x8_allow_float(self):
|
||||
"""Post training int16x8 quantize, allow float fallback."""
|
||||
return (self._any_optimization_enabled() and
|
||||
self._is_int16x8_target_required() and
|
||||
self._is_allow_float())
|
||||
|
||||
def post_training_dynamic_range_int8(self):
|
||||
"""Post training int8 const, on-the-fly int8 quantize of dynamic tensors."""
|
||||
# Post-training dynamic range quantization is only enabled if post-training
|
||||
# int8 quantization and training time quantization was not done.
|
||||
return (self._any_optimization_enabled() and
|
||||
@ -262,7 +247,6 @@ class QuantizationMode(object):
|
||||
self._smallest_supported_type() == _dtypes.int8)
|
||||
|
||||
def post_training_fp16(self):
|
||||
"""Post training fp16 quantize."""
|
||||
return (self._any_optimization_enabled() and
|
||||
self._smallest_supported_type() == _dtypes.float16)
|
||||
|
||||
@ -416,21 +400,20 @@ class TFLiteConverterBase(object):
|
||||
"""Converter subclass to share functionality between V1 and V2 converters."""
|
||||
|
||||
def __init__(self):
|
||||
self.allow_custom_ops = False
|
||||
self.target_spec = TargetSpec()
|
||||
self.optimizations = []
|
||||
self.optimizations = set()
|
||||
self.representative_dataset = None
|
||||
self.target_spec = TargetSpec()
|
||||
self.allow_custom_ops = False
|
||||
self.experimental_new_converter = True
|
||||
self._experimental_new_quantizer = False
|
||||
self._experimental_calibrate_only = False
|
||||
# The 'GraphDebugInfo' contains the stack traces of all the original nodes
|
||||
# in the `GraphDef` to the converter.
|
||||
self._debug_info = None
|
||||
self._experimental_sparsify_model = False
|
||||
self._debug_info = None # contains the stack traces of all the original
|
||||
# nodes in the `GraphDef` to the converter.
|
||||
self.saved_model_dir = None
|
||||
self._saved_model_tags = None
|
||||
self._saved_model_version = 0
|
||||
self._saved_model_exported_names = []
|
||||
self._experimental_sparsify_model = False
|
||||
|
||||
def _grappler_config(self, optimizers=None):
|
||||
"""Creates a tf.compat.v1.ConfigProto for configuring Grappler.
|
||||
@ -684,9 +667,9 @@ class TFLiteSavedModelConverterV2(TFLiteConverterBaseV2):
|
||||
saved_model_dir: Directory of the SavedModel.
|
||||
saved_model_tags: Set of tags identifying the MetaGraphDef within the
|
||||
SavedModel to analyze. All tags in the tag set must be present. (default
|
||||
set(SERVING)).
|
||||
saved_model_exported_names: Names to be exported (default: export all)
|
||||
when the saved model import path is on.
|
||||
{tf.saved_model.SERVING}).
|
||||
saved_model_exported_names: Names to be exported when the saved model
|
||||
import path is on.
|
||||
trackable_obj: tf.AutoTrackable object associated with `funcs`. A
|
||||
reference to this object needs to be maintained so that Variables do not
|
||||
get garbage collected since functions have a weak reference to
|
||||
@ -972,20 +955,21 @@ class TFLiteConverterV2(TFLiteFrozenGraphConverterV2):
|
||||
"""Converts a TensorFlow model into TensorFlow Lite model.
|
||||
|
||||
Attributes:
|
||||
allow_custom_ops: Boolean indicating whether to allow custom operations.
|
||||
When False, any unknown operation is an error. When True, custom ops are
|
||||
created for any op that is unknown. The developer needs to provide these
|
||||
to the TensorFlow Lite runtime with a custom resolver. (default False)
|
||||
optimizations: Experimental flag, subject to change. A list of optimizations
|
||||
to apply when converting the model. E.g. `[Optimize.DEFAULT]`
|
||||
representative_dataset: A representative dataset that can be used to
|
||||
generate input and output samples for the model. The converter can use the
|
||||
dataset to evaluate different optimizations. Note that this is an optional
|
||||
attribute but it is necessary if INT8 is the only support builtin ops in
|
||||
target ops.
|
||||
optimizations: Experimental flag, subject to change. Set of optimizations
|
||||
to apply. e.g {tf.lite.Optimize.DEFAULT}. (default None, must be None or a
|
||||
set of values of type `tf.lite.Optimize`)
|
||||
representative_dataset: A generator function used for integer quantization
|
||||
where each generated sample has the same order, type and shape as the
|
||||
inputs to the model. Usually, this is a small subset of a few hundred
|
||||
samples randomly chosen, in no particular order, from the training or
|
||||
evaluation dataset. This is an optional attribute, but required for full
|
||||
integer quantization, i.e, if `tf.int8` is the only supported type in
|
||||
`target_spec.supported_types`. Refer to `tf.lite.RepresentativeDataset`.
|
||||
(default None)
|
||||
target_spec: Experimental flag, subject to change. Specifications of target
|
||||
device, including supported ops set, supported types and a set of user's
|
||||
defined TensorFlow operators required in the TensorFlow Lite runtime.
|
||||
Refer to `tf.lite.TargetSpec`.
|
||||
inference_input_type: Data type of the input layer. Note that integer types
|
||||
(tf.int8 and tf.uint8) are currently only supported for post training
|
||||
integer quantization and quantization aware training. (default tf.float32,
|
||||
@ -994,8 +978,13 @@ class TFLiteConverterV2(TFLiteFrozenGraphConverterV2):
|
||||
types (tf.int8 and tf.uint8) are currently only supported for post
|
||||
training integer quantization and quantization aware training. (default
|
||||
tf.float32, must be in {tf.float32, tf.int8, tf.uint8})
|
||||
allow_custom_ops: Boolean indicating whether to allow custom operations.
|
||||
When False, any unknown operation is an error. When True, custom ops are
|
||||
created for any op that is unknown. The developer needs to provide these
|
||||
to the TensorFlow Lite runtime with a custom resolver. (default False)
|
||||
experimental_new_converter: Experimental flag, subject to change. Enables
|
||||
MLIR-based conversion instead of TOCO conversion. (default True)
|
||||
|
||||
Example usage:
|
||||
|
||||
```python
|
||||
@ -1063,7 +1052,8 @@ class TFLiteConverterV2(TFLiteFrozenGraphConverterV2):
|
||||
`signatures` attribute of the MetaGraphdef is used. (default
|
||||
saved_model.signatures)
|
||||
tags: Set of tags identifying the MetaGraphDef within the SavedModel to
|
||||
analyze. All tags in the tag set must be present. (default set(SERVING))
|
||||
analyze. All tags in the tag set must be present. (default
|
||||
{tf.saved_model.SERVING} or {'serve'})
|
||||
|
||||
Returns:
|
||||
TFLiteConverter object.
|
||||
@ -1209,9 +1199,13 @@ class TFLiteConverterBaseV1(TFLiteConverterBase):
|
||||
|
||||
if (requires_quantized_input_stats and
|
||||
not converter_kwargs["quantized_input_stats"]):
|
||||
raise ValueError("The `quantized_input_stats` flag must be defined when "
|
||||
"either `inference_type` flag or `inference_input_type` "
|
||||
"flag is set to tf.uint8 or tf.int8.")
|
||||
raise ValueError(
|
||||
"The `quantized_input_stats` flag must be defined when either "
|
||||
"`inference_type` flag or `inference_input_type` flag is set to "
|
||||
"tf.int8 or tf.uint8. Currently, `inference_type={}` and "
|
||||
"`inference_input_type={}`.".format(
|
||||
_get_tf_type_name(converter_kwargs["inference_type"]),
|
||||
_get_tf_type_name(converter_kwargs["inference_input_type"])))
|
||||
|
||||
def convert(self):
|
||||
"""Converts a TensorFlow GraphDef based on instance variables.
|
||||
@ -1424,9 +1418,9 @@ class TFLiteSavedModelConverter(TFLiteConverterBaseV1):
|
||||
saved_model_dir: Directory of the SavedModel.
|
||||
saved_model_tags: Set of tags identifying the MetaGraphDef within the
|
||||
SavedModel to analyze. All tags in the tag set must be present. (default
|
||||
set(SERVING)).
|
||||
saved_model_exported_names: Names to be exported (default: export all)
|
||||
when the saved model import path is on.
|
||||
{tf.saved_model.SERVING}).
|
||||
saved_model_exported_names: Names to be exported when the saved model
|
||||
import path is on.
|
||||
experimental_debug_info_func: An experimental function to retrieve the
|
||||
graph debug info for a set of nodes from the `graph_def`.
|
||||
|
||||
@ -1645,33 +1639,42 @@ class TFLiteConverter(TFLiteFrozenGraphConverter):
|
||||
model into either a TFLite FlatBuffer or graph visualization.
|
||||
|
||||
Attributes:
|
||||
inference_type: Target data type of real-number arrays in the output file.
|
||||
Must be `{tf.float32, tf.uint8}`. If `optimzations` are provided, this
|
||||
parameter is ignored. (default tf.float32)
|
||||
inference_input_type: Target data type of real-number input arrays. Allows
|
||||
for a different type for input arrays. If an integer type is provided and
|
||||
`optimizations` are not used, `quantized_input_stats` must be provided. If
|
||||
`inference_type` is tf.uint8, signaling conversion to a fully quantized
|
||||
model from a quantization-aware trained input model, then
|
||||
`inference_input_type` defaults to tf.uint8. In all other cases,
|
||||
`inference_input_type` defaults to tf.float32. Must be `{tf.float32,
|
||||
tf.uint8, tf.int8}`
|
||||
inference_output_type: Target data type of real-number output arrays. Allows
|
||||
for a different type for output arrays. If `inference_type` is tf.uint8,
|
||||
signaling conversion to a fully quantized model from a quantization-aware
|
||||
trained output model, then `inference_output_type` defaults to tf.uint8.
|
||||
In all other cases, `inference_output_type` must be tf.float32, an error
|
||||
will be thrown otherwise. Must be `{tf.float32, tf.uint8, tf.int8}`
|
||||
output_format: Output file format. Currently must be `{TFLITE,
|
||||
GRAPHVIZ_DOT}`. (default TFLITE)
|
||||
quantized_input_stats: Dict of strings representing input tensor names
|
||||
mapped to tuple of floats representing the mean and standard deviation
|
||||
of the training data (e.g., {"foo" : (0., 1.)}). Only need if
|
||||
`inference_input_type` is `QUANTIZED_UINT8`. real_input_value =
|
||||
(quantized_input_value - mean_value) / std_dev_value. (default {})
|
||||
default_ranges_stats: Tuple of integers representing (min, max) range values
|
||||
for all arrays without a specified range. Intended for experimenting with
|
||||
quantization via "dummy quantization". (default None)
|
||||
optimizations: Experimental flag, subject to change. Set of optimizations to
|
||||
apply. e.g {tf.lite.Optimize.DEFAULT}. (default None, must be None or a
|
||||
set of values of type `tf.lite.Optimize`)
|
||||
representative_dataset: A generator function used for integer quantization
|
||||
where each generated sample has the same order, type and shape as the
|
||||
inputs to the model. Usually, this is a small subset of a few hundred
|
||||
samples randomly chosen, in no particular order, from the training or
|
||||
evaluation dataset. This is an optional attribute, but required for full
|
||||
integer quantization, i.e, if `tf.int8` is the only supported type in
|
||||
`target_spec.supported_types`. Refer to `tf.lite.RepresentativeDataset`.
|
||||
(default None)
|
||||
target_spec: Experimental flag, subject to change. Specifications of target
|
||||
device, including supported ops set, supported types and a set of user's
|
||||
defined TensorFlow operators required in the TensorFlow Lite runtime.
|
||||
Refer to `tf.lite.TargetSpec`.
|
||||
inference_type: Data type of numeric arrays, excluding the input layer.
|
||||
(default tf.float32, must be in {tf.float32, tf.int8, tf.uint8})
|
||||
inference_input_type: Data type of the numeric arrays in the input layer. If
|
||||
`inference_input_type` is in {tf.int8, tf.uint8}, then
|
||||
`quantized_input_stats` must be provided. (default is the value assigned
|
||||
to `inference_type`, must be in {tf.float32, tf.int8, tf.uint8})
|
||||
inference_output_type: Data type of the numeric arrays in the output layer.
|
||||
(default is the value assigned to `inference_type`, must be in
|
||||
{tf.float32, tf.int8, tf.uint8})
|
||||
quantized_input_stats: Map of input tensor names to a tuple of floats
|
||||
representing the mean and standard deviation of the training data.
|
||||
(e.g., {"foo" : (0., 1.)}). Required if `inference_input_type` is tf.int8
|
||||
or tf.uint8. (default None)
|
||||
default_ranges_stats: Tuple of integers (min, max) representing range values
|
||||
for all numeric arrays without a specified range. Intended for
|
||||
experimenting with quantization via "dummy quantization". (default None)
|
||||
allow_custom_ops: Boolean indicating whether to allow custom operations.
|
||||
When False any unknown operation is an error. When True, custom ops are
|
||||
created for any op that is unknown. The developer will need to provide
|
||||
these to the TensorFlow Lite runtime with a custom resolver. (default
|
||||
False)
|
||||
drop_control_dependency: Boolean indicating whether to drop control
|
||||
dependencies silently. This is due to TFLite not supporting control
|
||||
dependencies. (default True)
|
||||
@ -1683,37 +1686,25 @@ class TFLiteConverter(TFLiteFrozenGraphConverter):
|
||||
change_concat_input_ranges: Boolean to change behavior of min/max ranges for
|
||||
inputs and outputs of the concat operator for quantized models. Changes
|
||||
the ranges of concat operator overlap when true. (default False)
|
||||
allow_custom_ops: Boolean indicating whether to allow custom operations.
|
||||
When false any unknown operation is an error. When true, custom ops are
|
||||
created for any op that is unknown. The developer will need to provide
|
||||
these to the TensorFlow Lite runtime with a custom resolver. (default
|
||||
False)
|
||||
post_training_quantize: Deprecated. Please specify `[Optimize.DEFAULT]` for
|
||||
`optimizations` instead. Boolean indicating whether to quantize the
|
||||
weights of the converted float model. Model size will be reduced and
|
||||
there will be latency improvements (at the cost of accuracy). (default
|
||||
False)
|
||||
output_format: Output file format. (default
|
||||
tf.compat.v1.lite.constants.TFLITE, must be in
|
||||
{tf.compat.v1.lite.constants.TFLITE,
|
||||
tf.compat.v1.lite.constants.GRAPHVIZ_DOT})
|
||||
dump_graphviz_dir: Full filepath of folder to dump the graphs at various
|
||||
stages of processing GraphViz .dot files. Preferred over
|
||||
--output_format=GRAPHVIZ_DOT in order to keep the requirements of the
|
||||
output file. (default None)
|
||||
dump_graphviz_video: Boolean indicating whether to dump the graph after
|
||||
every graph transformation. (default False)
|
||||
conversion_summary_dir: A string indicating the path to the generated
|
||||
conversion logs.
|
||||
target_ops: Deprecated. Please specify `target_spec.supported_ops` instead.
|
||||
Set of OpsSet options indicating which converter to use. (default
|
||||
set([OpsSet.TFLITE_BUILTINS]))
|
||||
target_spec: Experimental flag, subject to change. Specifications of target
|
||||
device, including supported ops set, supported types and a set of user's
|
||||
defined TensorFlow operators required in the TensorFlow Lite runtime.
|
||||
optimizations: Experimental flag, subject to change. A list of optimizations
|
||||
to apply when converting the model. E.g. `[Optimize.DEFAULT]`
|
||||
representative_dataset: A representative dataset that can be used to
|
||||
generate input and output samples for the model. The converter can use the
|
||||
dataset to evaluate different optimizations.
|
||||
`output_format=tf.compat.v1.lite.constants.GRAPHVIZ_DOT` in order to keep
|
||||
the requirements of the output file. (default None)
|
||||
dump_graphviz_video: Boolean indicating whether to dump the GraphViz .dot
|
||||
files after every graph transformation. Requires the `dump_graphviz_dir`
|
||||
flag to be specified. (default False)
|
||||
conversion_summary_dir: Full path of the directory to store conversion logs.
|
||||
(default None)
|
||||
target_ops: Deprecated. Please use `target_spec.supported_ops` instead.
|
||||
post_training_quantize: Deprecated. Please use `optimizations` instead and
|
||||
set it to `{tf.lite.Optimize.DEFAULT}`. (default False)
|
||||
experimental_new_converter: Experimental flag, subject to change. Enables
|
||||
MLIR-based conversion instead of TOCO conversion. (default True)
|
||||
|
||||
Example usage:
|
||||
|
||||
```python
|
||||
@ -1911,9 +1902,10 @@ class TFLiteConverter(TFLiteFrozenGraphConverter):
|
||||
output_arrays: List of output tensors to freeze graph with. Uses output
|
||||
arrays from SignatureDef when none are provided. (default None)
|
||||
tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to
|
||||
analyze. All tags in the tag set must be present. (default set("serve"))
|
||||
analyze. All tags in the tag set must be present. (default
|
||||
{tf.saved_model.SERVING})
|
||||
signature_key: Key identifying SignatureDef containing inputs and outputs.
|
||||
(default DEFAULT_SERVING_SIGNATURE_DEF_KEY)
|
||||
(default tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY)
|
||||
|
||||
Returns:
|
||||
TFLiteConverter class.
|
||||
|
@ -1239,18 +1239,22 @@ class FromSessionTest(TestModels, parameterized.TestCase):
|
||||
quantized_converter.inference_type = quantized_type
|
||||
quantized_converter.convert()
|
||||
self.assertEqual(
|
||||
'The `quantized_input_stats` flag must be defined when '
|
||||
'either `inference_type` flag or `inference_input_type` '
|
||||
'flag is set to tf.uint8 or tf.int8.', str(error.exception))
|
||||
'The `quantized_input_stats` flag must be defined when either '
|
||||
'`inference_type` flag or `inference_input_type` flag is set to '
|
||||
'tf.int8 or tf.uint8. Currently, `inference_type=tf.{}` and '
|
||||
'`inference_input_type=None`.'.format(quantized_type.name),
|
||||
str(error.exception))
|
||||
|
||||
with self.assertRaises(ValueError) as error:
|
||||
quantized_converter.inference_type = dtypes.float32
|
||||
quantized_converter.inference_input_type = quantized_type
|
||||
quantized_converter.convert()
|
||||
self.assertEqual(
|
||||
'The `quantized_input_stats` flag must be defined when '
|
||||
'either `inference_type` flag or `inference_input_type` '
|
||||
'flag is set to tf.uint8 or tf.int8.', str(error.exception))
|
||||
'The `quantized_input_stats` flag must be defined when either '
|
||||
'`inference_type` flag or `inference_input_type` flag is set to '
|
||||
'tf.int8 or tf.uint8. Currently, `inference_type=tf.float32` and '
|
||||
'`inference_input_type=tf.{}`.'.format(quantized_type.name),
|
||||
str(error.exception))
|
||||
|
||||
quantized_converter.inference_type = quantized_type
|
||||
quantized_converter.inference_input_type = quantized_type
|
||||
|
@ -127,9 +127,9 @@ def _convert_tflite_enum_type_to_tf_type(tflite_enum_type):
|
||||
return tf_type
|
||||
|
||||
|
||||
def _get_tf_type_name(tf_type):
|
||||
def get_tf_type_name(tf_type):
|
||||
"""Converts tf.dtype (eg: tf.float32) to str (eg: "tf.float32")."""
|
||||
return "tf." + tf_type.name
|
||||
return "tf." + tf_type.name if tf_type else None
|
||||
|
||||
|
||||
def get_tensor_name(tensor):
|
||||
@ -674,7 +674,7 @@ def _modify_model_input_type(model, inference_input_type=dtypes.float32):
|
||||
raise ValueError(
|
||||
"Initial model input type must be tf.float32. Expected type for "
|
||||
"tensor with name '{}' is tf.float32, instead type is {}".format(
|
||||
float_tensor.name, _get_tf_type_name(float_type)))
|
||||
float_tensor.name, get_tf_type_name(float_type)))
|
||||
# If found, validate that the operator output is quantized and compatible
|
||||
# with the final model input type
|
||||
quant_type = _convert_tflite_enum_type_to_tf_type(quant_tensor.type)
|
||||
@ -683,17 +683,17 @@ def _modify_model_input_type(model, inference_input_type=dtypes.float32):
|
||||
"Initial model input is not quantized. Expected type for "
|
||||
"tensor with name '{}' should be in {}, instead type is {}".format(
|
||||
quant_tensor.name,
|
||||
tuple(_get_tf_type_name(t) for t in
|
||||
tuple(get_tf_type_name(t) for t in
|
||||
_MAP_QUANT_TO_IO_TYPES.keys()),
|
||||
_get_tf_type_name(quant_type)))
|
||||
get_tf_type_name(quant_type)))
|
||||
else:
|
||||
inference_io_types = _MAP_QUANT_TO_IO_TYPES[quant_type]
|
||||
if inference_input_type not in inference_io_types:
|
||||
raise ValueError(
|
||||
"Unsupported `inference_input_type` value. Expected to be in "
|
||||
"{}, instead got {}.".format(
|
||||
tuple(_get_tf_type_name(t) for t in inference_io_types),
|
||||
_get_tf_type_name(inference_input_type)))
|
||||
tuple(get_tf_type_name(t) for t in inference_io_types),
|
||||
get_tf_type_name(inference_input_type)))
|
||||
input_quant_ops.append(op)
|
||||
|
||||
if len(subgraph.inputs) != len(input_quant_ops):
|
||||
@ -725,7 +725,7 @@ def _modify_model_input_type(model, inference_input_type=dtypes.float32):
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupported `inference_input_type` value {}.".format(
|
||||
_get_tf_type_name(inference_input_type)))
|
||||
get_tf_type_name(inference_input_type)))
|
||||
|
||||
|
||||
def _modify_model_output_type(model, inference_output_type=dtypes.float32):
|
||||
@ -768,7 +768,7 @@ def _modify_model_output_type(model, inference_output_type=dtypes.float32):
|
||||
raise ValueError(
|
||||
"Initial model output type must be tf.float32. Expected type for "
|
||||
"tensor with name '{}' is tf.float32, instead type is {}".format(
|
||||
float_tensor.name, _get_tf_type_name(float_type)))
|
||||
float_tensor.name, get_tf_type_name(float_type)))
|
||||
# If found, validate that the operator input is quantized and compatible
|
||||
# with the final model output type
|
||||
quant_type = _convert_tflite_enum_type_to_tf_type(quant_tensor.type)
|
||||
@ -777,17 +777,17 @@ def _modify_model_output_type(model, inference_output_type=dtypes.float32):
|
||||
"Initial model output is not dequantized. Expected type for "
|
||||
"tensor with name '{}' should be in {}, instead type is {}".format(
|
||||
quant_tensor.name,
|
||||
tuple(_get_tf_type_name(t) for t in
|
||||
tuple(get_tf_type_name(t) for t in
|
||||
_MAP_QUANT_TO_IO_TYPES.keys()),
|
||||
_get_tf_type_name(quant_type)))
|
||||
get_tf_type_name(quant_type)))
|
||||
else:
|
||||
inference_io_types = _MAP_QUANT_TO_IO_TYPES[quant_type]
|
||||
if inference_output_type not in inference_io_types:
|
||||
raise ValueError(
|
||||
"Unsupported `inference_output_type` value. Expected to be in "
|
||||
"{}, instead got {}.".format(
|
||||
tuple(_get_tf_type_name(t) for t in inference_io_types),
|
||||
_get_tf_type_name(inference_output_type)))
|
||||
tuple(get_tf_type_name(t) for t in inference_io_types),
|
||||
get_tf_type_name(inference_output_type)))
|
||||
output_dequant_ops.append(op)
|
||||
|
||||
if len(subgraph.outputs) != len(output_dequant_ops):
|
||||
@ -834,7 +834,7 @@ def _modify_model_output_type(model, inference_output_type=dtypes.float32):
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupported `inference_output_type` value {}.".format(
|
||||
_get_tf_type_name(inference_output_type)))
|
||||
get_tf_type_name(inference_output_type)))
|
||||
|
||||
|
||||
def modify_model_io_type(
|
||||
|
@ -26,7 +26,26 @@ from tensorflow.lite.testing.zip_test_utils import register_make_test_function
|
||||
@register_make_test_function()
|
||||
def make_cast_tests(options):
|
||||
"""Generate examples for cast."""
|
||||
test_parameters = [{
|
||||
if options.use_experimental_converter:
|
||||
test_parameters = [
|
||||
{
|
||||
"input_dtype": [tf.float32],
|
||||
"output_dtype": [tf.int16],
|
||||
"input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]],
|
||||
},
|
||||
{
|
||||
"input_dtype": [tf.int16],
|
||||
"output_dtype": [tf.float32],
|
||||
"input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]],
|
||||
},
|
||||
{
|
||||
"input_dtype": [tf.int32],
|
||||
"output_dtype": [tf.float32],
|
||||
"input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]],
|
||||
}]
|
||||
else:
|
||||
test_parameters = [
|
||||
{
|
||||
"input_dtype": [tf.int32],
|
||||
"output_dtype": [tf.float32],
|
||||
"input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]],
|
||||
|
@ -763,6 +763,7 @@ py_library(
|
||||
"//tensorflow/python/util:_pywrap_tfprof",
|
||||
"//tensorflow/python/util:_pywrap_transform_graph",
|
||||
"//tensorflow/python/util:_pywrap_util_port",
|
||||
"//tensorflow/python/platform:_pywrap_tf2",
|
||||
":_pywrap_utils",
|
||||
":composite_tensor",
|
||||
":config",
|
||||
@ -5266,7 +5267,10 @@ pywrap_tensorflow_macro(
|
||||
tf_additional_plugin_deps() +
|
||||
tf_additional_profiler_deps()) + if_xla_available([
|
||||
"//tensorflow/compiler/aot:tfcompile_lib",
|
||||
]) + if_static(extra_deps = ["//tensorflow/core/platform:tensor_float_32_utils"]),
|
||||
]) + if_static(extra_deps = [
|
||||
"//tensorflow/core/platform:tensor_float_32_utils",
|
||||
"//tensorflow/core/platform:enable_tf2_utils",
|
||||
]),
|
||||
)
|
||||
|
||||
# ** Targets for Windows build (start) **
|
||||
@ -6779,6 +6783,8 @@ py_test(
|
||||
":client_testlib",
|
||||
":framework_combinations",
|
||||
":tf2",
|
||||
"//tensorflow/python/compat:v2_compat",
|
||||
"//tensorflow/python/data/kernel_tests:test_base",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -33,7 +33,7 @@ from tensorflow.python.util.tf_export import tf_export
|
||||
# This value changes every day with an automatic CL. It can be modified in code
|
||||
# via `forward_compatibility_horizon()` or with the environment variable
|
||||
# TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date.
|
||||
_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2021, 1, 2)
|
||||
_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2021, 1, 5)
|
||||
_FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS"
|
||||
_FORWARD_COMPATIBILITY_DATE_NUMBER = None
|
||||
|
||||
|
@ -21,6 +21,7 @@ from __future__ import print_function
|
||||
from tensorflow.python.compat import v2_compat
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.platform import _pywrap_tf2
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@ -29,9 +30,13 @@ class DisableV2BehaviorTest(test.TestCase):
|
||||
def test_basic(self):
|
||||
t = constant_op.constant([1, 2, 3]) # creates a hidden context
|
||||
self.assertTrue(isinstance(t, ops.EagerTensor))
|
||||
t = _pywrap_tf2.is_enabled()
|
||||
self.assertTrue(t)
|
||||
v2_compat.disable_v2_behavior()
|
||||
t = constant_op.constant([1, 2, 3])
|
||||
self.assertFalse(isinstance(t, ops.EagerTensor))
|
||||
t = _pywrap_tf2.is_enabled()
|
||||
self.assertFalse(t)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -36,7 +36,10 @@ py_library(
|
||||
|
||||
py_library(
|
||||
name = "trt_convert_py",
|
||||
srcs = ["trt_convert.py"],
|
||||
srcs = [
|
||||
"trt_convert.py",
|
||||
"utils.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/compiler/tf2tensorrt:_pywrap_py_utils",
|
||||
|
@ -35,7 +35,9 @@ from tensorflow.compiler.tf2tensorrt._pywrap_py_utils import get_linked_tensorrt
|
||||
from tensorflow.compiler.tf2tensorrt._pywrap_py_utils import is_tensorrt_enabled
|
||||
from tensorflow.core.framework import graph_pb2
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.core.protobuf import rewriter_config_pb2
|
||||
from tensorflow.python.compiler.tensorrt import trt_convert
|
||||
from tensorflow.python.compiler.tensorrt import utils as trt_utils
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import graph_io
|
||||
from tensorflow.python.framework import ops
|
||||
@ -331,17 +333,21 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
|
||||
"""Get config proto based on specific settings."""
|
||||
conversion_params = self.GetConversionParams(run_params)
|
||||
max_batch_size = self.GetMaxBatchSize(run_params)
|
||||
|
||||
if graph_state == GraphState.INFERENCE and run_params.convert_online:
|
||||
rewriter_cfg = trt_convert.get_tensorrt_rewriter_config(
|
||||
conversion_params,
|
||||
is_dynamic_op=run_params.dynamic_engine,
|
||||
max_batch_size=max_batch_size)
|
||||
graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_cfg)
|
||||
max_batch_size=max_batch_size,
|
||||
disable_non_trt_optimizers=self._disable_non_trt_optimizers)
|
||||
else:
|
||||
graph_options = config_pb2.GraphOptions()
|
||||
rewriter_cfg = rewriter_config_pb2.RewriterConfig()
|
||||
if self._disable_non_trt_optimizers:
|
||||
trt_utils.disable_non_trt_optimizers_in_rewriter_config(rewriter_cfg)
|
||||
|
||||
config = config_pb2.ConfigProto(
|
||||
gpu_options=self._GetGPUOptions(), graph_options=graph_options)
|
||||
gpu_options=self._GetGPUOptions(),
|
||||
graph_options=config_pb2.GraphOptions(rewrite_options=rewriter_cfg))
|
||||
return config
|
||||
|
||||
def _GetFeedNames(self):
|
||||
|
@ -30,6 +30,7 @@ from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.core.protobuf import meta_graph_pb2
|
||||
from tensorflow.core.protobuf import rewriter_config_pb2
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.compiler.tensorrt import utils as trt_utils
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import wrap_function
|
||||
from tensorflow.python.framework import convert_to_constants
|
||||
@ -271,9 +272,12 @@ def _get_tensorrt_rewriter_config(conversion_params,
|
||||
# need to run constant folding again.
|
||||
rewriter_config_with_trt.optimizers.extend(
|
||||
["constfold", "layout", "constfold"])
|
||||
|
||||
rewriter_config_with_trt.meta_optimizer_iterations = (
|
||||
rewriter_config_pb2.RewriterConfig.ONE)
|
||||
optimizer = rewriter_config_with_trt.custom_optimizers.add()
|
||||
|
||||
if not disable_non_trt_optimizers:
|
||||
# Add a constfold optimizer to cleanup the unused Const nodes.
|
||||
rewriter_config_with_trt.custom_optimizers.add().name = "constfold"
|
||||
|
||||
@ -295,25 +299,11 @@ def _get_tensorrt_rewriter_config(conversion_params,
|
||||
optimizer.parameter_map["max_batch_size"].i = max_batch_size
|
||||
optimizer.parameter_map["use_implicit_batch"].b = use_implicit_batch
|
||||
|
||||
# Disabling optimizers should happen after CopyFrom the template
|
||||
# Disabling optimizers should happen after defining the TF-TRT grappler pass
|
||||
# otherwise the template can overwrite the disablement.
|
||||
if disable_non_trt_optimizers:
|
||||
off = rewriter_config_pb2.RewriterConfig.OFF
|
||||
rewriter_config_with_trt.layout_optimizer = off
|
||||
rewriter_config_with_trt.constant_folding = off
|
||||
rewriter_config_with_trt.shape_optimization = off
|
||||
rewriter_config_with_trt.remapping = off
|
||||
rewriter_config_with_trt.arithmetic_optimization = off
|
||||
rewriter_config_with_trt.dependency_optimization = off
|
||||
rewriter_config_with_trt.loop_optimization = off
|
||||
rewriter_config_with_trt.function_optimization = off
|
||||
rewriter_config_with_trt.debug_stripper = off
|
||||
rewriter_config_with_trt.disable_model_pruning = True
|
||||
rewriter_config_with_trt.scoped_allocator_optimization = off
|
||||
rewriter_config_with_trt.memory_optimization = (
|
||||
rewriter_config_pb2.RewriterConfig.NO_MEM_OPT)
|
||||
rewriter_config_with_trt.pin_to_host_optimization = off
|
||||
rewriter_config_with_trt.auto_parallel.enable = False
|
||||
trt_utils.disable_non_trt_optimizers_in_rewriter_config(
|
||||
rewriter_config_with_trt)
|
||||
|
||||
return rewriter_config_with_trt
|
||||
|
||||
@ -652,10 +642,19 @@ class TrtGraphConverter(object):
|
||||
return_elements=fetch_names,
|
||||
name="")
|
||||
|
||||
calibrate_rewriter_cfg = rewriter_config_pb2.RewriterConfig()
|
||||
if self._test_only_disable_non_trt_optimizers:
|
||||
trt_utils.disable_non_trt_optimizers_in_rewriter_config(
|
||||
calibrate_rewriter_cfg)
|
||||
|
||||
# Set allow_soft_placement=True to run the graph for calibration so that
|
||||
# OPs supported by TensorRT but don't have a GPU implementation are allowed
|
||||
# to execute on CPU.
|
||||
calibrate_config = config_pb2.ConfigProto(allow_soft_placement=True)
|
||||
calibrate_config = config_pb2.ConfigProto(
|
||||
allow_soft_placement=True,
|
||||
graph_options=config_pb2.GraphOptions(
|
||||
rewrite_options=calibrate_rewriter_cfg))
|
||||
|
||||
with session.Session(
|
||||
graph=self._calibration_graph,
|
||||
config=calibrate_config) as calibration_sess:
|
||||
|
47
tensorflow/python/compiler/tensorrt/utils.py
Normal file
47
tensorflow/python/compiler/tensorrt/utils.py
Normal file
@ -0,0 +1,47 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
"""Exposes the Python wrapper conversion to trt_graph."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.core.protobuf import rewriter_config_pb2
|
||||
|
||||
|
||||
def disable_non_trt_optimizers_in_rewriter_config(rewriter_config):
|
||||
"""Modifies rewriter_config to disable all non-TRT optimizations."""
|
||||
off = rewriter_config_pb2.RewriterConfig.OFF
|
||||
|
||||
rewriter_config.arithmetic_optimization = off
|
||||
rewriter_config.auto_mixed_precision = off
|
||||
rewriter_config.auto_parallel.enable = False
|
||||
rewriter_config.constant_folding = off
|
||||
rewriter_config.debug_stripper = off
|
||||
rewriter_config.dependency_optimization = off
|
||||
# This one needs to be ON to allow TF-TRT
|
||||
rewriter_config.disable_meta_optimizer = False
|
||||
rewriter_config.disable_model_pruning = True
|
||||
rewriter_config.function_optimization = off
|
||||
rewriter_config.implementation_selector = off
|
||||
rewriter_config.layout_optimizer = off
|
||||
rewriter_config.loop_optimization = off
|
||||
rewriter_config.memory_optimization = (
|
||||
rewriter_config_pb2.RewriterConfig.NO_MEM_OPT)
|
||||
rewriter_config.min_graph_nodes = -1
|
||||
rewriter_config.pin_to_host_optimization = off
|
||||
rewriter_config.remapping = off
|
||||
rewriter_config.scoped_allocator_optimization = off
|
||||
rewriter_config.shape_optimization = off
|
@ -442,7 +442,8 @@ class DataServiceOpsTest(data_service_test_base.TestBase,
|
||||
results = {}
|
||||
for _ in range(elements_to_read):
|
||||
val = next(it).numpy()
|
||||
if val not in results: results[val] = 0
|
||||
if val not in results:
|
||||
results[val] = 0
|
||||
results[val] += 1
|
||||
for i in range(num_elements):
|
||||
self.assertGreater(results[i], elements_to_read / num_elements / 2)
|
||||
@ -527,6 +528,37 @@ class DataServiceOpsTest(data_service_test_base.TestBase,
|
||||
data_service_ops.distribute(
|
||||
processing_mode="invalid", service="grpc://localhost:5000"))
|
||||
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testZipDifferentProcessingModesDatasets(self):
|
||||
cluster = self.create_cluster(num_workers=1)
|
||||
num_elements = 100
|
||||
ds1 = dataset_ops.Dataset.range(num_elements)
|
||||
ds1 = self.make_distributed_dataset(
|
||||
ds1, cluster, processing_mode="distributed_epoch")
|
||||
ds2 = dataset_ops.Dataset.range(num_elements)
|
||||
ds2 = self.make_distributed_dataset(
|
||||
ds2, cluster, processing_mode="parallel_epochs")
|
||||
ds = dataset_ops.Dataset.zip((ds1, ds2))
|
||||
self.assertDatasetProduces(
|
||||
ds,
|
||||
list(zip(range(num_elements), range(num_elements))),
|
||||
assert_items_equal=True)
|
||||
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testZipDifferentProcessingModesDatasetsSharedJobName(self):
|
||||
cluster = self.create_cluster(num_workers=1)
|
||||
num_elements = 100
|
||||
ds1 = dataset_ops.Dataset.range(num_elements)
|
||||
ds1 = self.make_distributed_dataset(
|
||||
ds1, cluster, processing_mode="distributed_epoch", job_name="job_name")
|
||||
ds2 = dataset_ops.Dataset.range(num_elements)
|
||||
ds2 = self.make_distributed_dataset(
|
||||
ds2, cluster, processing_mode="parallel_epochs", job_name="job_name")
|
||||
ds = dataset_ops.Dataset.zip((ds1, ds2))
|
||||
with self.assertRaisesRegex(errors.FailedPreconditionError,
|
||||
"but there is already an existing job"):
|
||||
self.getDatasetOutput(ds)
|
||||
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testFromDatasetId(self):
|
||||
cluster = self.create_cluster(num_workers=1)
|
||||
|
@ -88,7 +88,8 @@ class DispatchServer(object):
|
||||
|
||||
>>> dispatcher = tf.data.experimental.service.DispatchServer()
|
||||
>>> dispatcher_address = dispatcher.target.split("://")[1]
|
||||
>>> worker = tf.data.experimental.service.WorkerServer(WorkerConfig(
|
||||
>>> worker = tf.data.experimental.service.WorkerServer(
|
||||
... tf.data.experimental.service.WorkerConfig(
|
||||
... dispatcher_address=dispatcher_address))
|
||||
>>> dataset = tf.data.Dataset.range(10)
|
||||
>>> dataset = dataset.apply(tf.data.experimental.service.distribute(
|
||||
|
@ -169,7 +169,7 @@ class _WorkerContext(object):
|
||||
def _get_master_target(self):
|
||||
"""Return the master target for a task."""
|
||||
# If cluster_spec is None or empty, we use local master.
|
||||
if not self._cluster_spec:
|
||||
if not self._cluster_spec or self._task_type == _TaskType.EVALUATOR:
|
||||
return ""
|
||||
|
||||
# If task_type is None, then it is in-graph replicated training. In this
|
||||
@ -842,7 +842,8 @@ def run_distribute_coordinator(worker_fn,
|
||||
session_config, cluster_spec,
|
||||
task_type, task_id)
|
||||
|
||||
if not getattr(strategy.extended, "_std_server_started", False):
|
||||
if (task_type != _TaskType.EVALUATOR and
|
||||
not getattr(strategy.extended, "_std_server_started", False)):
|
||||
# Right now, with eager mode, context is configured with a std server at
|
||||
# the very beginning while with graph mode the std server is started when
|
||||
# distribute coordinator is called. We should consolidate these two paths.
|
||||
|
@ -589,8 +589,7 @@ class DistributeCoordinatorTestStandaloneMode(DistributeCoordinatorTestBase):
|
||||
# and distributed_mode.
|
||||
self.assertEqual(self._worker_context["None"][0], (_strip_protocol(
|
||||
_bytes_to_str(self._workers[0].target)), 3, True, True))
|
||||
self.assertEqual(self._worker_context[EVALUATOR][0],
|
||||
("fake_evaluator", 3, True, False))
|
||||
self.assertEqual(self._worker_context[EVALUATOR][0], ("", 3, True, False))
|
||||
|
||||
|
||||
class DistributeCoordinatorTestIndependentWorkerMode(
|
||||
@ -755,19 +754,15 @@ class DistributeCoordinatorTestIndependentWorkerMode(
|
||||
# and distributed_mode.
|
||||
self.assertEqual(self._worker_context["None"][0],
|
||||
(_bytes_to_str(cluster_spec[WORKER][0]), 3, True, True))
|
||||
self.assertEqual(self._worker_context[EVALUATOR][0],
|
||||
(cluster_spec[EVALUATOR][0], 3, True, False))
|
||||
self.assertEqual(self._worker_context[EVALUATOR][0], ("", 3, True, False))
|
||||
|
||||
# Make sure each worker runs a std server.
|
||||
self.assertEqual(len(self._std_servers), 2)
|
||||
self.assertEqual(len(self._std_servers), 1)
|
||||
self.assertTrue(WORKER in self._std_servers)
|
||||
self.assertTrue(EVALUATOR in self._std_servers)
|
||||
self.assertEqual(len(self._std_servers[WORKER]), 3)
|
||||
self.assertEqual(len(self._std_servers[EVALUATOR]), 1)
|
||||
self.assertFalse(self._std_servers[WORKER][0].joined)
|
||||
self.assertTrue(self._std_servers[WORKER][1].joined)
|
||||
self.assertTrue(self._std_servers[WORKER][2].joined)
|
||||
self.assertFalse(self._std_servers[EVALUATOR][0].joined)
|
||||
|
||||
def testRunStdServerInGoogleEnvironment(self):
|
||||
cluster_spec = {"worker": ["fake_worker"], "ps": ["localhost:0"]}
|
||||
|
@ -1007,6 +1007,7 @@ class GradientTape(object):
|
||||
Raises:
|
||||
RuntimeError: If called on a used, non-persistent tape.
|
||||
RuntimeError: If called inside the context of the tape.
|
||||
TypeError: If the target is a None object.
|
||||
ValueError: If the target is a variable or if unconnected gradients is
|
||||
called with an unknown value.
|
||||
"""
|
||||
@ -1028,6 +1029,11 @@ class GradientTape(object):
|
||||
"gradient in order to compute higher order "
|
||||
"derivatives.", 1)
|
||||
|
||||
if target is None:
|
||||
raise TypeError("Target should be a list or nested structure"
|
||||
" of Tensors or Variables to be differentiated,"
|
||||
" but recieved %r" % (target))
|
||||
|
||||
num_ndarrays = 0
|
||||
flat_targets = []
|
||||
for t in nest.flatten(target):
|
||||
|
@ -1000,38 +1000,6 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertAllEqual(obj2.testDouble.experimental_get_tracing_count(), 3)
|
||||
self.assertAllEqual(obj1.testDouble.experimental_get_tracing_count(), 2)
|
||||
|
||||
def test_experimental_get_tracing_count_function(self):
|
||||
|
||||
@def_function.function
|
||||
def double(a):
|
||||
return a + a
|
||||
|
||||
double(constant_op.constant(1))
|
||||
double(constant_op.constant(2))
|
||||
self.assertAllEqual(double.experimental_get_tracing_count(), 1)
|
||||
double(constant_op.constant('a'))
|
||||
self.assertAllEqual(double.experimental_get_tracing_count(), 2)
|
||||
|
||||
def test_experimental_get_tracing_count_method(self):
|
||||
|
||||
class TestClass():
|
||||
|
||||
@def_function.function
|
||||
def testDouble(self, a):
|
||||
return a + a
|
||||
|
||||
obj1 = TestClass()
|
||||
obj1.testDouble(constant_op.constant(1))
|
||||
obj1.testDouble(constant_op.constant(2))
|
||||
obj1.testDouble(constant_op.constant(1.1))
|
||||
self.assertAllEqual(obj1.testDouble.experimental_get_tracing_count(), 2)
|
||||
obj2 = TestClass()
|
||||
obj2.testDouble(constant_op.constant(1))
|
||||
obj2.testDouble(constant_op.constant(1.1))
|
||||
obj2.testDouble(constant_op.constant('a'))
|
||||
self.assertAllEqual(obj2.testDouble.experimental_get_tracing_count(), 3)
|
||||
self.assertAllEqual(obj1.testDouble.experimental_get_tracing_count(), 2)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
ops.enable_eager_execution()
|
||||
|
@ -2047,6 +2047,10 @@ class TensorFlowTestCase(googletest.TestCase):
|
||||
self._tempdir = None
|
||||
self._cached_session = None
|
||||
self._test_start_time = None
|
||||
# This flag provides the ability to control whether the graph mode gets
|
||||
# initialized for TF1 or not. Initializing for TF1, which is what was
|
||||
# happening earlier, was preventing enablement of 'eager mode' in the test.
|
||||
self._set_default_seed = True
|
||||
|
||||
def setUp(self):
|
||||
super(TensorFlowTestCase, self).setUp()
|
||||
@ -2061,6 +2065,7 @@ class TensorFlowTestCase(googletest.TestCase):
|
||||
# cleared first.
|
||||
ops._default_graph_stack.reset() # pylint: disable=protected-access
|
||||
ops.reset_default_graph()
|
||||
if self._set_default_seed:
|
||||
random_seed.set_random_seed(random_seed.DEFAULT_GRAPH_SEED)
|
||||
# Reset summary writer in case another test used set_as_default() with their
|
||||
# summary writer.
|
||||
|
@ -18,68 +18,73 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.compat import v2_compat
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.platform import _pywrap_tf2
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
def set_environ():
|
||||
os.environ['TF2_BEHAVIOR'] = '1'
|
||||
|
||||
|
||||
def unset_environ():
|
||||
os.environ['TF2_BEHAVIOR'] = '0'
|
||||
|
||||
|
||||
class EnablingTF2Behavior(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(EnablingTF2Behavior, self).setUp()
|
||||
tf2._force_enable = None
|
||||
if 'TF2_BEHAVIOR' in os.environ:
|
||||
del os.environ['TF2_BEHAVIOR']
|
||||
def __init__(self, methodName):
|
||||
super().__init__(methodName)
|
||||
self._set_default_seed = False
|
||||
|
||||
actions = [tf2.enable, tf2.disable, set_environ, unset_environ]
|
||||
@combinations.generate(test_base.v1_only_combinations())
|
||||
def test_tf1_enable_tf2_behaviour(self):
|
||||
self.assertFalse(tf2.enabled())
|
||||
self.assertFalse(_pywrap_tf2.is_enabled())
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
action_0=actions, action_1=actions,
|
||||
action_2=actions, action_3=actions))
|
||||
def test_scenarios(self, action_0, action_1, action_2, action_3):
|
||||
v2_compat.enable_v2_behavior()
|
||||
self.assertTrue(tf2.enabled())
|
||||
self.assertTrue(_pywrap_tf2.is_enabled())
|
||||
|
||||
def state(action, enabled, disabled):
|
||||
"""Returns bool tuple (tf2_enabled, force_enabled, force_disabled)."""
|
||||
if action is tf2.enable:
|
||||
return True, True, False
|
||||
elif action is tf2.disable:
|
||||
return False, False, True
|
||||
elif action is set_environ:
|
||||
return not disabled, enabled, disabled
|
||||
elif action is unset_environ:
|
||||
return enabled, enabled, disabled
|
||||
else:
|
||||
raise ValueError('Unexpected action {}. {} are supported'.format(
|
||||
action, EnablingTF2Behavior.actions))
|
||||
v2_compat.disable_v2_behavior()
|
||||
self.assertFalse(tf2.enabled())
|
||||
self.assertFalse(_pywrap_tf2.is_enabled())
|
||||
|
||||
action_0()
|
||||
expected, enabled, disabled = state(action_0, False, False)
|
||||
self.assertEqual(tf2.enabled(), expected)
|
||||
@combinations.generate(test_base.v1_only_combinations())
|
||||
def test_tf1_disable_tf2_behaviour(self):
|
||||
self.assertFalse(tf2.enabled())
|
||||
self.assertFalse(_pywrap_tf2.is_enabled())
|
||||
|
||||
action_1()
|
||||
expected, enabled, disabled = state(action_1, enabled, disabled)
|
||||
self.assertEqual(tf2.enabled(), expected)
|
||||
v2_compat.disable_v2_behavior()
|
||||
self.assertFalse(tf2.enabled())
|
||||
self.assertFalse(_pywrap_tf2.is_enabled())
|
||||
|
||||
action_2()
|
||||
expected, enabled, disabled = state(action_2, enabled, disabled)
|
||||
self.assertEqual(tf2.enabled(), expected)
|
||||
v2_compat.enable_v2_behavior()
|
||||
self.assertTrue(tf2.enabled())
|
||||
self.assertTrue(_pywrap_tf2.is_enabled())
|
||||
|
||||
action_3()
|
||||
expected, enabled, disabled = state(action_3, enabled, disabled)
|
||||
self.assertEqual(tf2.enabled(), expected)
|
||||
@combinations.generate(test_base.v2_only_combinations())
|
||||
def test_tf2_enable_tf2_behaviour(self):
|
||||
self.assertTrue(tf2.enabled())
|
||||
self.assertTrue(_pywrap_tf2.is_enabled())
|
||||
|
||||
v2_compat.enable_v2_behavior()
|
||||
self.assertTrue(tf2.enabled())
|
||||
self.assertTrue(_pywrap_tf2.is_enabled())
|
||||
|
||||
v2_compat.disable_v2_behavior()
|
||||
self.assertFalse(tf2.enabled())
|
||||
self.assertFalse(_pywrap_tf2.is_enabled())
|
||||
|
||||
@combinations.generate(test_base.v2_only_combinations())
|
||||
def test_tf2_disable_tf2_behaviour(self):
|
||||
self.assertTrue(tf2.enabled())
|
||||
self.assertTrue(_pywrap_tf2.is_enabled())
|
||||
|
||||
v2_compat.disable_v2_behavior()
|
||||
self.assertFalse(tf2.enabled())
|
||||
self.assertFalse(_pywrap_tf2.is_enabled())
|
||||
|
||||
v2_compat.enable_v2_behavior()
|
||||
self.assertTrue(tf2.enabled())
|
||||
self.assertTrue(_pywrap_tf2.is_enabled())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -41,6 +41,7 @@ def index_directory(directory,
|
||||
directory: The target directory (string).
|
||||
labels: Either "inferred"
|
||||
(labels are generated from the directory structure),
|
||||
None (no labels),
|
||||
or a list/tuple of integer labels of the same size as the number of
|
||||
valid files found in the directory. Labels should be sorted according
|
||||
to the alphanumeric order of the image file paths
|
||||
@ -61,19 +62,24 @@ def index_directory(directory,
|
||||
labels: list of matching integer labels (same length as file_paths)
|
||||
class_names: names of the classes corresponding to these labels, in order.
|
||||
"""
|
||||
inferred_class_names = []
|
||||
if labels is None:
|
||||
# in the no-label case, index from the parent directory down.
|
||||
subdirs = ['']
|
||||
class_names = subdirs
|
||||
else:
|
||||
subdirs = []
|
||||
for subdir in sorted(os.listdir(directory)):
|
||||
if os.path.isdir(os.path.join(directory, subdir)):
|
||||
inferred_class_names.append(subdir)
|
||||
subdirs.append(subdir)
|
||||
if not class_names:
|
||||
class_names = inferred_class_names
|
||||
class_names = subdirs
|
||||
else:
|
||||
if set(class_names) != set(inferred_class_names):
|
||||
if set(class_names) != set(subdirs):
|
||||
raise ValueError(
|
||||
'The `class_names` passed did not match the '
|
||||
'names of the subdirectories of the target directory. '
|
||||
'Expected: %s, but received: %s' %
|
||||
(inferred_class_names, class_names))
|
||||
(subdirs, class_names))
|
||||
class_indices = dict(zip(class_names, range(len(class_names))))
|
||||
|
||||
# Build an index of the files
|
||||
@ -81,7 +87,8 @@ def index_directory(directory,
|
||||
pool = multiprocessing.pool.ThreadPool()
|
||||
results = []
|
||||
filenames = []
|
||||
for dirpath in (os.path.join(directory, subdir) for subdir in class_names):
|
||||
|
||||
for dirpath in (os.path.join(directory, subdir) for subdir in subdirs):
|
||||
results.append(
|
||||
pool.apply_async(index_subdirectory,
|
||||
(dirpath, class_indices, follow_links, formats)))
|
||||
@ -90,7 +97,7 @@ def index_directory(directory,
|
||||
partial_filenames, partial_labels = res.get()
|
||||
labels_list.append(partial_labels)
|
||||
filenames += partial_filenames
|
||||
if labels != 'inferred':
|
||||
if labels not in ('inferred', None):
|
||||
if len(labels) != len(filenames):
|
||||
raise ValueError('Expected the lengths of `labels` to match the number '
|
||||
'of files in the target directory. len(labels) is %s '
|
||||
@ -103,6 +110,9 @@ def index_directory(directory,
|
||||
labels[i:i + len(partial_labels)] = partial_labels
|
||||
i += len(partial_labels)
|
||||
|
||||
if labels is None:
|
||||
print('Found %d files.' % (len(filenames),))
|
||||
else:
|
||||
print('Found %d files belonging to %d classes.' %
|
||||
(len(filenames), len(class_names)))
|
||||
pool.close()
|
||||
|
@ -74,6 +74,7 @@ def image_dataset_from_directory(directory,
|
||||
Otherwise, the directory structure is ignored.
|
||||
labels: Either "inferred"
|
||||
(labels are generated from the directory structure),
|
||||
None (no labels),
|
||||
or a list/tuple of integer labels of the same size as the number of
|
||||
image files found in the directory. Labels should be sorted according
|
||||
to the alphanumeric order of the image file paths
|
||||
@ -139,7 +140,7 @@ def image_dataset_from_directory(directory,
|
||||
- if `color_mode` is `rgba`,
|
||||
there are 4 channel in the image tensors.
|
||||
"""
|
||||
if labels != 'inferred':
|
||||
if labels not in ('inferred', None):
|
||||
if not isinstance(labels, (list, tuple)):
|
||||
raise ValueError(
|
||||
'`labels` argument should be a list/tuple of integer labels, of '
|
||||
@ -156,6 +157,9 @@ def image_dataset_from_directory(directory,
|
||||
raise ValueError(
|
||||
'`label_mode` argument must be one of "int", "categorical", "binary", '
|
||||
'or None. Received: %s' % (label_mode,))
|
||||
if labels is None or label_mode is None:
|
||||
labels = None
|
||||
label_mode = None
|
||||
if color_mode == 'rgb':
|
||||
num_channels = 3
|
||||
elif color_mode == 'rgba':
|
||||
@ -188,6 +192,8 @@ def image_dataset_from_directory(directory,
|
||||
|
||||
image_paths, labels = dataset_utils.get_training_or_validation_split(
|
||||
image_paths, labels, validation_split, subset)
|
||||
if not image_paths:
|
||||
raise ValueError('No images found.')
|
||||
|
||||
dataset = paths_and_labels_to_dataset(
|
||||
image_paths=image_paths,
|
||||
|
@ -82,7 +82,7 @@ class ImageDatasetFromDirectoryTest(keras_parameterized.TestCase):
|
||||
# Save images to the paths
|
||||
i = 0
|
||||
for img in self._get_images(color_mode=color_mode, count=count):
|
||||
path = paths[count % len(paths)]
|
||||
path = paths[i % len(paths)]
|
||||
if color_mode == 'rgb':
|
||||
ext = 'jpg'
|
||||
else:
|
||||
@ -92,6 +92,32 @@ class ImageDatasetFromDirectoryTest(keras_parameterized.TestCase):
|
||||
i += 1
|
||||
return temp_dir
|
||||
|
||||
def test_image_dataset_from_directory_standalone(self):
|
||||
# Test retrieving images without labels from a directory and its subdirs.
|
||||
if PIL is None:
|
||||
return # Skip test if PIL is not available.
|
||||
|
||||
# Save a few extra images in the parent directory.
|
||||
directory = self._prepare_directory(count=7, num_classes=2)
|
||||
for i, img in enumerate(self._get_images(3)):
|
||||
filename = 'image_%s.jpg' % (i,)
|
||||
img.save(os.path.join(directory, filename))
|
||||
|
||||
dataset = image_dataset.image_dataset_from_directory(
|
||||
directory, batch_size=5, image_size=(18, 18), labels=None)
|
||||
batch = next(iter(dataset))
|
||||
# We return plain images
|
||||
self.assertEqual(batch.shape, (5, 18, 18, 3))
|
||||
self.assertEqual(batch.dtype.name, 'float32')
|
||||
# Count samples
|
||||
batch_count = 0
|
||||
sample_count = 0
|
||||
for batch in dataset:
|
||||
batch_count += 1
|
||||
sample_count += batch.shape[0]
|
||||
self.assertEqual(batch_count, 2)
|
||||
self.assertEqual(sample_count, 10)
|
||||
|
||||
def test_image_dataset_from_directory_binary(self):
|
||||
if PIL is None:
|
||||
return # Skip test if PIL is not available.
|
||||
@ -253,6 +279,11 @@ class ImageDatasetFromDirectoryTest(keras_parameterized.TestCase):
|
||||
sample_count += batch.shape[0]
|
||||
self.assertEqual(sample_count, 25)
|
||||
|
||||
def test_image_dataset_from_directory_no_images(self):
|
||||
directory = self._prepare_directory(num_classes=2, count=0)
|
||||
with self.assertRaisesRegex(ValueError, 'No images found.'):
|
||||
_ = image_dataset.image_dataset_from_directory(directory)
|
||||
|
||||
def test_image_dataset_from_directory_errors(self):
|
||||
if PIL is None:
|
||||
return # Skip test if PIL is not available.
|
||||
@ -261,7 +292,7 @@ class ImageDatasetFromDirectoryTest(keras_parameterized.TestCase):
|
||||
|
||||
with self.assertRaisesRegex(ValueError, '`labels` argument should be'):
|
||||
_ = image_dataset.image_dataset_from_directory(
|
||||
directory, labels=None)
|
||||
directory, labels='other')
|
||||
|
||||
with self.assertRaisesRegex(ValueError, '`label_mode` argument must be'):
|
||||
_ = image_dataset.image_dataset_from_directory(
|
||||
|
@ -66,6 +66,7 @@ def text_dataset_from_directory(directory,
|
||||
Otherwise, the directory structure is ignored.
|
||||
labels: Either "inferred"
|
||||
(labels are generated from the directory structure),
|
||||
None (no labels),
|
||||
or a list/tuple of integer labels of the same size as the number of
|
||||
text files found in the directory. Labels should be sorted according
|
||||
to the alphanumeric order of the text file paths
|
||||
@ -114,7 +115,7 @@ def text_dataset_from_directory(directory,
|
||||
of shape `(batch_size, num_classes)`, representing a one-hot
|
||||
encoding of the class index.
|
||||
"""
|
||||
if labels != 'inferred':
|
||||
if labels not in ('inferred', None):
|
||||
if not isinstance(labels, (list, tuple)):
|
||||
raise ValueError(
|
||||
'`labels` argument should be a list/tuple of integer labels, of '
|
||||
@ -131,6 +132,9 @@ def text_dataset_from_directory(directory,
|
||||
raise ValueError(
|
||||
'`label_mode` argument must be one of "int", "categorical", "binary", '
|
||||
'or None. Received: %s' % (label_mode,))
|
||||
if labels is None or label_mode is None:
|
||||
labels = None
|
||||
label_mode = None
|
||||
dataset_utils.check_validation_split_arg(
|
||||
validation_split, subset, shuffle, seed)
|
||||
|
||||
@ -152,6 +156,8 @@ def text_dataset_from_directory(directory,
|
||||
|
||||
file_paths, labels = dataset_utils.get_training_or_validation_split(
|
||||
file_paths, labels, validation_split, subset)
|
||||
if not file_paths:
|
||||
raise ValueError('No text files found.')
|
||||
|
||||
dataset = paths_and_labels_to_dataset(
|
||||
file_paths=file_paths,
|
||||
|
@ -58,7 +58,7 @@ class TextDatasetFromDirectoryTest(keras_parameterized.TestCase):
|
||||
paths += class_paths
|
||||
|
||||
for i in range(count):
|
||||
path = paths[count % len(paths)]
|
||||
path = paths[i % len(paths)]
|
||||
filename = os.path.join(path, 'text_%s.txt' % (i,))
|
||||
f = open(os.path.join(temp_dir, filename), 'w')
|
||||
text = ''.join([random.choice(string.printable) for _ in range(length)])
|
||||
@ -66,6 +66,32 @@ class TextDatasetFromDirectoryTest(keras_parameterized.TestCase):
|
||||
f.close()
|
||||
return temp_dir
|
||||
|
||||
def test_text_dataset_from_directory_standalone(self):
|
||||
# Test retrieving txt files without labels from a directory and its subdirs.
|
||||
# Save a few extra files in the parent directory.
|
||||
directory = self._prepare_directory(count=7, num_classes=2)
|
||||
for i in range(3):
|
||||
filename = 'text_%s.txt' % (i,)
|
||||
f = open(os.path.join(directory, filename), 'w')
|
||||
text = ''.join([random.choice(string.printable) for _ in range(20)])
|
||||
f.write(text)
|
||||
f.close()
|
||||
|
||||
dataset = text_dataset.text_dataset_from_directory(
|
||||
directory, batch_size=5, label_mode=None, max_length=10)
|
||||
batch = next(iter(dataset))
|
||||
# We just return the texts, no labels
|
||||
self.assertEqual(batch.shape, (5,))
|
||||
self.assertEqual(batch.dtype.name, 'string')
|
||||
# Count samples
|
||||
batch_count = 0
|
||||
sample_count = 0
|
||||
for batch in dataset:
|
||||
batch_count += 1
|
||||
sample_count += batch.shape[0]
|
||||
self.assertEqual(batch_count, 2)
|
||||
self.assertEqual(sample_count, 10)
|
||||
|
||||
def test_text_dataset_from_directory_binary(self):
|
||||
directory = self._prepare_directory(num_classes=2)
|
||||
dataset = text_dataset.text_dataset_from_directory(
|
||||
@ -172,12 +198,17 @@ class TextDatasetFromDirectoryTest(keras_parameterized.TestCase):
|
||||
sample_count += batch.shape[0]
|
||||
self.assertEqual(sample_count, 25)
|
||||
|
||||
def test_text_dataset_from_directory_no_files(self):
|
||||
directory = self._prepare_directory(num_classes=2, count=0)
|
||||
with self.assertRaisesRegex(ValueError, 'No text files found.'):
|
||||
_ = text_dataset.text_dataset_from_directory(directory)
|
||||
|
||||
def test_text_dataset_from_directory_errors(self):
|
||||
directory = self._prepare_directory(num_classes=3, count=5)
|
||||
|
||||
with self.assertRaisesRegex(ValueError, '`labels` argument should be'):
|
||||
_ = text_dataset.text_dataset_from_directory(
|
||||
directory, labels=None)
|
||||
directory, labels='other')
|
||||
|
||||
with self.assertRaisesRegex(ValueError, '`label_mode` argument must be'):
|
||||
_ = text_dataset.text_dataset_from_directory(
|
||||
|
@ -3246,10 +3246,12 @@ cuda_py_test(
|
||||
srcs = ["extract_volume_patches_grad_test.py"],
|
||||
shard_count = 50,
|
||||
tags = [
|
||||
"no_gpu", # b/171837334
|
||||
"no_oss", # Test times out on oss-nightly cpu builds
|
||||
"no_pip",
|
||||
"nogpu", # http://b/171837334
|
||||
"nomac", # http://b/139946976
|
||||
"notap", # http://b/31080670
|
||||
"nogpu", # b/171837334
|
||||
"nomac", # b/139946976
|
||||
"notap", # b/31080670
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/python:array_ops",
|
||||
|
@ -1557,6 +1557,14 @@ class AssertTypeTest(test.TestCase):
|
||||
with self.assertRaisesRegexp(TypeError, "must be of type.*float32"):
|
||||
check_ops.assert_type(sparse_float16, dtypes.float32)
|
||||
|
||||
def test_raise_when_tf_type_is_not_dtype(self):
|
||||
# Test case for GitHub issue:
|
||||
# https://github.com/tensorflow/tensorflow/issues/45975
|
||||
value = constant_op.constant(0.0)
|
||||
with self.assertRaisesRegexp(TypeError,
|
||||
"Cannot convert.*to a TensorFlow DType"):
|
||||
check_ops.assert_type(value, (dtypes.float32,))
|
||||
|
||||
|
||||
class AssertShapesTest(test.TestCase):
|
||||
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user