Merge branch 'master'

Merge branch 'master' of https://github.com/tensorflow/tensorflow
into feature-micro-add-op-depth-to-space-pr1
This commit is contained in:
Ryan Kuester 2021-01-05 15:08:34 -06:00
commit c20ac67cb1
127 changed files with 2354 additions and 1370 deletions

View File

@ -114,6 +114,143 @@ This release contains contributions from many people at Google, as well as:
<INSERT>, <NAME>, <HERE>, <USING>, <GITHUB>, <HANDLE> <INSERT>, <NAME>, <HERE>, <USING>, <GITHUB>, <HANDLE>
# Release 2.3.2
## Bug Fixes and Other Changes
* Fixes an access to unitialized memory in Eigen code
([CVE-2020-26266](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26266))
* Fixes a security vulnerability caused by lack of validation in
`tf.raw_ops.DataFormatVecPermute` and `tf.raw_ops.DataFormatDimMap`
([CVE-2020-26267](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26267))
* Fixes a vulnerability caused by attempting to write to immutable memory region in
`tf.raw_ops.ImmutableConst`
([CVE-2020-26268](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26268)
* Fixes a `CHECK`-fail in LSTM with zero-length input
([CVE-2020-26270](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26270))
* Fixes a security vulnerability caused by accessing heap data outside of bounds
when loading a specially crafted `SavedModel`
([CVE-2020-26271](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26271))
* Solves an OOM issue on TPUs when XLA contexts use fused average updates
* Updates `libjpeg-turbo` to `2.0.5` to handle
[CVE-2020-13790](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-13790).
* Updates `junit` to `4.13.1` to handle
[CVE-2020-15250](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15250).
* Updates `PCRE` to `8.44` to handle
[CVE-2019-20838](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-20838)
and
[CVE-2020-14155](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-14155).
* Updates `sqlite3` to `3.44.0` to keep in sync with master branch.
# Release 2.2.2
## Bug Fixes and Other Changes
* Fixes an access to unitialized memory in Eigen code
([CVE-2020-26266](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26266))
* Fixes a security vulnerability caused by lack of validation in
`tf.raw_ops.DataFormatVecPermute` and `tf.raw_ops.DataFormatDimMap`
([CVE-2020-26267](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26267))
* Fixes a vulnerability caused by attempting to write to immutable memory region in
`tf.raw_ops.ImmutableConst`
([CVE-2020-26268](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26268)
* Fixes a `CHECK`-fail in LSTM with zero-length input
([CVE-2020-26270](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26270))
* Fixes a security vulnerability caused by accessing heap data outside of bounds
when loading a specially crafted `SavedModel`
([CVE-2020-26271](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26271))
* Prevents memory leaks in loading `SavedModel`s that import functions
* Updates `libjpeg-turbo` to `2.0.5` to handle
[CVE-2020-13790](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-13790).
* Updates `junit` to `4.13.1` to handle
[CVE-2020-15250](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15250).
* Updates `PCRE` to `8.44` to handle
[CVE-2019-20838](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-20838)
and
[CVE-2020-14155](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-14155).
* Updates `sqlite3` to `3.44.0` to keep in sync with master branch.
# Release 2.1.3
## Bug Fixes and Other Changes
* Fixes an access to unitialized memory in Eigen code
([CVE-2020-26266](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26266))
* Fixes a security vulnerability caused by lack of validation in
`tf.raw_ops.DataFormatVecPermute` and `tf.raw_ops.DataFormatDimMap`
([CVE-2020-26267](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26267))
* Fixes a vulnerability caused by attempting to write to immutable memory region in
`tf.raw_ops.ImmutableConst`
([CVE-2020-26268](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26268)
* Fixes a `CHECK`-fail in LSTM with zero-length input
([CVE-2020-26270](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26270))
* Fixes a security vulnerability caused by accessing heap data outside of bounds
when loading a specially crafted `SavedModel`
([CVE-2020-26271](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26271))
* Updates `libjpeg-turbo` to `2.0.5` to handle
[CVE-2020-13790](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-13790).
* Updates `junit` to `4.13.1` to handle
[CVE-2020-15250](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15250).
* Updates `PCRE` to `8.44` to handle
[CVE-2019-20838](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-20838)
and
[CVE-2020-14155](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-14155).
* Updates `sqlite3` to `3.44.0` to keep in sync with master branch.
* Newer ROCm versions are supported on the 2.1 branch.
# Release 2.0.4
Note that this is the last patch release for the TensorFlow 2.0.x series.
## Bug Fixes and Other Changes
* Fixes an access to unitialized memory in Eigen code
([CVE-2020-26266](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26266))
* Fixes a security vulnerability caused by lack of validation in
`tf.raw_ops.DataFormatVecPermute` and `tf.raw_ops.DataFormatDimMap`
([CVE-2020-26267](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26267))
* Fixes a vulnerability caused by attempting to write to immutable memory region in
`tf.raw_ops.ImmutableConst`
([CVE-2020-26268](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26268)
* Fixes a `CHECK`-fail in LSTM with zero-length input
([CVE-2020-26270](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26270))
* Fixes a security vulnerability caused by accessing heap data outside of bounds
when loading a specially crafted `SavedModel`
([CVE-2020-26271](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26271))
* Updates `libjpeg-turbo` to `2.0.5` to handle
[CVE-2020-13790](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-13790).
* Updates `junit` to `4.13.1` to handle
[CVE-2020-15250](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15250).
* Updates `PCRE` to `8.44` to handle
[CVE-2019-20838](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-20838)
and
[CVE-2020-14155](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-14155).
* Updates `sqlite3` to `3.44.0` to keep in sync with master branch.
# Release 1.15.5
Note that this is the last patch release for the TensorFlow 1.x series.
## Bug Fixes and Other Changes
* Fixes an access to unitialized memory in Eigen code
([CVE-2020-26266](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26266))
* Fixes a security vulnerability caused by lack of validation in
`tf.raw_ops.DataFormatVecPermute` and `tf.raw_ops.DataFormatDimMap`
([CVE-2020-26267](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26267))
* Fixes a vulnerability caused by attempting to write to immutable memory region in
`tf.raw_ops.ImmutableConst`
([CVE-2020-26268](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26268)
* Fixes a `CHECK`-fail in LSTM with zero-length input
([CVE-2020-26270](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26270))
* Fixes a security vulnerability caused by accessing heap data outside of bounds
when loading a specially crafted `SavedModel`
([CVE-2020-26271](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26271))
* Updates `libjpeg-turbo` to `2.0.5` to handle
[CVE-2020-13790](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-13790).
* Updates `junit` to `4.13.1` to handle
[CVE-2020-15250](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15250).
* Updates `PCRE` to `8.44` to handle
[CVE-2019-20838](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-20838)
and
[CVE-2020-14155](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-14155).
* Updates `sqlite3` to `3.44.0` to keep in sync with master branch.
# Release 2.4.0 # Release 2.4.0
## Major Features and Improvements ## Major Features and Improvements
@ -163,7 +300,7 @@ This release contains contributions from many people at Google, as well as:
## Breaking Changes ## Breaking Changes
* TF Core: * TF Core:
* Certain float32 ops run in lower precsion on Ampere based GPUs, including * Certain float32 ops run in lower precision on Ampere based GPUs, including
matmuls and convolutions, due to the use of [TensorFloat-32] matmuls and convolutions, due to the use of [TensorFloat-32]
(https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/). (https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/).
Specifically, inputs to such ops are rounded from 23 bits of precision to 10 Specifically, inputs to such ops are rounded from 23 bits of precision to 10

View File

@ -1282,7 +1282,8 @@ class DynamicReshapeOpNotActuallyDynamic
void DynamicReshapeOp::getCanonicalizationPatterns( void DynamicReshapeOp::getCanonicalizationPatterns(
OwningRewritePatternList& results, MLIRContext* context) { OwningRewritePatternList& results, MLIRContext* context) {
results.insert<DynamicReshapeOpNotActuallyDynamic, results.insert<DynamicReshapeOpNotActuallyDynamic,
RemoveRedundantDynamicReshape, ShapeOfDynamicReshape>(context); RemoveRedundantDynamicBroadcast, RemoveRedundantDynamicReshape,
ShapeOfDynamicReshape>(context);
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -33,3 +33,10 @@ def UnaryEinsumToEinsum : Pat<
def RemoveRedundantDynamicReshape : Pat< def RemoveRedundantDynamicReshape : Pat<
(HLO_DynamicReshapeOp (HLO_DynamicReshapeOp $operand, $shape1), $shape2), (HLO_DynamicReshapeOp (HLO_DynamicReshapeOp $operand, $shape1), $shape2),
(HLO_DynamicReshapeOp $operand, $shape2)>; (HLO_DynamicReshapeOp $operand, $shape2)>;
// A dynamic broadcast of a dynamic reshape with the same shape operand
// is a dynamic reshape.
def RemoveRedundantDynamicBroadcast : Pat<
(HLO_DynamicBroadcastInDimOp
(HLO_DynamicReshapeOp $operand, $shape), $shape, $dims),
(HLO_DynamicReshapeOp $operand, $shape)>;

View File

@ -1540,3 +1540,14 @@ func @identity_broadcast_in_dim_reshape(%arg0: tensor<128xf32>) -> tensor<128xf3
return %1 : tensor<128xf32> return %1 : tensor<128xf32>
// CHECK: return %arg0 : tensor<128xf32> // CHECK: return %arg0 : tensor<128xf32>
} }
// CHECK-LABEL: @broadcast_of_reshape
func @broadcast_of_reshape(%arg: tensor<?xf32>, %shape: tensor<2xindex>) -> tensor<?x?xf32> {
%0 = "mhlo.dynamic_reshape"(%arg, %shape) : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
%1 = "mhlo.dynamic_broadcast_in_dim"(%0, %shape) {
broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>
} : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}
// CHECK: [[RESHAPE:%.*]] = "mhlo.dynamic_reshape"
// CHECK: return [[RESHAPE]]

View File

@ -133,7 +133,7 @@ class TFL_OperandsHaveSameShapesOrBroadcastableShape<
TFL_RuntimePredOpTrait<"operands do not have the same shape or " TFL_RuntimePredOpTrait<"operands do not have the same shape or "
"broadcastable shapes within the rank " # max_bcast_rank, "broadcastable shapes within the rank " # max_bcast_rank,
CPred<"TFL::VerifyOperandsHaveSameShapesOrBroadcastableShape(" CPred<"TFL::VerifyOperandsHaveSameShapesOrBroadcastableShape("
"$_op, llvm::ArrayRef<unsigned>({" # StrJoinInt<indices>.result # "$_op, llvm::ArrayRef<unsigned>({" # !interleave(indices, ", ") #
"}), " # max_bcast_rank # ")">>; "}), " # max_bcast_rank # ")">>;
// These additional types/type constraints here are used to decouple the ops // These additional types/type constraints here are used to decouple the ops
@ -3453,10 +3453,10 @@ def TFL_CastOp : TFL_Op<"cast", [
}]; }];
let arguments = (ins let arguments = (ins
TFL_TensorOf<[F32, I1, I32, I64, TFL_Quint8, UI8, Complex<F<32>>]>:$input TFL_TensorOf<[F32, I1, I16, I32, I64, TFL_Quint8, UI8, Complex<F<32>>]>:$input
); );
let results = (outs TFL_TensorOf<[F32, I1, I32, I64, TFL_Quint8, UI8, Complex<F<32>>]>:$output); let results = (outs TFL_TensorOf<[F32, I1, I16, I32, I64, TFL_Quint8, UI8, Complex<F<32>>]>:$output);
// TFLite's cast op does not utilize CastOptions, instead derives types // TFLite's cast op does not utilize CastOptions, instead derives types
// from the TfLiteTensors. // from the TfLiteTensors.

View File

@ -34,7 +34,7 @@ class QuantizedType<string n, list<int> params, bit signed>
"Q" # !if (signed, "I", "UI") # !head(params) # " type"> { "Q" # !if (signed, "I", "UI") # !head(params) # " type"> {
string name = n; string name = n;
string asTraitArgsStr = string asTraitArgsStr =
StrJoinInt<params>.result # !if(signed, ", true", ", false"); !interleave(params, ", ") # !if(signed, ", true", ", false");
} }
// Uniform quantized types. Two integers "smantissa" and "sexp" are used to // Uniform quantized types. Two integers "smantissa" and "sexp" are used to
@ -134,7 +134,7 @@ class FixedResultScale<QuantizedType qt> : NativeOpTrait<!strconcat(
// needs a scale based on the scales of op1 and op2. // needs a scale based on the scales of op1 and op2.
class AccumulatorUniformScale<int bias, int op1, int op2> : NativeOpTrait< class AccumulatorUniformScale<int bias, int op1, int op2> : NativeOpTrait<
!strconcat("quant::AccumulatorUniformScale<", !strconcat("quant::AccumulatorUniformScale<",
StrJoinInt<[bias, op1, op2]>.result, !interleave([bias, op1, op2], ", "),
">::Impl")>; ">::Impl")>;
// Specify the operand index of the coefficient operand for an affine op // Specify the operand index of the coefficient operand for an affine op
@ -142,7 +142,7 @@ class AccumulatorUniformScale<int bias, int op1, int op2> : NativeOpTrait<
// If the quantization dimension is -1, per-axis quantization isn't supported. // If the quantization dimension is -1, per-axis quantization isn't supported.
class AffineOpCoefficient<int dim, int index> : NativeOpTrait< class AffineOpCoefficient<int dim, int index> : NativeOpTrait<
!strconcat("quant::AffineOpCoefficient<", !strconcat("quant::AffineOpCoefficient<",
StrJoinInt<[dim, index]>.result, !interleave([dim, index], ", "),
">::Impl")>; ">::Impl")>;
// Specify this trait if the op doesn't have quantizable output. We shouldn't // Specify this trait if the op doesn't have quantizable output. We shouldn't

View File

@ -1320,6 +1320,22 @@ func @cast(%arg0: tensor<1x2x2x5xi32>) -> tensor<1x2x2x5xf32> {
// CHECK: "tfl.cast"(%arg0) : (tensor<1x2x2x5xi32>) -> tensor<1x2x2x5xf32> // CHECK: "tfl.cast"(%arg0) : (tensor<1x2x2x5xi32>) -> tensor<1x2x2x5xf32>
} }
func @castFloat32ToI16(%arg0: tensor<1x2x2x5xf32>) -> tensor<1x2x2x5xi16> {
%0 = "tf.Cast"(%arg0) : (tensor<1x2x2x5xf32>) -> tensor<1x2x2x5xi16>
return %0 : tensor<1x2x2x5xi16>
// CHECK-LABEL: castFloat32ToI16
// CHECK: "tfl.cast"(%arg0) : (tensor<1x2x2x5xf32>) -> tensor<1x2x2x5xi16>
}
func @castI16ToFloat32(%arg0: tensor<1x2x2x5xi16>) -> tensor<1x2x2x5xf32> {
%0 = "tf.Cast"(%arg0) : (tensor<1x2x2x5xi16>) -> tensor<1x2x2x5xf32>
return %0 : tensor<1x2x2x5xf32>
// CHECK-LABEL: castI16ToFloat32
// CHECK: "tfl.cast"(%arg0) : (tensor<1x2x2x5xi16>) -> tensor<1x2x2x5xf32>
}
func @castComplex(%arg0: tensor<1x2x2x5xf32>) -> tensor<1x2x2x5xcomplex<f32>> { func @castComplex(%arg0: tensor<1x2x2x5xf32>) -> tensor<1x2x2x5xcomplex<f32>> {
%0 = "tf.Cast"(%arg0) : (tensor<1x2x2x5xf32>) -> tensor<1x2x2x5xcomplex<f32>> %0 = "tf.Cast"(%arg0) : (tensor<1x2x2x5xf32>) -> tensor<1x2x2x5xcomplex<f32>>
return %0 : tensor<1x2x2x5xcomplex<f32>> return %0 : tensor<1x2x2x5xcomplex<f32>>

View File

@ -98,11 +98,13 @@ class TF_OpIsBroadcastableToRes<int opId, int resId> : And<[
class TF_AllTypesMatchPred<list<string> values> : class TF_AllTypesMatchPred<list<string> values> :
CPred<"TF::AreCastCompatible(llvm::makeArrayRef({"# StrJoin<values>.result #"}))">; CPred<"TF::AreCastCompatible(llvm::makeArrayRef({" #
!interleave(values, ", ") # "}))">;
class TF_AllTypesMatch<list<string> names> : class TF_AllTypesMatch<list<string> names> :
PredOpTrait< PredOpTrait<
"all of {" # StrJoin<names>.result # "} have dynamically equal types ", "all of {" # !interleave(names, ", ") #
"} have dynamically equal types ",
TF_AllTypesMatchPred< TF_AllTypesMatchPred<
!foreach(n, names, !subst("$_self", "$" # n, "$_self.getType()"))>>; !foreach(n, names, !subst("$_self", "$" # n, "$_self.getType()"))>>;

View File

@ -147,6 +147,7 @@ cc_library(
"//tensorflow/compiler/mlir/hlo:lhlo", "//tensorflow/compiler/mlir/hlo:lhlo",
"//tensorflow/compiler/mlir/hlo:lhlo_gpu", "//tensorflow/compiler/mlir/hlo:lhlo_gpu",
"//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:debug_options_flags",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/service:backend", "//tensorflow/compiler/xla/service:backend",

View File

@ -384,7 +384,7 @@ ENTRY main {
HloModule BatchNormForwardInference HloModule BatchNormForwardInference
// CHECK: func @main // CHECK: func @main
// CHECK: lmhlo_gpu.batch_norm_inference" // CHECK: "lmhlo_gpu.batch_norm_inference"
// CHECK-SAME: epsilon = 1.000000e-03 : f32 // CHECK-SAME: epsilon = 1.000000e-03 : f32
// CHECK-SAME: feature_index = 0 : i64 // CHECK-SAME: feature_index = 0 : i64
// CHECK-SAME: (memref<2x2x2x2xf32>, memref<2xf32>, memref<2xf32>, memref<2xf32>, memref<2xf32>, memref<2x2x2x2xf32>) -> () // CHECK-SAME: (memref<2x2x2x2xf32>, memref<2xf32>, memref<2xf32>, memref<2xf32>, memref<2xf32>, memref<2x2x2x2xf32>) -> ()
@ -399,4 +399,16 @@ ENTRY main {
ROOT %custom-call = f32[2,2,2,2]{3,2,1,0} ROOT %custom-call = f32[2,2,2,2]{3,2,1,0}
custom-call(f32[2,2,2,2]{3,2,1,0} %input, f32[2]{0} %offset, f32[2]{0} %scale, f32[2]{0} %mean, f32[2]{0} %variance, f32[] %constant, s64[] %constant_1), custom-call(f32[2,2,2,2]{3,2,1,0} %input, f32[2]{0} %offset, f32[2]{0} %scale, f32[2]{0} %mean, f32[2]{0} %variance, f32[] %constant, s64[] %constant_1),
custom_call_target="__cudnn$batchNormalizationForwardInference" custom_call_target="__cudnn$batchNormalizationForwardInference"
} }
// -----
HloModule Infeed
// CHECK: func @main
// CHECK: "lmhlo.infeed"
// CHECK-SAME: (memref<3xf32>) -> ()
ENTRY main {
%tok = token[] parameter(0)
ROOT %infeed = (f32[3]{0}, token[]) infeed(token[] %tok)
}

View File

@ -27,6 +27,7 @@ limitations under the License.
#include "mlir/IR/AffineMap.h" // from @llvm-project #include "mlir/IR/AffineMap.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/IR/Dialect.h" // from @llvm-project #include "mlir/IR/Dialect.h" // from @llvm-project
@ -60,6 +61,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/util.h"
@ -67,6 +69,7 @@ using xla::BufferAllocation;
using xla::BufferAssignment; using xla::BufferAssignment;
using xla::HloComputation; using xla::HloComputation;
using xla::HloCustomCallInstruction; using xla::HloCustomCallInstruction;
using xla::HloInfeedInstruction;
using xla::HloInstruction; using xla::HloInstruction;
using xla::HloModule; using xla::HloModule;
using xla::HloModuleProto; using xla::HloModuleProto;
@ -199,14 +202,16 @@ class XlaHloToLhloPass
} // namespace } // namespace
// Creates MLIR operands corresponding to operands and results of the XLA HLO // Creates MLIR operands corresponding to operands and results of the XLA HLO
// instruction. If `num_operands` is not -1, then only the first `num_operands` // instruction. If `num_operands` is valid, then only the first `num_operands`
// operands of the HLO instruction will be considered. // operands of the HLO instruction will be considered.
Status LhloDialectEmitter::CreateOperands( Status LhloDialectEmitter::CreateOperands(
HloInstruction* instr, llvm::SmallVectorImpl<Value>& operands, HloInstruction* instr, absl::optional<xla::int64> num_operands,
size_t& num_arguments, size_t& num_results, llvm::SmallVectorImpl<Value>& operands, size_t& num_arguments,
absl::optional<xla::int64> num_operands) { size_t& num_results) {
if (num_operands.value_or(0) > instr->operand_count())
return xla::InvalidArgument("num_operands must be <= operand count");
for (xla::int64 i = 0; i < num_operands.value_or(instr->operand_count()); for (xla::int64 i = 0; i < num_operands.value_or(instr->operand_count());
i++) { ++i) {
TF_RETURN_IF_ERROR(GetOrCreateView(instr->operand(i), &operands)); TF_RETURN_IF_ERROR(GetOrCreateView(instr->operand(i), &operands));
} }
num_arguments = operands.size(); num_arguments = operands.size();
@ -215,19 +220,23 @@ Status LhloDialectEmitter::CreateOperands(
return Status::OK(); return Status::OK();
} }
template <typename OpType>
OpType LhloDialectEmitter::CreateOpWithoutAttrs(HloInstruction* instr,
ValueRange operands) {
Location loc = getLocation(instr);
NamedAttribute attrs[] = {{Identifier::get("name", builder_.getContext()),
builder_.getStringAttr(instr->name())}};
return builder_.create<OpType>(loc, llvm::None, operands, attrs);
}
template <typename OpType> template <typename OpType>
StatusOr<OpType> LhloDialectEmitter::CreateOpWithoutAttrs( StatusOr<OpType> LhloDialectEmitter::CreateOpWithoutAttrs(
HloInstruction* instr, size_t& num_arguments, size_t& num_results, HloInstruction* instr, size_t& num_arguments, size_t& num_results,
absl::optional<xla::int64> num_operands) { absl::optional<xla::int64> num_operands) {
Location loc = getLocation(instr);
std::pair<Identifier, Attribute> attrs[] = {
{Identifier::get("name", builder_.getContext()),
builder_.getStringAttr(instr->name())},
};
llvm::SmallVector<Value, 4> operands; llvm::SmallVector<Value, 4> operands;
TF_RETURN_IF_ERROR(CreateOperands(instr, operands, num_arguments, num_results, TF_RETURN_IF_ERROR(CreateOperands(instr, num_operands, operands,
num_operands)); num_arguments, num_results));
return builder_.create<OpType>(loc, llvm::None, operands, attrs); return CreateOpWithoutAttrs<OpType>(instr, operands);
} }
StatusOr<mlir::Operation*> LhloDialectEmitter::EmitOp(HloInstruction* instr) { StatusOr<mlir::Operation*> LhloDialectEmitter::EmitOp(HloInstruction* instr) {
@ -273,6 +282,8 @@ StatusOr<mlir::Operation*> LhloDialectEmitter::EmitOp(HloInstruction* instr) {
return CreateOpWithoutAttrs<lmhlo::FloorOp>(instr); return CreateOpWithoutAttrs<lmhlo::FloorOp>(instr);
case HloOpcode::kImag: case HloOpcode::kImag:
return CreateOpWithoutAttrs<lmhlo::ImagOp>(instr); return CreateOpWithoutAttrs<lmhlo::ImagOp>(instr);
case HloOpcode::kInfeed:
return EmitInfeedOp(instr);
case HloOpcode::kIsFinite: case HloOpcode::kIsFinite:
return CreateOpWithoutAttrs<lmhlo::IsFiniteOp>(instr); return CreateOpWithoutAttrs<lmhlo::IsFiniteOp>(instr);
case HloOpcode::kLog: case HloOpcode::kLog:
@ -387,7 +398,7 @@ StatusOr<Value> LhloDialectEmitter::RewriteFusionOperand(
::xla::ShapeIndex* shape_index, OpBuilder* b, Location loc) { ::xla::ShapeIndex* shape_index, OpBuilder* b, Location loc) {
if (shape.IsTuple()) { if (shape.IsTuple()) {
llvm::SmallVector<Value, 4> values; llvm::SmallVector<Value, 4> values;
for (int i = 0; i < shape.tuple_shapes_size(); i++) { for (int i = 0; i < shape.tuple_shapes_size(); ++i) {
shape_index->push_back(i); shape_index->push_back(i);
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
auto v, RewriteFusionOperand(root, shape.tuple_shapes(i), shape_index, auto v, RewriteFusionOperand(root, shape.tuple_shapes(i), shape_index,
@ -423,7 +434,7 @@ StatusOr<lmhlo::FusionOp> LhloDialectEmitter::EmitFusionOp(
auto region_builder = OpBuilder::atBlockBegin(&fusion.region().front()); auto region_builder = OpBuilder::atBlockBegin(&fusion.region().front());
llvm::SmallVector<Value, 8> arguments; llvm::SmallVector<Value, 8> arguments;
for (int i = 0; i < instr->operands().size(); i++) { for (int i = 0; i < instr->operands().size(); ++i) {
const HloInstruction* operand = instr->operand(i); const HloInstruction* operand = instr->operand(i);
xla::ShapeIndex shape_index; xla::ShapeIndex shape_index;
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
@ -982,6 +993,19 @@ StatusOr<lmhlo::AllReduceOp> LhloDialectEmitter::EmitAllReduceOp(
return all_reduce_op; return all_reduce_op;
} }
StatusOr<lmhlo::InfeedOp> LhloDialectEmitter::EmitInfeedOp(
HloInstruction* instr) {
HloInfeedInstruction* infeed = ::xla::Cast<HloInfeedInstruction>(instr);
// HLO Infeed instruction has a single operand of token type and a tuple
// with buffers and a token as its output. LMHLO Infeed operation does not
// need the token operand or result, so drop it.
SmallVector<Value, 2> operands;
TF_RETURN_IF_ERROR(GetOrCreateView(instr, &operands, /*result_subset=*/{0}));
auto infeed_op = CreateOpWithoutAttrs<lmhlo::InfeedOp>(instr, operands);
infeed_op.configAttr(builder_.getStringAttr(infeed->infeed_config()));
return infeed_op;
}
StatusOr<Value> LhloDialectEmitter::GetOrCreateArrayView( StatusOr<Value> LhloDialectEmitter::GetOrCreateArrayView(
const ::xla::HloInstruction* instr, const ::xla::Shape& current_shape, const ::xla::HloInstruction* instr, const ::xla::Shape& current_shape,
const ::xla::ShapeIndex& shape_index) { const ::xla::ShapeIndex& shape_index) {
@ -1055,7 +1079,7 @@ Status LhloDialectEmitter::GetOrCreateViewImpl(
const HloInstruction* instr, const Shape& current_shape, const HloInstruction* instr, const Shape& current_shape,
::xla::ShapeIndex* current_shape_index, SmallVectorImpl<Value>* values) { ::xla::ShapeIndex* current_shape_index, SmallVectorImpl<Value>* values) {
if (current_shape.IsTuple()) { if (current_shape.IsTuple()) {
for (int i = 0; i < current_shape.tuple_shapes().size(); i++) { for (int i = 0; i < current_shape.tuple_shapes().size(); ++i) {
current_shape_index->push_back(i); current_shape_index->push_back(i);
TF_RETURN_IF_ERROR(GetOrCreateViewImpl( TF_RETURN_IF_ERROR(GetOrCreateViewImpl(
instr, current_shape.tuple_shapes(i), current_shape_index, values)); instr, current_shape.tuple_shapes(i), current_shape_index, values));
@ -1063,19 +1087,26 @@ Status LhloDialectEmitter::GetOrCreateViewImpl(
} }
return Status::OK(); return Status::OK();
} }
TF_ASSIGN_OR_RETURN( if (current_shape.IsArray()) {
auto v, GetOrCreateArrayView(instr, current_shape, *current_shape_index)); TF_ASSIGN_OR_RETURN(auto v, GetOrCreateArrayView(instr, current_shape,
values->push_back(v); *current_shape_index));
return Status::OK(); values->push_back(v);
return Status::OK();
}
return xla::InternalError("Unexpected shape kind for %s and shape index %s",
instr->ToString(), current_shape_index->ToString());
} }
// Returns a view for the result of an instruction. // Returns a view for the result of an instruction.
// We first get a view for the slice in the allocation, and then may need to // We first get a view for the slice in the allocation, and then may need to
// create another view to adjust the slice for the shape of the instruction. // create another view to adjust the slice for the shape of the instruction.
Status LhloDialectEmitter::GetOrCreateView(const HloInstruction* instr, Status LhloDialectEmitter::GetOrCreateView(
SmallVectorImpl<Value>* values) { const HloInstruction* instr, SmallVectorImpl<Value>* values,
::xla::ShapeIndex shape_index; const xla::ShapeIndex& result_subset) {
return GetOrCreateViewImpl(instr, instr->shape(), &shape_index, values); ::xla::ShapeIndex shape_index = result_subset;
const Shape& sub_shape =
::xla::ShapeUtil::GetSubshape(instr->shape(), shape_index);
return GetOrCreateViewImpl(instr, sub_shape, &shape_index, values);
} }
Status LhloDialectEmitter::Initialize() { Status LhloDialectEmitter::Initialize() {

View File

@ -27,6 +27,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h"
namespace mlir { namespace mlir {
@ -79,6 +81,7 @@ class LhloDialectEmitter : public ::xla::DfsHloVisitorWithDefault {
::xla::StatusOr<lmhlo::CompareOp> EmitCompareOp(::xla::HloInstruction* instr); ::xla::StatusOr<lmhlo::CompareOp> EmitCompareOp(::xla::HloInstruction* instr);
::xla::StatusOr<lmhlo::InfeedOp> EmitInfeedOp(::xla::HloInstruction* instr);
::xla::StatusOr<lmhlo::MapOp> EmitMapOp(::xla::HloInstruction* instr); ::xla::StatusOr<lmhlo::MapOp> EmitMapOp(::xla::HloInstruction* instr);
::xla::StatusOr<lmhlo::ReducePrecisionOp> EmitReducePrecisionOp( ::xla::StatusOr<lmhlo::ReducePrecisionOp> EmitReducePrecisionOp(
@ -87,10 +90,16 @@ class LhloDialectEmitter : public ::xla::DfsHloVisitorWithDefault {
::xla::StatusOr<lmhlo::AllReduceOp> EmitAllReduceOp( ::xla::StatusOr<lmhlo::AllReduceOp> EmitAllReduceOp(
::xla::HloInstruction* instr); ::xla::HloInstruction* instr);
::xla::Status CreateOperands( // Create LHLO operation operands given an XLA HLO instruction. By default,
::xla::HloInstruction* instr, SmallVectorImpl<Value>& operands, // all XLA HLO operands and results are converted to MLIR and appended to
size_t& num_arguments, size_t& num_results, // `operands`. If `num_operands` is specified, only the first `num_operand`
absl::optional<xla::int64> num_operands = absl::nullopt); // operands of the instruction are converted to MLIR. The function returns the
// actual number of operands and results generated for MLIR in `num_arguments`
// and `num_results`.
::xla::Status CreateOperands(::xla::HloInstruction* instr,
absl::optional<xla::int64> num_operands,
SmallVectorImpl<Value>& operands,
size_t& num_arguments, size_t& num_results);
template <typename OpType> template <typename OpType>
::xla::StatusOr<OpType> CreateOpWithoutAttrs( ::xla::StatusOr<OpType> CreateOpWithoutAttrs(
@ -105,6 +114,10 @@ class LhloDialectEmitter : public ::xla::DfsHloVisitorWithDefault {
::xla::HloInstruction* instr, size_t& num_arguments, size_t& num_results, ::xla::HloInstruction* instr, size_t& num_arguments, size_t& num_results,
absl::optional<xla::int64> num_operands = absl::nullopt); absl::optional<xla::int64> num_operands = absl::nullopt);
template <typename OpType>
OpType CreateOpWithoutAttrs(::xla::HloInstruction* instr,
ValueRange operands);
template <typename T> template <typename T>
DenseIntElementsAttr GetI64DenseElementsAttr(const T& container) { DenseIntElementsAttr GetI64DenseElementsAttr(const T& container) {
return builder_.getI64TensorAttr( return builder_.getI64TensorAttr(
@ -140,9 +153,14 @@ class LhloDialectEmitter : public ::xla::DfsHloVisitorWithDefault {
SmallVectorImpl<Value>* values); SmallVectorImpl<Value>* values);
// Helper function to create view/tuple of views to a buffer for a given // Helper function to create view/tuple of views to a buffer for a given
// instruction result. // instruction result. `result_subset` can be used to for instructions that
// have a tuple result and MLIR conversion needs to convert only one of the
// tuple elements. Note that if needed, this can be extended to take a list of
// ShapeIndex values in case we need finer control on what elements of the
// output tuple to be converted to MLIR.
tensorflow::Status GetOrCreateView(const ::xla::HloInstruction* instr, tensorflow::Status GetOrCreateView(const ::xla::HloInstruction* instr,
SmallVectorImpl<Value>* values); SmallVectorImpl<Value>* values,
const xla::ShapeIndex& result_subset = {});
::xla::StatusOr<Value> GetOrCreateArrayView( ::xla::StatusOr<Value> GetOrCreateArrayView(
const ::xla::HloInstruction* instr, const ::xla::Shape& current_shape, const ::xla::HloInstruction* instr, const ::xla::Shape& current_shape,

View File

@ -51,10 +51,10 @@
" \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/xla/tutorials/compile\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n", " \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/xla/tutorials/compile\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n",
" \u003c/td\u003e\n", " \u003c/td\u003e\n",
" \u003ctd\u003e\n", " \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/g3doc/tutorials/compile.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n", " \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/g3doc/tutorials/jit_compile.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
" \u003c/td\u003e\n", " \u003c/td\u003e\n",
" \u003ctd\u003e\n", " \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/g3doc/tutorials/compile.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n", " \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/g3doc/tutorials/jit_compile.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n",
" \u003c/td\u003e\n", " \u003c/td\u003e\n",
"\u003c/table\u003e" "\u003c/table\u003e"
] ]

View File

@ -89,7 +89,7 @@ PyClient::GetDefaultDeviceAssignment1D(int num_replicas) {
} }
StatusOr<std::unique_ptr<PjRtBuffer>> PyClient::PjRtBufferFromPyval( StatusOr<std::unique_ptr<PjRtBuffer>> PyClient::PjRtBufferFromPyval(
const pybind11::object& argument, PjRtDevice* device, bool force_copy, pybind11::handle argument, PjRtDevice* device, bool force_copy,
PjRtClient::HostBufferSemantics host_buffer_semantics) { PjRtClient::HostBufferSemantics host_buffer_semantics) {
if (device == nullptr) { if (device == nullptr) {
TF_RET_CHECK(!pjrt_client_->local_devices().empty()); TF_RET_CHECK(!pjrt_client_->local_devices().empty());
@ -123,7 +123,7 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PyClient::PjRtBufferFromPyval(
return buffer; return buffer;
} }
StatusOr<std::unique_ptr<PyBuffer>> PyClient::BufferFromPyval( StatusOr<std::unique_ptr<PyBuffer>> PyClient::BufferFromPyval(
const pybind11::object& argument, PjRtDevice* device, bool force_copy, pybind11::handle argument, PjRtDevice* device, bool force_copy,
PjRtClient::HostBufferSemantics host_buffer_semantics) { PjRtClient::HostBufferSemantics host_buffer_semantics) {
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
std::unique_ptr<PjRtBuffer> buffer, std::unique_ptr<PjRtBuffer> buffer,

View File

@ -124,10 +124,10 @@ class PyClient : public std::enable_shared_from_this<PyClient> {
} }
StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBufferFromPyval( StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBufferFromPyval(
const pybind11::object& argument, PjRtDevice* device, bool force_copy, pybind11::handle argument, PjRtDevice* device, bool force_copy,
PjRtClient::HostBufferSemantics host_buffer_semantics); PjRtClient::HostBufferSemantics host_buffer_semantics);
StatusOr<std::unique_ptr<PyBuffer>> BufferFromPyval( StatusOr<std::unique_ptr<PyBuffer>> BufferFromPyval(
const pybind11::object& argument, PjRtDevice* device, bool force_copy, pybind11::handle argument, PjRtDevice* device, bool force_copy,
PjRtClient::HostBufferSemantics host_buffer_semantics); PjRtClient::HostBufferSemantics host_buffer_semantics);
StatusOr<std::shared_ptr<PyExecutable>> Compile( StatusOr<std::shared_ptr<PyExecutable>> Compile(

View File

@ -71,6 +71,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/gpu_constants.h" #include "tensorflow/compiler/xla/service/gpu/gpu_constants.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h" #include "tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h"
#include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h" #include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h"
#include "tensorflow/compiler/xla/service/gpu/infeed_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h" #include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h"
#include "tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h" #include "tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h"
@ -3168,7 +3169,22 @@ Status IrEmitterUnnested::HandleAllToAll(HloInstruction* hlo) {
} }
Status IrEmitterUnnested::HandleInfeed(HloInstruction* xla_infeed) { Status IrEmitterUnnested::HandleInfeed(HloInstruction* xla_infeed) {
return ThunkEmitter(this).HandleInfeed(xla_infeed); TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(xla_infeed));
auto infeed_op = mlir::dyn_cast<mlir::lmhlo::InfeedOp>(input.op);
std::vector<InfeedThunk::ShapedSlice> dest_slices;
dest_slices.reserve(infeed_op.outputs().size());
for (mlir::Value output : infeed_op.outputs()) {
TF_ASSIGN_OR_RETURN(auto slice, GetAllocationSliceForMlir(output));
const Shape& shape = TypeToShape(output.getType());
dest_slices.push_back(InfeedThunk::ShapedSlice{slice, shape});
}
AddThunkToThunkSequence(
absl::make_unique<InfeedThunk>(input.thunk_info, std::move(dest_slices)));
return Status::OK();
} }
Status IrEmitterUnnested::HandleOutfeed(HloInstruction* outfeed) { Status IrEmitterUnnested::HandleOutfeed(HloInstruction* outfeed) {

View File

@ -800,8 +800,16 @@ Status AMDGPUTargetModuleLinker(llvm::Module* module, GpuVersion gpu_version,
std::unique_ptr<llvm::TargetMachine> AMDGPUGetTargetMachine( std::unique_ptr<llvm::TargetMachine> AMDGPUGetTargetMachine(
llvm::Triple target_triple, int amdgpu_version, llvm::Triple target_triple, int amdgpu_version,
const HloModuleConfig& hlo_module_config) { const HloModuleConfig& hlo_module_config) {
string feature_str = "+code-object-v3";
#if TF_ROCM_VERSION >= 30900
// code-object-v3 is default, so no need to expliticitly specify it
// in the feature string. Also, starting with ROCm 4.0, this feature string
// is deprecated, and we get a warning to that effect. So removing that
// feature string
feature_str = "";
#endif
return GetTargetMachine(target_triple, absl::StrCat("gfx", amdgpu_version), return GetTargetMachine(target_triple, absl::StrCat("gfx", amdgpu_version),
hlo_module_config, "+code-object-v3"); hlo_module_config, feature_str);
} }
void AMDGPUBackendInit(const HloModuleConfig& hlo_module_config) { void AMDGPUBackendInit(const HloModuleConfig& hlo_module_config) {

View File

@ -115,34 +115,6 @@ std::unique_ptr<Thunk> ThunkEmitter::BuildGemmThunk(
/*implements_whole_instruction=*/true); /*implements_whole_instruction=*/true);
} }
std::unique_ptr<Thunk> ThunkEmitter::BuildInfeedThunk(
const HloInstruction* inst) {
CHECK_EQ(HloOpcode::kInfeed, inst->opcode());
std::vector<ShapeUtil::IndexedShape> leaf_shapes =
ShapeUtil::GetLeafShapes(inst->shape());
// For an infeed HLO, the output is a 2 element tuple where the first element
// of the tuple is all the infeed buffers and the second element is a token.
// The infeed thunk does not need to handle this token output, so just drop
// it.
leaf_shapes.pop_back();
std::vector<InfeedThunk::ShapedSlice> dest_slices;
dest_slices.reserve(leaf_shapes.size());
for (ShapeUtil::IndexedShape& indexed_shape : leaf_shapes) {
BufferAllocation::Slice slice =
GetAllocationSlice(*inst, indexed_shape.index);
const Shape& shape =
ShapeUtil::GetSubshape(inst->shape(), indexed_shape.index);
dest_slices.emplace_back(InfeedThunk::ShapedSlice{slice, shape});
}
return absl::make_unique<InfeedThunk>(context_->GetThunkInfo(inst),
std::move(dest_slices));
}
std::unique_ptr<Thunk> ThunkEmitter::BuildOutfeedThunk( std::unique_ptr<Thunk> ThunkEmitter::BuildOutfeedThunk(
const HloInstruction* inst) { const HloInstruction* inst) {
CHECK_EQ(HloOpcode::kOutfeed, inst->opcode()); CHECK_EQ(HloOpcode::kOutfeed, inst->opcode());
@ -258,11 +230,6 @@ Status ThunkEmitter::HandleTriangularSolve(HloInstruction* hlo) {
return Status::OK(); return Status::OK();
} }
Status ThunkEmitter::HandleInfeed(HloInstruction* infeed) {
AddThunkToThunkSequence(BuildInfeedThunk(infeed));
return Status::OK();
}
Status ThunkEmitter::HandleOutfeed(HloInstruction* outfeed) { Status ThunkEmitter::HandleOutfeed(HloInstruction* outfeed) {
AddThunkToThunkSequence(BuildOutfeedThunk(outfeed)); AddThunkToThunkSequence(BuildOutfeedThunk(outfeed));
return Status::OK(); return Status::OK();

View File

@ -46,7 +46,6 @@ class ThunkEmitter {
Status HandleCustomCall(HloInstruction* custom_call); Status HandleCustomCall(HloInstruction* custom_call);
Status HandleFft(HloInstruction* fft); Status HandleFft(HloInstruction* fft);
Status HandleTriangularSolve(HloInstruction* hlo); Status HandleTriangularSolve(HloInstruction* hlo);
Status HandleInfeed(HloInstruction* xla_infeed);
Status HandleOutfeed(HloInstruction* outfeed); Status HandleOutfeed(HloInstruction* outfeed);
private: private:

View File

@ -367,7 +367,6 @@ xla_test(
"conv_depthwise_test.cc", "conv_depthwise_test.cc",
], ],
shard_count = 50, shard_count = 50,
tags = ["no_rocm"], # ROCm 3.9 regression
deps = [ deps = [
":conv_depthwise_common", ":conv_depthwise_common",
":test_macros_header", ":test_macros_header",
@ -389,7 +388,6 @@ xla_test(
timeout = "long", timeout = "long",
srcs = ["conv_depthwise_backprop_filter_test.cc"], srcs = ["conv_depthwise_backprop_filter_test.cc"],
shard_count = 40, shard_count = 40,
tags = ["no_rocm"], # ROCm 3.9 regression
deps = [ deps = [
":test_macros_header", ":test_macros_header",
"//tensorflow/compiler/xla:execution_options_util", "//tensorflow/compiler/xla:execution_options_util",
@ -414,7 +412,6 @@ xla_test(
"cpu", "cpu",
], ],
shard_count = 50, shard_count = 50,
tags = ["no_rocm"], # ROCm 3.9 regression
deps = [ deps = [
":client_library_test_base", ":client_library_test_base",
":hlo_test_base", ":hlo_test_base",
@ -924,7 +921,6 @@ xla_test(
srcs = ["dot_operation_test.cc"], srcs = ["dot_operation_test.cc"],
shard_count = 20, shard_count = 20,
tags = [ tags = [
"no_rocm", # ROCm 3.9 regression
"optonly", "optonly",
], ],
deps = [ deps = [
@ -958,7 +954,6 @@ xla_test(
backends = ["gpu"], backends = ["gpu"],
shard_count = 20, shard_count = 20,
tags = [ tags = [
"no_rocm", # ROCm 3.9 regression
"optonly", "optonly",
# TODO(b/151340488): Timed out on 2020-03-12. # TODO(b/151340488): Timed out on 2020-03-12.
"nozapfhahn", "nozapfhahn",
@ -1025,7 +1020,6 @@ xla_test(
}, },
shard_count = 20, shard_count = 20,
tags = [ tags = [
"no_rocm", # ROCm 3.9 regression
"optonly", "optonly",
], ],
deps = [ deps = [
@ -1253,7 +1247,6 @@ xla_test(
"cpu": ["nomsan"], "cpu": ["nomsan"],
}, },
shard_count = 30, shard_count = 30,
tags = ["no_rocm"], # ROCm 3.9 regression
deps = [ deps = [
":test_macros_header", ":test_macros_header",
"//tensorflow/compiler/xla:array3d", "//tensorflow/compiler/xla:array3d",
@ -1278,7 +1271,6 @@ xla_test(
timeout = "long", timeout = "long",
srcs = ["convolution_dimension_numbers_test.cc"], srcs = ["convolution_dimension_numbers_test.cc"],
shard_count = 20, shard_count = 20,
tags = ["no_rocm"], # ROCm 3.9 regression
deps = [ deps = [
":test_macros_header", ":test_macros_header",
"//tensorflow/compiler/xla:array4d", "//tensorflow/compiler/xla:array4d",
@ -2322,7 +2314,6 @@ xla_test(
name = "multioutput_fusion_test", name = "multioutput_fusion_test",
srcs = ["multioutput_fusion_test.cc"], srcs = ["multioutput_fusion_test.cc"],
backends = ["gpu"], backends = ["gpu"],
tags = ["no_rocm"], # ROCm 3.9 regression
deps = [ deps = [
":test_macros_header", ":test_macros_header",
"//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal",

View File

@ -480,13 +480,11 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
{csinfo_.fused_batch_norm_grad_v3, {csinfo_.fused_batch_norm_grad_v3,
mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad_v3), mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad_v3),
CopyAttrsAll, FusedBatchNormV3Rewrite, GetRewriteCause()}); CopyAttrsAll, FusedBatchNormV3Rewrite, GetRewriteCause()});
#ifdef ENABLE_MKLDNN_V1
rinfo_.push_back({csinfo_.fused_batch_norm_ex, rinfo_.push_back({csinfo_.fused_batch_norm_ex,
native_fmt ? csinfo_.mkl_native_fused_batch_norm_ex native_fmt ? csinfo_.mkl_native_fused_batch_norm_ex
: csinfo_.mkl_fused_batch_norm_ex, : csinfo_.mkl_fused_batch_norm_ex,
CopyAttrsAll, FusedBatchNormExRewrite, CopyAttrsAll, FusedBatchNormExRewrite,
GetRewriteCause()}); GetRewriteCause()});
#endif
rinfo_.push_back({csinfo_.fused_conv2d, rinfo_.push_back({csinfo_.fused_conv2d,
native_fmt ? csinfo_.mkl_native_fused_conv2d native_fmt ? csinfo_.mkl_native_fused_conv2d
: csinfo_.mkl_fused_conv2d, : csinfo_.mkl_fused_conv2d,
@ -672,14 +670,12 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
rinfo_.push_back({csinfo_.requantize, rinfo_.push_back({csinfo_.requantize,
mkl_op_registry::GetMklOpName(csinfo_.requantize), mkl_op_registry::GetMklOpName(csinfo_.requantize),
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()}); CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
#ifdef ENABLE_MKLDNN_V1
// Optimized TanhGrad support exists only in DNNL 1.x. // Optimized TanhGrad support exists only in DNNL 1.x.
rinfo_.push_back({csinfo_.tanh, mkl_op_registry::GetMklOpName(csinfo_.tanh), rinfo_.push_back({csinfo_.tanh, mkl_op_registry::GetMklOpName(csinfo_.tanh),
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()}); CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
rinfo_.push_back({csinfo_.tanh_grad, rinfo_.push_back({csinfo_.tanh_grad,
mkl_op_registry::GetMklOpName(csinfo_.tanh_grad), mkl_op_registry::GetMklOpName(csinfo_.tanh_grad),
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()}); CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
#endif // ENABLE_MKLDNN_V1
rinfo_.push_back({csinfo_.reshape, rinfo_.push_back({csinfo_.reshape,
mkl_op_registry::GetMklOpName(csinfo_.reshape), mkl_op_registry::GetMklOpName(csinfo_.reshape),
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()}); CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});

View File

@ -53,7 +53,6 @@ static void InitGraph(const string& s, Graph* graph,
GraphDef graph_def; GraphDef graph_def;
auto parser = protobuf::TextFormat::Parser(); auto parser = protobuf::TextFormat::Parser();
// parser.AllowRelaxedWhitespace(true);
CHECK(parser.MergeFromString(s, &graph_def)) << s; CHECK(parser.MergeFromString(s, &graph_def)) << s;
GraphConstructorOptions opts; GraphConstructorOptions opts;
TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, graph)); TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, graph));
@ -66,7 +65,6 @@ static void InitGraph(const string& s, Graph* graph,
class MklLayoutPassTest : public ::testing::Test { class MklLayoutPassTest : public ::testing::Test {
public: public:
MklLayoutPassTest() : graph_(OpRegistry::Global()) {} MklLayoutPassTest() : graph_(OpRegistry::Global()) {}
// Ashraf added
Node* FindNode(const string& name) { Node* FindNode(const string& name) {
for (Node* node : graph_.nodes()) { for (Node* node : graph_.nodes()) {
if (node->name() == name) return node; if (node->name() == name) return node;
@ -3087,8 +3085,6 @@ REGISTER_TEST_ALL_TYPES(NodeRewrite_LeakyReluGrad_Negative);
REGISTER_TEST_ALL_TYPES(NodeRewrite_LeakyReluLeakyReluGrad_Positive); REGISTER_TEST_ALL_TYPES(NodeRewrite_LeakyReluLeakyReluGrad_Positive);
#undef REGISTER_TEST #undef REGISTER_TEST
#ifdef ENABLE_MKLDNN_V1
#define REGISTER_TEST(NAME, T, INPUT) \ #define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \ TEST_F(MklLayoutPassTest, NAME##_##T) { \
DCHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); \ DCHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); \
@ -3146,7 +3142,6 @@ REGISTER_TEST_ALL_TYPES(NodeRewrite_TanhGrad_Positive);
} }
REGISTER_TEST_ALL_TYPES(NodeRewrite_TanhTanhGrad_Positive); REGISTER_TEST_ALL_TYPES(NodeRewrite_TanhTanhGrad_Positive);
#undef REGISTER_TEST #undef REGISTER_TEST
#endif // ENABLE_MKLDNN_V1
#define REGISTER_TEST(NAME, T, INPUT) \ #define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \ TEST_F(MklLayoutPassTest, NAME##_##T) { \
@ -3513,7 +3508,6 @@ REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedBatchNormGradV3_5D_Negative_2);
#undef DATA_FORMAT #undef DATA_FORMAT
#undef REGISTER_TEST #undef REGISTER_TEST
#ifdef ENABLE_MKLDNN_V1
#define REGISTER_TEST(NAME, T, INPUT) \ #define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \ TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph("node { name: 'A' op: '" #INPUT "'}" \ InitGraph("node { name: 'A' op: '" #INPUT "'}" \
@ -3603,7 +3597,6 @@ REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedBatchNormEx_Negative1);
} }
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedBatchNormEx_Negative2); REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedBatchNormEx_Negative2);
#undef REGISTER_TEST #undef REGISTER_TEST
#endif // ENABLE_MKLDNN_V1
TEST_F(MklLayoutPassTest, NodeRewrite_QuantizedDepthwiseConv2D_Positive) { TEST_F(MklLayoutPassTest, NodeRewrite_QuantizedDepthwiseConv2D_Positive) {
InitGraph( InitGraph(
@ -5184,8 +5177,8 @@ static void BM_MklLayoutRewritePass(int iters, int op_nodes) {
bool first = true; bool first = true;
while (iters > 0) { while (iters > 0) {
Graph* graph = new Graph(OpRegistry::Global()); std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
InitGraph(s, graph); InitGraph(s, graph.get());
int N = graph->num_node_ids(); int N = graph->num_node_ids();
if (first) { if (first) {
testing::SetLabel(strings::StrCat("Per graph node. Nodes: ", N)); testing::SetLabel(strings::StrCat("Per graph node. Nodes: ", N));
@ -5193,13 +5186,12 @@ static void BM_MklLayoutRewritePass(int iters, int op_nodes) {
} }
{ {
testing::StartTiming(); testing::StartTiming();
std::unique_ptr<Graph> ug(graph); std::unique_ptr<Graph> ug(graph.get());
RunMklLayoutRewritePass(&ug); RunMklLayoutRewritePass(&ug);
testing::StopTiming(); testing::StopTiming();
} }
iters -= N; // Our benchmark units are individual graph nodes, iters -= N; // Our benchmark units are individual graph nodes,
// not whole graphs // not whole graphs
// delete graph;
} }
} }
BENCHMARK(BM_MklLayoutRewritePass)->Arg(1000)->Arg(10000); BENCHMARK(BM_MklLayoutRewritePass)->Arg(1000)->Arg(10000);

View File

@ -37,6 +37,7 @@ load(
package( package(
default_visibility = [ default_visibility = [
"//tensorflow/core:__subpackages__", "//tensorflow/core:__subpackages__",
"//tensorflow/security/fuzzing:__subpackages__",
], ],
licenses = ["notice"], # Apache 2.0 licenses = ["notice"], # Apache 2.0
) )
@ -622,7 +623,10 @@ cc_library(
name = "bfloat16", name = "bfloat16",
srcs = ["bfloat16.cc"], srcs = ["bfloat16.cc"],
hdrs = ["bfloat16.h"], hdrs = ["bfloat16.h"],
visibility = ["//tensorflow/core:__subpackages__"], visibility = [
"//tensorflow/core:__subpackages__",
"//tensorflow/security/fuzzing:__subpackages__",
],
deps = [ deps = [
":numeric_types", ":numeric_types",
"//tensorflow/core/platform:byte_order", "//tensorflow/core/platform:byte_order",

View File

@ -269,6 +269,9 @@ Status MetaOptimizer::InitializeOptimizers(
if (cfg_.pin_to_host_optimization() == RewriterConfig::ON) { if (cfg_.pin_to_host_optimization() == RewriterConfig::ON) {
optimizers->push_back(MakeUnique<PinToHostOptimizer>()); optimizers->push_back(MakeUnique<PinToHostOptimizer>());
} }
if (cfg_.remapping() != RewriterConfig::OFF) {
optimizers->push_back(MakeUnique<Remapper>(cfg_.remapping()));
}
if (cfg_.arithmetic_optimization() != RewriterConfig::OFF) { if (cfg_.arithmetic_optimization() != RewriterConfig::OFF) {
optimizers->push_back( optimizers->push_back(
MakeUnique<ArithmeticOptimizer>(cfg_.arithmetic_optimization())); MakeUnique<ArithmeticOptimizer>(cfg_.arithmetic_optimization()));
@ -278,9 +281,6 @@ Status MetaOptimizer::InitializeOptimizers(
/*optimization level*/ cfg_.layout_optimizer(), /*optimization level*/ cfg_.layout_optimizer(),
/*CPU layout conversion*/ cfg_.cpu_layout_conversion())); /*CPU layout conversion*/ cfg_.cpu_layout_conversion()));
} }
if (cfg_.remapping() != RewriterConfig::OFF) {
optimizers->push_back(MakeUnique<Remapper>(cfg_.remapping()));
}
if (cfg_.loop_optimization() != RewriterConfig::OFF) { if (cfg_.loop_optimization() != RewriterConfig::OFF) {
optimizers->push_back( optimizers->push_back(
MakeUnique<LoopOptimizer>(cfg_.loop_optimization(), cpu_device_)); MakeUnique<LoopOptimizer>(cfg_.loop_optimization(), cpu_device_));

View File

@ -29,8 +29,14 @@ REGISTER5(BinaryOp, CPU, "DivNoNan", functor::div_no_nan, Eigen::half, float,
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
// ROCM TODO: re-enable complex64 / complex128 after compiler fix // ROCM TODO: re-enable complex64 / complex128 after compiler fix
#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
!defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED)
REGISTER9(BinaryOp, GPU, "Div", functor::div, float, Eigen::half, double, uint8, REGISTER9(BinaryOp, GPU, "Div", functor::div, float, Eigen::half, double, uint8,
uint16, int16, int64, complex64, complex128); uint16, int16, int64, complex64, complex128);
#else
REGISTER4(BinaryOp, GPU, "Div", functor::div, uint8, uint16, complex64,
complex128);
#endif
REGISTER4(BinaryOp, GPU, "TruncateDiv", functor::div, uint8, uint16, int16, REGISTER4(BinaryOp, GPU, "TruncateDiv", functor::div, uint8, uint16, int16,
int64); int64);
REGISTER5(BinaryOp, GPU, "RealDiv", functor::div, float, Eigen::half, double, REGISTER5(BinaryOp, GPU, "RealDiv", functor::div, float, Eigen::half, double,

View File

@ -30,8 +30,13 @@ REGISTER(BinaryOp, CPU, "Sub", functor::sub, int32);
#endif // __ANDROID_TYPES_SLIM__ #endif // __ANDROID_TYPES_SLIM__
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
!defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED)
REGISTER7(BinaryOp, GPU, "Sub", functor::sub, float, Eigen::half, double, int64, REGISTER7(BinaryOp, GPU, "Sub", functor::sub, float, Eigen::half, double, int64,
complex64, complex128, uint32); complex64, complex128, uint32);
#else
REGISTER3(BinaryOp, GPU, "Sub", functor::sub, complex64, complex128, uint32);
#endif
// A special GPU kernel for int32. // A special GPU kernel for int32.
// TODO(b/25387198): Also enable int32 in device memory. This kernel // TODO(b/25387198): Also enable int32 in device memory. This kernel

View File

@ -85,7 +85,7 @@ void OptimizeDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
// clang-format off // clang-format off
absl::flat_hash_map<string, uint64> live_experiments = { absl::flat_hash_map<string, uint64> live_experiments = {
{"enable_gradient_descent", 0}, {"enable_gradient_descent", 0},
{"map_parallelization", 0} {"map_parallelization", 1}
}; };
// clang-format on // clang-format on
auto hash_func = [](const string& str) { return Hash64(str); }; auto hash_func = [](const string& str) { return Hash64(str); };

View File

@ -43,7 +43,6 @@ limitations under the License.
#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/util/mkl_types.h"
#include "tensorflow/core/util/mkl_util.h" #include "tensorflow/core/util/mkl_util.h"
#include "tensorflow/core/util/padding.h" #include "tensorflow/core/util/padding.h"
#include "tensorflow/core/util/tensor_format.h" #include "tensorflow/core/util/tensor_format.h"
@ -65,7 +64,7 @@ struct MklConvFwdParams {
memory::dims dilations; memory::dims dilations;
memory::dims padding_left; memory::dims padding_left;
memory::dims padding_right; memory::dims padding_right;
MKL_TENSOR_FORMAT tf_fmt; MklTensorFormat tf_fmt;
bool native_format; bool native_format;
string dtypes = string(""); string dtypes = string("");
struct PostOpParam { struct PostOpParam {
@ -80,7 +79,7 @@ struct MklConvFwdParams {
memory::dims bias_dims, memory::dims dst_dims, memory::dims bias_dims, memory::dims dst_dims,
memory::dims strides, memory::dims dilations, memory::dims strides, memory::dims dilations,
memory::dims padding_left, memory::dims padding_right, memory::dims padding_left, memory::dims padding_right,
MKL_TENSOR_FORMAT tf_fmt, bool native_format) MklTensorFormat tf_fmt, bool native_format)
: src_dims(src_dims), : src_dims(src_dims),
filter_dims(filter_dims), filter_dims(filter_dims),
bias_dims(bias_dims), bias_dims(bias_dims),
@ -99,7 +98,7 @@ template <typename Tinput, typename Tfilter, typename Tbias, typename Toutput>
class MklConvFwdPrimitive : public MklPrimitive { class MklConvFwdPrimitive : public MklPrimitive {
public: public:
explicit MklConvFwdPrimitive(const MklConvFwdParams& convFwdDims) explicit MklConvFwdPrimitive(const MklConvFwdParams& convFwdDims)
: MklPrimitive(engine(ENGINE_CPU, 0)) { : MklPrimitive(engine(engine::kind::cpu, 0)) {
// Create convolution primitive // Create convolution primitive
if (context_.conv_fwd == nullptr) { if (context_.conv_fwd == nullptr) {
Setup(convFwdDims); Setup(convFwdDims);
@ -115,8 +114,8 @@ class MklConvFwdPrimitive : public MklPrimitive {
void Execute(const Tinput* src_data, const Tfilter* filter_data, void Execute(const Tinput* src_data, const Tfilter* filter_data,
const Tbias* bias_data, const Toutput* dst_data, const Tbias* bias_data, const Toutput* dst_data,
std::shared_ptr<stream> fwd_stream) { std::shared_ptr<stream> fwd_stream) {
// TODO: Create a common function and avoid the duplicate code
#ifdef ENABLE_MKLDNN_THREADPOOL #ifdef ENABLE_MKLDNN_THREADPOOL
// TODO: Create a common function and avoid the duplicate code
context_.src_mem->set_data_handle( context_.src_mem->set_data_handle(
static_cast<void*>(const_cast<Tinput*>(src_data)), *fwd_stream); static_cast<void*>(const_cast<Tinput*>(src_data)), *fwd_stream);
context_.filter_mem->set_data_handle( context_.filter_mem->set_data_handle(
@ -139,16 +138,13 @@ class MklConvFwdPrimitive : public MklPrimitive {
context_.dst_mem->set_data_handle( context_.dst_mem->set_data_handle(
static_cast<void*>(const_cast<Toutput*>(dst_data))); static_cast<void*>(const_cast<Toutput*>(dst_data)));
#endif // ENABLE_MKLDNN_THREADPOOL #endif // ENABLE_MKLDNN_THREADPOOL
#ifdef ENABLE_MKLDNN_V1
DCHECK_EQ(context_.fwd_primitives.size(), DCHECK_EQ(context_.fwd_primitives.size(),
context_.fwd_primitives_args.size()); context_.fwd_primitives_args.size());
for (size_t i = 0; i < context_.fwd_primitives.size(); ++i) { for (size_t i = 0; i < context_.fwd_primitives.size(); ++i) {
context_.fwd_primitives.at(i).execute(*fwd_stream, context_.fwd_primitives.at(i).execute(*fwd_stream,
context_.fwd_primitives_args.at(i)); context_.fwd_primitives_args.at(i));
} }
#else
fwd_stream->submit(context_.fwd_primitives);
#endif // ENABLE_MKLDNN_V1
// After execution, set data handle back // After execution, set data handle back
context_.src_mem->set_data_handle(DummyData); context_.src_mem->set_data_handle(DummyData);
@ -168,13 +164,6 @@ class MklConvFwdPrimitive : public MklPrimitive {
Execute(src_data, filter_data, nullptr, dst_data, fwd_stream); Execute(src_data, filter_data, nullptr, dst_data, fwd_stream);
} }
#ifndef ENABLE_MKLDNN_V1
// In MKL-DNN v1.x, memory format tags only provide a partial description
// of the memory layout. Hence, these functions are disabled for v1.x.
memory::format GetSrcMemoryFormat() const { return context_.src_fmt; }
memory::format GetFilterMemoryFormat() const { return context_.filter_fmt; }
#endif // !ENABLE_MKLDNN_V1
std::shared_ptr<ConvFwdPd> GetPrimitiveDesc() const { std::shared_ptr<ConvFwdPd> GetPrimitiveDesc() const {
return context_.fwd_pd; return context_.fwd_pd;
} }
@ -182,12 +171,6 @@ class MklConvFwdPrimitive : public MklPrimitive {
private: private:
// Primitive reuse context for Conv2D Fwd op // Primitive reuse context for Conv2D Fwd op
struct ConvFwdContext { struct ConvFwdContext {
#ifndef ENABLE_MKLDNN_V1
// Expected memory format for this primitive instance
memory::format src_fmt;
memory::format filter_fmt;
#endif // !ENABLE_MKLDNN_V1
// MKL-DNN memory // MKL-DNN memory
std::shared_ptr<mkldnn::memory> src_mem; std::shared_ptr<mkldnn::memory> src_mem;
std::shared_ptr<mkldnn::memory> filter_mem; std::shared_ptr<mkldnn::memory> filter_mem;
@ -208,18 +191,10 @@ class MklConvFwdPrimitive : public MklPrimitive {
std::shared_ptr<mkldnn::primitive> conv_fwd; std::shared_ptr<mkldnn::primitive> conv_fwd;
std::vector<mkldnn::primitive> fwd_primitives; std::vector<mkldnn::primitive> fwd_primitives;
#ifdef ENABLE_MKLDNN_V1
std::vector<std::unordered_map<int, memory>> fwd_primitives_args; std::vector<std::unordered_map<int, memory>> fwd_primitives_args;
#endif // ENABLE_MKLDNN_V1
ConvFwdContext() ConvFwdContext()
: : src_mem(nullptr),
#ifndef ENABLE_MKLDNN_V1
src_fmt(memory::format::any),
filter_fmt(memory::format::any),
#endif // !ENABLE_MKLDNN_V1
src_mem(nullptr),
filter_mem(nullptr), filter_mem(nullptr),
bias_mem(nullptr), bias_mem(nullptr),
dst_mem(nullptr), dst_mem(nullptr),
@ -228,52 +203,45 @@ class MklConvFwdPrimitive : public MklPrimitive {
filter_md(nullptr), filter_md(nullptr),
bias_md(nullptr), bias_md(nullptr),
fwd_pd(nullptr), fwd_pd(nullptr),
conv_fwd(nullptr) { conv_fwd(nullptr) {}
}
}; };
void Setup(const MklConvFwdParams& convFwdDims) { void Setup(const MklConvFwdParams& convFwdDims) {
MEMORY_FORMAT user_data_fmt; memory::format_tag user_data_fmt;
if (convFwdDims.native_format) { if (convFwdDims.native_format) {
user_data_fmt = MklTensorFormatToMklDnnDataFormat(convFwdDims.tf_fmt); user_data_fmt = MklTensorFormatToMklDnnDataFormat(convFwdDims.tf_fmt);
} else { } else {
// Create memory descriptors for convolution data w/ no specified format // Create memory descriptors for convolution data w/ no specified format
user_data_fmt = MEMORY_FORMAT::any; user_data_fmt = memory::format_tag::any;
} }
context_.src_md.reset(new memory::desc( context_.src_md.reset(new memory::desc(
{convFwdDims.src_dims}, MklDnnType<Tinput>(), user_data_fmt)); {convFwdDims.src_dims}, MklDnnType<Tinput>(), user_data_fmt));
context_.filter_md.reset(new memory::desc( context_.filter_md.reset(new memory::desc({convFwdDims.filter_dims},
{convFwdDims.filter_dims}, MklDnnType<Tfilter>(), MEMORY_FORMAT::any)); MklDnnType<Tfilter>(),
memory::format_tag::any));
context_.dst_md.reset(new memory::desc( context_.dst_md.reset(new memory::desc(
{convFwdDims.dst_dims}, MklDnnType<Toutput>(), user_data_fmt)); {convFwdDims.dst_dims}, MklDnnType<Toutput>(), user_data_fmt));
if (!convFwdDims.bias_dims.empty()) if (!convFwdDims.bias_dims.empty())
context_.bias_md.reset(new memory::desc( context_.bias_md.reset(new memory::desc({convFwdDims.bias_dims},
{convFwdDims.bias_dims}, MklDnnType<Tbias>(), MEMORY_FORMAT::any)); MklDnnType<Tbias>(),
memory::format_tag::any));
// Create a convolution descriptor // Create a convolution descriptor
if (!convFwdDims.bias_dims.empty()) { if (!convFwdDims.bias_dims.empty()) {
context_.fwd_desc.reset(new convolution_forward::desc( context_.fwd_desc.reset(new convolution_forward::desc(
prop_kind::forward, ALGORITHM::convolution_direct, *context_.src_md, prop_kind::forward, mkldnn::algorithm::convolution_direct,
*context_.filter_md, *context_.bias_md, *context_.dst_md, *context_.src_md, *context_.filter_md, *context_.bias_md,
convFwdDims.strides, convFwdDims.dilations, convFwdDims.padding_left, *context_.dst_md, convFwdDims.strides, convFwdDims.dilations,
#ifndef ENABLE_MKLDNN_V1 convFwdDims.padding_left, convFwdDims.padding_right));
convFwdDims.padding_right, padding_kind::zero));
#else
convFwdDims.padding_right));
#endif // !ENABLE_MKLDNN_V1
} else { } else {
context_.fwd_desc.reset(new convolution_forward::desc( context_.fwd_desc.reset(new convolution_forward::desc(
prop_kind::forward, ALGORITHM::convolution_direct, *context_.src_md, prop_kind::forward, mkldnn::algorithm::convolution_direct,
*context_.filter_md, *context_.dst_md, convFwdDims.strides, *context_.src_md, *context_.filter_md, *context_.dst_md,
convFwdDims.dilations, convFwdDims.padding_left, convFwdDims.strides, convFwdDims.dilations, convFwdDims.padding_left,
#ifndef ENABLE_MKLDNN_V1
convFwdDims.padding_right, padding_kind::zero));
#else
convFwdDims.padding_right)); convFwdDims.padding_right));
#endif // !ENABLE_MKLDNN_V1
} }
context_.fwd_pd.reset(new ConvFwdPd(*context_.fwd_desc, cpu_engine_)); context_.fwd_pd.reset(new ConvFwdPd(*context_.fwd_desc, cpu_engine_));
@ -314,54 +282,32 @@ class MklConvFwdPrimitive : public MklPrimitive {
context_.fwd_pd.reset(new ConvFwdPd(*context_.fwd_desc, cpu_engine_)); context_.fwd_pd.reset(new ConvFwdPd(*context_.fwd_desc, cpu_engine_));
} }
#ifndef ENABLE_MKLDNN_V1
// Store the expected memory format
context_.src_fmt = static_cast<mkldnn::memory::format>(
context_.fwd_pd.get()->src_primitive_desc().desc().data.format);
context_.filter_fmt = static_cast<mkldnn::memory::format>(
context_.fwd_pd.get()->weights_primitive_desc().desc().data.format);
#endif // !ENABLE_MKLDNN_V1
// Create memory primitive based on dummy data // Create memory primitive based on dummy data
context_.src_mem.reset(new MEMORY_CONSTRUCTOR( context_.src_mem.reset(
context_.fwd_pd.get()->PRIMITIVE_DESC_SRC, cpu_engine_, DummyData)); new memory(context_.fwd_pd.get()->src_desc(), cpu_engine_, DummyData));
context_.filter_mem.reset(new MEMORY_CONSTRUCTOR( context_.filter_mem.reset(new memory(context_.fwd_pd.get()->weights_desc(),
context_.fwd_pd.get()->PRIMITIVE_DESC_WEIGHTS, cpu_engine_, DummyData)); cpu_engine_, DummyData));
context_.dst_mem.reset(new MEMORY_CONSTRUCTOR( context_.dst_mem.reset(
context_.fwd_pd.get()->PRIMITIVE_DESC_DST, cpu_engine_, DummyData)); new memory(context_.fwd_pd.get()->dst_desc(), cpu_engine_, DummyData));
// Create convolution primitive and add it to net // Create convolution primitive and add it to net
if (!convFwdDims.bias_dims.empty()) { if (!convFwdDims.bias_dims.empty()) {
context_.bias_mem.reset(new MEMORY_CONSTRUCTOR_USING_MEM_PD( context_.bias_mem.reset(new memory(
convFwdDims.bias_dims, Tbias, MEMORY_FORMAT::x, cpu_engine_, {{convFwdDims.bias_dims}, MklDnnType<Tbias>(), memory::format_tag::x},
DummyData)); cpu_engine_, DummyData));
#ifdef ENABLE_MKLDNN_V1
context_.conv_fwd.reset(new convolution_forward(*context_.fwd_pd)); context_.conv_fwd.reset(new convolution_forward(*context_.fwd_pd));
context_.fwd_primitives_args.push_back( context_.fwd_primitives_args.push_back(
{{MKLDNN_ARG_SRC, *context_.src_mem}, {{MKLDNN_ARG_SRC, *context_.src_mem},
{MKLDNN_ARG_WEIGHTS, *context_.filter_mem}, {MKLDNN_ARG_WEIGHTS, *context_.filter_mem},
{MKLDNN_ARG_BIAS, *context_.bias_mem}, {MKLDNN_ARG_BIAS, *context_.bias_mem},
{ MKLDNN_ARG_DST, {MKLDNN_ARG_DST, *context_.dst_mem}});
*context_.dst_mem }});
} else { } else {
context_.conv_fwd.reset(new convolution_forward(*context_.fwd_pd)); context_.conv_fwd.reset(new convolution_forward(*context_.fwd_pd));
context_.fwd_primitives_args.push_back( context_.fwd_primitives_args.push_back(
{{MKLDNN_ARG_SRC, *context_.src_mem}, {{MKLDNN_ARG_SRC, *context_.src_mem},
{MKLDNN_ARG_WEIGHTS, *context_.filter_mem}, {MKLDNN_ARG_WEIGHTS, *context_.filter_mem},
{ MKLDNN_ARG_DST, {MKLDNN_ARG_DST, *context_.dst_mem}});
*context_.dst_mem }});
} }
#else
context_.conv_fwd.reset(new convolution_forward(
*context_.fwd_pd, *context_.src_mem, *context_.filter_mem,
*context_.bias_mem, *context_.dst_mem));
} else {
context_.conv_fwd.reset(
new convolution_forward(*context_.fwd_pd, *context_.src_mem,
*context_.filter_mem, *context_.dst_mem));
}
#endif // ENABLE_MKLDNN_V1
context_.fwd_primitives.push_back(*context_.conv_fwd); context_.fwd_primitives.push_back(*context_.conv_fwd);
} }
@ -650,12 +596,10 @@ class MklConvOp : public OpKernel {
auto tf_fmt = is_conv2d ? TFDataFormatToMklDnnDataFormat(data_format_) auto tf_fmt = is_conv2d ? TFDataFormatToMklDnnDataFormat(data_format_)
: TFDataFormatToMklDnn3DDataFormat(data_format_); : TFDataFormatToMklDnn3DDataFormat(data_format_);
#ifdef ENABLE_MKLDNN_V1
auto mkl_fmt_tag = MklTensorFormatToMklDnnDataFormat(tf_fmt); auto mkl_fmt_tag = MklTensorFormatToMklDnnDataFormat(tf_fmt);
// NOTE: `mkl_fmt_tag` will be `format_tag::undef` for ReLU // NOTE: `mkl_fmt_tag` will be `format_tag::undef` for ReLU
OP_REQUIRES(context, mkl_fmt_tag != memory::format_tag::undef, OP_REQUIRES(context, mkl_fmt_tag != memory::format_tag::undef,
errors::InvalidArgument("Invalid data format")); errors::InvalidArgument("Invalid data format"));
#endif // ENABLE_MKLDNN_V1
// If input is in MKL layout, then simply grab the layout; otherwise, // If input is in MKL layout, then simply grab the layout; otherwise,
// construct TF layout for input. // construct TF layout for input.
@ -667,19 +611,15 @@ class MklConvOp : public OpKernel {
auto src_md = auto src_md =
src_mkl_shape.IsMklTensor() src_mkl_shape.IsMklTensor()
? src_mkl_shape.GetMklLayout() ? src_mkl_shape.GetMklLayout()
#ifdef ENABLE_MKLDNN_V1
: memory::desc(src_dims, MklDnnType<Tinput>(), mkl_fmt_tag); : memory::desc(src_dims, MklDnnType<Tinput>(), mkl_fmt_tag);
#else
: memory::desc(src_dims, MklDnnType<Tinput>(), tf_fmt);
#endif // ENABLE_MKLDNN_V1
src.SetUsrMem(src_md, &src_tensor); src.SetUsrMem(src_md, &src_tensor);
// Although filter shape (filter_dims) required is in MKL-DNN order, // Although filter shape (filter_dims) required is in MKL-DNN order,
// the layout is Tensorflow's layout (HWIO) and (HWIGO) for // the layout is Tensorflow's layout (HWIO) and (HWIGO) for
// depthwise/group convolutions. // depthwise/group convolutions.
auto filter_format = is_conv2d ? (is_depthwise ? MEMORY_FORMAT::hwigo auto filter_format = is_conv2d ? (is_depthwise ? memory::format_tag::hwigo
: MEMORY_FORMAT::hwio) : memory::format_tag::hwio)
: MEMORY_FORMAT::dhwio; : memory::format_tag::dhwio;
DCHECK(!filter_mkl_shape.IsMklTensor()); DCHECK(!filter_mkl_shape.IsMklTensor());
auto filter_md = auto filter_md =
@ -738,12 +678,9 @@ class MklConvOp : public OpKernel {
// Check whether src and filter need to be reordered. // Check whether src and filter need to be reordered.
Tinput* src_data = nullptr; Tinput* src_data = nullptr;
if (IS_SRC_REORDER_NEEDED(src_md, conv_fwd_pd, conv_fwd)) { if (src_md != conv_fwd_pd->src_desc()) {
src.SetUsrMem(src_md, &src_tensor); src.SetUsrMem(src_md, &src_tensor);
src.CheckReorderToOpMem( src.CheckReorderToOpMem(conv_fwd_pd->src_desc(), cpu_engine_, context);
MEMORY_PD_WITHOUT_DATA(GET_SRC_DESC_FROM_OP_PD(conv_fwd_pd),
cpu_engine_),
context);
src_data = static_cast<Tinput*>(src.GetOpMem().get_data_handle()); src_data = static_cast<Tinput*>(src.GetOpMem().get_data_handle());
} else { } else {
src_data = static_cast<Tinput*>( src_data = static_cast<Tinput*>(
@ -751,7 +688,7 @@ class MklConvOp : public OpKernel {
} }
Tfilter* filter_data = nullptr; Tfilter* filter_data = nullptr;
if (IS_FILTER_REORDER_NEEDED(filter_md, conv_fwd_pd, conv_fwd)) { if (filter_md != conv_fwd_pd->weights_desc()) {
bool is_filter_cached = false; bool is_filter_cached = false;
// If filter is a constant, we can avoid the conversion of filter from // If filter is a constant, we can avoid the conversion of filter from
// Tensorflow format to MKL format by caching the filter when it is // Tensorflow format to MKL format by caching the filter when it is
@ -761,28 +698,20 @@ class MklConvOp : public OpKernel {
if (IsFilterCacheEmpty(context)) { if (IsFilterCacheEmpty(context)) {
// Cache filter if it is not already cached. // Cache filter if it is not already cached.
CacheFilter(context, conv_fwd_pd, filter_data, filter_tensor, CacheFilter(context, conv_fwd_pd, filter_data, filter_tensor,
#ifdef ENABLE_MKLDNN_V1
filter, filter_md, filter_mkl_shape); filter, filter_md, filter_mkl_shape);
#else
filter, filter_md);
#endif // ENABLE_MKLDNN_V1
} }
filter_data = GetCachedFilter( filter_data = GetCachedFilter(context, conv_fwd_pd->weights_desc());
context, GET_WEIGHTS_FORMAT_FROM_OP_PD(conv_fwd_pd, conv_fwd));
is_filter_cached = (filter_data != nullptr); is_filter_cached = (filter_data != nullptr);
} }
if (!is_filter_cached) { if (!is_filter_cached) {
filter.SetUsrMem(filter_md, &filter_tensor); filter.SetUsrMem(filter_md, &filter_tensor);
if (filter_out_tensor == nullptr) { if (filter_out_tensor == nullptr) {
filter.CheckReorderToOpMem( filter.CheckReorderToOpMem(conv_fwd_pd->weights_desc(), cpu_engine_,
MEMORY_PD_WITHOUT_DATA(GET_WEIGHTS_DESC_FROM_OP_PD(conv_fwd_pd), context);
cpu_engine_),
context);
} else { } else {
filter.CheckReorderToOpMem( filter.CheckReorderToOpMem(
GET_WEIGHTS_DESC_FROM_OP_PD(conv_fwd_pd), conv_fwd_pd->weights_desc(),
DATA_WITH_ENGINE(filter.GetTensorBuffer(filter_out_tensor), filter.GetTensorBuffer(filter_out_tensor), cpu_engine_,
cpu_engine_),
context); context);
} }
filter_data = filter_data =
@ -897,7 +826,8 @@ class MklConvOp : public OpKernel {
// NOTE: Fusion of BiasAdd is handled directly inside MklConvOp by // NOTE: Fusion of BiasAdd is handled directly inside MklConvOp by
// checking `fuse_biasadd_` flag. // checking `fuse_biasadd_` flag.
if (fuse_add_) { if (fuse_add_) {
params.post_op_params.push_back({"sum", ALGORITHM_UNDEF, {1.0}, ""}); params.post_op_params.push_back(
{"sum", mkldnn::algorithm::undef, {1.0}, ""});
} }
if (fuse_activation_) { if (fuse_activation_) {
params.post_op_params.push_back( params.post_op_params.push_back(
@ -918,35 +848,27 @@ class MklConvOp : public OpKernel {
virtual void AllocateOutputTensor(OpKernelContext* context, virtual void AllocateOutputTensor(OpKernelContext* context,
const ConvFwdPd& conv_prim_desc, const ConvFwdPd& conv_prim_desc,
const memory::dims& output_dims_mkl_order, const memory::dims& output_dims_mkl_order,
MKL_TENSOR_FORMAT output_tf_format, MklTensorFormat output_tf_format,
MklDnnShape* output_mkl_shape, MklDnnShape* output_mkl_shape,
Tensor** output_tensor) { Tensor** output_tensor) {
DCHECK(output_tensor); DCHECK(output_tensor);
#ifdef ENABLE_MKLDNN_V1
auto dst_md = conv_prim_desc.dst_desc(); auto dst_md = conv_prim_desc.dst_desc();
#else
auto dst_pd = conv_prim_desc.dst_primitive_desc();
auto dst_md = dst_pd.desc();
#endif // ENABLE_MKLDNN_V1
if (!std::is_same<Ttemp_output, Toutput>::value) { if (!std::is_same<Ttemp_output, Toutput>::value) {
dst_md.data.data_type = dst_md.data.data_type =
static_cast<mkldnn_data_type_t>(MklDnnType<Toutput>()); static_cast<mkldnn_data_type_t>(MklDnnType<Toutput>());
#ifndef ENABLE_MKLDNN_V1
dst_pd = memory::primitive_desc(dst_md, cpu_engine_);
#endif // !ENABLE_MKLDNN_V1
} }
// Allocate shape of MKL tensor // Allocate shape of MKL tensor
output_mkl_shape->SetMklTensor(true); output_mkl_shape->SetMklTensor(true);
output_mkl_shape->SetMklLayout(&DST_MD); output_mkl_shape->SetMklLayout(&dst_md);
output_mkl_shape->SetElemType(MklDnnType<Toutput>()); output_mkl_shape->SetElemType(MklDnnType<Toutput>());
output_mkl_shape->SetTfLayout(output_dims_mkl_order.size(), output_mkl_shape->SetTfLayout(output_dims_mkl_order.size(),
output_dims_mkl_order, output_tf_format); output_dims_mkl_order, output_tf_format);
// Allocate shape of TF tensor // Allocate shape of TF tensor
TensorShape output_tf_shape; TensorShape output_tf_shape;
output_tf_shape.AddDim((DST_MD.get_size() / sizeof(Toutput))); output_tf_shape.AddDim((dst_md.get_size() / sizeof(Toutput)));
if (native_format) { if (native_format) {
output_tf_shape = output_mkl_shape->GetTfShape(); output_tf_shape = output_mkl_shape->GetTfShape();
} }
@ -972,23 +894,16 @@ class MklConvOp : public OpKernel {
AllocateOutputSetMklShape(context, kOutputIndex_Dst, output_tensor, AllocateOutputSetMklShape(context, kOutputIndex_Dst, output_tensor,
output_tf_shape, *output_mkl_shape, output_tf_shape, *output_mkl_shape,
native_format); native_format);
#ifdef ENABLE_MKLDNN_V1
auto output_format_tag = MklTensorFormatToMklDnnDataFormat( auto output_format_tag = MklTensorFormatToMklDnnDataFormat(
output_mkl_shape->GetTfDataFormat()); output_mkl_shape->GetTfDataFormat());
OP_REQUIRES(context, output_format_tag != memory::format_tag::undef, OP_REQUIRES(context, output_format_tag != memory::format_tag::undef,
errors::InvalidArgument( errors::InvalidArgument(
"MklConvOp: AddN fusion: Invalid data format")); "MklConvOp: AddN fusion: Invalid data format"));
#endif // ENABLE_MKLDNN_V1
auto add_md = auto add_md =
add_mkl_shape.IsMklTensor() add_mkl_shape.IsMklTensor()
? add_mkl_shape.GetMklLayout() ? add_mkl_shape.GetMklLayout()
: memory::desc(output_dims_mkl_order, MklDnnType<Toutput>(), : memory::desc(output_dims_mkl_order, MklDnnType<Toutput>(),
#ifdef ENABLE_MKLDNN_V1
output_format_tag); output_format_tag);
#else
output_mkl_shape->GetTfDataFormat());
auto add_pd = memory::primitive_desc(add_md, this->cpu_engine_);
#endif // ENABLE_MKLDNN_V1
void* add_buf = static_cast<void*>( void* add_buf = static_cast<void*>(
const_cast<Toutput*>(add_tensor.flat<Toutput>().data())); const_cast<Toutput*>(add_tensor.flat<Toutput>().data()));
void* dst_buf = void* dst_buf =
@ -996,16 +911,14 @@ class MklConvOp : public OpKernel {
if (native_format) { if (native_format) {
// We are simply deep copying the add_tensor to output_tensor without // We are simply deep copying the add_tensor to output_tensor without
// changing memory layout, hence using same memory descriptor. // changing memory layout, hence using same memory descriptor.
ADD_MD = DST_MD = add_md = dst_md =
memory::desc({add_tensor.NumElements()}, MklDnnType<Toutput>(), memory::desc({add_tensor.NumElements()}, MklDnnType<Toutput>(),
mkldnn::memory::format_tag::x); mkldnn::memory::format_tag::x);
} }
fuse_add_src_.reset( fuse_add_src_.reset(new memory(add_md, this->cpu_engine_, add_buf));
new MEMORY_CONSTRUCTOR(ADD_MD, this->cpu_engine_, add_buf)); fuse_add_dst_.reset(new memory(dst_md, this->cpu_engine_, dst_buf));
fuse_add_dst_.reset(
new MEMORY_CONSTRUCTOR(DST_MD, this->cpu_engine_, dst_buf));
auto reorder_desc = auto reorder_desc =
REORDER_PD_CONSTRUCTOR(ADD_MD, DST_MD, this->cpu_engine_); ReorderPd(this->cpu_engine_, add_md, this->cpu_engine_, dst_md);
CreateAndExecuteReorder(reorder_desc, *fuse_add_src_, *fuse_add_dst_, CreateAndExecuteReorder(reorder_desc, *fuse_add_src_, *fuse_add_dst_,
this->cpu_engine_, context); this->cpu_engine_, context);
@ -1017,7 +930,7 @@ class MklConvOp : public OpKernel {
} }
} }
engine cpu_engine_ = engine(ENGINE_CPU, 0); engine cpu_engine_ = engine(engine::kind::cpu, 0);
private: private:
std::shared_ptr<mkldnn::memory> fuse_add_src_; std::shared_ptr<mkldnn::memory> fuse_add_src_;
@ -1041,7 +954,7 @@ class MklConvOp : public OpKernel {
// This variable is used for alpha in leakyrelu or upper bound in relu6 // This variable is used for alpha in leakyrelu or upper bound in relu6
// depending on the context // depending on the context
float alpha_or_upbound_ = 0.0; float alpha_or_upbound_ = 0.0;
mkldnn::algorithm activation_alg_ = ALGORITHM_UNDEF; mkldnn::algorithm activation_alg_ = mkldnn::algorithm::undef;
int input_index_pad_ = 2; int input_index_pad_ = 2;
@ -1050,15 +963,10 @@ class MklConvOp : public OpKernel {
const int kOutputIndex_Dst = 0, kOutputIndex_Filter = 1; const int kOutputIndex_Dst = 0, kOutputIndex_Filter = 1;
const int kDilationH = 0, kDilationW = 1; const int kDilationH = 0, kDilationW = 1;
MKL_TENSOR_FORMAT_IN_C GetFilterTfDataFormat( MklTensorFormat GetFilterTfDataFormat(const MklDnnShape* filter_mkl_shape,
const MklDnnShape* filter_mkl_shape, const ConvFwdPd& conv_prim_desc) const {
const ConvFwdPd& conv_prim_desc) const {
#ifdef ENABLE_MKLDNN_V1
DCHECK(filter_mkl_shape); DCHECK(filter_mkl_shape);
return filter_mkl_shape->GetTfDataFormat(); return filter_mkl_shape->GetTfDataFormat();
#else
return conv_prim_desc.weights_primitive_desc().desc().data.format;
#endif // ENABLE_MKLDNN_V1
} }
// Allocate persistent tensors for cached filter data and // Allocate persistent tensors for cached filter data and
@ -1070,23 +978,13 @@ class MklConvOp : public OpKernel {
DCHECK(filter_tensor); DCHECK(filter_tensor);
TensorShape filter_tf_shape; TensorShape filter_tf_shape;
filter_tf_shape.AddDim( filter_tf_shape.AddDim(
(conv_prim_desc.PRIMITIVE_DESC_WEIGHTS.get_size() / sizeof(Tfilter))); (conv_prim_desc.weights_desc().get_size() / sizeof(Tfilter)));
OP_REQUIRES_OK(context, context->allocate_persistent( OP_REQUIRES_OK(context, context->allocate_persistent(
DataTypeToEnum<Tfilter>::value, filter_tf_shape, DataTypeToEnum<Tfilter>::value, filter_tf_shape,
&cached_filter_data_ptensor_, filter_tensor)); &cached_filter_data_ptensor_, filter_tensor));
Tensor* second_tensor = nullptr; Tensor* second_tensor = nullptr;
#ifndef ENABLE_MKLDNN_V1
TensorShape filter_mkl_format;
filter_mkl_format.AddDim(
sizeof(GetFilterTfDataFormat(filter_mkl_shape, conv_prim_desc)) /
sizeof(DT_INT32));
OP_REQUIRES_OK(context, context->allocate_persistent(
DT_INT32, filter_mkl_format,
&cached_filter_md_ptensor_, &second_tensor));
second_tensor->scalar<int32>()() = static_cast<int32>(
GetFilterTfDataFormat(filter_mkl_shape, conv_prim_desc));
#else
// There is no tensor format in DNNL 1.x. So we cache the complete filter // There is no tensor format in DNNL 1.x. So we cache the complete filter
// descriptor as flat byte array. // descriptor as flat byte array.
TensorShape cached_filter_md_shape; TensorShape cached_filter_md_shape;
@ -1100,7 +998,6 @@ class MklConvOp : public OpKernel {
&cached_filter_md_ptensor_, &second_tensor)); &cached_filter_md_ptensor_, &second_tensor));
*reinterpret_cast<memory::desc*>(second_tensor->flat<uint8>().data()) = *reinterpret_cast<memory::desc*>(second_tensor->flat<uint8>().data()) =
weights_desc; weights_desc;
#endif // !ENABLE_MKLDNN_V1
} }
void AllocatePersistentTensor(OpKernelContext* context, void AllocatePersistentTensor(OpKernelContext* context,
@ -1114,7 +1011,7 @@ class MklConvOp : public OpKernel {
const memory::dims& filter_dims_tf_order, const memory::dims& filter_dims_tf_order,
Tensor** filter_tensor) { Tensor** filter_tensor) {
DCHECK(filter_tensor); DCHECK(filter_tensor);
auto filter_md = conv_prim_desc.PRIMITIVE_DESC_WEIGHTS; auto filter_md = conv_prim_desc.weights_desc();
// Allocate shape of MKL tensor // Allocate shape of MKL tensor
MklDnnShape filter_mkl_shape; MklDnnShape filter_mkl_shape;
@ -1127,7 +1024,7 @@ class MklConvOp : public OpKernel {
// is stored in the MKL data. // is stored in the MKL data.
filter_mkl_shape.SetTfLayout(filter_dims_tf_order.size(), filter_mkl_shape.SetTfLayout(filter_dims_tf_order.size(),
filter_dims_tf_order, filter_dims_tf_order,
MKL_TENSOR_FORMAT_BLOCKED); MklTensorFormat::FORMAT_BLOCKED);
// Allocate the data space for the filter to propagate as TF tensor. // Allocate the data space for the filter to propagate as TF tensor.
TensorShape filter_tf_shape; TensorShape filter_tf_shape;
@ -1150,17 +1047,15 @@ class MklConvOp : public OpKernel {
// Create reorders between user layout and MKL layout if it is needed and // Create reorders between user layout and MKL layout if it is needed and
// add it to the net before convolution. No need to check for output // add it to the net before convolution. No need to check for output
// reorder as we propagate output layout to the next layer. // reorder as we propagate output layout to the next layer.
src->CheckReorderToOpMem( src->CheckReorderToOpMem(conv_prim_desc.src_desc(), cpu_engine_);
MEMORY_PD_WITHOUT_DATA(conv_prim_desc.PRIMITIVE_DESC_SRC, cpu_engine_));
// Rather than re-ordering to a temp buffer, reorder directly to the // Rather than re-ordering to a temp buffer, reorder directly to the
// filter output tensor // filter output tensor
filter->CheckReorderToOpMem(conv_prim_desc.PRIMITIVE_DESC_WEIGHTS, filter->CheckReorderToOpMem(conv_prim_desc.weights_desc(),
filter->GetTensorBuffer(filter_out_tensor)); filter->GetTensorBuffer(filter_out_tensor));
// Create convolution primitive and add it to net. // Create convolution primitive and add it to net.
std::vector<primitive> net; std::vector<primitive> net;
#ifdef ENABLE_MKLDNN_V1
std::vector<std::unordered_map<int, memory>> net_args; std::vector<std::unordered_map<int, memory>> net_args;
if (bias) { if (bias) {
DCHECK(fuse_biasadd_); DCHECK(fuse_biasadd_);
@ -1168,31 +1063,15 @@ class MklConvOp : public OpKernel {
net_args.push_back({{MKLDNN_ARG_SRC, src->GetOpMem()}, net_args.push_back({{MKLDNN_ARG_SRC, src->GetOpMem()},
{MKLDNN_ARG_WEIGHTS, filter->GetOpMem()}, {MKLDNN_ARG_WEIGHTS, filter->GetOpMem()},
{MKLDNN_ARG_BIAS, bias->GetOpMem()}, {MKLDNN_ARG_BIAS, bias->GetOpMem()},
{ MKLDNN_ARG_DST, {MKLDNN_ARG_DST, output->GetOpMem()}});
output->GetOpMem() }});
} else { } else {
DCHECK(!fuse_biasadd_); DCHECK(!fuse_biasadd_);
net.push_back(convolution_forward(conv_prim_desc)); net.push_back(convolution_forward(conv_prim_desc));
net_args.push_back({{MKLDNN_ARG_SRC, src->GetOpMem()}, net_args.push_back({{MKLDNN_ARG_SRC, src->GetOpMem()},
{MKLDNN_ARG_WEIGHTS, filter->GetOpMem()}, {MKLDNN_ARG_WEIGHTS, filter->GetOpMem()},
{ MKLDNN_ARG_DST, {MKLDNN_ARG_DST, output->GetOpMem()}});
output->GetOpMem() }});
} }
ExecutePrimitive(net, &net_args, cpu_engine_); ExecutePrimitive(net, &net_args, cpu_engine_);
#else
if (bias) {
DCHECK(fuse_biasadd_);
net.push_back(convolution_forward(conv_prim_desc, src->GetOpMem(),
filter->GetOpMem(), bias->GetOpMem(),
output->GetOpMem()));
} else {
DCHECK(!fuse_biasadd_);
net.push_back(convolution_forward(conv_prim_desc, src->GetOpMem(),
filter->GetOpMem(),
output->GetOpMem()));
}
ExecutePrimitive(net, nullptr, cpu_engine_);
#endif // ENABLE_MKLDNN_V1
} }
// TF_LOCKS_EXCLUDED annotation ensures that the lock (mu_) cannot // TF_LOCKS_EXCLUDED annotation ensures that the lock (mu_) cannot
@ -1206,9 +1085,8 @@ class MklConvOp : public OpKernel {
return (cached_filter_data_tensor.NumElements() == 0); return (cached_filter_data_tensor.NumElements() == 0);
} }
// Cache the converted filter in a persistent tensor. // Cache the converted filter in a persistent tensor.
// Only one thread can execute this method at any given time. // Only one thread can execute this method at any given time.
#ifdef ENABLE_MKLDNN_V1
void CacheFilter(OpKernelContext* context, void CacheFilter(OpKernelContext* context,
const std::shared_ptr<ConvFwdPd>& conv_fwd_pd, const std::shared_ptr<ConvFwdPd>& conv_fwd_pd,
Tfilter* filter_data, const Tensor& filter_tensor, Tfilter* filter_data, const Tensor& filter_tensor,
@ -1254,37 +1132,8 @@ class MklConvOp : public OpKernel {
return true; return true;
} }
#else
void CacheFilter(OpKernelContext* context,
const std::shared_ptr<ConvFwdPd>& conv_fwd_pd,
Tfilter* filter_data, const Tensor& filter_tensor,
MklDnnData<Tfilter>& filter, const memory::desc& filter_md)
TF_LOCKS_EXCLUDED(mu_) {
mutex_lock lock(mu_);
const Tensor& cached_filter_data_tensor =
*cached_filter_data_ptensor_.AccessTensor(context);
// If filter is already cached, there's nothing to do.
if (cached_filter_data_tensor.NumElements() > 0) {
return;
}
// Otherwise, cache filter
filter.SetUsrMem(filter_md, &filter_tensor);
filter.CheckReorderToOpMem(conv_fwd_pd.get()->weights_primitive_desc());
filter_data = static_cast<Tfilter*>(filter.GetOpMem().get_data_handle());
Tensor* filter_tensor_ptr = nullptr;
AllocatePersistentTensor(context, *conv_fwd_pd, &filter_tensor_ptr);
void* cached_filter_data = filter.GetTensorBuffer(filter_tensor_ptr);
size_t cached_filter_data_size =
filter.GetOpMem().get_primitive_desc().get_size();
memcpy(cached_filter_data, filter_data, cached_filter_data_size);
}
#endif // ENABLE_MKLDNN_V1
Tfilter* GetCachedFilter(OpKernelContext* context, Tfilter* GetCachedFilter(OpKernelContext* context,
const MEMORY_DESC& filter_md) const memory::desc& filter_md)
TF_LOCKS_EXCLUDED(mu_) { TF_LOCKS_EXCLUDED(mu_) {
tf_shared_lock lock(mu_); tf_shared_lock lock(mu_);
const Tensor& cached_filter_data = const Tensor& cached_filter_data =
@ -1292,15 +1141,10 @@ class MklConvOp : public OpKernel {
const Tensor& cached_filter_md = const Tensor& cached_filter_md =
*cached_filter_md_ptensor_.AccessTensor(context); *cached_filter_md_ptensor_.AccessTensor(context);
// Check if the memory descriptor of the cached weights is same as // Check if the memory descriptor of the cached weights is the same as
// filter_md. If so, we can use the cached weights; otherwise // filter_md. If so, we can use the cached weights; otherwise
// return nullptr. // return nullptr.
#ifdef ENABLE_MKLDNN_V1
if (filter_md == *static_cast<memory::desc*>(cached_filter_md.data())) { if (filter_md == *static_cast<memory::desc*>(cached_filter_md.data())) {
#else
if (cached_filter_md.scalar<int32>().size() &&
cached_filter_md.scalar<int32>()() == filter_md) {
#endif // ENABLE_MKLDNN_V1
return static_cast<Tfilter*>( return static_cast<Tfilter*>(
const_cast<Tfilter*>(cached_filter_data.flat<Tfilter>().data())); const_cast<Tfilter*>(cached_filter_data.flat<Tfilter>().data()));
} }
@ -1336,31 +1180,34 @@ class MklFusedConvOp
errors::InvalidArgument( errors::InvalidArgument(
"Fused Conv2D must have one extra argument: bias.")); "Fused Conv2D must have one extra argument: bias."));
} else if (fused_ops == std::vector<string>{"Relu"}) { } else if (fused_ops == std::vector<string>{"Relu"}) {
this->set_fuse_activation(true, ALGORITHM::eltwise_relu); this->set_fuse_activation(true, mkldnn::algorithm::eltwise_relu);
} else if (fused_ops == std::vector<string>{"Relu6"}) { } else if (fused_ops == std::vector<string>{"Relu6"}) {
this->set_fuse_activation(true, ALGORITHM::eltwise_bounded_relu, 6.0); this->set_fuse_activation(true, mkldnn::algorithm::eltwise_bounded_relu,
6.0);
} else if (fused_ops == std::vector<string>{"Elu"}) { } else if (fused_ops == std::vector<string>{"Elu"}) {
this->set_fuse_activation(true, ALGORITHM::eltwise_elu, 1.0); this->set_fuse_activation(true, mkldnn::algorithm::eltwise_elu, 1.0);
} else if (fused_ops == std::vector<string>{"LeakyRelu"}) { } else if (fused_ops == std::vector<string>{"LeakyRelu"}) {
float leakyrelu_alpha; float leakyrelu_alpha;
OP_REQUIRES_OK(context, OP_REQUIRES_OK(context,
context->GetAttr("leakyrelu_alpha", &leakyrelu_alpha)); context->GetAttr("leakyrelu_alpha", &leakyrelu_alpha));
this->set_fuse_activation(true, ALGORITHM::eltwise_relu, leakyrelu_alpha); this->set_fuse_activation(true, mkldnn::algorithm::eltwise_relu,
leakyrelu_alpha);
} else if (fused_ops == std::vector<string>{"BiasAdd", "Relu"}) { } else if (fused_ops == std::vector<string>{"BiasAdd", "Relu"}) {
this->set_fuse_biasadd(true); this->set_fuse_biasadd(true);
this->set_fuse_activation(true, ALGORITHM::eltwise_relu); this->set_fuse_activation(true, mkldnn::algorithm::eltwise_relu);
OP_REQUIRES(context, num_args == 1, OP_REQUIRES(context, num_args == 1,
errors::InvalidArgument( errors::InvalidArgument(
"Fused Conv2D must have one extra argument: bias.")); "Fused Conv2D must have one extra argument: bias."));
} else if (fused_ops == std::vector<string>{"BiasAdd", "Relu6"}) { } else if (fused_ops == std::vector<string>{"BiasAdd", "Relu6"}) {
this->set_fuse_biasadd(true); this->set_fuse_biasadd(true);
this->set_fuse_activation(true, ALGORITHM::eltwise_bounded_relu, 6.0); this->set_fuse_activation(true, mkldnn::algorithm::eltwise_bounded_relu,
6.0);
OP_REQUIRES(context, num_args == 1, OP_REQUIRES(context, num_args == 1,
errors::InvalidArgument( errors::InvalidArgument(
"Fused Conv2D must have one extra argument: bias.")); "Fused Conv2D must have one extra argument: bias."));
} else if (fused_ops == std::vector<string>{"BiasAdd", "Elu"}) { } else if (fused_ops == std::vector<string>{"BiasAdd", "Elu"}) {
this->set_fuse_biasadd(true); this->set_fuse_biasadd(true);
this->set_fuse_activation(true, ALGORITHM::eltwise_elu, 1.0); this->set_fuse_activation(true, mkldnn::algorithm::eltwise_elu, 1.0);
OP_REQUIRES(context, num_args == 1, OP_REQUIRES(context, num_args == 1,
errors::InvalidArgument( errors::InvalidArgument(
"Fused Conv2D must have one extra argument: bias.")); "Fused Conv2D must have one extra argument: bias."));
@ -1369,7 +1216,8 @@ class MklFusedConvOp
float leakyrelu_alpha; float leakyrelu_alpha;
OP_REQUIRES_OK(context, OP_REQUIRES_OK(context,
context->GetAttr("leakyrelu_alpha", &leakyrelu_alpha)); context->GetAttr("leakyrelu_alpha", &leakyrelu_alpha));
this->set_fuse_activation(true, ALGORITHM::eltwise_relu, leakyrelu_alpha); this->set_fuse_activation(true, mkldnn::algorithm::eltwise_relu,
leakyrelu_alpha);
OP_REQUIRES(context, num_args == 1, OP_REQUIRES(context, num_args == 1,
errors::InvalidArgument( errors::InvalidArgument(
"Fused Conv2D must have one extra argument: bias.")); "Fused Conv2D must have one extra argument: bias."));
@ -1383,7 +1231,7 @@ class MklFusedConvOp
} else if (fused_ops == std::vector<string>{"BiasAdd", "Add", "Relu"}) { } else if (fused_ops == std::vector<string>{"BiasAdd", "Add", "Relu"}) {
this->set_fuse_biasadd(true); this->set_fuse_biasadd(true);
this->set_fuse_add(true); this->set_fuse_add(true);
this->set_fuse_activation(true, ALGORITHM::eltwise_relu); this->set_fuse_activation(true, mkldnn::algorithm::eltwise_relu);
OP_REQUIRES( OP_REQUIRES(
context, num_args == 2, context, num_args == 2,
errors::InvalidArgument( errors::InvalidArgument(
@ -1391,7 +1239,8 @@ class MklFusedConvOp
} else if (fused_ops == std::vector<string>{"BiasAdd", "Add", "Relu6"}) { } else if (fused_ops == std::vector<string>{"BiasAdd", "Add", "Relu6"}) {
this->set_fuse_biasadd(true); this->set_fuse_biasadd(true);
this->set_fuse_add(true); this->set_fuse_add(true);
this->set_fuse_activation(true, ALGORITHM::eltwise_bounded_relu, 6.0); this->set_fuse_activation(true, mkldnn::algorithm::eltwise_bounded_relu,
6.0);
OP_REQUIRES( OP_REQUIRES(
context, num_args == 2, context, num_args == 2,
errors::InvalidArgument( errors::InvalidArgument(
@ -1399,7 +1248,7 @@ class MklFusedConvOp
} else if (fused_ops == std::vector<string>{"BiasAdd", "Add", "Elu"}) { } else if (fused_ops == std::vector<string>{"BiasAdd", "Add", "Elu"}) {
this->set_fuse_biasadd(true); this->set_fuse_biasadd(true);
this->set_fuse_add(true); this->set_fuse_add(true);
this->set_fuse_activation(true, ALGORITHM::eltwise_elu, 1.0); this->set_fuse_activation(true, mkldnn::algorithm::eltwise_elu, 1.0);
OP_REQUIRES( OP_REQUIRES(
context, num_args == 2, context, num_args == 2,
errors::InvalidArgument( errors::InvalidArgument(
@ -1411,7 +1260,8 @@ class MklFusedConvOp
float leakyrelu_alpha; float leakyrelu_alpha;
OP_REQUIRES_OK(context, OP_REQUIRES_OK(context,
context->GetAttr("leakyrelu_alpha", &leakyrelu_alpha)); context->GetAttr("leakyrelu_alpha", &leakyrelu_alpha));
this->set_fuse_activation(true, ALGORITHM::eltwise_relu, leakyrelu_alpha); this->set_fuse_activation(true, mkldnn::algorithm::eltwise_relu,
leakyrelu_alpha);
OP_REQUIRES( OP_REQUIRES(
context, num_args == 2, context, num_args == 2,
errors::InvalidArgument( errors::InvalidArgument(
@ -1459,13 +1309,14 @@ class MklFusedDepthwiseConvOp
this->set_fuse_biasadd(true); this->set_fuse_biasadd(true);
} else if (fused_ops == std::vector<string>{"BiasAdd", "Relu"}) { } else if (fused_ops == std::vector<string>{"BiasAdd", "Relu"}) {
this->set_fuse_biasadd(true); this->set_fuse_biasadd(true);
this->set_fuse_activation(true, ALGORITHM::eltwise_relu); this->set_fuse_activation(true, mkldnn::algorithm::eltwise_relu);
} else if (fused_ops == std::vector<string>{"BiasAdd", "Relu6"}) { } else if (fused_ops == std::vector<string>{"BiasAdd", "Relu6"}) {
this->set_fuse_biasadd(true); this->set_fuse_biasadd(true);
this->set_fuse_activation(true, ALGORITHM::eltwise_bounded_relu, 6.0); this->set_fuse_activation(true, mkldnn::algorithm::eltwise_bounded_relu,
6.0);
} else if (fused_ops == std::vector<string>{"BiasAdd", "Elu"}) { } else if (fused_ops == std::vector<string>{"BiasAdd", "Elu"}) {
this->set_fuse_biasadd(true); this->set_fuse_biasadd(true);
this->set_fuse_activation(true, ALGORITHM::eltwise_elu, 1.0); this->set_fuse_activation(true, mkldnn::algorithm::eltwise_elu, 1.0);
} else { } else {
OP_REQUIRES(context, false, OP_REQUIRES(context, false,
errors::Unimplemented("Fusion is not implemented: [", errors::Unimplemented("Fusion is not implemented: [",
@ -1642,8 +1493,8 @@ class MklQuantizedConv2DOp
param_key.AddAsKey<float>(max_freezed_output); param_key.AddAsKey<float>(max_freezed_output);
param_key.AddAsKey<const float*>(min_filter); param_key.AddAsKey<const float*>(min_filter);
param_key.AddAsKey<const float*>(max_filter); param_key.AddAsKey<const float*>(max_filter);
params.post_op_params.push_back( params.post_op_params.push_back({"output_scale", mkldnn::algorithm::undef,
{"output_scale", ALGORITHM_UNDEF, scales, param_key.GetKey()}); scales, param_key.GetKey()});
} }
} }
@ -1696,31 +1547,27 @@ class MklQuantizedConv2DOp
bias_attr.set_output_scales(1, scales_); bias_attr.set_output_scales(1, scales_);
} }
auto bias_md = auto bias_md = memory::desc({static_cast<int>(bias_tensor.NumElements())},
MEMORY_PD_CONSTRUCTOR(static_cast<int>(bias_tensor.NumElements()), MklDnnType<Tbias>(), memory::format_tag::x);
Tbias, MEMORY_FORMAT::x, this->cpu_engine_);
void* bias_buf = static_cast<void*>( void* bias_buf = static_cast<void*>(
const_cast<Tbias*>(bias_tensor.flat<Tbias>().data())); const_cast<Tbias*>(bias_tensor.flat<Tbias>().data()));
if (!input_bias_) { if (!input_bias_) {
input_bias_ = input_bias_ = new memory(bias_md, this->cpu_engine_, bias_buf);
new MEMORY_CONSTRUCTOR(bias_md, this->cpu_engine_, bias_buf);
} else { } else {
input_bias_->set_data_handle(bias_buf); input_bias_->set_data_handle(bias_buf);
} }
if (!scaled_bias_buf_) if (!scaled_bias_buf_)
AllocTmpBuffer<Tbias>(context, &scaled_bias_tensor_, AllocTmpBuffer<Tbias>(context, &scaled_bias_tensor_,
GET_BIAS_DESC_FROM_OP_PD(conv_fwd_pd), conv_fwd_pd->bias_desc(), &scaled_bias_buf_);
&scaled_bias_buf_);
if (!scaled_bias_) { if (!scaled_bias_) {
scaled_bias_ = new MEMORY_CONSTRUCTOR(bias_md, this->cpu_engine_, scaled_bias_ = new memory(bias_md, this->cpu_engine_, scaled_bias_buf_);
scaled_bias_buf_);
} else { } else {
scaled_bias_->set_data_handle(scaled_bias_buf_); scaled_bias_->set_data_handle(scaled_bias_buf_);
} }
auto reorder_desc = REORDER_PD_CONSTRUCTOR_WITH_ATTR( auto reorder_desc =
input_bias_->GET_DESC, scaled_bias_->GET_DESC, this->cpu_engine_, ReorderPd(this->cpu_engine_, input_bias_->get_desc(),
bias_attr); this->cpu_engine_, scaled_bias_->get_desc(), bias_attr);
CreateAndExecuteReorder(reorder_desc, *input_bias_, *scaled_bias_, CreateAndExecuteReorder(reorder_desc, *input_bias_, *scaled_bias_,
this->cpu_engine_, context); this->cpu_engine_, context);
@ -1754,7 +1601,7 @@ class MklQuantizedConv2DOp
DCHECK(bias_tensor); DCHECK(bias_tensor);
TensorShape bias_tf_shape; TensorShape bias_tf_shape;
bias_tf_shape.AddDim( bias_tf_shape.AddDim(
(conv_prim_desc.PRIMITIVE_DESC_BIAS.get_size() / sizeof(Tbias))); (conv_prim_desc.bias_desc().get_size() / sizeof(Tbias)));
OP_REQUIRES_OK(context, context->allocate_persistent( OP_REQUIRES_OK(context, context->allocate_persistent(
DataTypeToEnum<Tbias>::value, bias_tf_shape, DataTypeToEnum<Tbias>::value, bias_tf_shape,
&cached_bias_data_ptensor_, bias_tensor)); &cached_bias_data_ptensor_, bias_tensor));
@ -1787,7 +1634,7 @@ class MklQuantizedConv2DOp
AllocatePersistentTensor(context, *conv_fwd_pd, &bias_tensor_ptr); AllocatePersistentTensor(context, *conv_fwd_pd, &bias_tensor_ptr);
void* cached_bias_data = const_cast<void*>( void* cached_bias_data = const_cast<void*>(
static_cast<const void*>(bias_tensor_ptr->flat<Tbias>().data())); static_cast<const void*>(bias_tensor_ptr->flat<Tbias>().data()));
size_t cached_bias_data_size = scaled_bias->GET_DESC.get_size(); size_t cached_bias_data_size = scaled_bias->get_desc().get_size();
memcpy(cached_bias_data, bias_data, cached_bias_data_size); memcpy(cached_bias_data, bias_data, cached_bias_data_size);
} }
@ -1822,7 +1669,7 @@ class MklQuantizedConv2DReluOp
is_depthwise>::ExtendConvFwdParams(context, params); is_depthwise>::ExtendConvFwdParams(context, params);
params.post_op_params.push_back( params.post_op_params.push_back(
{"activation", ALGORITHM::eltwise_relu, {1.0, 0.0, 0.0}, ""}); {"activation", mkldnn::algorithm::eltwise_relu, {1.0, 0.0, 0.0}, ""});
} }
}; };
@ -1868,26 +1715,30 @@ class MklQuantizedConv2DSumReluOp
// if summand_type is also DT_QUINT8 as the scale_output, // if summand_type is also DT_QUINT8 as the scale_output,
// the scaling factor of 255.0f cancels each other and thus is avoided. // the scaling factor of 255.0f cancels each other and thus is avoided.
// If it is not then it is DT_INT8 and is scaled appropriately. // If it is not then it is DT_INT8 and is scaled appropriately.
if (summand_type == DT_QUINT8) if (summand_type == DT_QUINT8) {
params.post_op_params.push_back( params.post_op_params.push_back({"sum",
{"sum", ALGORITHM_UNDEF, {scale_summand / scale_output}, ""}); mkldnn::algorithm::undef,
else {scale_summand / scale_output},
""});
} else {
params.post_op_params.push_back( params.post_op_params.push_back(
{"sum", {"sum",
ALGORITHM_UNDEF, mkldnn::algorithm::undef,
{255.0f * scale_summand / (scale_output * 127.0f)}, {255.0f * scale_summand / (scale_output * 127.0f)},
""}); ""});
}
} else { } else {
params.post_op_params.push_back({"sum", ALGORITHM_UNDEF, {1.0}, ""}); params.post_op_params.push_back(
{"sum", mkldnn::algorithm::undef, {1.0}, ""});
} }
params.post_op_params.push_back( params.post_op_params.push_back(
{"activation", ALGORITHM::eltwise_relu, {1.0, 0.0, 0.0}, ""}); {"activation", mkldnn::algorithm::eltwise_relu, {1.0, 0.0, 0.0}, ""});
} }
void AllocateOutputTensor(OpKernelContext* context, void AllocateOutputTensor(OpKernelContext* context,
const ConvFwdPd& conv_prim_desc, const ConvFwdPd& conv_prim_desc,
const memory::dims& output_dims_mkl_order, const memory::dims& output_dims_mkl_order,
MKL_TENSOR_FORMAT output_tf_format, MklTensorFormat output_tf_format,
MklDnnShape* output_mkl_shape, MklDnnShape* output_mkl_shape,
Tensor** output_tensor) override { Tensor** output_tensor) override {
int summand_idx = context->num_inputs() / 2 - 1; int summand_idx = context->num_inputs() / 2 - 1;
@ -1966,21 +1817,17 @@ class MklQuantizedConv2DSumReluOp
summand_mkl_shape.IsMklTensor() summand_mkl_shape.IsMklTensor()
? summand_mkl_shape.GetMklLayout() ? summand_mkl_shape.GetMklLayout()
: memory::desc(output_dims_mkl_order, MklDnnType<Tbias>(), : memory::desc(output_dims_mkl_order, MklDnnType<Tbias>(),
MEMORY_FORMAT::nhwc); memory::format_tag::nhwc);
#ifndef ENABLE_MKLDNN_V1
auto summand_pd = memory::primitive_desc(summand_md, this->cpu_engine_);
#endif // !ENABLE_MKLDNN_V1
void* summand_buf = void* summand_buf =
static_cast<void*>(const_cast<Tbias*>(summand.flat<Tbias>().data())); static_cast<void*>(const_cast<Tbias*>(summand.flat<Tbias>().data()));
void* dst_buf = void* dst_buf =
static_cast<void*>((*output_tensor)->flat<Ttemp_output>().data()); static_cast<void*>((*output_tensor)->flat<Ttemp_output>().data());
summand_.reset( summand_.reset(new memory(summand_md, this->cpu_engine_, summand_buf));
new MEMORY_CONSTRUCTOR(SUMMAND_MD, this->cpu_engine_, summand_buf)); dst_.reset(
dst_.reset(new MEMORY_CONSTRUCTOR(conv_prim_desc.PRIMITIVE_DESC_DST, new memory(conv_prim_desc.dst_desc(), this->cpu_engine_, dst_buf));
this->cpu_engine_, dst_buf)); auto reorder_desc =
auto reorder_desc = REORDER_PD_CONSTRUCTOR_WITH_ATTR( ReorderPd(this->cpu_engine_, summand_md, this->cpu_engine_,
SUMMAND_MD, conv_prim_desc.PRIMITIVE_DESC_DST, this->cpu_engine_, conv_prim_desc.dst_desc(), reorder_attr);
reorder_attr);
CreateAndExecuteReorder(reorder_desc, *summand_, *dst_, this->cpu_engine_, CreateAndExecuteReorder(reorder_desc, *summand_, *dst_, this->cpu_engine_,
context); context);
} }

View File

@ -42,20 +42,13 @@ limitations under the License.
#include "tensorflow/core/util/padding.h" #include "tensorflow/core/util/padding.h"
#include "tensorflow/core/util/tensor_format.h" #include "tensorflow/core/util/tensor_format.h"
#ifndef ENABLE_MKLDNN_V1
using mkldnn::convolution_direct;
#endif // !ENABLE_MKLDNN_V1
using mkldnn::convolution_forward; using mkldnn::convolution_forward;
using mkldnn::prop_kind; using mkldnn::prop_kind;
using mkldnn::stream; using mkldnn::stream;
namespace tensorflow { namespace tensorflow {
#ifdef ENABLE_MKLDNN_V1
#define MKLDNN_SIZE_DTYPE memory::dim #define MKLDNN_SIZE_DTYPE memory::dim
#else
#define MKLDNN_SIZE_DTYPE int
#endif // ENABLE_MKLDNN_V1
using ConvFwdDesc = mkldnn::convolution_forward::desc; using ConvFwdDesc = mkldnn::convolution_forward::desc;
using ConvFwdPd = mkldnn::convolution_forward::primitive_desc; using ConvFwdPd = mkldnn::convolution_forward::primitive_desc;

View File

@ -134,6 +134,7 @@ tf_kernel_library(
"gpu_op_bitwise_and.cc", "gpu_op_bitwise_and.cc",
"gpu_op_bitwise_or.cc", "gpu_op_bitwise_or.cc",
"gpu_op_bitwise_xor.cc", "gpu_op_bitwise_xor.cc",
"gpu_op_div.cc",
"gpu_op_equal.cc", "gpu_op_equal.cc",
"gpu_op_floor_div.cc", "gpu_op_floor_div.cc",
"gpu_op_greater.cc", "gpu_op_greater.cc",
@ -146,6 +147,7 @@ tf_kernel_library(
"gpu_op_mul.cc", "gpu_op_mul.cc",
"gpu_op_not_equal.cc", "gpu_op_not_equal.cc",
"gpu_op_right_shift.cc", "gpu_op_right_shift.cc",
"gpu_op_sub.cc",
], ],
tags = [ tags = [
"manual", "manual",
@ -155,6 +157,7 @@ tf_kernel_library(
":bitwise_and_kernels", ":bitwise_and_kernels",
":bitwise_or_kernels", ":bitwise_or_kernels",
":bitwise_xor_kernels", ":bitwise_xor_kernels",
":div_kernels",
":equal_kernels", ":equal_kernels",
":floor_div_kernels", ":floor_div_kernels",
":gpu_ops_base", ":gpu_ops_base",
@ -170,6 +173,7 @@ tf_kernel_library(
":mul_kernels", ":mul_kernels",
":not_equal_kernels", ":not_equal_kernels",
":right_shift_kernels", ":right_shift_kernels",
":sub_kernels",
"//third_party/eigen3", "//third_party/eigen3",
], ],
) )
@ -366,18 +370,24 @@ gen_kernel_library(
unroll_factors = "4", unroll_factors = "4",
) )
gen_kernel_library( [
name = "add_v2", gen_kernel_library(
tile_size = "256,1,1", name = name,
types = [ tile_size = "256,1,1",
"f16", types = [
"f32", "f16",
"f64", "f32",
"i64", "f64",
], "i64",
# TODO(b/174543802): Enable once fusion heuristics is better. ],
# unroll_factors = "4", # TODO(b/174543802): Enable once fusion heuristics is better.
) # unroll_factors = "4",
)
for name in [
"add_v2",
"sub",
]
]
gen_kernel_library( gen_kernel_library(
name = "complex", name = "complex",
@ -390,6 +400,20 @@ gen_kernel_library(
# unroll_factors = "2", # unroll_factors = "2",
) )
gen_kernel_library(
name = "div",
tile_size = "256,1,1",
types = [
"f16",
"f32",
"f64",
"i16",
"i64",
],
# TODO(b/174543802): Enable once fusion heuristics is better.
# unroll_factors = "4",
)
gen_kernel_library( gen_kernel_library(
name = "mul", name = "mul",
tile_size = "256,1,1", tile_size = "256,1,1",

View File

@ -48,14 +48,17 @@ class GpuBinaryOpTest : public OpsTestBase {
void SetOpKernel(const std::string& op_name, const TensorShape& lhs_shape, void SetOpKernel(const std::string& op_name, const TensorShape& lhs_shape,
const absl::InlinedVector<T, 10>& lhs_input, const absl::InlinedVector<T, 10>& lhs_input,
const TensorShape& rhs_shape, const TensorShape& rhs_shape,
const absl::InlinedVector<T, 10>& rhs_input, const absl::InlinedVector<T, 10>& rhs_input, bool add_t,
bool use_constraint) { bool add_tout) {
auto builder = NodeDefBuilder("some_name", op_name) auto builder = NodeDefBuilder("some_name", op_name)
.Input(FakeInput(DataTypeToEnum<T>::v())) .Input(FakeInput(DataTypeToEnum<T>::v()))
.Input(FakeInput(DataTypeToEnum<T>::v())); .Input(FakeInput(DataTypeToEnum<T>::v()));
if (use_constraint) { if (add_t) {
builder.Attr("T", DataTypeToEnum<T>::v()); builder.Attr("T", DataTypeToEnum<T>::v());
} }
if (add_tout) {
builder.Attr("Tout", DataTypeToEnum<OutT>::v());
}
TF_ASSERT_OK(builder.Finalize(node_def())); TF_ASSERT_OK(builder.Finalize(node_def()));
TF_ASSERT_OK(InitOp()); TF_ASSERT_OK(InitOp());
@ -73,16 +76,20 @@ class GpuBinaryOpTest : public OpsTestBase {
const absl::InlinedVector<T, 10>& rhs_input, const absl::InlinedVector<T, 10>& rhs_input,
const TensorShape& expected_shape, const TensorShape& expected_shape,
const absl::InlinedVector<OutT, 10>& expected_output, const absl::InlinedVector<OutT, 10>& expected_output,
bool use_constraint = true) { const test::GpuOpsTestConfig& config) {
SetOpKernel<T, OutT>(op_name, lhs_shape, lhs_input, rhs_shape, rhs_input, SetOpKernel<T, OutT>(op_name, lhs_shape, lhs_input, rhs_shape, rhs_input,
use_constraint); config.add_t, config.add_tout);
TF_ASSERT_OK(RunOpKernel()); TF_ASSERT_OK(RunOpKernel());
// Compare output to expectation. // Compare output to expectation.
Tensor expected_tensor(allocator(), DataTypeToEnum<OutT>::value, Tensor expected_tensor(allocator(), DataTypeToEnum<OutT>::value,
expected_shape); expected_shape);
test::FillValues<OutT>(&expected_tensor, expected_output); test::FillValues<OutT>(&expected_tensor, expected_output);
test::ExpectEqual(expected_tensor, *GetOutput(0)); if (config.expect_strictly_equal) {
test::ExpectEqual(expected_tensor, *GetOutput(0));
} else {
test::ExpectClose(expected_tensor, *GetOutput(0));
}
} }
template <typename T, typename OutT> template <typename T, typename OutT>
@ -91,9 +98,9 @@ class GpuBinaryOpTest : public OpsTestBase {
const absl::InlinedVector<T, 10>& lhs_input, const absl::InlinedVector<T, 10>& lhs_input,
const TensorShape& rhs_shape, const TensorShape& rhs_shape,
const absl::InlinedVector<T, 10>& rhs_input, const absl::InlinedVector<T, 10>& rhs_input,
bool use_constraint = true) { const test::GpuOpsTestConfig& config) {
SetOpKernel<T, OutT>(op_name, lhs_shape, lhs_input, rhs_shape, rhs_input, SetOpKernel<T, OutT>(op_name, lhs_shape, lhs_input, rhs_shape, rhs_input,
use_constraint); config.add_t, config.add_tout);
auto status = RunOpKernel(); auto status = RunOpKernel();
EXPECT_FALSE(status.ok()); EXPECT_FALSE(status.ok());
EXPECT_EQ(status.code(), error::INVALID_ARGUMENT); EXPECT_EQ(status.code(), error::INVALID_ARGUMENT);
@ -105,7 +112,7 @@ class GpuBinaryOpTest : public OpsTestBase {
void TestIncompatibleShapes(const std::string& op_name, void TestIncompatibleShapes(const std::string& op_name,
const absl::InlinedVector<T, 10>& lhs_input, const absl::InlinedVector<T, 10>& lhs_input,
const absl::InlinedVector<T, 10>& rhs_input, const absl::InlinedVector<T, 10>& rhs_input,
bool use_constraint = true) { const test::GpuOpsTestConfig& config) {
// Prepare incompatibly shaped inputs. // Prepare incompatibly shaped inputs.
TensorShape lhs_shape{3}; TensorShape lhs_shape{3};
TensorShape rhs_shape{2}; TensorShape rhs_shape{2};
@ -115,8 +122,7 @@ class GpuBinaryOpTest : public OpsTestBase {
test::RepeatInputToMatchShape(rhs_input, rhs_shape.num_elements()); test::RepeatInputToMatchShape(rhs_input, rhs_shape.num_elements());
RunAndExpectInvalidArgument<T, OutT>(op_name, lhs_shape, repeated_lhs_input, RunAndExpectInvalidArgument<T, OutT>(op_name, lhs_shape, repeated_lhs_input,
rhs_shape, repeated_rhs_input, rhs_shape, repeated_rhs_input, config);
use_constraint);
} }
template <typename T, typename BaselineT, typename OutT, template <typename T, typename BaselineT, typename OutT,
@ -125,7 +131,7 @@ class GpuBinaryOpTest : public OpsTestBase {
const absl::InlinedVector<T, 10>& lhs_input, const absl::InlinedVector<T, 10>& lhs_input,
const absl::InlinedVector<T, 10>& rhs_input, const absl::InlinedVector<T, 10>& rhs_input,
BaselineOutT (*baseline_callback)(BaselineT, BaselineT), BaselineOutT (*baseline_callback)(BaselineT, BaselineT),
bool use_constraint = true) { const test::GpuOpsTestConfig& config) {
// Prepare inputs. // Prepare inputs.
int input_size = shape.num_elements(); int input_size = shape.num_elements();
auto repeated_lhs_input = auto repeated_lhs_input =
@ -147,7 +153,7 @@ class GpuBinaryOpTest : public OpsTestBase {
RunAndExpectResult<T, OutT>(op_name, shape, repeated_lhs_input, shape, RunAndExpectResult<T, OutT>(op_name, shape, repeated_lhs_input, shape,
repeated_rhs_input, shape, expected_output, repeated_rhs_input, shape, expected_output,
use_constraint); config);
} }
template <typename T, typename BaselineT, typename OutT, template <typename T, typename BaselineT, typename OutT,
@ -156,7 +162,7 @@ class GpuBinaryOpTest : public OpsTestBase {
const TensorShape& other_shape, const TensorShape& other_shape,
const absl::InlinedVector<T, 10>& other_input, const absl::InlinedVector<T, 10>& other_input,
BaselineOutT (*baseline_callback)(BaselineT, BaselineT), BaselineOutT (*baseline_callback)(BaselineT, BaselineT),
bool use_constraint = true) { const test::GpuOpsTestConfig& config) {
// Prepare inputs. // Prepare inputs.
TensorShape scalar_shape{}; TensorShape scalar_shape{};
auto repeated_other_input = auto repeated_other_input =
@ -177,7 +183,7 @@ class GpuBinaryOpTest : public OpsTestBase {
RunAndExpectResult<T, OutT>(op_name, scalar_shape, scalar_input_vector, RunAndExpectResult<T, OutT>(op_name, scalar_shape, scalar_input_vector,
other_shape, repeated_other_input, other_shape, repeated_other_input,
/*expected_shape=*/other_shape, expected_output, /*expected_shape=*/other_shape, expected_output,
use_constraint); config);
} }
template <typename T, typename BaselineT, typename OutT, template <typename T, typename BaselineT, typename OutT,
@ -187,7 +193,7 @@ class GpuBinaryOpTest : public OpsTestBase {
const absl::InlinedVector<T, 10>& rhs_input, const absl::InlinedVector<T, 10>& rhs_input,
BaselineOutT (*baseline_callback)(BaselineT, BaselineOutT (*baseline_callback)(BaselineT,
BaselineT), BaselineT),
bool use_constraint = true) { const test::GpuOpsTestConfig& config) {
// Prepare inputs. // Prepare inputs.
TensorShape lhs_shape{1}; TensorShape lhs_shape{1};
TensorShape rhs_shape{6}; TensorShape rhs_shape{6};
@ -206,7 +212,7 @@ class GpuBinaryOpTest : public OpsTestBase {
RunAndExpectResult<T, OutT>( RunAndExpectResult<T, OutT>(
op_name, lhs_shape, repeated_lhs_input, rhs_shape, repeated_rhs_input, op_name, lhs_shape, repeated_lhs_input, rhs_shape, repeated_rhs_input,
/*expected_shape=*/rhs_shape, expected_output, use_constraint); /*expected_shape=*/rhs_shape, expected_output, config);
} }
template <typename T, typename BaselineT, typename OutT, template <typename T, typename BaselineT, typename OutT,
@ -216,7 +222,7 @@ class GpuBinaryOpTest : public OpsTestBase {
const absl::InlinedVector<T, 10>& rhs_input, const absl::InlinedVector<T, 10>& rhs_input,
BaselineOutT (*baseline_callback)(BaselineT, BaselineOutT (*baseline_callback)(BaselineT,
BaselineT), BaselineT),
bool use_constraint = true) { const test::GpuOpsTestConfig& config) {
// Prepare inputs. // Prepare inputs.
TensorShape lhs_shape{3}; TensorShape lhs_shape{3};
TensorShape rhs_shape{2, 3}; TensorShape rhs_shape{2, 3};
@ -235,7 +241,7 @@ class GpuBinaryOpTest : public OpsTestBase {
RunAndExpectResult<T, OutT>( RunAndExpectResult<T, OutT>(
op_name, lhs_shape, repeated_lhs_input, rhs_shape, repeated_rhs_input, op_name, lhs_shape, repeated_lhs_input, rhs_shape, repeated_rhs_input,
/*expected_shape=*/rhs_shape, expected_output, use_constraint); /*expected_shape=*/rhs_shape, expected_output, config);
} }
template <typename T, typename BaselineT, typename OutT, template <typename T, typename BaselineT, typename OutT,
@ -244,7 +250,7 @@ class GpuBinaryOpTest : public OpsTestBase {
const absl::InlinedVector<T, 10>& lhs_input, const absl::InlinedVector<T, 10>& lhs_input,
const absl::InlinedVector<T, 10>& rhs_input, const absl::InlinedVector<T, 10>& rhs_input,
BaselineOutT (*baseline_callback)(BaselineT, BaselineT), BaselineOutT (*baseline_callback)(BaselineT, BaselineT),
bool use_constraint = true) { const test::GpuOpsTestConfig& config) {
// Prepare inputs. // Prepare inputs.
TensorShape lhs_shape{2, 1}; TensorShape lhs_shape{2, 1};
TensorShape rhs_shape{3}; TensorShape rhs_shape{3};
@ -264,7 +270,7 @@ class GpuBinaryOpTest : public OpsTestBase {
RunAndExpectResult<T, OutT>(op_name, lhs_shape, repeated_lhs_input, RunAndExpectResult<T, OutT>(op_name, lhs_shape, repeated_lhs_input,
rhs_shape, repeated_rhs_input, expected_shape, rhs_shape, repeated_rhs_input, expected_shape,
expected_output, use_constraint); expected_output, config);
} }
template <typename T, typename BaselineT, typename OutT, template <typename T, typename BaselineT, typename OutT,
@ -272,7 +278,7 @@ class GpuBinaryOpTest : public OpsTestBase {
void TestEmptyShapeBroadcasting(const std::string& op_name, void TestEmptyShapeBroadcasting(const std::string& op_name,
const absl::InlinedVector<T, 10>& lhs_input, const absl::InlinedVector<T, 10>& lhs_input,
const absl::InlinedVector<T, 10>& rhs_input, const absl::InlinedVector<T, 10>& rhs_input,
bool use_constraint = true) { const test::GpuOpsTestConfig& config) {
// Prepare inputs. // Prepare inputs.
TensorShape lhs_shape{2, 0, 1}; TensorShape lhs_shape{2, 0, 1};
TensorShape rhs_shape{2, 0, 5}; TensorShape rhs_shape{2, 0, 5};
@ -284,7 +290,7 @@ class GpuBinaryOpTest : public OpsTestBase {
RunAndExpectResult<T, OutT>(op_name, lhs_shape, empty_input, rhs_shape, RunAndExpectResult<T, OutT>(op_name, lhs_shape, empty_input, rhs_shape,
empty_input, expected_shape, expected_output, empty_input, expected_shape, expected_output,
use_constraint); config);
} }
private: private:
@ -309,60 +315,60 @@ class GpuBinaryOpTest : public OpsTestBase {
// define your own test fixtures. // define your own test fixtures.
#define GENERATE_DEFAULT_TESTS_2(op_name, test_name, T, BaselineT, OutT, \ #define GENERATE_DEFAULT_TESTS_2(op_name, test_name, T, BaselineT, OutT, \
BaselineOutT, baseline_callback, \ BaselineOutT, baseline_callback, config) \
use_constraint) \
TEST_F(GpuBinaryOpTest, op_name##EqShapes##test_name) { \ TEST_F(GpuBinaryOpTest, op_name##EqShapes##test_name) { \
TestEqualShapes<T, BaselineT, OutT, BaselineOutT>( \ TestEqualShapes<T, BaselineT, OutT, BaselineOutT>( \
#op_name, /*shape=*/test::DefaultInputShape(), \ #op_name, /*shape=*/test::DefaultInputShape(), \
/*lhs_input=*/test::DefaultInput<T>(#op_name), \ /*lhs_input=*/test::DefaultInput<T>(#op_name), \
/*rhs_input=*/test::DefaultInput<T>(#op_name), baseline_callback, \ /*rhs_input=*/test::DefaultInput<T>(#op_name), baseline_callback, \
use_constraint); \ config); \
} \ } \
\ \
TEST_F(GpuBinaryOpTest, op_name##OneScalar##test_name) { \ TEST_F(GpuBinaryOpTest, op_name##OneScalar##test_name) { \
TestOneScalar<T, BaselineT, OutT, BaselineOutT>( \ TestOneScalar<T, BaselineT, OutT, BaselineOutT>( \
#op_name, /*scalar_input=*/test::DefaultScalarInput<T>(), \ #op_name, /*scalar_input=*/test::DefaultInput<T>(#op_name).front(), \
/*other_shape=*/test::DefaultInputShape(), \ /*other_shape=*/test::DefaultInputShape(), \
/*other_input=*/test::DefaultInput<T>(#op_name), baseline_callback, \ /*other_input=*/test::DefaultInput<T>(#op_name), baseline_callback, \
use_constraint); \ config); \
} \ } \
\ \
TEST_F(GpuBinaryOpTest, op_name##IncompatibleShapes##test_name) { \ TEST_F(GpuBinaryOpTest, op_name##IncompatibleShapes##test_name) { \
TestIncompatibleShapes<T, OutT>( \ TestIncompatibleShapes<T, OutT>( \
#op_name, /*lhs_input=*/test::DefaultInput<T>(#op_name), \ #op_name, /*lhs_input=*/test::DefaultInput<T>(#op_name), \
/*rhs_input=*/test::DefaultInput<T>(#op_name), use_constraint); \ /*rhs_input=*/test::DefaultInput<T>(#op_name), config); \
} \ } \
\ \
TEST_F(GpuBinaryOpTest, op_name##BroadcastingExpand##test_name) { \ TEST_F(GpuBinaryOpTest, op_name##BroadcastingExpand##test_name) { \
TestBroadcastingExpand<T, BaselineT, OutT, BaselineOutT>( \ TestBroadcastingExpand<T, BaselineT, OutT, BaselineOutT>( \
#op_name, /*lhs_input=*/test::DefaultInput<T>(#op_name), \ #op_name, /*lhs_input=*/test::DefaultInput<T>(#op_name), \
/*rhs_input=*/test::DefaultInput<T>(#op_name), baseline_callback, \ /*rhs_input=*/test::DefaultInput<T>(#op_name), baseline_callback, \
use_constraint); \ config); \
} \ } \
\ \
TEST_F(GpuBinaryOpTest, op_name##BroadcastingInDim##test_name) { \ TEST_F(GpuBinaryOpTest, op_name##BroadcastingInDim##test_name) { \
TestBroadcastingInDim<T, BaselineT, OutT, BaselineOutT>( \ TestBroadcastingInDim<T, BaselineT, OutT, BaselineOutT>( \
#op_name, /*lhs_input=*/test::DefaultInput<T>(#op_name), \ #op_name, /*lhs_input=*/test::DefaultInput<T>(#op_name), \
/*rhs_input=*/test::DefaultInput<T>(#op_name), baseline_callback, \ /*rhs_input=*/test::DefaultInput<T>(#op_name), baseline_callback, \
use_constraint); \ config); \
} \ } \
\ \
TEST_F(GpuBinaryOpTest, op_name##Broadcasting##test_name) { \ TEST_F(GpuBinaryOpTest, op_name##Broadcasting##test_name) { \
TestBroadcasting<T, BaselineT, OutT, BaselineOutT>( \ TestBroadcasting<T, BaselineT, OutT, BaselineOutT>( \
#op_name, /*lhs_input=*/test::DefaultInput<T>(#op_name), \ #op_name, /*lhs_input=*/test::DefaultInput<T>(#op_name), \
/*rhs_input=*/test::DefaultInput<T>(#op_name), baseline_callback, \ /*rhs_input=*/test::DefaultInput<T>(#op_name), baseline_callback, \
use_constraint); \ config); \
} \ } \
\ \
TEST_F(GpuBinaryOpTest, op_name##EmptyShapeBroadcasting##test_name) { \ TEST_F(GpuBinaryOpTest, op_name##EmptyShapeBroadcasting##test_name) { \
TestEmptyShapeBroadcasting<T, BaselineT, OutT, BaselineOutT>( \ TestEmptyShapeBroadcasting<T, BaselineT, OutT, BaselineOutT>( \
#op_name, /*lhs_input=*/test::DefaultInput<T>(#op_name), \ #op_name, /*lhs_input=*/test::DefaultInput<T>(#op_name), \
/*rhs_input=*/test::DefaultInput<T>(#op_name), use_constraint); \ /*rhs_input=*/test::DefaultInput<T>(#op_name), config); \
} }
#define GENERATE_DEFAULT_TESTS(op_name, test_name, T, OutT, baseline_callback) \ #define GENERATE_DEFAULT_TESTS(op_name, test_name, T, OutT, baseline_callback) \
GENERATE_DEFAULT_TESTS_2(op_name, test_name, T, T, OutT, OutT, \ GENERATE_DEFAULT_TESTS_2(op_name, test_name, T, T, OutT, OutT, \
baseline_callback, /*use_constraint=*/false) baseline_callback, \
test::GpuOpsTestConfig().ExpectStrictlyEqual())
#define GENERATE_DEFAULT_TESTS_SAME_INPUT_AND_OUTPUT_TYPE( \ #define GENERATE_DEFAULT_TESTS_SAME_INPUT_AND_OUTPUT_TYPE( \
op_name, test_name, T, baseline_callback) \ op_name, test_name, T, baseline_callback) \
@ -433,37 +439,23 @@ GENERATE_DEFAULT_TESTS(BitwiseXor,
GENERATE_DEFAULT_TESTS(BitwiseXor, GENERATE_DEFAULT_TESTS(BitwiseXor,
/*test_name=*/Int64, int64, int64, baseline_bitwise_xor) /*test_name=*/Int64, int64, int64, baseline_bitwise_xor)
/// Test `tf.LeftShift`. /// Test `tf.Div`.
template <typename T> template <typename T>
T baseline_left_shift(T lhs, T rhs) { T baseline_div(T lhs, T rhs) {
return lhs << rhs; return lhs / rhs;
} }
GENERATE_DEFAULT_TESTS(LeftShift, /*test_name=*/Int8, int8, int8, GENERATE_DEFAULT_TESTS(Div,
baseline_left_shift) /*test_name=*/Half, Eigen::half, Eigen::half,
GENERATE_DEFAULT_TESTS(LeftShift, /*test_name=*/Int16, int16, int16, baseline_div);
baseline_left_shift) GENERATE_DEFAULT_TESTS(Div,
GENERATE_DEFAULT_TESTS(LeftShift, /*test_name=*/Int32, int32, int32, /*test_name=*/Float, float, float, baseline_div);
baseline_left_shift) GENERATE_DEFAULT_TESTS(Div,
GENERATE_DEFAULT_TESTS(LeftShift, /*test_name=*/Int64, int64, int64, /*test_name=*/Double, double, double, baseline_div);
baseline_left_shift) GENERATE_DEFAULT_TESTS(Div,
/*test_name=*/Int16, int16, int16, baseline_div);
/// Test `tf.RightShift`. GENERATE_DEFAULT_TESTS(Div,
/*test_name=*/Int64, int64, int64, baseline_div);
template <typename T>
T baseline_right_shift(T lhs, T rhs) {
return lhs >> rhs;
}
GENERATE_DEFAULT_TESTS(RightShift,
/*test_name=*/Int8, int8, int8, baseline_right_shift)
GENERATE_DEFAULT_TESTS(RightShift,
/*test_name=*/Int16, int16, int16, baseline_right_shift)
GENERATE_DEFAULT_TESTS(RightShift,
/*test_name=*/Int32, int32, int32, baseline_right_shift)
GENERATE_DEFAULT_TESTS(RightShift,
/*test_name=*/Int64, int64, int64, baseline_right_shift)
/// Test `tf.Equal`. /// Test `tf.Equal`.
@ -482,27 +474,25 @@ GENERATE_DEFAULT_TESTS(Equal, /*test_name=*/Int8, int8, bool, baseline_equal)
GENERATE_DEFAULT_TESTS(Equal, /*test_name=*/Int16, int16, bool, baseline_equal) GENERATE_DEFAULT_TESTS(Equal, /*test_name=*/Int16, int16, bool, baseline_equal)
GENERATE_DEFAULT_TESTS(Equal, /*test_name=*/Int64, int64, bool, baseline_equal) GENERATE_DEFAULT_TESTS(Equal, /*test_name=*/Int64, int64, bool, baseline_equal)
/// Test `tf.NotEqual`. /// Test `tf.FloorDiv`.
template <typename T> template <typename T>
bool baseline_not_equal(T lhs, T rhs) { T baseline_floor_div(T lhs, T rhs) {
return lhs != rhs; return std::floor(lhs / rhs);
} }
GENERATE_DEFAULT_TESTS(NotEqual, /*test_name=*/Half, Eigen::half, bool, template <>
baseline_not_equal) Eigen::half baseline_floor_div(Eigen::half lhs, Eigen::half rhs) {
GENERATE_DEFAULT_TESTS(NotEqual, /*test_name=*/Float, float, bool, return static_cast<Eigen::half>(std::floor(static_cast<float>(lhs / rhs)));
baseline_not_equal) }
GENERATE_DEFAULT_TESTS(NotEqual, /*test_name=*/Double, double, bool,
baseline_not_equal) GENERATE_DEFAULT_TESTS(FloorDiv,
GENERATE_DEFAULT_TESTS(NotEqual, /*test_name=*/Bool, bool, bool, /*test_name=*/Half, Eigen::half, Eigen::half,
baseline_not_equal) baseline_floor_div)
GENERATE_DEFAULT_TESTS(NotEqual, /*test_name=*/Int8, int8, bool, GENERATE_DEFAULT_TESTS(FloorDiv,
baseline_not_equal) /*test_name=*/Float, float, float, baseline_floor_div)
GENERATE_DEFAULT_TESTS(NotEqual, /*test_name=*/Int16, int16, bool, GENERATE_DEFAULT_TESTS(FloorDiv,
baseline_not_equal) /*test_name=*/Double, double, double, baseline_floor_div)
GENERATE_DEFAULT_TESTS(NotEqual, /*test_name=*/Int64, int64, bool,
baseline_not_equal)
/// Test `tf.Greater`. /// Test `tf.Greater`.
@ -544,6 +534,22 @@ GENERATE_DEFAULT_TESTS(GreaterEqual, /*test_name=*/Int16, int16, bool,
GENERATE_DEFAULT_TESTS(GreaterEqual, /*test_name=*/Int64, int64, bool, GENERATE_DEFAULT_TESTS(GreaterEqual, /*test_name=*/Int64, int64, bool,
baseline_greater_equal) baseline_greater_equal)
/// Test `tf.LeftShift`.
template <typename T>
T baseline_left_shift(T lhs, T rhs) {
return lhs << rhs;
}
GENERATE_DEFAULT_TESTS(LeftShift, /*test_name=*/Int8, int8, int8,
baseline_left_shift)
GENERATE_DEFAULT_TESTS(LeftShift, /*test_name=*/Int16, int16, int16,
baseline_left_shift)
GENERATE_DEFAULT_TESTS(LeftShift, /*test_name=*/Int32, int32, int32,
baseline_left_shift)
GENERATE_DEFAULT_TESTS(LeftShift, /*test_name=*/Int64, int64, int64,
baseline_left_shift)
/// Test `tf.Less`. /// Test `tf.Less`.
template <typename T> template <typename T>
@ -586,7 +592,7 @@ bool baseline_logical_and(bool lhs, bool rhs) { return lhs && rhs; }
GENERATE_DEFAULT_TESTS_2(LogicalAnd, /*test_name=*/Bool, /*T=*/bool, GENERATE_DEFAULT_TESTS_2(LogicalAnd, /*test_name=*/Bool, /*T=*/bool,
/*BaselineT=*/bool, /*OutT=*/bool, /*BaselineT=*/bool, /*OutT=*/bool,
/*BaselineOutT=*/bool, baseline_logical_and, /*BaselineOutT=*/bool, baseline_logical_and,
/*use_constraint=*/false) test::GpuOpsTestConfig().ExpectStrictlyEqual().NoT())
/// Test `tf.LogicalOr`. /// Test `tf.LogicalOr`.
@ -595,27 +601,77 @@ bool baseline_logical_or(bool lhs, bool rhs) { return lhs || rhs; }
GENERATE_DEFAULT_TESTS_2(LogicalOr, /*test_name=*/Bool, /*T=*/bool, GENERATE_DEFAULT_TESTS_2(LogicalOr, /*test_name=*/Bool, /*T=*/bool,
/*BaselineT=*/bool, /*OutT=*/bool, /*BaselineT=*/bool, /*OutT=*/bool,
/*BaselineOutT=*/bool, baseline_logical_or, /*BaselineOutT=*/bool, baseline_logical_or,
/*use_constraint=*/false) test::GpuOpsTestConfig().ExpectStrictlyEqual().NoT())
/// Test `tf.Mul`.
/// Test `tf.FloorDiv`.
template <typename T> template <typename T>
T baseline_floor_div(T lhs, T rhs) { T baseline_mul(T lhs, T rhs) {
return std::floor(lhs / rhs); return lhs * rhs;
} }
template <> GENERATE_DEFAULT_TESTS(Mul, /*test_name=*/Half, Eigen::half, Eigen::half,
Eigen::half baseline_floor_div(Eigen::half lhs, Eigen::half rhs) { baseline_mul)
return static_cast<Eigen::half>(std::floor(static_cast<float>(lhs / rhs))); GENERATE_DEFAULT_TESTS(Mul, /*test_name=*/Float, float, float, baseline_mul)
GENERATE_DEFAULT_TESTS(Mul, /*test_name=*/Double, double, double, baseline_mul)
GENERATE_DEFAULT_TESTS(Mul, /*test_name=*/Int8, int8, int8, baseline_mul)
GENERATE_DEFAULT_TESTS(Mul, /*test_name=*/Int16, int16, int16, baseline_mul)
GENERATE_DEFAULT_TESTS(Mul, /*test_name=*/Int64, int64, int64, baseline_mul)
/// Test `tf.NotEqual`.
template <typename T>
bool baseline_not_equal(T lhs, T rhs) {
return lhs != rhs;
} }
GENERATE_DEFAULT_TESTS(FloorDiv, GENERATE_DEFAULT_TESTS(NotEqual, /*test_name=*/Half, Eigen::half, bool,
baseline_not_equal)
GENERATE_DEFAULT_TESTS(NotEqual, /*test_name=*/Float, float, bool,
baseline_not_equal)
GENERATE_DEFAULT_TESTS(NotEqual, /*test_name=*/Double, double, bool,
baseline_not_equal)
GENERATE_DEFAULT_TESTS(NotEqual, /*test_name=*/Bool, bool, bool,
baseline_not_equal)
GENERATE_DEFAULT_TESTS(NotEqual, /*test_name=*/Int8, int8, bool,
baseline_not_equal)
GENERATE_DEFAULT_TESTS(NotEqual, /*test_name=*/Int16, int16, bool,
baseline_not_equal)
GENERATE_DEFAULT_TESTS(NotEqual, /*test_name=*/Int64, int64, bool,
baseline_not_equal)
/// Test `tf.RightShift`.
template <typename T>
T baseline_right_shift(T lhs, T rhs) {
return lhs >> rhs;
}
GENERATE_DEFAULT_TESTS(RightShift,
/*test_name=*/Int8, int8, int8, baseline_right_shift)
GENERATE_DEFAULT_TESTS(RightShift,
/*test_name=*/Int16, int16, int16, baseline_right_shift)
GENERATE_DEFAULT_TESTS(RightShift,
/*test_name=*/Int32, int32, int32, baseline_right_shift)
GENERATE_DEFAULT_TESTS(RightShift,
/*test_name=*/Int64, int64, int64, baseline_right_shift)
/// Test `tf.Sub`.
template <typename T>
T baseline_sub(T lhs, T rhs) {
return lhs - rhs;
}
GENERATE_DEFAULT_TESTS(Sub,
/*test_name=*/Half, Eigen::half, Eigen::half, /*test_name=*/Half, Eigen::half, Eigen::half,
baseline_floor_div); baseline_sub)
GENERATE_DEFAULT_TESTS(FloorDiv, GENERATE_DEFAULT_TESTS(Sub,
/*test_name=*/Float, float, float, baseline_floor_div); /*test_name=*/Float, float, float, baseline_sub)
GENERATE_DEFAULT_TESTS(FloorDiv, GENERATE_DEFAULT_TESTS(Sub,
/*test_name=*/Double, double, double, /*test_name=*/Double, double, double, baseline_sub)
baseline_floor_div); GENERATE_DEFAULT_TESTS(Sub,
/*test_name=*/Int64, int64, int64, baseline_sub)
} // namespace } // namespace
} // end namespace tensorflow } // end namespace tensorflow

View File

@ -0,0 +1,27 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
namespace tensorflow {
GENERATE_AND_REGISTER_BINARY_KERNEL(Div, f16, DT_HALF, Eigen::half);
GENERATE_AND_REGISTER_BINARY_KERNEL(Div, f32, DT_FLOAT, float);
GENERATE_AND_REGISTER_BINARY_KERNEL(Div, f64, DT_DOUBLE, double);
GENERATE_AND_REGISTER_BINARY_KERNEL(Div, i16, DT_INT16, int16);
GENERATE_AND_REGISTER_BINARY_KERNEL(Div, i64, DT_INT64, int64);
} // namespace tensorflow

View File

@ -0,0 +1,26 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
namespace tensorflow {
GENERATE_AND_REGISTER_BINARY_KERNEL(Sub, f16, DT_HALF, Eigen::half);
GENERATE_AND_REGISTER_BINARY_KERNEL(Sub, f32, DT_FLOAT, float);
GENERATE_AND_REGISTER_BINARY_KERNEL(Sub, f64, DT_DOUBLE, double);
GENERATE_AND_REGISTER_BINARY_KERNEL(Sub, i64, DT_INT64, int64);
} // namespace tensorflow

View File

@ -57,6 +57,7 @@ TensorShape DefaultInputShape();
struct GpuOpsTestConfig { struct GpuOpsTestConfig {
bool add_t = true; bool add_t = true;
bool add_tout = false; bool add_tout = false;
// Only used for gpu_unary_ops_test.
bool expect_buffer_reuse = true; bool expect_buffer_reuse = true;
bool expect_strictly_equal = false; bool expect_strictly_equal = false;
GpuOpsTestConfig ExpectStrictlyEqual() { GpuOpsTestConfig ExpectStrictlyEqual() {
@ -119,33 +120,10 @@ absl::InlinedVector<T, 10> DefaultInputGreaterOrEqualToZero() {
/// Helper functions to get default input data. /// Helper functions to get default input data.
template <typename T,
std::enable_if_t<llvm::is_one_of<T, int8, int16, int32, int64>::value,
bool> = true>
T DefaultScalarInput() {
return static_cast<T>(3);
}
template <typename T, std::enable_if_t<
llvm::is_one_of<T, Eigen::half, float, double>::value,
bool> = true>
T DefaultScalarInput() {
return static_cast<T>(2.0);
}
template <typename T,
std::enable_if_t<llvm::is_one_of<T, bool>::value, bool> = true>
T DefaultScalarInput() {
return static_cast<T>(true);
}
template <typename T, template <typename T,
std::enable_if_t<llvm::is_one_of<T, int8, int16, int32, int64>::value, std::enable_if_t<llvm::is_one_of<T, int8, int16, int32, int64>::value,
bool> = true> bool> = true>
absl::InlinedVector<T, 10> DefaultInput(absl::string_view op_name) { absl::InlinedVector<T, 10> DefaultInput(absl::string_view op_name) {
if (op_name == "Abs") {
return NearZeroAndExtremeInput<T>();
}
// Only generate values less than the bitwidth of the data type. // Only generate values less than the bitwidth of the data type.
if (op_name == "LeftShift" || op_name == "RightShift") { if (op_name == "LeftShift" || op_name == "RightShift") {
auto max_shift = sizeof(T) * 8 - 1; auto max_shift = sizeof(T) * 8 - 1;
@ -153,6 +131,9 @@ absl::InlinedVector<T, 10> DefaultInput(absl::string_view op_name) {
for (auto i = 0; i < max_shift; ++i) v.push_back(i); for (auto i = 0; i < max_shift; ++i) v.push_back(i);
return v; return v;
} }
if (op_name == "Div") {
return InputAsVector<T, int>({-18, -9, 9, 18});
}
return InputAsVector<T, int>({-18, -9, -1, 0, 0, 1, 1, 2, 3, 5, 7, 9, 9, 18}); return InputAsVector<T, int>({-18, -9, -1, 0, 0, 1, 1, 2, 3, 5, 7, 9, 9, 18});
} }
@ -160,16 +141,7 @@ template <typename T, std::enable_if_t<
llvm::is_one_of<T, Eigen::half, float, double>::value, llvm::is_one_of<T, Eigen::half, float, double>::value,
bool> = true> bool> = true>
absl::InlinedVector<T, 10> DefaultInput(absl::string_view op_name) { absl::InlinedVector<T, 10> DefaultInput(absl::string_view op_name) {
if (op_name == "Abs") { if (op_name == "Div" || op_name == "FloorDiv") {
return NearZeroAndExtremeInput<T>();
}
if (op_name == "Log" || op_name == "Rsqrt") {
return DefaultInputGreaterThanZero<T>();
}
if (op_name == "Sqrt") {
return DefaultInputGreaterOrEqualToZero<T>();
}
if (op_name == "FloorDiv") {
return InputAsVector<T, double>({-18.0, -9.0, -1e-6, -0.1, 0.1, 1e-6, 0.1, return InputAsVector<T, double>({-18.0, -9.0, -1e-6, -0.1, 0.1, 1e-6, 0.1,
0.2, 0.3, 0.5, 0.7, 0.9, 9.0, 18.0}); 0.2, 0.3, 0.5, 0.7, 0.9, 9.0, 18.0});
} }

View File

@ -125,37 +125,55 @@ class GpuUnaryOpTest : public OpsTestBase {
// define your own test fixtures. // define your own test fixtures.
#define GENERATE_DEFAULT_TEST(op_name, InT, OutT, baseline_callback, config) \ #define GENERATE_DEFAULT_TEST(op_name, InT, OutT, baseline_callback, config) \
GENERATE_DEFAULT_TEST2(op_name, InT, InT, OutT, OutT, baseline_callback, \ GENERATE_DEFAULT_TEST_2(op_name, InT, InT, OutT, OutT, baseline_callback, \
config) config)
#define GENERATE_DEFAULT_TEST2(op_name, InT, BaselineT, OutT, BaselineOutT, \ #define GENERATE_DEFAULT_TEST_2(op_name, InT, BaselineT, OutT, BaselineOutT, \
baseline_callback, config) \ baseline_callback, config) \
TEST_F(GpuUnaryOpTest, op_name##InT) { \ GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES_2( \
using NativeT = EnumToDataType<InT>::Type; \ op_name, InT, BaselineT, OutT, BaselineOutT, \
using NativeBaselineT = EnumToDataType<BaselineT>::Type; \ test::DefaultInput<NativeT>(#op_name), baseline_callback, config)
using NativeOutT = EnumToDataType<OutT>::Type; \
using NativeBaselineOutT = EnumToDataType<BaselineOutT>::Type; \ #define GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES( \
Test<NativeT, NativeBaselineT, NativeOutT, NativeBaselineOutT>( \ op_name, InT, OutT, input_values, baseline_callback, config) \
#op_name, test::DefaultInputShape(), \ GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES_2( \
test::DefaultInput<NativeT>(#op_name), baseline_callback, config); \ op_name, InT, InT, OutT, OutT, input_values, baseline_callback, config)
#define GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES_2( \
op_name, InT, BaselineT, OutT, BaselineOutT, input_values, \
baseline_callback, config) \
TEST_F(GpuUnaryOpTest, op_name##InT) { \
using NativeT = EnumToDataType<InT>::Type; \
using NativeBaselineT = EnumToDataType<BaselineT>::Type; \
using NativeOutT = EnumToDataType<OutT>::Type; \
using NativeBaselineOutT = EnumToDataType<BaselineOutT>::Type; \
Test<NativeT, NativeBaselineT, NativeOutT, NativeBaselineOutT>( \
#op_name, test::DefaultInputShape(), input_values, baseline_callback, \
config); \
} }
/// Test `tf.Abs`. /// Test `tf.Abs`.
GENERATE_DEFAULT_TEST(Abs, DT_FLOAT, DT_FLOAT, std::abs, GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES(
test::GpuOpsTestConfig().ExpectStrictlyEqual()) Abs, DT_FLOAT, DT_FLOAT, test::NearZeroAndExtremeInput<float>(), std::abs,
test::GpuOpsTestConfig().ExpectStrictlyEqual())
GENERATE_DEFAULT_TEST(Abs, DT_DOUBLE, DT_DOUBLE, std::abs, GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES(
test::GpuOpsTestConfig().ExpectStrictlyEqual()) Abs, DT_DOUBLE, DT_DOUBLE, test::NearZeroAndExtremeInput<double>(),
std::abs, test::GpuOpsTestConfig().ExpectStrictlyEqual())
GENERATE_DEFAULT_TEST2(Abs, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT, std::abs, GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES_2(
test::GpuOpsTestConfig().ExpectStrictlyEqual()) Abs, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT,
test::NearZeroAndExtremeInput<Eigen::half>(), std::abs,
test::GpuOpsTestConfig().ExpectStrictlyEqual())
GENERATE_DEFAULT_TEST(Abs, DT_INT32, DT_INT32, std::abs, GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES(
test::GpuOpsTestConfig().ExpectStrictlyEqual()) Abs, DT_INT32, DT_INT32, test::NearZeroAndExtremeInput<int32>(), std::abs,
test::GpuOpsTestConfig().ExpectStrictlyEqual())
GENERATE_DEFAULT_TEST(Abs, DT_INT64, DT_INT64, std::abs, GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES(
test::GpuOpsTestConfig().ExpectStrictlyEqual()) Abs, DT_INT64, DT_INT64, test::NearZeroAndExtremeInput<int64>(), std::abs,
test::GpuOpsTestConfig().ExpectStrictlyEqual())
/// Test `tf.Ceil`. /// Test `tf.Ceil`.
@ -165,8 +183,8 @@ GENERATE_DEFAULT_TEST(Ceil, DT_FLOAT, DT_FLOAT, std::ceil,
GENERATE_DEFAULT_TEST(Ceil, DT_DOUBLE, DT_DOUBLE, std::ceil, GENERATE_DEFAULT_TEST(Ceil, DT_DOUBLE, DT_DOUBLE, std::ceil,
test::GpuOpsTestConfig().ExpectStrictlyEqual()) test::GpuOpsTestConfig().ExpectStrictlyEqual())
GENERATE_DEFAULT_TEST2(Ceil, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT, std::ceil, GENERATE_DEFAULT_TEST_2(Ceil, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT, std::ceil,
test::GpuOpsTestConfig().ExpectStrictlyEqual()) test::GpuOpsTestConfig().ExpectStrictlyEqual())
/// Test `tf.Conj`. /// Test `tf.Conj`.
@ -189,8 +207,8 @@ GENERATE_DEFAULT_TEST(Cos, DT_FLOAT, DT_FLOAT, std::cos,
GENERATE_DEFAULT_TEST(Cos, DT_DOUBLE, DT_DOUBLE, std::cos, GENERATE_DEFAULT_TEST(Cos, DT_DOUBLE, DT_DOUBLE, std::cos,
test::GpuOpsTestConfig()) test::GpuOpsTestConfig())
GENERATE_DEFAULT_TEST2(Cos, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT, std::cos, GENERATE_DEFAULT_TEST_2(Cos, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT, std::cos,
test::GpuOpsTestConfig()) test::GpuOpsTestConfig())
/// Test `tf.Exp`. /// Test `tf.Exp`.
@ -200,8 +218,8 @@ GENERATE_DEFAULT_TEST(Exp, DT_FLOAT, DT_FLOAT, std::exp,
GENERATE_DEFAULT_TEST(Exp, DT_DOUBLE, DT_DOUBLE, std::exp, GENERATE_DEFAULT_TEST(Exp, DT_DOUBLE, DT_DOUBLE, std::exp,
test::GpuOpsTestConfig()) test::GpuOpsTestConfig())
GENERATE_DEFAULT_TEST2(Exp, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT, std::exp, GENERATE_DEFAULT_TEST_2(Exp, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT, std::exp,
test::GpuOpsTestConfig()) test::GpuOpsTestConfig())
/// Test `tf.Floor`. /// Test `tf.Floor`.
@ -211,8 +229,8 @@ GENERATE_DEFAULT_TEST(Floor, DT_FLOAT, DT_FLOAT, std::floor,
GENERATE_DEFAULT_TEST(Floor, DT_DOUBLE, DT_DOUBLE, std::floor, GENERATE_DEFAULT_TEST(Floor, DT_DOUBLE, DT_DOUBLE, std::floor,
test::GpuOpsTestConfig().ExpectStrictlyEqual()) test::GpuOpsTestConfig().ExpectStrictlyEqual())
GENERATE_DEFAULT_TEST2(Floor, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT, std::floor, GENERATE_DEFAULT_TEST_2(Floor, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT, std::floor,
test::GpuOpsTestConfig().ExpectStrictlyEqual()) test::GpuOpsTestConfig().ExpectStrictlyEqual())
/// Test `tf.Imag`. /// Test `tf.Imag`.
@ -260,14 +278,18 @@ TEST_F(GpuUnaryOpTest, DISABLED_IsInfHalf) {
/// Test `tf.Log`. /// Test `tf.Log`.
GENERATE_DEFAULT_TEST(Log, DT_FLOAT, DT_FLOAT, std::log, GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES(
test::GpuOpsTestConfig()) Log, DT_FLOAT, DT_FLOAT, test::DefaultInputGreaterThanZero<float>(),
std::log, test::GpuOpsTestConfig())
GENERATE_DEFAULT_TEST(Log, DT_DOUBLE, DT_DOUBLE, std::log, GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES(
test::GpuOpsTestConfig()) Log, DT_DOUBLE, DT_DOUBLE, test::DefaultInputGreaterThanZero<double>(),
std::log, test::GpuOpsTestConfig())
GENERATE_DEFAULT_TEST2(Log, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT, std::log, GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES_2(
test::GpuOpsTestConfig()) Log, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT,
test::DefaultInputGreaterThanZero<Eigen::half>(), std::log,
test::GpuOpsTestConfig())
/// Test `tf.LogicalNot` /// Test `tf.LogicalNot`
@ -290,8 +312,8 @@ GENERATE_DEFAULT_TEST(Neg, DT_FLOAT, DT_FLOAT, baseline_neg,
GENERATE_DEFAULT_TEST(Neg, DT_DOUBLE, DT_DOUBLE, baseline_neg, GENERATE_DEFAULT_TEST(Neg, DT_DOUBLE, DT_DOUBLE, baseline_neg,
test::GpuOpsTestConfig().ExpectStrictlyEqual()) test::GpuOpsTestConfig().ExpectStrictlyEqual())
GENERATE_DEFAULT_TEST2(Neg, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT, baseline_neg, GENERATE_DEFAULT_TEST_2(Neg, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT, baseline_neg,
test::GpuOpsTestConfig()) test::GpuOpsTestConfig())
GENERATE_DEFAULT_TEST(Neg, DT_INT8, DT_INT8, baseline_neg, GENERATE_DEFAULT_TEST(Neg, DT_INT8, DT_INT8, baseline_neg,
test::GpuOpsTestConfig().ExpectStrictlyEqual()) test::GpuOpsTestConfig().ExpectStrictlyEqual())
@ -323,14 +345,18 @@ T baseline_rsqrt(T x) {
return 1.0 / std::sqrt(x); return 1.0 / std::sqrt(x);
} }
GENERATE_DEFAULT_TEST(Rsqrt, DT_FLOAT, DT_FLOAT, baseline_rsqrt, GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES(
test::GpuOpsTestConfig()) Rsqrt, DT_FLOAT, DT_FLOAT, test::DefaultInputGreaterThanZero<float>(),
baseline_rsqrt, test::GpuOpsTestConfig())
GENERATE_DEFAULT_TEST(Rsqrt, DT_DOUBLE, DT_DOUBLE, baseline_rsqrt, GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES(
test::GpuOpsTestConfig()) Rsqrt, DT_DOUBLE, DT_DOUBLE, test::DefaultInputGreaterThanZero<double>(),
baseline_rsqrt, test::GpuOpsTestConfig())
GENERATE_DEFAULT_TEST2(Rsqrt, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT, GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES_2(
baseline_rsqrt, test::GpuOpsTestConfig()) Rsqrt, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT,
test::DefaultInputGreaterThanZero<Eigen::half>(), baseline_rsqrt,
test::GpuOpsTestConfig())
/// Test `tf.Sign`. /// Test `tf.Sign`.
@ -350,8 +376,8 @@ GENERATE_DEFAULT_TEST(Sign, DT_DOUBLE, DT_DOUBLE, baseline_sign,
// TODO(b/162577610): We should actually use ExpectStrictlyEqual() // TODO(b/162577610): We should actually use ExpectStrictlyEqual()
// here. This requires returning 0.0 for input -0.0. // here. This requires returning 0.0 for input -0.0.
GENERATE_DEFAULT_TEST2(Sign, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT, GENERATE_DEFAULT_TEST_2(Sign, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT,
baseline_sign, test::GpuOpsTestConfig()) baseline_sign, test::GpuOpsTestConfig())
GENERATE_DEFAULT_TEST(Sign, DT_INT64, DT_INT64, baseline_sign, GENERATE_DEFAULT_TEST(Sign, DT_INT64, DT_INT64, baseline_sign,
test::GpuOpsTestConfig().ExpectStrictlyEqual()) test::GpuOpsTestConfig().ExpectStrictlyEqual())
@ -364,19 +390,24 @@ GENERATE_DEFAULT_TEST(Sin, DT_FLOAT, DT_FLOAT, std::sin,
GENERATE_DEFAULT_TEST(Sin, DT_DOUBLE, DT_DOUBLE, std::sin, GENERATE_DEFAULT_TEST(Sin, DT_DOUBLE, DT_DOUBLE, std::sin,
test::GpuOpsTestConfig()) test::GpuOpsTestConfig())
GENERATE_DEFAULT_TEST2(Sin, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT, std::sin, GENERATE_DEFAULT_TEST_2(Sin, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT, std::sin,
test::GpuOpsTestConfig()) test::GpuOpsTestConfig())
/// Test `tf.Sqrt`. /// Test `tf.Sqrt`.
GENERATE_DEFAULT_TEST(Sqrt, DT_FLOAT, DT_FLOAT, std::sqrt, GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES(
test::GpuOpsTestConfig()) Sqrt, DT_FLOAT, DT_FLOAT, test::DefaultInputGreaterOrEqualToZero<float>(),
std::sqrt, test::GpuOpsTestConfig())
GENERATE_DEFAULT_TEST(Sqrt, DT_DOUBLE, DT_DOUBLE, std::sqrt, GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES(
test::GpuOpsTestConfig()) Sqrt, DT_DOUBLE, DT_DOUBLE,
test::DefaultInputGreaterOrEqualToZero<double>(), std::sqrt,
test::GpuOpsTestConfig())
GENERATE_DEFAULT_TEST2(Sqrt, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT, std::sqrt, GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES_2(
test::GpuOpsTestConfig()) Sqrt, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT,
test::DefaultInputGreaterOrEqualToZero<Eigen::half>(), std::sqrt,
test::GpuOpsTestConfig())
/// Test `tf.Tanh`. /// Test `tf.Tanh`.
@ -386,8 +417,8 @@ GENERATE_DEFAULT_TEST(Tanh, DT_FLOAT, DT_FLOAT, std::tanh,
GENERATE_DEFAULT_TEST(Tanh, DT_DOUBLE, DT_DOUBLE, std::tanh, GENERATE_DEFAULT_TEST(Tanh, DT_DOUBLE, DT_DOUBLE, std::tanh,
test::GpuOpsTestConfig()) test::GpuOpsTestConfig())
GENERATE_DEFAULT_TEST2(Tanh, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT, std::tanh, GENERATE_DEFAULT_TEST_2(Tanh, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT, std::tanh,
test::GpuOpsTestConfig()) test::GpuOpsTestConfig())
} // namespace } // namespace
} // end namespace tensorflow } // end namespace tensorflow

View File

@ -0,0 +1,6 @@
func @Div_elem_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>)
-> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} {
%0 = "tf.Div"(%arg0, %arg1) {T = elem_type, device = ""}
: (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xelem_type>
return %0 : tensor<*xelem_type>
}

View File

@ -0,0 +1,6 @@
func @Sub_elem_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>)
-> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} {
%0 = "tf.Sub"(%arg0, %arg1) {T = elem_type, device = ""}
: (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xelem_type>
return %0 : tensor<*xelem_type>
}

View File

@ -137,4 +137,14 @@ REGISTER_KERNEL_BUILDER(
Name("QuantizedMaxPool").Device(DEVICE_CPU).TypeConstraint<quint8>("T"), Name("QuantizedMaxPool").Device(DEVICE_CPU).TypeConstraint<quint8>("T"),
QuantizedMaxPoolingOp<CPUDevice, quint8>); QuantizedMaxPoolingOp<CPUDevice, quint8>);
#ifdef INTEL_MKL
REGISTER_KERNEL_BUILDER(
Name("QuantizedAvgPool").Device(DEVICE_CPU).TypeConstraint<qint8>("T"),
QuantizedAvgPoolingOp<CPUDevice, qint8>);
REGISTER_KERNEL_BUILDER(
Name("QuantizedMaxPool").Device(DEVICE_CPU).TypeConstraint<qint8>("T"),
QuantizedMaxPoolingOp<CPUDevice, qint8>);
#endif
} // namespace tensorflow } // namespace tensorflow

View File

@ -232,11 +232,7 @@ REGISTER_OP("_FusedBatchNormEx")
.Output("reserve_space_1: U") .Output("reserve_space_1: U")
.Output("reserve_space_2: U") .Output("reserve_space_2: U")
.Output("reserve_space_3: U") .Output("reserve_space_3: U")
#ifdef ENABLE_MKLDNN_V1
.Attr("T: {half, float, bfloat16}") .Attr("T: {half, float, bfloat16}")
#else
.Attr("T: {half, float}")
#endif
.Attr("U: {float}") .Attr("U: {float}")
.Attr("epsilon: float = 0.0001") .Attr("epsilon: float = 0.0001")
.Attr("exponential_avg_factor: float = 1.0") .Attr("exponential_avg_factor: float = 1.0")

View File

@ -981,6 +981,17 @@ cc_library(
alwayslink = 1, alwayslink = 1,
) )
cc_library(
name = "enable_tf2_utils",
srcs = ["enable_tf2_utils.cc"],
hdrs = ["enable_tf2_utils.h"],
copts = tf_copts(),
deps = [
"//tensorflow/core/util:env_var",
],
alwayslink = 1,
)
alias( alias(
name = "profile_utils_cpu_utils", name = "profile_utils_cpu_utils",
actual = "//tensorflow/core/platform/profile_utils:profile_utils_cpu_utils", actual = "//tensorflow/core/platform/profile_utils:profile_utils_cpu_utils",
@ -992,6 +1003,12 @@ filegroup(
compatible_with = get_compatible_with_portable(), compatible_with = get_compatible_with_portable(),
) )
filegroup(
name = "enable_tf2_hdr",
srcs = ["enable_tf2_utils.h"],
compatible_with = get_compatible_with_portable(),
)
tf_cc_tests( tf_cc_tests(
name = "low_level_library_tests", name = "low_level_library_tests",
size = "small", size = "small",
@ -1047,6 +1064,20 @@ tf_cc_test(
], ],
) )
tf_cc_test(
name = "enable_tf2_utils_test",
size = "small",
srcs = [
"enable_tf2_utils_test.cc",
],
deps = [
":enable_tf2_utils",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/util:env_var",
],
)
tf_cc_tests( tf_cc_tests(
name = "stacktrace_handler_test", name = "stacktrace_handler_test",
size = "small", size = "small",

View File

@ -149,9 +149,7 @@ class PosixEnv : public Env {
return true; return true;
} }
} }
#if defined(__ANDROID__) || defined(__EMSCRIPTEN__) #if defined(__GLIBC__) || defined(__FreeBSD__)
return false;
#else
char buf[100]; char buf[100];
#ifdef __FreeBSD__ #ifdef __FreeBSD__
int res = 0; int res = 0;
@ -164,6 +162,8 @@ class PosixEnv : public Env {
} }
*name = buf; *name = buf;
return true; return true;
#else
return false;
#endif #endif
} }

View File

@ -0,0 +1,49 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/platform/enable_tf2_utils.h"
#include <atomic>
#include "tensorflow/core/util/env_var.h"
namespace tensorflow {
enum Enablement : uint8 { kFalse = 0, kTrue = 1, undefined = 2 };
// If this flag is set, we will use it as a signal to decide on whether to
// use the MLIR based TF-XLA bridge.
static std::atomic<Enablement> tf2_enabled{undefined};
// Determine whether or not the user has explicitly asked for tf2 execution.
// Will be used to determine whether to use the MLIR based bridge.
void set_tf2_execution(bool enabled) {
tf2_enabled = (enabled) ? Enablement::kTrue : Enablement::kFalse;
}
bool tf2_execution_enabled() {
if (tf2_enabled == Enablement::undefined) {
static bool tf2_behavior_env_enabled = [] {
string tf2_env;
TF_CHECK_OK(ReadStringFromEnvVar("TF2_BEHAVIOR", "0", &tf2_env));
return tf2_env != "0";
}();
tf2_enabled =
(tf2_behavior_env_enabled) ? Enablement::kTrue : Enablement::kFalse;
}
return tf2_enabled;
}
} // namespace tensorflow

View File

@ -0,0 +1,31 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TF_CORE_PLATFORM_TF2_UTILS_H_
#define TF_CORE_PLATFORM_TF2_UTILS_H_
namespace tensorflow {
// Sets the tf2 execution state. This can be used to indicate whether the user
// has explicitly asked for tf2 execution.
void set_tf2_execution(bool enabled);
// Returns true or false depending on whether the user flag for tf2 execution
// has been set. The default is false.
bool tf2_execution_enabled();
} // namespace tensorflow
#endif // TF_CORE_PLATFORM_TF2_UTILS_H_

View File

@ -0,0 +1,35 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Testing TF2 enablement.
#include "tensorflow/core/platform/enable_tf2_utils.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/util/env_var.h"
namespace tensorflow {
TEST(TF2EnabledTest, enabled_behavior) {
string tf2_env;
TF_CHECK_OK(ReadStringFromEnvVar("TF2_BEHAVIOR", "0", &tf2_env));
bool expected = (tf2_env != "0");
EXPECT_EQ(tensorflow::tf2_execution_enabled(), expected);
tensorflow::set_tf2_execution(true);
EXPECT_TRUE(tensorflow::tf2_execution_enabled());
tensorflow::set_tf2_execution(false);
EXPECT_FALSE(tensorflow::tf2_execution_enabled());
}
} // namespace tensorflow

View File

@ -108,7 +108,7 @@ limitations under the License.
#define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0
#define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0
#define TF_GRAPH_DEF_VERSION 634 // Updated: 2021/1/2 #define TF_GRAPH_DEF_VERSION 637 // Updated: 2021/1/5
// Checkpoint compatibility versions (the versions field in SavedSliceMeta). // Checkpoint compatibility versions (the versions field in SavedSliceMeta).
// //

View File

@ -386,6 +386,10 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type,
return ParseSoftmax(op, error_reporter, allocator, builtin_data); return ParseSoftmax(op, error_reporter, allocator, builtin_data);
} }
case BuiltinOperator_SPACE_TO_DEPTH: {
return ParseSpaceToDepth(op, error_reporter, allocator, builtin_data);
}
case BuiltinOperator_SPLIT: { case BuiltinOperator_SPLIT: {
return ParseSplit(op, error_reporter, allocator, builtin_data); return ParseSplit(op, error_reporter, allocator, builtin_data);
} }
@ -596,16 +600,6 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type,
*builtin_data = params.release(); *builtin_data = params.release();
return kTfLiteOk; return kTfLiteOk;
} }
case BuiltinOperator_SPACE_TO_DEPTH: {
auto params = safe_allocator.Allocate<TfLiteSpaceToDepthParams>();
TF_LITE_ENSURE(error_reporter, params != nullptr);
if (const auto* schema_params =
op->builtin_options_as_SpaceToDepthOptions()) {
params->block_size = schema_params->block_size();
}
*builtin_data = params.release();
return kTfLiteOk;
}
case BuiltinOperator_GATHER: { case BuiltinOperator_GATHER: {
auto params = safe_allocator.Allocate<TfLiteGatherParams>(); auto params = safe_allocator.Allocate<TfLiteGatherParams>();
@ -1684,6 +1678,31 @@ TfLiteStatus ParseSoftmax(const Operator* op, ErrorReporter* error_reporter,
return kTfLiteOk; return kTfLiteOk;
} }
TfLiteStatus ParseSpaceToDepth(const Operator* op,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data) {
CheckParsePointerParams(op, error_reporter, allocator, builtin_data);
SafeBuiltinDataAllocator safe_allocator(allocator);
std::unique_ptr<TfLiteSpaceToDepthParams,
SafeBuiltinDataAllocator::BuiltinDataDeleter>
params = safe_allocator.Allocate<TfLiteSpaceToDepthParams>();
TF_LITE_ENSURE(error_reporter, params != nullptr);
const auto* schema_params = op->builtin_options_as_SpaceToDepthOptions();
if (schema_params != nullptr) {
params->block_size = schema_params->block_size();
} else {
// TODO(b/157480169): We should either return kTfLiteError or fill in some
// reasonable defaults in the params struct. We are not doing so until we
// better undertand the ramifications of changing the legacy behavior.
}
*builtin_data = params.release();
return kTfLiteOk;
}
TfLiteStatus ParseSplit(const Operator* op, ErrorReporter* error_reporter, TfLiteStatus ParseSplit(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data) { BuiltinDataAllocator* allocator, void** builtin_data) {
CheckParsePointerParams(op, error_reporter, allocator, builtin_data); CheckParsePointerParams(op, error_reporter, allocator, builtin_data);

View File

@ -254,6 +254,11 @@ TfLiteStatus ParseSin(const Operator* op, ErrorReporter* error_reporter,
TfLiteStatus ParseSoftmax(const Operator* op, ErrorReporter* error_reporter, TfLiteStatus ParseSoftmax(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data); BuiltinDataAllocator* allocator, void** builtin_data);
TfLiteStatus ParseSpaceToDepth(const Operator* op,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data);
TfLiteStatus ParseSplit(const Operator* op, ErrorReporter* error_reporter, TfLiteStatus ParseSplit(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data); BuiltinDataAllocator* allocator, void** builtin_data);

View File

@ -111,11 +111,15 @@ class Subgraph {
inline TfLiteStatus SetTensorParametersReadWrite( inline TfLiteStatus SetTensorParametersReadWrite(
int tensor_index, TfLiteType type, const char* name, int tensor_index, TfLiteType type, const char* name,
const std::vector<int>& dims, TfLiteQuantization quantization, const std::vector<int>& dims, TfLiteQuantization quantization,
bool is_variable = false, const size_t rank_dims_signature = 0, bool is_variable = false, const std::vector<int>& dims_signature = {}) {
const int* dims_signature = nullptr) { if (dims_signature.empty()) {
return SetTensorParametersReadWrite(tensor_index, type, name, dims.size(), return SetTensorParametersReadWrite(tensor_index, type, name, dims.size(),
dims.data(), quantization, is_variable, dims.data(), quantization,
rank_dims_signature, dims_signature); is_variable);
}
return SetTensorParametersReadWrite(
tensor_index, type, name, dims.size(), dims.data(), quantization,
is_variable, dims_signature.size(), dims_signature.data());
} }
TfLiteStatus SetTensorParametersReadWrite( TfLiteStatus SetTensorParametersReadWrite(
int tensor_index, TfLiteType type, const char* name, const size_t rank, int tensor_index, TfLiteType type, const char* name, const size_t rank,

View File

@ -45,6 +45,56 @@ int3 GetWorkGroupsCount(int grid_dimension, const int3& grid_size,
} }
return work_groups_count; return work_groups_count;
} }
std::string GetCommonOpenCLDefines(CalculationsPrecision precision) {
std::string result;
result += "#define GLOBAL_ID_0 get_global_id(0)\n";
result += "#define GLOBAL_ID_1 get_global_id(1)\n";
result += "#define GLOBAL_ID_2 get_global_id(2)\n";
result += "#define MAIN_FUNCTION __kernel void main_function\n";
switch (precision) {
case CalculationsPrecision::F32:
result += "#pragma OPENCL EXTENSION cl_khr_3d_image_writes : enable\n";
result += "#define ACCUM_FLT4 float4\n";
result += "#define FLT float\n";
result += "#define FLT2 float2\n";
result += "#define FLT3 float3\n";
result += "#define FLT4 float4\n";
result += "#define TO_FLT4 convert_float4\n";
result += "#define TO_ACCUM_TYPE convert_float4\n";
result += "#define TO_ACCUM_FLT convert_float\n";
result += "#define INIT_FLT4(value) (float4)(value)\n";
break;
case CalculationsPrecision::F16:
result += "#pragma OPENCL EXTENSION cl_khr_3d_image_writes : enable\n";
result += "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n";
result += "#define ACCUM_FLT4 half4\n";
result += "#define FLT half\n";
result += "#define FLT2 half2\n";
result += "#define FLT3 half3\n";
result += "#define FLT4 half4\n";
result += "#define TO_FLT4 convert_half4\n";
result += "#define TO_ACCUM_TYPE convert_half4\n";
result += "#define TO_ACCUM_FLT convert_half\n";
result += "#define INIT_FLT4(value) (half4)(value)\n";
break;
case CalculationsPrecision::F32_F16:
result += "#pragma OPENCL EXTENSION cl_khr_3d_image_writes : enable\n";
result += "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n";
result += "#define ACCUM_FLT4 float4\n";
result += "#define FLT half\n";
result += "#define FLT2 half2\n";
result += "#define FLT3 half3\n";
result += "#define FLT4 half4\n";
result += "#define TO_FLT4 convert_half4\n";
result += "#define TO_ACCUM_TYPE convert_float4\n";
result += "#define TO_ACCUM_FLT convert_float\n";
result += "#define INIT_FLT4(value) (half4)(value)\n";
break;
}
return result;
}
} // namespace } // namespace
ClOperation::ClOperation(ClOperation&& operation) ClOperation::ClOperation(ClOperation&& operation)
@ -94,6 +144,9 @@ absl::Status ClOperation::UpdateParams() {
absl::Status ClOperation::Compile(const CreationContext& creation_context) { absl::Status ClOperation::Compile(const CreationContext& creation_context) {
operation_->AssembleCode(creation_context.GetGpuInfo()); operation_->AssembleCode(creation_context.GetGpuInfo());
operation_->code_ =
GetCommonOpenCLDefines(operation_->definition_.precision) +
operation_->code_;
RETURN_IF_ERROR(cl_args_.Init( RETURN_IF_ERROR(cl_args_.Init(
creation_context.GetGpuInfo(), creation_context.GetGpuInfo(),
{{operation_->dst_tensors_names_[0], operation_->elementwise_code_}}, {{operation_->dst_tensors_names_[0], operation_->elementwise_code_}},

View File

@ -336,6 +336,8 @@ absl::Status Tensor::GetGPUResources(const GPUObjectDescriptor* obj_ptr,
if (!tensor_desc) { if (!tensor_desc) {
return absl::InvalidArgumentError("Expected TensorDescriptor on input."); return absl::InvalidArgumentError("Expected TensorDescriptor on input.");
} }
resources->ints.push_back(
{"slice_stride", tensor_desc->GetSliceStrideSize(shape_)});
if (descriptor_.HasAxis(Axis::WIDTH)) { if (descriptor_.HasAxis(Axis::WIDTH)) {
resources->ints.push_back({"width", Width()}); resources->ints.push_back({"width", Width()});
resources->ints.push_back({"width_div2", Width() / 2}); resources->ints.push_back({"width_div2", Width() / 2});

View File

@ -22,61 +22,18 @@ limitations under the License.
namespace tflite { namespace tflite {
namespace gpu { namespace gpu {
namespace { namespace {
std::string GetCommonDefines(CalculationsPrecision precision) {
std::string result;
switch (precision) {
case CalculationsPrecision::F32:
result += "#pragma OPENCL EXTENSION cl_khr_3d_image_writes : enable\n";
result += "#define ACCUM_FLT4 float4\n";
result += "#define FLT float\n";
result += "#define FLT2 float2\n";
result += "#define FLT3 float3\n";
result += "#define FLT4 float4\n";
result += "#define TO_FLT4 convert_float4\n";
result += "#define TO_ACCUM_TYPE convert_float4\n";
result += "#define TO_ACCUM_FLT convert_float\n";
break;
case CalculationsPrecision::F16:
result += "#pragma OPENCL EXTENSION cl_khr_3d_image_writes : enable\n";
result += "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n";
result += "#define ACCUM_FLT4 half4\n";
result += "#define FLT half\n";
result += "#define FLT2 half2\n";
result += "#define FLT3 half3\n";
result += "#define FLT4 half4\n";
result += "#define TO_FLT4 convert_half4\n";
result += "#define TO_ACCUM_TYPE convert_half4\n";
result += "#define TO_ACCUM_FLT convert_half\n";
break;
case CalculationsPrecision::F32_F16:
result += "#pragma OPENCL EXTENSION cl_khr_3d_image_writes : enable\n";
result += "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n";
result += "#define ACCUM_FLT4 float4\n";
result += "#define FLT half\n";
result += "#define FLT2 half2\n";
result += "#define FLT3 half3\n";
result += "#define FLT4 half4\n";
result += "#define TO_FLT4 convert_half4\n";
result += "#define TO_ACCUM_TYPE convert_float4\n";
result += "#define TO_ACCUM_FLT convert_float\n";
break;
}
return result;
}
std::string GetElementWiseCode(const OperationDef& op_def, std::string GetElementWiseCode(const OperationDef& op_def,
bool check_src_slices) { bool check_src_slices) {
std::string c; std::string c;
c += "__kernel void main_function(\n"; c += "MAIN_FUNCTION(\n";
c += "$0) {\n"; c += "$0) {\n";
c += " int X = get_global_id(0);\n"; c += " int X = GLOBAL_ID_0;\n";
c += " int Y = get_global_id(1);\n"; c += " int Y = GLOBAL_ID_1;\n";
c += " int Z = get_global_id(2);\n"; c += " int Z = GLOBAL_ID_2;\n";
c += " if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() || " c += " if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() || "
"Z >= args.dst_tensor.Slices()) return; \n"; "Z >= args.dst_tensor.Slices()) return; \n";
if (check_src_slices) { if (check_src_slices) {
c += " FLT4 src = (FLT4)(0.0f);\n"; c += " FLT4 src = INIT_FLT4(0.0f);\n";
c += " if (Z < args.src_tensor.Slices()) {\n"; c += " if (Z < args.src_tensor.Slices()) {\n";
c += " src = args.src_tensor.Read(X, Y, Z);\n"; c += " src = args.src_tensor.Read(X, Y, Z);\n";
c += " }\n"; c += " }\n";
@ -240,7 +197,6 @@ void GPUOperation::AssembleCode(const GpuInfo& gpu_info) {
elementwise_code_ = "{\n" + code_ + "\n}\n" + elementwise_code_; elementwise_code_ = "{\n" + code_ + "\n}\n" + elementwise_code_;
code_ = GetElementWiseCode(definition_, check_src_channels_size_); code_ = GetElementWiseCode(definition_, check_src_channels_size_);
} }
code_ = GetCommonDefines(definition_.precision) + code_;
} }
void GPUOperation::GetPossibleKernelWorkGroups( void GPUOperation::GetPossibleKernelWorkGroups(

View File

@ -94,6 +94,7 @@ TensorDescriptor& TensorDescriptor::operator=(TensorDescriptor&& desc) {
GPUResources TensorDescriptor::GetGPUResources() const { GPUResources TensorDescriptor::GetGPUResources() const {
GPUResources resources; GPUResources resources;
resources.ints.push_back("slice_stride");
if (HasAxis(Axis::WIDTH)) { if (HasAxis(Axis::WIDTH)) {
resources.ints.push_back("width"); resources.ints.push_back("width");
resources.ints.push_back("width_div2"); resources.ints.push_back("width_div2");
@ -175,7 +176,7 @@ absl::Status TensorDescriptor::PerformSelector(
*result = "slices"; *result = "slices";
return absl::OkStatus(); return absl::OkStatus();
} else if (selector == "SliceStride") { } else if (selector == "SliceStride") {
*result = GetSliceStride(); *result = "slice_stride";
return absl::OkStatus(); return absl::OkStatus();
} else if (selector == "Channels") { } else if (selector == "Channels") {
*result = "channels"; *result = "channels";
@ -402,7 +403,7 @@ absl::Status TensorDescriptor::PerformGetPtrWithSliceOffsetSelector(
"GetPtrWithSliceOffset require one argument(slice coordinate), but ", "GetPtrWithSliceOffset require one argument(slice coordinate), but ",
args.size(), " was passed")); args.size(), " was passed"));
} }
*result = absl::StrCat("buffer + ", args[0], " * ", GetSliceStride()); *result = absl::StrCat("buffer + ", args[0], " * slice_stride");
return absl::OkStatus(); return absl::OkStatus();
} }
@ -644,6 +645,35 @@ bool TensorDescriptor::HasAxis(Axis axis) const {
return false; return false;
} }
int TensorDescriptor::GetWidthSize(BHWDC shape) const {
int width = shape.w;
auto it1 = state_vars_.find("ElementsX2");
if (it1 != state_vars_.end() && it1->second == "true") {
width /= 2;
}
auto it2 = state_vars_.find("ElementsX4");
if (it2 != state_vars_.end() && it2->second == "true") {
width /= 4;
}
auto it = state_vars_.find("BatchedWidth");
if (it != state_vars_.end() && it->second == "true") {
width *= shape.b;
}
return width;
}
int TensorDescriptor::GetSliceStrideSize(BHWDC shape) const {
if (IsBatchedWidth()) {
return GetWidthSize(shape) * shape.h;
} else {
if (HasAxis(Axis::BATCH)) {
return GetWidthSize(shape) * shape.h * shape.b;
} else {
return GetWidthSize(shape) * shape.h;
}
}
}
void TensorDescriptor::SetAddressMode(AddressMode mode) { void TensorDescriptor::SetAddressMode(AddressMode mode) {
if (mode == AddressMode::kZero) { if (mode == AddressMode::kZero) {
state_vars_["TextureMode"] = "ZERO"; state_vars_["TextureMode"] = "ZERO";
@ -719,18 +749,6 @@ std::string TensorDescriptor::GetWidth() const {
} }
} }
std::string TensorDescriptor::GetSliceStride() const {
if (IsBatchedWidth()) {
return GetWidth() + " * height";
} else {
if (HasAxis(Axis::BATCH)) {
return GetWidth() + " * height * batch";
} else {
return GetWidth() + " * height";
}
}
}
AddressMode TensorDescriptor::AddressModeFromState() const { AddressMode TensorDescriptor::AddressModeFromState() const {
auto it = state_vars_.find("TextureMode"); auto it = state_vars_.find("TextureMode");
if (it != state_vars_.end()) { if (it != state_vars_.end()) {

View File

@ -70,6 +70,8 @@ struct TensorDescriptor : public GPUObjectDescriptor {
bool HasAxis(Axis axis) const; bool HasAxis(Axis axis) const;
void SetAddressMode(AddressMode mode); void SetAddressMode(AddressMode mode);
int GetWidthSize(BHWDC shape) const;
int GetSliceStrideSize(BHWDC shape) const;
absl::Status GetLinkingContextFromWriteSelector( absl::Status GetLinkingContextFromWriteSelector(
const std::vector<std::string>& args, std::string* value_name, const std::vector<std::string>& args, std::string* value_name,
@ -136,7 +138,6 @@ struct TensorDescriptor : public GPUObjectDescriptor {
bool IsBatchedWidth() const; bool IsBatchedWidth() const;
std::string GetWidth() const; std::string GetWidth() const;
std::string GetSliceStride() const;
AddressMode AddressModeFromState() const; AddressMode AddressModeFromState() const;

View File

@ -677,6 +677,7 @@ cc_library(
"//tensorflow/lite/delegates/gpu/common:shape", "//tensorflow/lite/delegates/gpu/common:shape",
"//tensorflow/lite/delegates/gpu/common:types", "//tensorflow/lite/delegates/gpu/common:types",
"//tensorflow/lite/delegates/gpu/common:util", "//tensorflow/lite/delegates/gpu/common:util",
"//tensorflow/lite/delegates/gpu/common/task:util",
"//tensorflow/lite/delegates/gpu/metal:compute_task_descriptor", "//tensorflow/lite/delegates/gpu/metal:compute_task_descriptor",
], ],
) )

View File

@ -32,59 +32,56 @@ namespace gpu {
namespace metal { namespace metal {
std::string GetResizeBilinearCode(const Resize2DAttributes& attr) { std::string GetResizeBilinearCode(const Resize2DAttributes& attr) {
std::string code = R"( std::string c = R"(
#include <metal_stdlib> #include <metal_stdlib>
using namespace metal; using namespace metal;
$0 $0
kernel void ComputeFunction( kernel void ComputeFunction(
$1 $1
uint3 gid[[thread_position_in_grid]]) { uint3 gid[[thread_position_in_grid]]) {
if (int(gid.x) >= size.z || int(gid.y) >= size.w) { if (int(gid.x) >= args.dst_tensor.Width() || int(gid.y) >= args.dst_tensor.Height()) {
return; return;
})";
if (attr.half_pixel_centers) {
code += "const float2 tex_coord = (float2(gid.xy) + 0.5f) * scale - 0.5f;";
} else {
code += "const float2 tex_coord = float2(gid.xy) * scale;";
} }
code += R"( )";
const float2 tex_coord_floor = floor(tex_coord); if (attr.half_pixel_centers) {
const int2 itex_coord_floor = int2(tex_coord_floor); c += " float2 tex_coord = (float2(gid.xy) + 0.5f) * scale - 0.5f;";
const int2 borders = size.xy - int2(1, 1); } else {
int4 st; c += " float2 tex_coord = float2(gid.xy) * scale;";
st.xy = max(itex_coord_floor, int2(0, 0)); }
st.zw = min(itex_coord_floor + int2(1, 1), borders); c += R"(
const float2 t = tex_coord - tex_coord_floor; // interpolating factors float2 tex_coord_floor = floor(tex_coord);
const int src_index0 = (gid.z * size.y + st.y) * size.x + st.x; int2 itex_coord_floor = int2(tex_coord_floor);
const int src_index1 = (gid.z * size.y + st.y) * size.x + st.z; int2 borders = int2(args.src_tensor.Width() - 1, args.src_tensor.Height() - 1);
const int src_index2 = (gid.z * size.y + st.w) * size.x + st.x; int4 st;
const int src_index3 = (gid.z * size.y + st.w) * size.x + st.z; st.xy = max(itex_coord_floor, int2(0, 0));
FLT4 tex11 = src_tensor[src_index0]; st.zw = min(itex_coord_floor + int2(1, 1), borders);
FLT4 tex21 = src_tensor[src_index1]; float2 t = tex_coord - tex_coord_floor; // interpolating factors
FLT4 tex12 = src_tensor[src_index2]; FLT4 tex11 = args.src_tensor.Read(st.x, st.y, gid.z);
FLT4 tex22 = src_tensor[src_index3]; FLT4 tex21 = args.src_tensor.Read(st.z, st.y, gid.z);
// bilinear interpolation FLT4 tex12 = args.src_tensor.Read(st.x, st.w, gid.z);
FLT4 value = mix(mix(tex11, tex21, static_cast<FLT>(t.x)), FLT4 tex22 = args.src_tensor.Read(st.z, st.w, gid.z);
mix(tex12, tex22, static_cast<FLT>(t.x)), static_cast<FLT>(t.y)); // bilinear interpolation
const int linear_index = (gid.z * size.w + gid.y) * size.z + gid.x; FLT4 value = mix(mix(tex11, tex21, static_cast<FLT>(t.x)),
$2 mix(tex12, tex22, static_cast<FLT>(t.x)), static_cast<FLT>(t.y));
dst_tensor[linear_index] = value; args.dst_tensor.GetAddress(linear_index, gid.x, gid.y, gid.z);
} $2
)"; args.dst_tensor.Write(value, gid.x, gid.y, gid.z);
return code; }
)";
return c;
} }
std::string GetResizeNearestCode(const Resize2DAttributes& attr) { std::string GetResizeNearestCode(const Resize2DAttributes& attr) {
std::string code = R"( std::string c = R"(
#include <metal_stdlib> #include <metal_stdlib>
using namespace metal; using namespace metal;
$0 $0
kernel void ComputeFunction( kernel void ComputeFunction(
$1 $1
uint3 gid[[thread_position_in_grid]]) { uint3 gid[[thread_position_in_grid]]) {
if (int(gid.x) >= size.z || int(gid.y) >= size.w) { if (int(gid.x) >= args.dst_tensor.Width() || int(gid.y) >= args.dst_tensor.Height()) {
return; return;
} }
)"; )";
std::string fxc; std::string fxc;
std::string fyc; std::string fyc;
@ -99,27 +96,27 @@ std::string GetResizeNearestCode(const Resize2DAttributes& attr) {
fxc += " + 0.5f"; fxc += " + 0.5f";
fyc += " + 0.5f"; fyc += " + 0.5f";
} }
code += " int2 coord;\n"; c += " int2 coord;\n";
code += " coord.x = static_cast<int>(" + fxc + ");\n"; c += " coord.x = static_cast<int>(" + fxc + ");\n";
code += " coord.y = static_cast<int>(" + fyc + ");\n"; c += " coord.y = static_cast<int>(" + fyc + ");\n";
code += " coord.x = max(0, coord.x);\n"; c += " coord.x = max(0, coord.x);\n";
code += " coord.y = max(0, coord.y);\n"; c += " coord.y = max(0, coord.y);\n";
code += " coord.x = min(coord.x, size.x - 1);\n"; c += " coord.x = min(coord.x, args.src_tensor.Width() - 1);\n";
code += " coord.y = min(coord.y, size.y - 1);\n"; c += " coord.y = min(coord.y, args.src_tensor.Height() - 1);\n";
code += R"( c += R"(
const int src_index = (gid.z * size.y + coord.y) * size.x + coord.x; FLT4 value = args.src_tensor.Read(coord.x, coord.y, gid.z);
FLT4 value = src_tensor[src_index]; args.dst_tensor.GetAddress(linear_index, gid.x, gid.y, gid.z);
const int linear_index = (gid.z * size.w + gid.y) * size.z + gid.x; $2
$2 args.dst_tensor.Write(value, gid.x, gid.y, gid.z);
dst_tensor[linear_index] = value; }
} )";
)"; return c;
return code;
} }
ComputeTaskDescriptor Resize(const OperationDef& definition, ComputeTaskDescriptor Resize(const OperationDef& definition,
const Resize2DAttributes& attr) { const Resize2DAttributes& attr) {
ComputeTaskDescriptor desc(definition); ComputeTaskDescriptor desc(definition);
desc.tensors_as_args = true;
switch (attr.type) { switch (attr.type) {
case SamplingType::BILINEAR: case SamplingType::BILINEAR:
desc.shader_source = GetResizeBilinearCode(attr); desc.shader_source = GetResizeBilinearCode(attr);
@ -136,17 +133,6 @@ ComputeTaskDescriptor Resize(const OperationDef& definition,
desc.AddDstTensor("dst_tensor", definition.dst_tensors[0]); desc.AddDstTensor("dst_tensor", definition.dst_tensors[0]);
desc.uniform_buffers = { desc.uniform_buffers = {
{"constant int4& size",
[](const std::vector<BHWC>& src_shapes,
const std::vector<BHWC>& dst_shapes) {
std::vector<int> sizes = {
src_shapes[0].w,
src_shapes[0].h,
dst_shapes[0].w,
dst_shapes[0].h,
};
return GetByteBuffer(sizes);
}},
{"constant float2& scale", {"constant float2& scale",
[attr](const std::vector<BHWC>& src_shapes, [attr](const std::vector<BHWC>& src_shapes,
const std::vector<BHWC>& dst_shapes) { const std::vector<BHWC>& dst_shapes) {
@ -160,10 +146,10 @@ ComputeTaskDescriptor Resize(const OperationDef& definition,
desc.resize_function = [](const std::vector<BHWC>& src_shapes, desc.resize_function = [](const std::vector<BHWC>& src_shapes,
const std::vector<BHWC>& dst_shapes) { const std::vector<BHWC>& dst_shapes) {
const uint3 groups_size{16, 16, 1}; const uint3 groups_size{8, 8, 1};
const int dst_layers = DivideRoundUp(dst_shapes[0].c, 4);
int groups_x = DivideRoundUp(dst_shapes[0].w, groups_size.x); int groups_x = DivideRoundUp(dst_shapes[0].w, groups_size.x);
int groups_y = DivideRoundUp(dst_shapes[0].h, groups_size.y); int groups_y = DivideRoundUp(dst_shapes[0].h, groups_size.y);
const int dst_layers = DivideRoundUp(dst_shapes[0].c, 4);
int groups_z = DivideRoundUp(dst_layers, groups_size.z); int groups_z = DivideRoundUp(dst_layers, groups_size.z);
return std::make_pair(groups_size, uint3{groups_x, groups_y, groups_z}); return std::make_pair(groups_size, uint3{groups_x, groups_y, groups_z});
}; };

View File

@ -35,123 +35,146 @@ namespace gpu {
namespace metal { namespace metal {
namespace { namespace {
std::string GetSliceCode(const SliceAttributes& attr) { namespace {
std::stringstream code; bool Is4Aligned(const SliceAttributes& attr) {
return attr.strides.c == 1 && attr.starts.c % 4 == 0;
}
code << R"( int4 GetOffset(const SliceAttributes& attr, int src_width, int src_height,
#include <metal_stdlib> int src_channels, int src_batch) {
using namespace metal; int4 offset;
struct uniforms {
int4 src_size;
int4 dst_size;
};
constant int4 width = int4($0, $1, $2, 0);
constant int4 height = int4($3, $4, $5, 0);
constant int4 channels = int4($6, $7, $8, 0);
constant FLT4 null_vec = FLT4(0.0f, 0.0f, 0.0f, 0.0f);
$$0
kernel void ComputeFunction(
$$1
uint3 gid[[thread_position_in_grid]]) {
if (static_cast<int>(gid.x) >= params.dst_size.x ||
static_cast<int>(gid.y) >= params.dst_size.y) {
return;
}
FLT4 value;
short2 offset;
)";
if (attr.strides.w > 0) { if (attr.strides.w > 0) {
code << " offset.x = width.x;" << std::endl; offset.x = attr.starts.w;
} else { } else {
if (attr.ends.w > 0) { if (attr.ends.w > 0) {
code << " offset.x = width.z;" << std::endl; offset.x = attr.ends.w;
} else { } else {
code << " offset.x = params.src_size.x + width.z;" << std::endl; offset.x = src_width + attr.ends.w;
} }
} }
if (attr.strides.h > 0) { if (attr.strides.h > 0) {
code << " offset.y = height.x;" << std::endl; offset.y = attr.starts.h;
} else { } else {
if (attr.ends.h > 0) { if (attr.ends.h > 0) {
code << " offset.y = height.z;" << std::endl; offset.y = attr.ends.h;
} else { } else {
code << " offset.y = params.src_size.y + height.z;" << std::endl; offset.y = src_height + attr.ends.h;
} }
} }
code << std::endl; if (attr.strides.c > 0) {
code << " short2 stride = short2(width.y, height.y);" << std::endl; offset.z = attr.starts.c;
} else {
if (attr.ends.c > 0) {
offset.z = attr.ends.c;
} else {
offset.z = src_channels + attr.ends.c;
}
}
if (Is4Aligned(attr)) {
offset.z /= 4;
}
if (attr.strides.b > 0) {
offset.w = attr.starts.b;
} else {
if (attr.ends.b > 0) {
offset.w = attr.ends.b;
} else {
offset.w = src_batch + attr.ends.b;
}
}
return offset;
}
code << " const short2 s_c = offset + short2(gid.xy) * stride;" } // namespace
<< std::endl;
code << " bool outside = false;" << std::endl; std::string GetSliceCode(const OperationDef& op_def, bool alignedx4) {
code << " int step = gid.z * 4;" << std::endl; const std::string batch_id =
code << " FLT4 tmp;" << std::endl; op_def.dst_tensors[0].HasAxis(Axis::BATCH) ? "B" : "0";
code << " int buffer_index = 0;" << std::endl; std::string c = R"(
code << " int addr = 0;" << std::endl; #include <metal_stdlib>
code << std::endl; using namespace metal;
for (int i = 0; i < 4; i++) {
code << " addr = step * channels.y;" << std::endl; struct uniforms {
if (attr.strides.c > 0) { int4 offset;
code << " addr += channels.x;" << std::endl; int4 stride;
} else { };
if (attr.ends.c > 0) {
code << " addr += channels.z;" << std::endl; $0
} else { kernel void ComputeFunction($1
code << " addr += params.src_size.z + channels.z;" << std::endl; uint3 gid[[thread_position_in_grid]]) {
} )";
} if (op_def.dst_tensors[0].HasAxis(Axis::BATCH)) {
code << " buffer_index = ((addr / 4) * params.src_size.y + s_c.y) * " c += " int linear_id = static_cast<int>(gid.x);\n";
"params.src_size.x + " c += " int X = linear_id / args.dst_tensor.Batch();\n";
"s_c.x;" c += " int B = linear_id % args.dst_tensor.Batch();\n";
<< std::endl; c += " args.dst_tensor.SetBatchRef(B);\n";
code << " outside = step >= params.dst_size.z;" << std::endl; } else {
code << " tmp = outside ? null_vec : src_tensor[buffer_index];" c += " int X = static_cast<int>(gid.x);\n";
<< std::endl; }
code << " value[" << i << "] = tmp[addr % 4];" << std::endl; c += " int Y = static_cast<int>(gid.y);\n";
if (i != 3) { c += " int Z = static_cast<int>(gid.z);\n";
code << " step++;" << std::endl; c += " if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() || "
code << std::endl; "Z >= args.dst_tensor.Slices()) { \n";
c += " return; \n";
c += " } \n";
c += " int s_x = X * params.stride.x + params.offset.x;\n";
c += " int s_y = Y * params.stride.y + params.offset.y;\n";
if (op_def.src_tensors[0].HasAxis(Axis::BATCH)) {
c += " int s_b = " + batch_id + " * params.stride.w + params.offset.w;\n";
c += " args.src_tensor.SetBatchRef(s_b);\n";
}
if (alignedx4) {
c += " int s_z = Z + params.offset.z;\n";
c += " FLT4 result = args.src_tensor.Read(s_x, s_y, s_z);\n";
} else {
c += " FLT4 result;\n";
const std::string postfixes[] = {"x", "y", "z", "w"};
for (int i = 0; i < 4; ++i) {
c += " {\n";
const std::string ch = "(Z * 4 + " + std::to_string(i) + ")";
c += " int s_ch = " + ch + " * params.stride.z + params.offset.z;\n";
c += " int s_z = min(s_ch >> 2, args.src_tensor.Slices() - 1);\n";
c += " int s_z_rem = s_ch & 3;\n";
c += " FLT4 t = args.src_tensor.Read(s_x, s_y, s_z);\n";
c += " result." + postfixes[i] + " = t[s_ch & 3];\n";
c += " }\n";
} }
} }
code << R"( c += " FLT4 value = result;\n";
int linear_index = (gid.z * params.dst_size.y + int(gid.y)) * c += " args.dst_tensor.GetAddress(linear_index, X, Y, Z);\n";
params.dst_size.x + int(gid.x); c += " $2\n";
$$2 c += " args.dst_tensor.Write(value, X, Y, Z);\n";
dst_tensor[linear_index] = value; c += "}\n";
})"; return c;
return absl::Substitute(
code.str(), attr.starts.w, attr.strides.w, attr.ends.w, attr.starts.h,
attr.strides.h, attr.ends.h, attr.starts.c, attr.strides.c, attr.ends.c);
} }
} // namespace } // namespace
ComputeTaskDescriptor Slice(const OperationDef& definition, ComputeTaskDescriptor Slice(const OperationDef& definition,
const SliceAttributes& attr) { const SliceAttributes& attr) {
ComputeTaskDescriptor desc(definition); ComputeTaskDescriptor desc(definition);
desc.shader_source = GetSliceCode(attr); desc.tensors_as_args = true;
desc.shader_source = GetSliceCode(definition, Is4Aligned(attr));
desc.AddSrcTensor("src_tensor", definition.src_tensors[0]); desc.AddSrcTensor("src_tensor", definition.src_tensors[0]);
desc.AddDstTensor("dst_tensor", definition.dst_tensors[0]); desc.AddDstTensor("dst_tensor", definition.dst_tensors[0]);
desc.uniform_buffers = { desc.uniform_buffers = {
{"constant uniforms& params", {"constant uniforms& params",
[](const std::vector<BHWC>& src_shapes, [attr](const std::vector<BHWC>& src_shapes,
const std::vector<BHWC>& dst_shapes) { const std::vector<BHWC>& dst_shapes) {
int4 offset = GetOffset(attr, src_shapes[0].w, src_shapes[0].h,
src_shapes[0].c, src_shapes[0].b);
std::vector<int> uniform_params{ std::vector<int> uniform_params{
// int4 src_size // int4 offset
src_shapes[0].w, offset.x,
src_shapes[0].h, offset.y,
src_shapes[0].c, offset.z,
DivideRoundUp(src_shapes[0].c, 4), offset.w,
// int4 dst_size // int4 stride
dst_shapes[0].w, attr.strides.w,
dst_shapes[0].h, attr.strides.h,
dst_shapes[0].c, attr.strides.c,
DivideRoundUp(dst_shapes[0].c, 4), attr.strides.b,
}; };
return GetByteBuffer(uniform_params); return GetByteBuffer(uniform_params);
}}, }},
@ -159,10 +182,10 @@ ComputeTaskDescriptor Slice(const OperationDef& definition,
desc.resize_function = [attr](const std::vector<BHWC>& src_shapes, desc.resize_function = [attr](const std::vector<BHWC>& src_shapes,
const std::vector<BHWC>& dst_shapes) { const std::vector<BHWC>& dst_shapes) {
const uint3 groups_size{16, 16, 1}; const uint3 groups_size{8, 4, 1};
const int dst_layers = DivideRoundUp(dst_shapes[0].c, 4);
int groups_x = DivideRoundUp(dst_shapes[0].w, groups_size.x); int groups_x = DivideRoundUp(dst_shapes[0].w, groups_size.x);
int groups_y = DivideRoundUp(dst_shapes[0].h, groups_size.y); int groups_y = DivideRoundUp(dst_shapes[0].h, groups_size.y);
const int dst_layers = DivideRoundUp(dst_shapes[0].c, 4);
int groups_z = DivideRoundUp(dst_layers, groups_size.z); int groups_z = DivideRoundUp(dst_layers, groups_size.z);
return std::make_pair(groups_size, uint3{groups_x, groups_y, groups_z}); return std::make_pair(groups_size, uint3{groups_x, groups_y, groups_z});
}; };

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/lite/delegates/gpu/common/gpu_info.h" #include "tensorflow/lite/delegates/gpu/common/gpu_info.h"
#include "tensorflow/lite/delegates/gpu/common/model.h" #include "tensorflow/lite/delegates/gpu/common/model.h"
#include "tensorflow/lite/delegates/gpu/common/shape.h" #include "tensorflow/lite/delegates/gpu/common/shape.h"
#include "tensorflow/lite/delegates/gpu/common/task/util.h"
#include "tensorflow/lite/delegates/gpu/common/types.h" #include "tensorflow/lite/delegates/gpu/common/types.h"
#include "tensorflow/lite/delegates/gpu/common/util.h" #include "tensorflow/lite/delegates/gpu/common/util.h"
#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" #include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
@ -40,7 +41,6 @@ std::string GetSoftmax1x1Code(const GpuInfo& gpu_info) {
using namespace metal; using namespace metal;
struct uniforms { struct uniforms {
int4 size;
float4 mask; float4 mask;
}; };
@ -51,11 +51,11 @@ kernel void ComputeFunction($1
uint3 ugid[[thread_position_in_grid]]) uint3 ugid[[thread_position_in_grid]])
{ {
float4 maxx4 = float4(src_tensor[0].x); float4 maxx4 = float4(args.src_tensor.Read(0, 0, 0).x);
for (int s = int(tid); s < params.size.x; s += 32) { for (int s = int(tid); s < args.src_tensor.Slices(); s += 32) {
float4 mask_a = s == params.size.x - 1 ? params.mask : float4(1.0f); float4 mask_a = s == args.src_tensor.Slices() - 1 ? params.mask : float4(1.0f);
float4 mask_b = float4(1.0f) - mask_a; float4 mask_b = float4(1.0f) - mask_a;
float4 src = float4(src_tensor[s]); float4 src = float4(args.src_tensor.Read(0, 0, s));
src = src * mask_a + mask_b * src.x; src = src * mask_a + mask_b * src.x;
maxx4 = max(maxx4, src); maxx4 = max(maxx4, src);
} }
@ -89,9 +89,9 @@ kernel void ComputeFunction($1
maximum = tmpx1[0]; maximum = tmpx1[0];
float sum = 0.0f; float sum = 0.0f;
for (int s = int(tid); s < params.size.x; s += 32) { for (int s = int(tid); s < args.src_tensor.Slices(); s += 32) {
float4 mask_temp = s == params.size.x - 1 ? params.mask : float4(1.0f); float4 mask_temp = s == args.src_tensor.Slices() - 1 ? params.mask : float4(1.0f);
float4 src = float4(src_tensor[s]) - float4(maximum); float4 src = float4(args.src_tensor.Read(0, 0, s)) - float4(maximum);
sum += dot(mask_temp, exp(src)); sum += dot(mask_temp, exp(src));
} }
@ -120,13 +120,13 @@ kernel void ComputeFunction($1
sum = tmpx1[0]; sum = tmpx1[0];
int dst_s = int(ugid.x); int dst_s = int(ugid.x);
if (dst_s < params.size.x) { if (dst_s < args.src_tensor.Slices()) {
int linear_index = dst_s; float4 src = float4(args.src_tensor.Read(0, 0, dst_s)) - float4(maximum);
float4 src = float4(src_tensor[linear_index]) - float4(maximum);
FLT4 value = FLT4(exp(src) * sum); FLT4 value = FLT4(exp(src) * sum);
uint3 gid = uint3(0, 0, linear_index); uint3 gid = uint3(0, 0, dst_s);
args.dst_tensor.GetAddress(linear_index, 0, 0, dst_s);
$2 $2
dst_tensor[linear_index] = value; args.dst_tensor.Write(value, 0, 0, dst_s);
} }
})"; })";
return code; return code;
@ -135,28 +135,27 @@ kernel void ComputeFunction($1
ComputeTaskDescriptor Softmax(const OperationDef& definition) { ComputeTaskDescriptor Softmax(const OperationDef& definition) {
ComputeTaskDescriptor desc(definition); ComputeTaskDescriptor desc(definition);
desc.tensors_as_args = true;
desc.shader_source = R"( desc.shader_source = R"(
#include <metal_stdlib> #include <metal_stdlib>
using namespace metal; using namespace metal;
struct uniforms { struct uniforms {
int4 size;
float4 mask; float4 mask;
}; };
$0 $0
kernel void ComputeFunction( kernel void ComputeFunction(
$1 $1
uint3 gid[[thread_position_in_grid]]) { uint3 gid[[thread_position_in_grid]]) {
if (int(gid.x) >= params.size.x || int(gid.y) >= params.size.y) { if (int(gid.x) >= args.dst_tensor.Width() || int(gid.y) >= args.dst_tensor.Height()) {
return; return;
} }
float maximum = src_tensor[gid.y * params.size.x + gid.x].x; float maximum = args.src_tensor.Read(gid.x, gid.y, 0).x;
for (int d = 0; d < params.size.z; ++d) { for (int d = 0; d < args.dst_tensor.Slices(); ++d) {
int buffer_index = (d * params.size.y + gid.y) * params.size.x + gid.x; float4 mask_a = d == args.dst_tensor.Slices() - 1 ? params.mask : float4(1.0f);
float4 mask_a = d == params.size.z - 1 ? params.mask : float4(1.0f);
float4 mask_b = float4(1.0f) - mask_a; float4 mask_b = float4(1.0f) - mask_a;
float4 src = float4(src_tensor[buffer_index]); float4 src = float4(args.src_tensor.Read(gid.x, gid.y, d));
src = src * mask_a + mask_b * src.x; src = src * mask_a + mask_b * src.x;
maximum = max(maximum, src.x); maximum = max(maximum, src.x);
maximum = max(maximum, src.y); maximum = max(maximum, src.y);
@ -165,19 +164,18 @@ kernel void ComputeFunction(
} }
float sum = 0.0f; float sum = 0.0f;
for (int d = 0; d < params.size.z; ++d) { for (int d = 0; d < args.dst_tensor.Slices(); ++d) {
int buffer_index = (d * params.size.y + gid.y) * params.size.x + gid.x; float4 mask_temp = d == args.dst_tensor.Slices() - 1 ? params.mask : float4(1.0f);
float4 mask_temp = d == params.size.z - 1 ? params.mask : float4(1.0f); float4 src = float4(args.src_tensor.Read(gid.x, gid.y, d)) - float4(maximum);
float4 src = float4(src_tensor[buffer_index]) - float4(maximum);
sum += dot(mask_temp, exp(src)); sum += dot(mask_temp, exp(src));
} }
for (int d = 0; d < params.size.z; ++d) { for (int d = 0; d < args.dst_tensor.Slices(); ++d) {
const int linear_index = (d * params.size.y + gid.y) * params.size.x + gid.x; float4 src = float4(args.src_tensor.Read(gid.x, gid.y, d)) - float4(maximum);
float4 src = float4(src_tensor[linear_index]) - float4(maximum);
FLT4 value = FLT4(exp(src) / sum); FLT4 value = FLT4(exp(src) / sum);
args.dst_tensor.GetAddress(linear_index, gid.x, gid.y, d);
$2 $2
dst_tensor[linear_index] = value; args.dst_tensor.Write(value, gid.x, gid.y, d);
} }
} }
)"; )";
@ -189,20 +187,9 @@ kernel void ComputeFunction(
{"constant uniforms& params", {"constant uniforms& params",
[](const std::vector<BHWC>& src_shapes, [](const std::vector<BHWC>& src_shapes,
const std::vector<BHWC>& dst_shapes) { const std::vector<BHWC>& dst_shapes) {
const int dst_depth = DivideRoundUp(dst_shapes[0].c, 4); float4 mask = GetMaskForLastPlane(dst_shapes[0].c);
struct uniforms { const uint8_t* ptr = reinterpret_cast<const uint8_t*>(&mask);
int4 size; return std::vector<uint8_t>(ptr, ptr + sizeof(float4));
float4 mask;
};
uniforms params;
params.size = {dst_shapes[0].w, dst_shapes[0].h, dst_depth, 1};
params.mask = {0.0f, 0.0f, 0.0f, 0.0f};
int reminder = dst_shapes[0].c % 4 == 0 ? 4 : dst_shapes[0].c % 4;
for (int i = 0; i < reminder; ++i) {
params.mask[i] = 1.0f;
}
const uint8_t* ptr = reinterpret_cast<const uint8_t*>(&params);
return std::vector<uint8_t>(ptr, ptr + sizeof(uniforms));
}}, }},
}; };
@ -220,6 +207,7 @@ kernel void ComputeFunction(
ComputeTaskDescriptor Softmax1x1(const OperationDef& definition, ComputeTaskDescriptor Softmax1x1(const OperationDef& definition,
const GpuInfo& gpu_info) { const GpuInfo& gpu_info) {
ComputeTaskDescriptor desc(definition); ComputeTaskDescriptor desc(definition);
desc.tensors_as_args = true;
desc.shader_source = GetSoftmax1x1Code(gpu_info); desc.shader_source = GetSoftmax1x1Code(gpu_info);
desc.AddSrcTensor("src_tensor", definition.src_tensors[0]); desc.AddSrcTensor("src_tensor", definition.src_tensors[0]);
@ -229,20 +217,9 @@ ComputeTaskDescriptor Softmax1x1(const OperationDef& definition,
{"constant uniforms& params", {"constant uniforms& params",
[](const std::vector<BHWC>& src_shapes, [](const std::vector<BHWC>& src_shapes,
const std::vector<BHWC>& dst_shapes) { const std::vector<BHWC>& dst_shapes) {
const int src_depth = DivideRoundUp(dst_shapes[0].c, 4); float4 mask = GetMaskForLastPlane(dst_shapes[0].c);
struct uniforms { const uint8_t* ptr = reinterpret_cast<const uint8_t*>(&mask);
int4 size; return std::vector<uint8_t>(ptr, ptr + sizeof(float4));
float4 mask;
};
uniforms params;
params.size = {src_depth, DivideRoundUp(src_depth, 32), 1, 1};
params.mask = {0.0f, 0.0f, 0.0f, 0.0f};
int reminder = dst_shapes[0].c % 4 == 0 ? 4 : dst_shapes[0].c % 4;
for (int i = 0; i < reminder; ++i) {
params.mask[i] = 1.0f;
}
const uint8_t* ptr = reinterpret_cast<const uint8_t*>(&params);
return std::vector<uint8_t>(ptr, ptr + sizeof(uniforms));
}}, }},
}; };

View File

@ -124,6 +124,8 @@ absl::Status MetalSpatialTensor::GetGPUResources(
if (!tensor_desc) { if (!tensor_desc) {
return absl::InvalidArgumentError("Expected TensorDescriptor on input."); return absl::InvalidArgumentError("Expected TensorDescriptor on input.");
} }
resources->ints.push_back(
{"slice_stride", tensor_desc->GetSliceStrideSize(shape_)});
if (descriptor_.HasAxis(Axis::WIDTH)) { if (descriptor_.HasAxis(Axis::WIDTH)) {
resources->ints.push_back({"width", Width()}); resources->ints.push_back({"width", Width()});
resources->ints.push_back({"width_div2", Width() / 2}); resources->ints.push_back({"width_div2", Width() / 2});

View File

@ -310,7 +310,7 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"Downloading data from https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media\u0026token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8\n", "Downloading data from https://dl.fbaipublicfiles.com/glue/data/SST-2.zip\n",
"7446528/7439277 [==============================] - 0s 0us/step\n" "7446528/7439277 [==============================] - 0s 0us/step\n"
] ]
} }
@ -318,7 +318,7 @@
"source": [ "source": [
"data_dir = tf.keras.utils.get_file(\n", "data_dir = tf.keras.utils.get_file(\n",
" fname='SST-2.zip',\n", " fname='SST-2.zip',\n",
" origin='https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media\u0026token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8',\n", " origin='https://dl.fbaipublicfiles.com/glue/data/SST-2.zip',\n",
" extract=True)\n", " extract=True)\n",
"data_dir = os.path.join(os.path.dirname(data_dir), 'SST-2')" "data_dir = os.path.join(os.path.dirname(data_dir), 'SST-2')"
] ]

View File

@ -587,11 +587,9 @@ TfLiteStatus InterpreterBuilder::ParseTensors(
status = kTfLiteError; status = kTfLiteError;
} }
size_t dims_signature_rank = 0; std::vector<int> dims_signature = {};
const int* dims_signature_data = nullptr;
if (tensor->shape_signature()) { if (tensor->shape_signature()) {
dims_signature_rank = tensor->shape_signature()->size(); dims_signature = FlatBufferIntArrayToVector(tensor->shape_signature());
dims_signature_data = tensor->shape_signature()->data();
} }
bool is_variable = tensor->is_variable(); bool is_variable = tensor->is_variable();
@ -623,7 +621,7 @@ TfLiteStatus InterpreterBuilder::ParseTensors(
} else { } else {
if (subgraph->SetTensorParametersReadWrite( if (subgraph->SetTensorParametersReadWrite(
i, type, get_name(tensor), dims, quantization, is_variable, i, type, get_name(tensor), dims, quantization, is_variable,
dims_signature_rank, dims_signature_data) != kTfLiteOk) { dims_signature) != kTfLiteOk) {
error_reporter_->Report("Tensor %d is invalidly specified in schema.\n", error_reporter_->Report("Tensor %d is invalidly specified in schema.\n",
i); i);
status = kTfLiteError; status = kTfLiteError;

View File

@ -80,6 +80,9 @@ TfLiteStatus copyToTensor(TfLiteContext* context, const FromT* in,
case kTfLiteInt32: case kTfLiteInt32:
copyCast(in, out->data.i32, num_elements); copyCast(in, out->data.i32, num_elements);
break; break;
case kTfLiteInt16:
copyCast(in, out->data.i16, num_elements);
break;
case kTfLiteUInt8: case kTfLiteUInt8:
copyCast(in, out->data.uint8, num_elements); copyCast(in, out->data.uint8, num_elements);
break; break;
@ -113,6 +116,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
return copyToTensor(context, input->data.i64, output, num_elements); return copyToTensor(context, input->data.i64, output, num_elements);
case kTfLiteInt32: case kTfLiteInt32:
return copyToTensor(context, input->data.i32, output, num_elements); return copyToTensor(context, input->data.i32, output, num_elements);
case kTfLiteInt16:
return copyToTensor(context, input->data.i16, output, num_elements);
case kTfLiteUInt8: case kTfLiteUInt8:
return copyToTensor(context, input->data.uint8, output, num_elements); return copyToTensor(context, input->data.uint8, output, num_elements);
case kTfLiteFloat32: case kTfLiteFloat32:

View File

@ -46,6 +46,22 @@ class CastOpModel : public SingleOpModel {
int output_; int output_;
}; };
TEST(CastOpModel, CastInt16ToFloat) {
CastOpModel m({TensorType_INT16, {2, 3}}, {TensorType_FLOAT32, {2, 3}});
m.PopulateTensor<int16_t>(m.input(), {100, 200, 300, 400, 500, 600});
m.Invoke();
EXPECT_THAT(m.ExtractVector<float>(m.output()),
ElementsAreArray({100.f, 200.f, 300.f, 400.f, 500.f, 600.f}));
}
TEST(CastOpModel, CastInt16ToInt32) {
CastOpModel m({TensorType_INT16, {2, 3}}, {TensorType_INT32, {2, 3}});
m.PopulateTensor<int16_t>(m.input(), {100, 200, 300, 400, 500, 600});
m.Invoke();
EXPECT_THAT(m.ExtractVector<int32_t>(m.output()),
ElementsAreArray({100, 200, 300, 400, 500, 600}));
}
TEST(CastOpModel, CastInt32ToFloat) { TEST(CastOpModel, CastInt32ToFloat) {
CastOpModel m({TensorType_INT32, {2, 3}}, {TensorType_FLOAT32, {2, 3}}); CastOpModel m({TensorType_INT32, {2, 3}}, {TensorType_FLOAT32, {2, 3}});
m.PopulateTensor<int32_t>(m.input(), {100, 200, 300, 400, 500, 600}); m.PopulateTensor<int32_t>(m.input(), {100, 200, 300, 400, 500, 600});
@ -62,6 +78,14 @@ TEST(CastOpModel, CastFloatToInt32) {
ElementsAreArray({100, 20, 3, 0, 0, 1})); ElementsAreArray({100, 20, 3, 0, 0, 1}));
} }
TEST(CastOpModel, CastFloatToInt16) {
CastOpModel m({TensorType_FLOAT32, {3, 2}}, {TensorType_INT16, {3, 2}});
m.PopulateTensor<float>(m.input(), {100.f, 20.f, 3.f, 0.4f, 0.999f, 1.1f});
m.Invoke();
EXPECT_THAT(m.ExtractVector<int16_t>(m.output()),
ElementsAreArray({100, 20, 3, 0, 0, 1}));
}
TEST(CastOpModel, CastInt64ToFloat) { TEST(CastOpModel, CastInt64ToFloat) {
CastOpModel m({TensorType_INT64, {2, 3}}, {TensorType_FLOAT32, {2, 3}}); CastOpModel m({TensorType_INT64, {2, 3}}, {TensorType_FLOAT32, {2, 3}});
m.PopulateTensor<int64_t>(m.input(), {100, 200, 300, 400, 500, 600}); m.PopulateTensor<int64_t>(m.input(), {100, 200, 300, 400, 500, 600});

View File

@ -121,7 +121,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
// TODO(ahentz): Creating 'all_inputs' below is not very efficient. We should // TODO(ahentz): Creating 'all_inputs' below is not very efficient. We should
// allocate and populate these during Prepare(). // allocate and populate these during Prepare().
// TODO(ycling): Activation function parameter is ignored. For now we dont have // TODO(ycling): Activation function parameter is ignored. For now we don't have
// a model with a Concatenation with fused activation function. // a model with a Concatenation with fused activation function.
#define TF_LITE_CONCATENATION(scalar) \ #define TF_LITE_CONCATENATION(scalar) \
{ \ { \

View File

@ -494,6 +494,7 @@ cc_library(
"reference/resize_nearest_neighbor.h", "reference/resize_nearest_neighbor.h",
"reference/round.h", "reference/round.h",
"reference/softmax.h", "reference/softmax.h",
"reference/space_to_depth.h",
"reference/strided_slice.h", "reference/strided_slice.h",
"reference/sub.h", "reference/sub.h",
"reference/svdf.h", "reference/svdf.h",
@ -511,13 +512,14 @@ cc_library(
}), }),
compatible_with = get_compatible_with_portable(), compatible_with = get_compatible_with_portable(),
copts = tflite_copts(), copts = tflite_copts(),
# We are disabling parse_headers for the tf_lite_static_memory build to # We are disabling parse_headers for this header-only target so that the
# allow it to be consistent with the OSS bazel build. See b/175817116 # external and internal builds are consistent. The primary issue here is
# for more details. # that parse_headers is not supported with bazel and the TFLM team would
features = select({ # really like to have all build errors in shared Micro/Lite code be
":tf_lite_static_memory": ["-parse_headers"], # reproducible from the OSS build as well.
"//conditions:default": [], #
}), # See b/175817116 for more details.
features = ["-parse_headers"],
deps = [ deps = [
":common", ":common",
":compatibility", ":compatibility",
@ -588,6 +590,7 @@ cc_library(
"reference/resize_nearest_neighbor.h", "reference/resize_nearest_neighbor.h",
"reference/round.h", "reference/round.h",
"reference/softmax.h", "reference/softmax.h",
"reference/space_to_depth.h",
"reference/strided_slice.h", "reference/strided_slice.h",
"reference/string_comparisons.h", "reference/string_comparisons.h",
"reference/sub.h", "reference/sub.h",

View File

@ -15,16 +15,13 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_CONV_H_ #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_CONV_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_CONV_H_ #define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_CONV_H_
#include "tensorflow/lite/kernels/internal/types.h"
#include "tensorflow/lite/kernels/internal/common.h" #include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/types.h"
namespace tflite { namespace tflite {
namespace reference_ops { namespace reference_ops {
inline void Conv(const ConvParams& params, const RuntimeShape& input_shape, inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
const float* input_data, const RuntimeShape& filter_shape, const float* input_data, const RuntimeShape& filter_shape,
const float* filter_data, const RuntimeShape& bias_shape, const float* filter_data, const RuntimeShape& bias_shape,
@ -108,8 +105,8 @@ inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
uint8_t* output_data, const RuntimeShape& im2col_shape, uint8_t* output_data, const RuntimeShape& im2col_shape,
uint8_t* im2col_data, void* cpu_backend_context) { uint8_t* im2col_data, void* cpu_backend_context) {
(void)cpu_backend_context; // only used in optimized code. (void)cpu_backend_context; // only used in optimized code.
(void)im2col_data; // only used in optimized code. (void)im2col_data; // only used in optimized code.
(void)im2col_shape; // only used in optimized code. (void)im2col_shape; // only used in optimized code.
const int stride_width = params.stride_width; const int stride_width = params.stride_width;
const int stride_height = params.stride_height; const int stride_height = params.stride_height;
const int dilation_width_factor = params.dilation_width_factor; const int dilation_width_factor = params.dilation_width_factor;

View File

@ -61,6 +61,7 @@ limitations under the License.
#include "tensorflow/lite/kernels/internal/reference/resize_nearest_neighbor.h" #include "tensorflow/lite/kernels/internal/reference/resize_nearest_neighbor.h"
#include "tensorflow/lite/kernels/internal/reference/round.h" #include "tensorflow/lite/kernels/internal/reference/round.h"
#include "tensorflow/lite/kernels/internal/reference/softmax.h" #include "tensorflow/lite/kernels/internal/reference/softmax.h"
#include "tensorflow/lite/kernels/internal/reference/space_to_depth.h"
#include "tensorflow/lite/kernels/internal/reference/strided_slice.h" #include "tensorflow/lite/kernels/internal/reference/strided_slice.h"
#include "tensorflow/lite/kernels/internal/reference/string_comparisons.h" #include "tensorflow/lite/kernels/internal/reference/string_comparisons.h"
#include "tensorflow/lite/kernels/internal/reference/sub.h" #include "tensorflow/lite/kernels/internal/reference/sub.h"
@ -126,58 +127,6 @@ inline void DepthToSpace(const tflite::DepthToSpaceParams& op_params,
} }
} }
template <typename T>
inline void SpaceToDepth(const tflite::SpaceToDepthParams& op_params,
const RuntimeShape& unextended_input_shape,
const T* input_data,
const RuntimeShape& unextended_output_shape,
T* output_data) {
TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
const RuntimeShape input_shape =
RuntimeShape::ExtendedShape(4, unextended_input_shape);
const RuntimeShape output_shape =
RuntimeShape::ExtendedShape(4, unextended_output_shape);
const int input_depth = input_shape.Dims(3);
const int input_width = input_shape.Dims(2);
const int input_height = input_shape.Dims(1);
const int input_batch = input_shape.Dims(0);
const int output_depth = output_shape.Dims(3);
const int output_width = output_shape.Dims(2);
const int output_height = output_shape.Dims(1);
const int output_batch = output_shape.Dims(0);
const int32 block_size = op_params.block_size;
TFLITE_DCHECK_EQ(input_width, output_width * block_size);
TFLITE_DCHECK_EQ(input_height, output_height * block_size);
TFLITE_DCHECK_EQ(input_depth * block_size * block_size, output_depth);
TFLITE_DCHECK_EQ(input_batch, output_batch);
for (int in_b = 0; in_b < input_batch; ++in_b) {
for (int in_h = 0; in_h < input_height; ++in_h) {
for (int in_w = 0; in_w < input_width; ++in_w) {
for (int in_d = 0; in_d < input_depth; ++in_d) {
const int out_d =
in_d + ((in_h % block_size) * block_size + in_w % block_size) *
input_depth;
const int out_w = in_w / block_size;
const int out_h = in_h / block_size;
const int out_b = in_b;
const int input_index = Offset(input_shape, in_b, in_h, in_w, in_d);
const int output_index =
Offset(output_shape, out_b, out_h, out_w, out_d);
output_data[output_index] = input_data[input_index];
}
}
}
}
}
inline void Elu(const RuntimeShape& input_shape, const float* input_data, inline void Elu(const RuntimeShape& input_shape, const float* input_data,
const RuntimeShape& output_shape, float* output_data) { const RuntimeShape& output_shape, float* output_data) {
const int flat_size = MatchingFlatSize(input_shape, output_shape); const int flat_size = MatchingFlatSize(input_shape, output_shape);

View File

@ -0,0 +1,78 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_SPACE_TO_DEPTH_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_SPACE_TO_DEPTH_H_
#include "tensorflow/lite/kernels/internal/types.h"
namespace tflite {
namespace reference_ops {
template <typename T>
inline void SpaceToDepth(const tflite::SpaceToDepthParams& op_params,
const RuntimeShape& unextended_input_shape,
const T* input_data,
const RuntimeShape& unextended_output_shape,
T* output_data) {
TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
const RuntimeShape input_shape =
RuntimeShape::ExtendedShape(4, unextended_input_shape);
const RuntimeShape output_shape =
RuntimeShape::ExtendedShape(4, unextended_output_shape);
const int input_depth = input_shape.Dims(3);
const int input_width = input_shape.Dims(2);
const int input_height = input_shape.Dims(1);
const int input_batch = input_shape.Dims(0);
const int output_depth = output_shape.Dims(3);
const int output_width = output_shape.Dims(2);
const int output_height = output_shape.Dims(1);
const int output_batch = output_shape.Dims(0);
const int32 block_size = op_params.block_size;
TFLITE_DCHECK_EQ(input_width, output_width * block_size);
TFLITE_DCHECK_EQ(input_height, output_height * block_size);
TFLITE_DCHECK_EQ(input_depth * block_size * block_size, output_depth);
TFLITE_DCHECK_EQ(input_batch, output_batch);
for (int in_b = 0; in_b < input_batch; ++in_b) {
for (int in_h = 0; in_h < input_height; ++in_h) {
for (int in_w = 0; in_w < input_width; ++in_w) {
for (int in_d = 0; in_d < input_depth; ++in_d) {
const int out_d =
in_d + ((in_h % block_size) * block_size + in_w % block_size) *
input_depth;
const int out_w = in_w / block_size;
const int out_h = in_h / block_size;
const int out_b = in_b;
const int input_index = Offset(input_shape, in_b, in_h, in_w, in_d);
const int output_index =
Offset(output_shape, out_b, out_h, out_w, out_d);
output_data[output_index] = input_data[input_index];
}
}
}
}
}
} // namespace reference_ops
} // namespace tflite
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_SPACE_TO_DEPTH_H_

View File

@ -170,7 +170,7 @@ TfLiteStatus PopulateQuantizedLstmParams8x8_16(
intermediate_zp.push_back(0); intermediate_zp.push_back(0);
} }
} }
// In the absense of projection, hidden becomes otuput and this intermediate // In the absence of projection, hidden becomes otuput and this intermediate
// is ignored. // is ignored.
TfLiteTensor* hidden; TfLiteTensor* hidden;
TF_LITE_ENSURE_OK(context, GetIntermediatesSafe(context, node, 4, &hidden)); TF_LITE_ENSURE_OK(context, GetIntermediatesSafe(context, node, 4, &hidden));

View File

@ -204,7 +204,7 @@ to determine if the requested feature aligns with the TFLM roadmap.
1. Run all the tests for x86, and any other platform that you are modifying. 1. Run all the tests for x86, and any other platform that you are modifying.
``` ```
tensorflow/lite/micro/tools/make/tools/ci_build/test_x86.sh tensorflow/lite/micro/tools/ci_build/test_x86.sh
``` ```
Please check the READMEs in the optimized kernel directories for specific Please check the READMEs in the optimized kernel directories for specific

View File

@ -1,2 +1,2 @@
numpy==1.16.2 numpy==1.16.2
tensorflow==2.0.0-beta1 tensorflow==2.4.0

View File

@ -33,7 +33,7 @@ set +e
# The pigweed scripts only work from a git repository and the Tensorflow CI # The pigweed scripts only work from a git repository and the Tensorflow CI
# infrastructure does not always guarantee that. As an ugly workaround, we # infrastructure does not always guarantee that. As an ugly workaround, we
# create our own git repo when running on the CI servers. # create our own git repo when running on the CI servers.
pushd tensorflow/lite/micro/ pushd tensorflow/lite/
if [[ ${1} == "PRESUBMIT" ]]; then if [[ ${1} == "PRESUBMIT" ]]; then
git init . git init .
git config user.email "tflm@google.com" git config user.email "tflm@google.com"
@ -43,9 +43,12 @@ if [[ ${1} == "PRESUBMIT" ]]; then
fi fi
# Check for license with the necessary exclusions. # Check for license with the necessary exclusions.
tools/make/downloads/pigweed/pw_presubmit/py/pw_presubmit/pigweed_presubmit.py \ micro/tools/make/downloads/pigweed/pw_presubmit/py/pw_presubmit/pigweed_presubmit.py \
. \ kernels/internal/reference/ \
micro/ \
-p copyright_notice \ -p copyright_notice \
-e kernels/internal/reference/integer_ops/ \
-e kernels/internal/reference/reference_ops.h \
-e tools/make/downloads \ -e tools/make/downloads \
-e tools/make/targets/ecm3531 \ -e tools/make/targets/ecm3531 \
-e BUILD\ -e BUILD\
@ -66,8 +69,11 @@ LICENSE_CHECK_RESULT=$?
# Python files (with yapf as the formatter) because that needs additional setup. # Python files (with yapf as the formatter) because that needs additional setup.
# We are also ignoring the markdown files to allow for a more gradual rollout of # We are also ignoring the markdown files to allow for a more gradual rollout of
# this presubmit check. # this presubmit check.
tools/make/downloads/pigweed/pw_presubmit/py/pw_presubmit/format_code.py \ micro/tools/make/downloads/pigweed/pw_presubmit/py/pw_presubmit/format_code.py \
. \ kernels/internal/reference/ \
micro/ \
-e kernels/internal/reference/integer_ops/ \
-e kernels/internal/reference/reference_ops.h \
-e "\.inc" \ -e "\.inc" \
-e "\.md" \ -e "\.md" \
-e "\.py" -e "\.py"
@ -76,7 +82,7 @@ CLANG_FORMAT_RESULT=$?
popd popd
if [[ ${1} == "PRESUBMIT" ]]; then if [[ ${1} == "PRESUBMIT" ]]; then
rm -rf tensorflow/lite/micro/.git rm -rf tensorflow/lite/.git
fi fi
# Re-enable exit on error now that we are done with the temporary git repo. # Re-enable exit on error now that we are done with the temporary git repo.

View File

@ -116,7 +116,7 @@ download_and_extract() {
local tempdir=$(mktemp -d) local tempdir=$(mktemp -d)
local tempdir2=$(mktemp -d) local tempdir2=$(mktemp -d)
local tempfile=${tempdir}/temp_file local tempfile=${tempdir}/temp_file
local curl_retries=3 local curl_retries=5
# Destionation already downloaded. # Destionation already downloaded.
if [ -d ${dir} ]; then if [ -d ${dir} ]; then
@ -131,24 +131,21 @@ download_and_extract() {
mkdir -p "${dir}" mkdir -p "${dir}"
# We've been seeing occasional 56 errors from valid URLs, so set up a retry # We've been seeing occasional 56 errors from valid URLs, so set up a retry
# loop to attempt to recover from them. # loop to attempt to recover from them.
for (( i=1; i<=$curl_retries; ++i )) for (( i=1; i<=$curl_retries; ++i )); do
do
# We have to use this approach because we normally halt the script when # We have to use this approach because we normally halt the script when
# there's an error, and instead we want to catch errors so we can retry. # there's an error, and instead we want to catch errors so we can retry.
set +e set +ex
curl -Ls --fail --retry 5 "${url}" > ${tempfile} curl -LsS --fail --retry 5 "${url}" > ${tempfile}
CURL_RESULT=$? CURL_RESULT=$?
set -e set -ex
# Was the command successful? If so, continue. # Was the command successful? If so, continue.
if [[ $CURL_RESULT -eq 0 ]] if [[ $CURL_RESULT -eq 0 ]]; then
then
break break
fi fi
# Keep trying if we see the '56' error code. # Keep trying if we see the '56' error code.
if [[ ( $CURL_RESULT -ne 56 ) || ( $i -eq $curl_retries ) ]] if [[ ( $CURL_RESULT -ne 56 ) || ( $i -eq $curl_retries ) ]]; then
then
echo "Error $CURL_RESULT downloading '${url}'" echo "Error $CURL_RESULT downloading '${url}'"
exit 1 exit 1
fi fi

View File

@ -110,7 +110,7 @@ class OpsSet(enum.Enum):
"EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8" "EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8"
def __str__(self): def __str__(self):
return self.value return str(self.value)
@staticmethod @staticmethod
def get_options(): def get_options():
@ -394,22 +394,22 @@ def build_toco_convert_protos(input_tensors,
input_tensors: List of input tensors. Type and shape are computed using input_tensors: List of input tensors. Type and shape are computed using
`foo.shape` and `foo.dtype`. `foo.shape` and `foo.dtype`.
output_tensors: List of output tensors (only .name is used from this). output_tensors: List of output tensors (only .name is used from this).
inference_type: Target data type of real-number arrays in the output file. inference_type: Data type of numeric arrays, excluding the input layer.
Must be `{tf.float32, tf.uint8, tf.int8}`. (default tf.float32) (default tf.float32, must be in {tf.float32, tf.int8, tf.uint8})
inference_input_type: Target data type of real-number input arrays. Allows inference_input_type: Data type of the numeric arrays in the input layer. If
for a different type for input arrays in the case of quantization. Must be `inference_input_type` is in {tf.int8, tf.uint8}, then
`{tf.float32, tf.uint8, tf.int8}`. (default `inference_type`) `quantized_input_stats` must be provided. (default is the value assigned
input_format: Type of data to read Currently must be to `inference_type`, must be in {tf.float32, tf.int8, tf.uint8})
`{TENSORFLOW_GRAPHDEF}`. (default TENSORFLOW_GRAPHDEF) input_format: Type of data to read.
input_shapes: Input array shape. It needs to be a list of the same length as (default TENSORFLOW_GRAPHDEF, must be in {TENSORFLOW_GRAPHDEF})
`input_tensors`, or None. (default None) input_shapes: Input array shape. (default None, must be None or a list of
output_format: Output file format. Currently must be `{TFLITE, the same length as `input_tensors`.)
GRAPHVIZ_DOT}`. (default TFLITE) output_format: Output file format. (default TFLITE, must be in
quantized_input_stats: List of tuples of floats representing the mean and {TFLITE, GRAPHVIZ_DOT})
standard deviation. Each tuple maps to the corresponding input tensor. quantized_input_stats: Map of input tensor names to a tuple of floats
Only need if `inference_input_type` is `QUANTIZED_UINT8` or `INT8`. representing the mean and standard deviation of the training data.
real_input_value = (quantized_input_value - mean_value) / std_dev_value. (e.g., {"foo" : (0., 1.)}). Required if `inference_input_type` is tf.int8
(default None) or tf.uint8. (default None)
default_ranges_stats: Tuple of integers representing (min, max) range values default_ranges_stats: Tuple of integers representing (min, max) range values
for all arrays without a specified range. Intended for experimenting with for all arrays without a specified range. Intended for experimenting with
quantization via "dummy quantization". (default None) quantization via "dummy quantization". (default None)
@ -574,8 +574,10 @@ def toco_convert_graph_def(input_data, input_arrays_with_shape, output_arrays,
if _requires_input_stats(toco_flags): if _requires_input_stats(toco_flags):
if (("quantized_input_stats" not in kwargs) or if (("quantized_input_stats" not in kwargs) or
(not kwargs["quantized_input_stats"])): (not kwargs["quantized_input_stats"])):
raise ValueError("std_dev and mean must be defined when inference_type " raise ValueError(
"or inference_input_type is QUANTIZED_UINT8 or INT8.") "The `quantized_input_stats` flag must be defined when either "
"`inference_type` flag or `inference_input_type` flag is set to "
"tf.int8 or tf.uint8.")
input_array.mean_value, input_array.std_value = kwargs[ input_array.mean_value, input_array.std_value = kwargs[
"quantized_input_stats"][idx] "quantized_input_stats"][idx]
input_array.name = name input_array.name = name
@ -661,7 +663,7 @@ def toco_convert(input_data, input_tensors, output_tensors, *args, **kwargs):
Typically this function is used to convert from TensorFlow GraphDef to TFLite. Typically this function is used to convert from TensorFlow GraphDef to TFLite.
Conversion can be customized by providing arguments that are forwarded to Conversion can be customized by providing arguments that are forwarded to
`build_toco_convert_protos` (see documentation for details). This function has `build_toco_convert_protos` (see documentation for details). This function has
been deprecated. Please use `lite.TFLiteConverter` instead. been deprecated. Please use `tf.lite.TFLiteConverter` instead.
Args: Args:
input_data: Input data (i.e. often `sess.graph_def`), input_data: Input data (i.e. often `sess.graph_def`),

View File

@ -137,7 +137,7 @@ class ConvertTest(test_util.TensorFlowTestCase):
self.assertEqual("output", output_details[0]["name"]) self.assertEqual("output", output_details[0]["name"])
self.assertEqual(np.uint8, output_details[0]["dtype"]) self.assertEqual(np.uint8, output_details[0]["dtype"])
self.assertTrue(([1, 16, 16, 3] == output_details[0]["shape"]).all()) self.assertTrue(([1, 16, 16, 3] == output_details[0]["shape"]).all())
self.assertTrue(output_details[0]["quantization"][0] > 0) # scale self.assertGreater(output_details[0]["quantization"][0], 0) # scale
def testGraphDefQuantizationInvalid(self): def testGraphDefQuantizationInvalid(self):
with ops.Graph().as_default(): with ops.Graph().as_default():
@ -159,9 +159,9 @@ class ConvertTest(test_util.TensorFlowTestCase):
enable_mlir_converter=False, enable_mlir_converter=False,
inference_type=dtypes.uint8) inference_type=dtypes.uint8)
self.assertEqual( self.assertEqual(
"std_dev and mean must be defined when inference_type or " "The `quantized_input_stats` flag must be defined when either "
"inference_input_type is QUANTIZED_UINT8 or INT8.", "`inference_type` flag or `inference_input_type` flag is set to "
str(error.exception)) "tf.int8 or tf.uint8.", str(error.exception))
class ConvertTestOpHint(test_util.TensorFlowTestCase): class ConvertTestOpHint(test_util.TensorFlowTestCase):

View File

@ -61,6 +61,7 @@ from tensorflow.lite.python.util import get_debug_info as _get_debug_info
from tensorflow.lite.python.util import get_grappler_config as _get_grappler_config from tensorflow.lite.python.util import get_grappler_config as _get_grappler_config
from tensorflow.lite.python.util import get_tensor_name as _get_tensor_name from tensorflow.lite.python.util import get_tensor_name as _get_tensor_name
from tensorflow.lite.python.util import get_tensors_from_tensor_names as _get_tensors_from_tensor_names from tensorflow.lite.python.util import get_tensors_from_tensor_names as _get_tensors_from_tensor_names
from tensorflow.lite.python.util import get_tf_type_name as _get_tf_type_name
from tensorflow.lite.python.util import is_frozen_graph as _is_frozen_graph from tensorflow.lite.python.util import is_frozen_graph as _is_frozen_graph
from tensorflow.lite.python.util import model_input_signature as _model_input_signature from tensorflow.lite.python.util import model_input_signature as _model_input_signature
from tensorflow.lite.python.util import modify_model_io_type as _modify_model_io_type from tensorflow.lite.python.util import modify_model_io_type as _modify_model_io_type
@ -89,19 +90,14 @@ from tensorflow.python.util.tf_export import tf_export as _tf_export
@_tf_export("lite.Optimize") @_tf_export("lite.Optimize")
class Optimize(enum.Enum): class Optimize(enum.Enum):
"""Enum defining the optimizations to apply when generating tflite graphs. """Enum defining the optimizations to apply when generating a tflite model.
Some optimizations may come at the cost of accuracy.
DEFAULT DEFAULT
Default optimization strategy. Default optimization strategy that quantizes model weights. Enhanced
optimizations are gained by providing a representative dataset that
Converter will do its best to improve size and latency based on the quantizes biases and activations as well.
information provided. Converter will do its best to reduce size and latency, while minimizing
Enhanced optimizations are gained by providing a representative_dataset. the loss in accuracy.
This is recommended, and is currently equivalent to the modes below.
Currently, weights will be quantized and if representative_dataset is
provided, activations for quantizable operations will also be quantized.
OPTIMIZE_FOR_SIZE OPTIMIZE_FOR_SIZE
Deprecated. Does the same as DEFAULT. Deprecated. Does the same as DEFAULT.
@ -110,14 +106,11 @@ class Optimize(enum.Enum):
Deprecated. Does the same as DEFAULT. Deprecated. Does the same as DEFAULT.
""" """
# Default optimization strategy. # Default optimization strategy that quantizes model weights. Enhanced
# # optimizations are gained by providing a representative dataset that
# Converter will do its best to improve size and latency based on the # quantizes biases and activations as well.
# information provided. # Converter will do its best to reduce size and latency, while minimizing
# Enhanced optimizations can be gained by providing a representative_dataset. # the loss in accuracy.
# This is recommended, and is currently equivalent to the modes below.
# Currently, weights will be quantized and if representative_dataset is
# provided, activations for quantizable operations will also be quantized.
DEFAULT = "DEFAULT" DEFAULT = "DEFAULT"
# Deprecated. Does the same as DEFAULT. # Deprecated. Does the same as DEFAULT.
@ -132,48 +125,47 @@ class Optimize(enum.Enum):
@_tf_export("lite.RepresentativeDataset") @_tf_export("lite.RepresentativeDataset")
class RepresentativeDataset(object): class RepresentativeDataset(object):
"""Representative dataset to evaluate optimizations. """Representative dataset used to optimize the model.
A representative dataset that can be used to evaluate optimizations by the This is a generator function that provides a small dataset to calibrate or
converter. E.g. converter can use these examples to estimate (min, max) ranges estimate the range, i.e, (min, max) of all floating-point arrays in the model
by calibrating the model on inputs. This can allow converter to quantize a (such as model input, activation outputs of intermediate layers, and model
converted floating point model. output) for quantization. Usually, this is a small subset of a few hundred
samples randomly chosen, in no particular order, from the training or
evaluation dataset.
""" """
def __init__(self, input_gen): def __init__(self, input_gen):
"""Creates a representative dataset. """Creates a representative dataset.
Args: Args:
input_gen: an input generator that can be used to generate input samples input_gen: A generator function that generates input samples for the
for the model. This must be a callable object that returns an object model and has the same order, type and shape as the inputs to the model.
that supports the `iter()` protocol (e.g. a generator function). The Usually, this is a small subset of a few hundred samples randomly
elements generated must have same type and shape as inputs to the model. chosen, in no particular order, from the training or evaluation dataset.
""" """
self.input_gen = input_gen self.input_gen = input_gen
@_tf_export("lite.TargetSpec") @_tf_export("lite.TargetSpec")
class TargetSpec(object): class TargetSpec(object):
"""Specification of target device. """Specification of target device used to optimize the model.
Details about target device. Converter optimizes the generated model for
specific device.
Attributes: Attributes:
supported_ops: Experimental flag, subject to change. Set of OpsSet options supported_ops: Experimental flag, subject to change. Set of `tf.lite.OpsSet`
supported by the device. (default set([OpsSet.TFLITE_BUILTINS])) options, where each option represents a set of operators supported by the
supported_types: List of types for constant values on the target device. target device. (default {tf.lite.OpsSet.TFLITE_BUILTINS}))
Frequently, an optimization choice is driven by the most compact supported_types: Set of `tf.dtypes.DType` data types supported on the target
(i.e. smallest) type in this list (default [tf.float32]) device. If initialized, optimization might be driven by the smallest type
in this set. (default set())
experimental_select_user_tf_ops: Experimental flag, subject to change. Set experimental_select_user_tf_ops: Experimental flag, subject to change. Set
of user's TensorFlow operators' names that are required in the TensorFlow of user's TensorFlow operators' names that are required in the TensorFlow
Lite runtime. These ops will be exported as select TensorFlow ops in the Lite runtime. These ops will be exported as select TensorFlow ops in the
model (in conjunction with the OpsSet.SELECT_TF_OPS flag). This is an model (in conjunction with the tf.lite.OpsSet.SELECT_TF_OPS flag). This is
advanced feature that should only be used if the client is using TF ops an advanced feature that should only be used if the client is using TF ops
that may not be linked in by default with the TF ops that are provided that may not be linked in by default with the TF ops that are provided
when using the SELECT_TF_OPS path. The client is responsible for linking when using the SELECT_TF_OPS path. The client is responsible for linking
these ops into the target runtime. these ops into the target runtime.
""" """
def __init__(self, def __init__(self,
@ -181,17 +173,17 @@ class TargetSpec(object):
supported_types=None, supported_types=None,
experimental_select_user_tf_ops=None): experimental_select_user_tf_ops=None):
if supported_ops is None: if supported_ops is None:
supported_ops = set([OpsSet.TFLITE_BUILTINS]) supported_ops = {OpsSet.TFLITE_BUILTINS}
self.supported_ops = supported_ops self.supported_ops = supported_ops
if supported_types is None: if supported_types is None:
supported_types = [] supported_types = set()
self.supported_types = supported_types self.supported_types = supported_types
if experimental_select_user_tf_ops is None: if experimental_select_user_tf_ops is None:
self.experimental_select_user_tf_ops = [] self.experimental_select_user_tf_ops = set()
class QuantizationMode(object): class QuantizationMode(object):
"""QuantizationMode determines the quantized conversion from user options.""" """QuantizationMode determines the quantization type from user options."""
def __init__(self, optimizations, target_spec, representative_dataset, def __init__(self, optimizations, target_spec, representative_dataset,
graph_def): graph_def):
@ -205,7 +197,6 @@ class QuantizationMode(object):
# TODO(b/162537905): Refactor the following quantization functions - # TODO(b/162537905): Refactor the following quantization functions -
# re-organize and refactor for better readability. # re-organize and refactor for better readability.
def post_training_int8_no_float(self): def post_training_int8_no_float(self):
"""Post training int8 quantize, disallow float fallback."""
return (self._any_optimization_enabled() and return (self._any_optimization_enabled() and
self._is_int8_target_required() and self._is_int8_target_required() and
not self._is_int16x8_target_required() and not self._is_int16x8_target_required() and
@ -213,19 +204,16 @@ class QuantizationMode(object):
self._representative_dataset is not None) self._representative_dataset is not None)
def post_training_int8_allow_float(self): def post_training_int8_allow_float(self):
"""Post training int8 quantize, allow float fallback."""
return (self._any_optimization_enabled() and return (self._any_optimization_enabled() and
not self._is_int16x8_target_required() and not self._is_int16x8_target_required() and
self._representative_dataset is not None and self._representative_dataset is not None and
self._smallest_supported_type() == _dtypes.int8) self._smallest_supported_type() == _dtypes.int8)
def is_post_training_integer_quantize_8(self): def is_post_training_integer_quantize_8(self):
"""Post training integer 8 quantization."""
return (self.post_training_int8_no_float() or return (self.post_training_int8_no_float() or
self.post_training_int8_allow_float()) self.post_training_int8_allow_float())
def is_post_training_integer_quantize_16x8(self): def is_post_training_integer_quantize_16x8(self):
"""Post training integer 16x8 quantization."""
return (self.post_training_int16x8_no_float() or return (self.post_training_int16x8_no_float() or
self.post_training_int16x8_allow_float()) self.post_training_int16x8_allow_float())
@ -239,7 +227,6 @@ class QuantizationMode(object):
self.contains_training_quant_op()) self.contains_training_quant_op())
def post_training_int16x8_no_float(self): def post_training_int16x8_no_float(self):
"""Post training int16x8 quantize, disallow float fallback."""
return (self._any_optimization_enabled() and return (self._any_optimization_enabled() and
not self._is_int8_target_required() and not self._is_int8_target_required() and
self._is_int16x8_target_required() and self._is_int16x8_target_required() and
@ -247,13 +234,11 @@ class QuantizationMode(object):
self._representative_dataset is not None) self._representative_dataset is not None)
def post_training_int16x8_allow_float(self): def post_training_int16x8_allow_float(self):
"""Post training int16x8 quantize, allow float fallback."""
return (self._any_optimization_enabled() and return (self._any_optimization_enabled() and
self._is_int16x8_target_required() and self._is_int16x8_target_required() and
self._is_allow_float()) self._is_allow_float())
def post_training_dynamic_range_int8(self): def post_training_dynamic_range_int8(self):
"""Post training int8 const, on-the-fly int8 quantize of dynamic tensors."""
# Post-training dynamic range quantization is only enabled if post-training # Post-training dynamic range quantization is only enabled if post-training
# int8 quantization and training time quantization was not done. # int8 quantization and training time quantization was not done.
return (self._any_optimization_enabled() and return (self._any_optimization_enabled() and
@ -262,7 +247,6 @@ class QuantizationMode(object):
self._smallest_supported_type() == _dtypes.int8) self._smallest_supported_type() == _dtypes.int8)
def post_training_fp16(self): def post_training_fp16(self):
"""Post training fp16 quantize."""
return (self._any_optimization_enabled() and return (self._any_optimization_enabled() and
self._smallest_supported_type() == _dtypes.float16) self._smallest_supported_type() == _dtypes.float16)
@ -416,21 +400,20 @@ class TFLiteConverterBase(object):
"""Converter subclass to share functionality between V1 and V2 converters.""" """Converter subclass to share functionality between V1 and V2 converters."""
def __init__(self): def __init__(self):
self.allow_custom_ops = False self.optimizations = set()
self.target_spec = TargetSpec()
self.optimizations = []
self.representative_dataset = None self.representative_dataset = None
self.target_spec = TargetSpec()
self.allow_custom_ops = False
self.experimental_new_converter = True self.experimental_new_converter = True
self._experimental_new_quantizer = False self._experimental_new_quantizer = False
self._experimental_calibrate_only = False self._experimental_calibrate_only = False
# The 'GraphDebugInfo' contains the stack traces of all the original nodes self._experimental_sparsify_model = False
# in the `GraphDef` to the converter. self._debug_info = None # contains the stack traces of all the original
self._debug_info = None # nodes in the `GraphDef` to the converter.
self.saved_model_dir = None self.saved_model_dir = None
self._saved_model_tags = None self._saved_model_tags = None
self._saved_model_version = 0 self._saved_model_version = 0
self._saved_model_exported_names = [] self._saved_model_exported_names = []
self._experimental_sparsify_model = False
def _grappler_config(self, optimizers=None): def _grappler_config(self, optimizers=None):
"""Creates a tf.compat.v1.ConfigProto for configuring Grappler. """Creates a tf.compat.v1.ConfigProto for configuring Grappler.
@ -684,9 +667,9 @@ class TFLiteSavedModelConverterV2(TFLiteConverterBaseV2):
saved_model_dir: Directory of the SavedModel. saved_model_dir: Directory of the SavedModel.
saved_model_tags: Set of tags identifying the MetaGraphDef within the saved_model_tags: Set of tags identifying the MetaGraphDef within the
SavedModel to analyze. All tags in the tag set must be present. (default SavedModel to analyze. All tags in the tag set must be present. (default
set(SERVING)). {tf.saved_model.SERVING}).
saved_model_exported_names: Names to be exported (default: export all) saved_model_exported_names: Names to be exported when the saved model
when the saved model import path is on. import path is on.
trackable_obj: tf.AutoTrackable object associated with `funcs`. A trackable_obj: tf.AutoTrackable object associated with `funcs`. A
reference to this object needs to be maintained so that Variables do not reference to this object needs to be maintained so that Variables do not
get garbage collected since functions have a weak reference to get garbage collected since functions have a weak reference to
@ -972,20 +955,21 @@ class TFLiteConverterV2(TFLiteFrozenGraphConverterV2):
"""Converts a TensorFlow model into TensorFlow Lite model. """Converts a TensorFlow model into TensorFlow Lite model.
Attributes: Attributes:
allow_custom_ops: Boolean indicating whether to allow custom operations. optimizations: Experimental flag, subject to change. Set of optimizations
When False, any unknown operation is an error. When True, custom ops are to apply. e.g {tf.lite.Optimize.DEFAULT}. (default None, must be None or a
created for any op that is unknown. The developer needs to provide these set of values of type `tf.lite.Optimize`)
to the TensorFlow Lite runtime with a custom resolver. (default False) representative_dataset: A generator function used for integer quantization
optimizations: Experimental flag, subject to change. A list of optimizations where each generated sample has the same order, type and shape as the
to apply when converting the model. E.g. `[Optimize.DEFAULT]` inputs to the model. Usually, this is a small subset of a few hundred
representative_dataset: A representative dataset that can be used to samples randomly chosen, in no particular order, from the training or
generate input and output samples for the model. The converter can use the evaluation dataset. This is an optional attribute, but required for full
dataset to evaluate different optimizations. Note that this is an optional integer quantization, i.e, if `tf.int8` is the only supported type in
attribute but it is necessary if INT8 is the only support builtin ops in `target_spec.supported_types`. Refer to `tf.lite.RepresentativeDataset`.
target ops. (default None)
target_spec: Experimental flag, subject to change. Specifications of target target_spec: Experimental flag, subject to change. Specifications of target
device, including supported ops set, supported types and a set of user's device, including supported ops set, supported types and a set of user's
defined TensorFlow operators required in the TensorFlow Lite runtime. defined TensorFlow operators required in the TensorFlow Lite runtime.
Refer to `tf.lite.TargetSpec`.
inference_input_type: Data type of the input layer. Note that integer types inference_input_type: Data type of the input layer. Note that integer types
(tf.int8 and tf.uint8) are currently only supported for post training (tf.int8 and tf.uint8) are currently only supported for post training
integer quantization and quantization aware training. (default tf.float32, integer quantization and quantization aware training. (default tf.float32,
@ -994,8 +978,13 @@ class TFLiteConverterV2(TFLiteFrozenGraphConverterV2):
types (tf.int8 and tf.uint8) are currently only supported for post types (tf.int8 and tf.uint8) are currently only supported for post
training integer quantization and quantization aware training. (default training integer quantization and quantization aware training. (default
tf.float32, must be in {tf.float32, tf.int8, tf.uint8}) tf.float32, must be in {tf.float32, tf.int8, tf.uint8})
allow_custom_ops: Boolean indicating whether to allow custom operations.
When False, any unknown operation is an error. When True, custom ops are
created for any op that is unknown. The developer needs to provide these
to the TensorFlow Lite runtime with a custom resolver. (default False)
experimental_new_converter: Experimental flag, subject to change. Enables experimental_new_converter: Experimental flag, subject to change. Enables
MLIR-based conversion instead of TOCO conversion. (default True) MLIR-based conversion instead of TOCO conversion. (default True)
Example usage: Example usage:
```python ```python
@ -1063,7 +1052,8 @@ class TFLiteConverterV2(TFLiteFrozenGraphConverterV2):
`signatures` attribute of the MetaGraphdef is used. (default `signatures` attribute of the MetaGraphdef is used. (default
saved_model.signatures) saved_model.signatures)
tags: Set of tags identifying the MetaGraphDef within the SavedModel to tags: Set of tags identifying the MetaGraphDef within the SavedModel to
analyze. All tags in the tag set must be present. (default set(SERVING)) analyze. All tags in the tag set must be present. (default
{tf.saved_model.SERVING} or {'serve'})
Returns: Returns:
TFLiteConverter object. TFLiteConverter object.
@ -1209,9 +1199,13 @@ class TFLiteConverterBaseV1(TFLiteConverterBase):
if (requires_quantized_input_stats and if (requires_quantized_input_stats and
not converter_kwargs["quantized_input_stats"]): not converter_kwargs["quantized_input_stats"]):
raise ValueError("The `quantized_input_stats` flag must be defined when " raise ValueError(
"either `inference_type` flag or `inference_input_type` " "The `quantized_input_stats` flag must be defined when either "
"flag is set to tf.uint8 or tf.int8.") "`inference_type` flag or `inference_input_type` flag is set to "
"tf.int8 or tf.uint8. Currently, `inference_type={}` and "
"`inference_input_type={}`.".format(
_get_tf_type_name(converter_kwargs["inference_type"]),
_get_tf_type_name(converter_kwargs["inference_input_type"])))
def convert(self): def convert(self):
"""Converts a TensorFlow GraphDef based on instance variables. """Converts a TensorFlow GraphDef based on instance variables.
@ -1424,9 +1418,9 @@ class TFLiteSavedModelConverter(TFLiteConverterBaseV1):
saved_model_dir: Directory of the SavedModel. saved_model_dir: Directory of the SavedModel.
saved_model_tags: Set of tags identifying the MetaGraphDef within the saved_model_tags: Set of tags identifying the MetaGraphDef within the
SavedModel to analyze. All tags in the tag set must be present. (default SavedModel to analyze. All tags in the tag set must be present. (default
set(SERVING)). {tf.saved_model.SERVING}).
saved_model_exported_names: Names to be exported (default: export all) saved_model_exported_names: Names to be exported when the saved model
when the saved model import path is on. import path is on.
experimental_debug_info_func: An experimental function to retrieve the experimental_debug_info_func: An experimental function to retrieve the
graph debug info for a set of nodes from the `graph_def`. graph debug info for a set of nodes from the `graph_def`.
@ -1645,33 +1639,42 @@ class TFLiteConverter(TFLiteFrozenGraphConverter):
model into either a TFLite FlatBuffer or graph visualization. model into either a TFLite FlatBuffer or graph visualization.
Attributes: Attributes:
inference_type: Target data type of real-number arrays in the output file. optimizations: Experimental flag, subject to change. Set of optimizations to
Must be `{tf.float32, tf.uint8}`. If `optimzations` are provided, this apply. e.g {tf.lite.Optimize.DEFAULT}. (default None, must be None or a
parameter is ignored. (default tf.float32) set of values of type `tf.lite.Optimize`)
inference_input_type: Target data type of real-number input arrays. Allows representative_dataset: A generator function used for integer quantization
for a different type for input arrays. If an integer type is provided and where each generated sample has the same order, type and shape as the
`optimizations` are not used, `quantized_input_stats` must be provided. If inputs to the model. Usually, this is a small subset of a few hundred
`inference_type` is tf.uint8, signaling conversion to a fully quantized samples randomly chosen, in no particular order, from the training or
model from a quantization-aware trained input model, then evaluation dataset. This is an optional attribute, but required for full
`inference_input_type` defaults to tf.uint8. In all other cases, integer quantization, i.e, if `tf.int8` is the only supported type in
`inference_input_type` defaults to tf.float32. Must be `{tf.float32, `target_spec.supported_types`. Refer to `tf.lite.RepresentativeDataset`.
tf.uint8, tf.int8}` (default None)
inference_output_type: Target data type of real-number output arrays. Allows target_spec: Experimental flag, subject to change. Specifications of target
for a different type for output arrays. If `inference_type` is tf.uint8, device, including supported ops set, supported types and a set of user's
signaling conversion to a fully quantized model from a quantization-aware defined TensorFlow operators required in the TensorFlow Lite runtime.
trained output model, then `inference_output_type` defaults to tf.uint8. Refer to `tf.lite.TargetSpec`.
In all other cases, `inference_output_type` must be tf.float32, an error inference_type: Data type of numeric arrays, excluding the input layer.
will be thrown otherwise. Must be `{tf.float32, tf.uint8, tf.int8}` (default tf.float32, must be in {tf.float32, tf.int8, tf.uint8})
output_format: Output file format. Currently must be `{TFLITE, inference_input_type: Data type of the numeric arrays in the input layer. If
GRAPHVIZ_DOT}`. (default TFLITE) `inference_input_type` is in {tf.int8, tf.uint8}, then
quantized_input_stats: Dict of strings representing input tensor names `quantized_input_stats` must be provided. (default is the value assigned
mapped to tuple of floats representing the mean and standard deviation to `inference_type`, must be in {tf.float32, tf.int8, tf.uint8})
of the training data (e.g., {"foo" : (0., 1.)}). Only need if inference_output_type: Data type of the numeric arrays in the output layer.
`inference_input_type` is `QUANTIZED_UINT8`. real_input_value = (default is the value assigned to `inference_type`, must be in
(quantized_input_value - mean_value) / std_dev_value. (default {}) {tf.float32, tf.int8, tf.uint8})
default_ranges_stats: Tuple of integers representing (min, max) range values quantized_input_stats: Map of input tensor names to a tuple of floats
for all arrays without a specified range. Intended for experimenting with representing the mean and standard deviation of the training data.
quantization via "dummy quantization". (default None) (e.g., {"foo" : (0., 1.)}). Required if `inference_input_type` is tf.int8
or tf.uint8. (default None)
default_ranges_stats: Tuple of integers (min, max) representing range values
for all numeric arrays without a specified range. Intended for
experimenting with quantization via "dummy quantization". (default None)
allow_custom_ops: Boolean indicating whether to allow custom operations.
When False any unknown operation is an error. When True, custom ops are
created for any op that is unknown. The developer will need to provide
these to the TensorFlow Lite runtime with a custom resolver. (default
False)
drop_control_dependency: Boolean indicating whether to drop control drop_control_dependency: Boolean indicating whether to drop control
dependencies silently. This is due to TFLite not supporting control dependencies silently. This is due to TFLite not supporting control
dependencies. (default True) dependencies. (default True)
@ -1683,37 +1686,25 @@ class TFLiteConverter(TFLiteFrozenGraphConverter):
change_concat_input_ranges: Boolean to change behavior of min/max ranges for change_concat_input_ranges: Boolean to change behavior of min/max ranges for
inputs and outputs of the concat operator for quantized models. Changes inputs and outputs of the concat operator for quantized models. Changes
the ranges of concat operator overlap when true. (default False) the ranges of concat operator overlap when true. (default False)
allow_custom_ops: Boolean indicating whether to allow custom operations. output_format: Output file format. (default
When false any unknown operation is an error. When true, custom ops are tf.compat.v1.lite.constants.TFLITE, must be in
created for any op that is unknown. The developer will need to provide {tf.compat.v1.lite.constants.TFLITE,
these to the TensorFlow Lite runtime with a custom resolver. (default tf.compat.v1.lite.constants.GRAPHVIZ_DOT})
False)
post_training_quantize: Deprecated. Please specify `[Optimize.DEFAULT]` for
`optimizations` instead. Boolean indicating whether to quantize the
weights of the converted float model. Model size will be reduced and
there will be latency improvements (at the cost of accuracy). (default
False)
dump_graphviz_dir: Full filepath of folder to dump the graphs at various dump_graphviz_dir: Full filepath of folder to dump the graphs at various
stages of processing GraphViz .dot files. Preferred over stages of processing GraphViz .dot files. Preferred over
--output_format=GRAPHVIZ_DOT in order to keep the requirements of the `output_format=tf.compat.v1.lite.constants.GRAPHVIZ_DOT` in order to keep
output file. (default None) the requirements of the output file. (default None)
dump_graphviz_video: Boolean indicating whether to dump the graph after dump_graphviz_video: Boolean indicating whether to dump the GraphViz .dot
every graph transformation. (default False) files after every graph transformation. Requires the `dump_graphviz_dir`
conversion_summary_dir: A string indicating the path to the generated flag to be specified. (default False)
conversion logs. conversion_summary_dir: Full path of the directory to store conversion logs.
target_ops: Deprecated. Please specify `target_spec.supported_ops` instead. (default None)
Set of OpsSet options indicating which converter to use. (default target_ops: Deprecated. Please use `target_spec.supported_ops` instead.
set([OpsSet.TFLITE_BUILTINS])) post_training_quantize: Deprecated. Please use `optimizations` instead and
target_spec: Experimental flag, subject to change. Specifications of target set it to `{tf.lite.Optimize.DEFAULT}`. (default False)
device, including supported ops set, supported types and a set of user's
defined TensorFlow operators required in the TensorFlow Lite runtime.
optimizations: Experimental flag, subject to change. A list of optimizations
to apply when converting the model. E.g. `[Optimize.DEFAULT]`
representative_dataset: A representative dataset that can be used to
generate input and output samples for the model. The converter can use the
dataset to evaluate different optimizations.
experimental_new_converter: Experimental flag, subject to change. Enables experimental_new_converter: Experimental flag, subject to change. Enables
MLIR-based conversion instead of TOCO conversion. (default True) MLIR-based conversion instead of TOCO conversion. (default True)
Example usage: Example usage:
```python ```python
@ -1911,9 +1902,10 @@ class TFLiteConverter(TFLiteFrozenGraphConverter):
output_arrays: List of output tensors to freeze graph with. Uses output output_arrays: List of output tensors to freeze graph with. Uses output
arrays from SignatureDef when none are provided. (default None) arrays from SignatureDef when none are provided. (default None)
tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to
analyze. All tags in the tag set must be present. (default set("serve")) analyze. All tags in the tag set must be present. (default
{tf.saved_model.SERVING})
signature_key: Key identifying SignatureDef containing inputs and outputs. signature_key: Key identifying SignatureDef containing inputs and outputs.
(default DEFAULT_SERVING_SIGNATURE_DEF_KEY) (default tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY)
Returns: Returns:
TFLiteConverter class. TFLiteConverter class.

View File

@ -1239,18 +1239,22 @@ class FromSessionTest(TestModels, parameterized.TestCase):
quantized_converter.inference_type = quantized_type quantized_converter.inference_type = quantized_type
quantized_converter.convert() quantized_converter.convert()
self.assertEqual( self.assertEqual(
'The `quantized_input_stats` flag must be defined when ' 'The `quantized_input_stats` flag must be defined when either '
'either `inference_type` flag or `inference_input_type` ' '`inference_type` flag or `inference_input_type` flag is set to '
'flag is set to tf.uint8 or tf.int8.', str(error.exception)) 'tf.int8 or tf.uint8. Currently, `inference_type=tf.{}` and '
'`inference_input_type=None`.'.format(quantized_type.name),
str(error.exception))
with self.assertRaises(ValueError) as error: with self.assertRaises(ValueError) as error:
quantized_converter.inference_type = dtypes.float32 quantized_converter.inference_type = dtypes.float32
quantized_converter.inference_input_type = quantized_type quantized_converter.inference_input_type = quantized_type
quantized_converter.convert() quantized_converter.convert()
self.assertEqual( self.assertEqual(
'The `quantized_input_stats` flag must be defined when ' 'The `quantized_input_stats` flag must be defined when either '
'either `inference_type` flag or `inference_input_type` ' '`inference_type` flag or `inference_input_type` flag is set to '
'flag is set to tf.uint8 or tf.int8.', str(error.exception)) 'tf.int8 or tf.uint8. Currently, `inference_type=tf.float32` and '
'`inference_input_type=tf.{}`.'.format(quantized_type.name),
str(error.exception))
quantized_converter.inference_type = quantized_type quantized_converter.inference_type = quantized_type
quantized_converter.inference_input_type = quantized_type quantized_converter.inference_input_type = quantized_type

View File

@ -127,9 +127,9 @@ def _convert_tflite_enum_type_to_tf_type(tflite_enum_type):
return tf_type return tf_type
def _get_tf_type_name(tf_type): def get_tf_type_name(tf_type):
"""Converts tf.dtype (eg: tf.float32) to str (eg: "tf.float32").""" """Converts tf.dtype (eg: tf.float32) to str (eg: "tf.float32")."""
return "tf." + tf_type.name return "tf." + tf_type.name if tf_type else None
def get_tensor_name(tensor): def get_tensor_name(tensor):
@ -674,7 +674,7 @@ def _modify_model_input_type(model, inference_input_type=dtypes.float32):
raise ValueError( raise ValueError(
"Initial model input type must be tf.float32. Expected type for " "Initial model input type must be tf.float32. Expected type for "
"tensor with name '{}' is tf.float32, instead type is {}".format( "tensor with name '{}' is tf.float32, instead type is {}".format(
float_tensor.name, _get_tf_type_name(float_type))) float_tensor.name, get_tf_type_name(float_type)))
# If found, validate that the operator output is quantized and compatible # If found, validate that the operator output is quantized and compatible
# with the final model input type # with the final model input type
quant_type = _convert_tflite_enum_type_to_tf_type(quant_tensor.type) quant_type = _convert_tflite_enum_type_to_tf_type(quant_tensor.type)
@ -683,17 +683,17 @@ def _modify_model_input_type(model, inference_input_type=dtypes.float32):
"Initial model input is not quantized. Expected type for " "Initial model input is not quantized. Expected type for "
"tensor with name '{}' should be in {}, instead type is {}".format( "tensor with name '{}' should be in {}, instead type is {}".format(
quant_tensor.name, quant_tensor.name,
tuple(_get_tf_type_name(t) for t in tuple(get_tf_type_name(t) for t in
_MAP_QUANT_TO_IO_TYPES.keys()), _MAP_QUANT_TO_IO_TYPES.keys()),
_get_tf_type_name(quant_type))) get_tf_type_name(quant_type)))
else: else:
inference_io_types = _MAP_QUANT_TO_IO_TYPES[quant_type] inference_io_types = _MAP_QUANT_TO_IO_TYPES[quant_type]
if inference_input_type not in inference_io_types: if inference_input_type not in inference_io_types:
raise ValueError( raise ValueError(
"Unsupported `inference_input_type` value. Expected to be in " "Unsupported `inference_input_type` value. Expected to be in "
"{}, instead got {}.".format( "{}, instead got {}.".format(
tuple(_get_tf_type_name(t) for t in inference_io_types), tuple(get_tf_type_name(t) for t in inference_io_types),
_get_tf_type_name(inference_input_type))) get_tf_type_name(inference_input_type)))
input_quant_ops.append(op) input_quant_ops.append(op)
if len(subgraph.inputs) != len(input_quant_ops): if len(subgraph.inputs) != len(input_quant_ops):
@ -725,7 +725,7 @@ def _modify_model_input_type(model, inference_input_type=dtypes.float32):
else: else:
raise ValueError( raise ValueError(
"Unsupported `inference_input_type` value {}.".format( "Unsupported `inference_input_type` value {}.".format(
_get_tf_type_name(inference_input_type))) get_tf_type_name(inference_input_type)))
def _modify_model_output_type(model, inference_output_type=dtypes.float32): def _modify_model_output_type(model, inference_output_type=dtypes.float32):
@ -768,7 +768,7 @@ def _modify_model_output_type(model, inference_output_type=dtypes.float32):
raise ValueError( raise ValueError(
"Initial model output type must be tf.float32. Expected type for " "Initial model output type must be tf.float32. Expected type for "
"tensor with name '{}' is tf.float32, instead type is {}".format( "tensor with name '{}' is tf.float32, instead type is {}".format(
float_tensor.name, _get_tf_type_name(float_type))) float_tensor.name, get_tf_type_name(float_type)))
# If found, validate that the operator input is quantized and compatible # If found, validate that the operator input is quantized and compatible
# with the final model output type # with the final model output type
quant_type = _convert_tflite_enum_type_to_tf_type(quant_tensor.type) quant_type = _convert_tflite_enum_type_to_tf_type(quant_tensor.type)
@ -777,17 +777,17 @@ def _modify_model_output_type(model, inference_output_type=dtypes.float32):
"Initial model output is not dequantized. Expected type for " "Initial model output is not dequantized. Expected type for "
"tensor with name '{}' should be in {}, instead type is {}".format( "tensor with name '{}' should be in {}, instead type is {}".format(
quant_tensor.name, quant_tensor.name,
tuple(_get_tf_type_name(t) for t in tuple(get_tf_type_name(t) for t in
_MAP_QUANT_TO_IO_TYPES.keys()), _MAP_QUANT_TO_IO_TYPES.keys()),
_get_tf_type_name(quant_type))) get_tf_type_name(quant_type)))
else: else:
inference_io_types = _MAP_QUANT_TO_IO_TYPES[quant_type] inference_io_types = _MAP_QUANT_TO_IO_TYPES[quant_type]
if inference_output_type not in inference_io_types: if inference_output_type not in inference_io_types:
raise ValueError( raise ValueError(
"Unsupported `inference_output_type` value. Expected to be in " "Unsupported `inference_output_type` value. Expected to be in "
"{}, instead got {}.".format( "{}, instead got {}.".format(
tuple(_get_tf_type_name(t) for t in inference_io_types), tuple(get_tf_type_name(t) for t in inference_io_types),
_get_tf_type_name(inference_output_type))) get_tf_type_name(inference_output_type)))
output_dequant_ops.append(op) output_dequant_ops.append(op)
if len(subgraph.outputs) != len(output_dequant_ops): if len(subgraph.outputs) != len(output_dequant_ops):
@ -834,7 +834,7 @@ def _modify_model_output_type(model, inference_output_type=dtypes.float32):
else: else:
raise ValueError( raise ValueError(
"Unsupported `inference_output_type` value {}.".format( "Unsupported `inference_output_type` value {}.".format(
_get_tf_type_name(inference_output_type))) get_tf_type_name(inference_output_type)))
def modify_model_io_type( def modify_model_io_type(

View File

@ -26,11 +26,30 @@ from tensorflow.lite.testing.zip_test_utils import register_make_test_function
@register_make_test_function() @register_make_test_function()
def make_cast_tests(options): def make_cast_tests(options):
"""Generate examples for cast.""" """Generate examples for cast."""
test_parameters = [{ if options.use_experimental_converter:
"input_dtype": [tf.int32], test_parameters = [
"output_dtype": [tf.float32], {
"input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]], "input_dtype": [tf.float32],
}] "output_dtype": [tf.int16],
"input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]],
},
{
"input_dtype": [tf.int16],
"output_dtype": [tf.float32],
"input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]],
},
{
"input_dtype": [tf.int32],
"output_dtype": [tf.float32],
"input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]],
}]
else:
test_parameters = [
{
"input_dtype": [tf.int32],
"output_dtype": [tf.float32],
"input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]],
}]
def build_graph(parameters): def build_graph(parameters):
"""Build the cast testing graph.""" """Build the cast testing graph."""

View File

@ -763,6 +763,7 @@ py_library(
"//tensorflow/python/util:_pywrap_tfprof", "//tensorflow/python/util:_pywrap_tfprof",
"//tensorflow/python/util:_pywrap_transform_graph", "//tensorflow/python/util:_pywrap_transform_graph",
"//tensorflow/python/util:_pywrap_util_port", "//tensorflow/python/util:_pywrap_util_port",
"//tensorflow/python/platform:_pywrap_tf2",
":_pywrap_utils", ":_pywrap_utils",
":composite_tensor", ":composite_tensor",
":config", ":config",
@ -5266,7 +5267,10 @@ pywrap_tensorflow_macro(
tf_additional_plugin_deps() + tf_additional_plugin_deps() +
tf_additional_profiler_deps()) + if_xla_available([ tf_additional_profiler_deps()) + if_xla_available([
"//tensorflow/compiler/aot:tfcompile_lib", "//tensorflow/compiler/aot:tfcompile_lib",
]) + if_static(extra_deps = ["//tensorflow/core/platform:tensor_float_32_utils"]), ]) + if_static(extra_deps = [
"//tensorflow/core/platform:tensor_float_32_utils",
"//tensorflow/core/platform:enable_tf2_utils",
]),
) )
# ** Targets for Windows build (start) ** # ** Targets for Windows build (start) **
@ -6779,6 +6783,8 @@ py_test(
":client_testlib", ":client_testlib",
":framework_combinations", ":framework_combinations",
":tf2", ":tf2",
"//tensorflow/python/compat:v2_compat",
"//tensorflow/python/data/kernel_tests:test_base",
], ],
) )

View File

@ -33,7 +33,7 @@ from tensorflow.python.util.tf_export import tf_export
# This value changes every day with an automatic CL. It can be modified in code # This value changes every day with an automatic CL. It can be modified in code
# via `forward_compatibility_horizon()` or with the environment variable # via `forward_compatibility_horizon()` or with the environment variable
# TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date.
_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2021, 1, 2) _FORWARD_COMPATIBILITY_HORIZON = datetime.date(2021, 1, 5)
_FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS"
_FORWARD_COMPATIBILITY_DATE_NUMBER = None _FORWARD_COMPATIBILITY_DATE_NUMBER = None

View File

@ -21,6 +21,7 @@ from __future__ import print_function
from tensorflow.python.compat import v2_compat from tensorflow.python.compat import v2_compat
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.platform import _pywrap_tf2
from tensorflow.python.platform import test from tensorflow.python.platform import test
@ -29,9 +30,13 @@ class DisableV2BehaviorTest(test.TestCase):
def test_basic(self): def test_basic(self):
t = constant_op.constant([1, 2, 3]) # creates a hidden context t = constant_op.constant([1, 2, 3]) # creates a hidden context
self.assertTrue(isinstance(t, ops.EagerTensor)) self.assertTrue(isinstance(t, ops.EagerTensor))
t = _pywrap_tf2.is_enabled()
self.assertTrue(t)
v2_compat.disable_v2_behavior() v2_compat.disable_v2_behavior()
t = constant_op.constant([1, 2, 3]) t = constant_op.constant([1, 2, 3])
self.assertFalse(isinstance(t, ops.EagerTensor)) self.assertFalse(isinstance(t, ops.EagerTensor))
t = _pywrap_tf2.is_enabled()
self.assertFalse(t)
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -36,7 +36,10 @@ py_library(
py_library( py_library(
name = "trt_convert_py", name = "trt_convert_py",
srcs = ["trt_convert.py"], srcs = [
"trt_convert.py",
"utils.py",
],
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
deps = [ deps = [
"//tensorflow/compiler/tf2tensorrt:_pywrap_py_utils", "//tensorflow/compiler/tf2tensorrt:_pywrap_py_utils",

View File

@ -35,7 +35,9 @@ from tensorflow.compiler.tf2tensorrt._pywrap_py_utils import get_linked_tensorrt
from tensorflow.compiler.tf2tensorrt._pywrap_py_utils import is_tensorrt_enabled from tensorflow.compiler.tf2tensorrt._pywrap_py_utils import is_tensorrt_enabled
from tensorflow.core.framework import graph_pb2 from tensorflow.core.framework import graph_pb2
from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.compiler.tensorrt import trt_convert from tensorflow.python.compiler.tensorrt import trt_convert
from tensorflow.python.compiler.tensorrt import utils as trt_utils
from tensorflow.python.eager import def_function from tensorflow.python.eager import def_function
from tensorflow.python.framework import graph_io from tensorflow.python.framework import graph_io
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
@ -331,17 +333,21 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
"""Get config proto based on specific settings.""" """Get config proto based on specific settings."""
conversion_params = self.GetConversionParams(run_params) conversion_params = self.GetConversionParams(run_params)
max_batch_size = self.GetMaxBatchSize(run_params) max_batch_size = self.GetMaxBatchSize(run_params)
if graph_state == GraphState.INFERENCE and run_params.convert_online: if graph_state == GraphState.INFERENCE and run_params.convert_online:
rewriter_cfg = trt_convert.get_tensorrt_rewriter_config( rewriter_cfg = trt_convert.get_tensorrt_rewriter_config(
conversion_params, conversion_params,
is_dynamic_op=run_params.dynamic_engine, is_dynamic_op=run_params.dynamic_engine,
max_batch_size=max_batch_size) max_batch_size=max_batch_size,
graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_cfg) disable_non_trt_optimizers=self._disable_non_trt_optimizers)
else: else:
graph_options = config_pb2.GraphOptions() rewriter_cfg = rewriter_config_pb2.RewriterConfig()
if self._disable_non_trt_optimizers:
trt_utils.disable_non_trt_optimizers_in_rewriter_config(rewriter_cfg)
config = config_pb2.ConfigProto( config = config_pb2.ConfigProto(
gpu_options=self._GetGPUOptions(), graph_options=graph_options) gpu_options=self._GetGPUOptions(),
graph_options=config_pb2.GraphOptions(rewrite_options=rewriter_cfg))
return config return config
def _GetFeedNames(self): def _GetFeedNames(self):

View File

@ -30,6 +30,7 @@ from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import meta_graph_pb2 from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.client import session from tensorflow.python.client import session
from tensorflow.python.compiler.tensorrt import utils as trt_utils
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.eager import wrap_function from tensorflow.python.eager import wrap_function
from tensorflow.python.framework import convert_to_constants from tensorflow.python.framework import convert_to_constants
@ -271,11 +272,14 @@ def _get_tensorrt_rewriter_config(conversion_params,
# need to run constant folding again. # need to run constant folding again.
rewriter_config_with_trt.optimizers.extend( rewriter_config_with_trt.optimizers.extend(
["constfold", "layout", "constfold"]) ["constfold", "layout", "constfold"])
rewriter_config_with_trt.meta_optimizer_iterations = ( rewriter_config_with_trt.meta_optimizer_iterations = (
rewriter_config_pb2.RewriterConfig.ONE) rewriter_config_pb2.RewriterConfig.ONE)
optimizer = rewriter_config_with_trt.custom_optimizers.add() optimizer = rewriter_config_with_trt.custom_optimizers.add()
# Add a constfold optimizer to cleanup the unused Const nodes.
rewriter_config_with_trt.custom_optimizers.add().name = "constfold" if not disable_non_trt_optimizers:
# Add a constfold optimizer to cleanup the unused Const nodes.
rewriter_config_with_trt.custom_optimizers.add().name = "constfold"
optimizer.name = "TensorRTOptimizer" optimizer.name = "TensorRTOptimizer"
optimizer.parameter_map[ optimizer.parameter_map[
@ -295,25 +299,11 @@ def _get_tensorrt_rewriter_config(conversion_params,
optimizer.parameter_map["max_batch_size"].i = max_batch_size optimizer.parameter_map["max_batch_size"].i = max_batch_size
optimizer.parameter_map["use_implicit_batch"].b = use_implicit_batch optimizer.parameter_map["use_implicit_batch"].b = use_implicit_batch
# Disabling optimizers should happen after CopyFrom the template # Disabling optimizers should happen after defining the TF-TRT grappler pass
# otherwise the template can overwrite the disablement. # otherwise the template can overwrite the disablement.
if disable_non_trt_optimizers: if disable_non_trt_optimizers:
off = rewriter_config_pb2.RewriterConfig.OFF trt_utils.disable_non_trt_optimizers_in_rewriter_config(
rewriter_config_with_trt.layout_optimizer = off rewriter_config_with_trt)
rewriter_config_with_trt.constant_folding = off
rewriter_config_with_trt.shape_optimization = off
rewriter_config_with_trt.remapping = off
rewriter_config_with_trt.arithmetic_optimization = off
rewriter_config_with_trt.dependency_optimization = off
rewriter_config_with_trt.loop_optimization = off
rewriter_config_with_trt.function_optimization = off
rewriter_config_with_trt.debug_stripper = off
rewriter_config_with_trt.disable_model_pruning = True
rewriter_config_with_trt.scoped_allocator_optimization = off
rewriter_config_with_trt.memory_optimization = (
rewriter_config_pb2.RewriterConfig.NO_MEM_OPT)
rewriter_config_with_trt.pin_to_host_optimization = off
rewriter_config_with_trt.auto_parallel.enable = False
return rewriter_config_with_trt return rewriter_config_with_trt
@ -652,10 +642,19 @@ class TrtGraphConverter(object):
return_elements=fetch_names, return_elements=fetch_names,
name="") name="")
calibrate_rewriter_cfg = rewriter_config_pb2.RewriterConfig()
if self._test_only_disable_non_trt_optimizers:
trt_utils.disable_non_trt_optimizers_in_rewriter_config(
calibrate_rewriter_cfg)
# Set allow_soft_placement=True to run the graph for calibration so that # Set allow_soft_placement=True to run the graph for calibration so that
# OPs supported by TensorRT but don't have a GPU implementation are allowed # OPs supported by TensorRT but don't have a GPU implementation are allowed
# to execute on CPU. # to execute on CPU.
calibrate_config = config_pb2.ConfigProto(allow_soft_placement=True) calibrate_config = config_pb2.ConfigProto(
allow_soft_placement=True,
graph_options=config_pb2.GraphOptions(
rewrite_options=calibrate_rewriter_cfg))
with session.Session( with session.Session(
graph=self._calibration_graph, graph=self._calibration_graph,
config=calibrate_config) as calibration_sess: config=calibrate_config) as calibration_sess:

View File

@ -0,0 +1,47 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Exposes the Python wrapper conversion to trt_graph."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.core.protobuf import rewriter_config_pb2
def disable_non_trt_optimizers_in_rewriter_config(rewriter_config):
"""Modifies rewriter_config to disable all non-TRT optimizations."""
off = rewriter_config_pb2.RewriterConfig.OFF
rewriter_config.arithmetic_optimization = off
rewriter_config.auto_mixed_precision = off
rewriter_config.auto_parallel.enable = False
rewriter_config.constant_folding = off
rewriter_config.debug_stripper = off
rewriter_config.dependency_optimization = off
# This one needs to be ON to allow TF-TRT
rewriter_config.disable_meta_optimizer = False
rewriter_config.disable_model_pruning = True
rewriter_config.function_optimization = off
rewriter_config.implementation_selector = off
rewriter_config.layout_optimizer = off
rewriter_config.loop_optimization = off
rewriter_config.memory_optimization = (
rewriter_config_pb2.RewriterConfig.NO_MEM_OPT)
rewriter_config.min_graph_nodes = -1
rewriter_config.pin_to_host_optimization = off
rewriter_config.remapping = off
rewriter_config.scoped_allocator_optimization = off
rewriter_config.shape_optimization = off

View File

@ -442,7 +442,8 @@ class DataServiceOpsTest(data_service_test_base.TestBase,
results = {} results = {}
for _ in range(elements_to_read): for _ in range(elements_to_read):
val = next(it).numpy() val = next(it).numpy()
if val not in results: results[val] = 0 if val not in results:
results[val] = 0
results[val] += 1 results[val] += 1
for i in range(num_elements): for i in range(num_elements):
self.assertGreater(results[i], elements_to_read / num_elements / 2) self.assertGreater(results[i], elements_to_read / num_elements / 2)
@ -527,6 +528,37 @@ class DataServiceOpsTest(data_service_test_base.TestBase,
data_service_ops.distribute( data_service_ops.distribute(
processing_mode="invalid", service="grpc://localhost:5000")) processing_mode="invalid", service="grpc://localhost:5000"))
@combinations.generate(test_base.eager_only_combinations())
def testZipDifferentProcessingModesDatasets(self):
cluster = self.create_cluster(num_workers=1)
num_elements = 100
ds1 = dataset_ops.Dataset.range(num_elements)
ds1 = self.make_distributed_dataset(
ds1, cluster, processing_mode="distributed_epoch")
ds2 = dataset_ops.Dataset.range(num_elements)
ds2 = self.make_distributed_dataset(
ds2, cluster, processing_mode="parallel_epochs")
ds = dataset_ops.Dataset.zip((ds1, ds2))
self.assertDatasetProduces(
ds,
list(zip(range(num_elements), range(num_elements))),
assert_items_equal=True)
@combinations.generate(test_base.eager_only_combinations())
def testZipDifferentProcessingModesDatasetsSharedJobName(self):
cluster = self.create_cluster(num_workers=1)
num_elements = 100
ds1 = dataset_ops.Dataset.range(num_elements)
ds1 = self.make_distributed_dataset(
ds1, cluster, processing_mode="distributed_epoch", job_name="job_name")
ds2 = dataset_ops.Dataset.range(num_elements)
ds2 = self.make_distributed_dataset(
ds2, cluster, processing_mode="parallel_epochs", job_name="job_name")
ds = dataset_ops.Dataset.zip((ds1, ds2))
with self.assertRaisesRegex(errors.FailedPreconditionError,
"but there is already an existing job"):
self.getDatasetOutput(ds)
@combinations.generate(test_base.eager_only_combinations()) @combinations.generate(test_base.eager_only_combinations())
def testFromDatasetId(self): def testFromDatasetId(self):
cluster = self.create_cluster(num_workers=1) cluster = self.create_cluster(num_workers=1)

View File

@ -88,7 +88,8 @@ class DispatchServer(object):
>>> dispatcher = tf.data.experimental.service.DispatchServer() >>> dispatcher = tf.data.experimental.service.DispatchServer()
>>> dispatcher_address = dispatcher.target.split("://")[1] >>> dispatcher_address = dispatcher.target.split("://")[1]
>>> worker = tf.data.experimental.service.WorkerServer(WorkerConfig( >>> worker = tf.data.experimental.service.WorkerServer(
... tf.data.experimental.service.WorkerConfig(
... dispatcher_address=dispatcher_address)) ... dispatcher_address=dispatcher_address))
>>> dataset = tf.data.Dataset.range(10) >>> dataset = tf.data.Dataset.range(10)
>>> dataset = dataset.apply(tf.data.experimental.service.distribute( >>> dataset = dataset.apply(tf.data.experimental.service.distribute(

View File

@ -169,7 +169,7 @@ class _WorkerContext(object):
def _get_master_target(self): def _get_master_target(self):
"""Return the master target for a task.""" """Return the master target for a task."""
# If cluster_spec is None or empty, we use local master. # If cluster_spec is None or empty, we use local master.
if not self._cluster_spec: if not self._cluster_spec or self._task_type == _TaskType.EVALUATOR:
return "" return ""
# If task_type is None, then it is in-graph replicated training. In this # If task_type is None, then it is in-graph replicated training. In this
@ -842,7 +842,8 @@ def run_distribute_coordinator(worker_fn,
session_config, cluster_spec, session_config, cluster_spec,
task_type, task_id) task_type, task_id)
if not getattr(strategy.extended, "_std_server_started", False): if (task_type != _TaskType.EVALUATOR and
not getattr(strategy.extended, "_std_server_started", False)):
# Right now, with eager mode, context is configured with a std server at # Right now, with eager mode, context is configured with a std server at
# the very beginning while with graph mode the std server is started when # the very beginning while with graph mode the std server is started when
# distribute coordinator is called. We should consolidate these two paths. # distribute coordinator is called. We should consolidate these two paths.

View File

@ -589,8 +589,7 @@ class DistributeCoordinatorTestStandaloneMode(DistributeCoordinatorTestBase):
# and distributed_mode. # and distributed_mode.
self.assertEqual(self._worker_context["None"][0], (_strip_protocol( self.assertEqual(self._worker_context["None"][0], (_strip_protocol(
_bytes_to_str(self._workers[0].target)), 3, True, True)) _bytes_to_str(self._workers[0].target)), 3, True, True))
self.assertEqual(self._worker_context[EVALUATOR][0], self.assertEqual(self._worker_context[EVALUATOR][0], ("", 3, True, False))
("fake_evaluator", 3, True, False))
class DistributeCoordinatorTestIndependentWorkerMode( class DistributeCoordinatorTestIndependentWorkerMode(
@ -755,19 +754,15 @@ class DistributeCoordinatorTestIndependentWorkerMode(
# and distributed_mode. # and distributed_mode.
self.assertEqual(self._worker_context["None"][0], self.assertEqual(self._worker_context["None"][0],
(_bytes_to_str(cluster_spec[WORKER][0]), 3, True, True)) (_bytes_to_str(cluster_spec[WORKER][0]), 3, True, True))
self.assertEqual(self._worker_context[EVALUATOR][0], self.assertEqual(self._worker_context[EVALUATOR][0], ("", 3, True, False))
(cluster_spec[EVALUATOR][0], 3, True, False))
# Make sure each worker runs a std server. # Make sure each worker runs a std server.
self.assertEqual(len(self._std_servers), 2) self.assertEqual(len(self._std_servers), 1)
self.assertTrue(WORKER in self._std_servers) self.assertTrue(WORKER in self._std_servers)
self.assertTrue(EVALUATOR in self._std_servers)
self.assertEqual(len(self._std_servers[WORKER]), 3) self.assertEqual(len(self._std_servers[WORKER]), 3)
self.assertEqual(len(self._std_servers[EVALUATOR]), 1)
self.assertFalse(self._std_servers[WORKER][0].joined) self.assertFalse(self._std_servers[WORKER][0].joined)
self.assertTrue(self._std_servers[WORKER][1].joined) self.assertTrue(self._std_servers[WORKER][1].joined)
self.assertTrue(self._std_servers[WORKER][2].joined) self.assertTrue(self._std_servers[WORKER][2].joined)
self.assertFalse(self._std_servers[EVALUATOR][0].joined)
def testRunStdServerInGoogleEnvironment(self): def testRunStdServerInGoogleEnvironment(self):
cluster_spec = {"worker": ["fake_worker"], "ps": ["localhost:0"]} cluster_spec = {"worker": ["fake_worker"], "ps": ["localhost:0"]}

View File

@ -1007,6 +1007,7 @@ class GradientTape(object):
Raises: Raises:
RuntimeError: If called on a used, non-persistent tape. RuntimeError: If called on a used, non-persistent tape.
RuntimeError: If called inside the context of the tape. RuntimeError: If called inside the context of the tape.
TypeError: If the target is a None object.
ValueError: If the target is a variable or if unconnected gradients is ValueError: If the target is a variable or if unconnected gradients is
called with an unknown value. called with an unknown value.
""" """
@ -1028,6 +1029,11 @@ class GradientTape(object):
"gradient in order to compute higher order " "gradient in order to compute higher order "
"derivatives.", 1) "derivatives.", 1)
if target is None:
raise TypeError("Target should be a list or nested structure"
" of Tensors or Variables to be differentiated,"
" but recieved %r" % (target))
num_ndarrays = 0 num_ndarrays = 0
flat_targets = [] flat_targets = []
for t in nest.flatten(target): for t in nest.flatten(target):

View File

@ -1000,38 +1000,6 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase):
self.assertAllEqual(obj2.testDouble.experimental_get_tracing_count(), 3) self.assertAllEqual(obj2.testDouble.experimental_get_tracing_count(), 3)
self.assertAllEqual(obj1.testDouble.experimental_get_tracing_count(), 2) self.assertAllEqual(obj1.testDouble.experimental_get_tracing_count(), 2)
def test_experimental_get_tracing_count_function(self):
@def_function.function
def double(a):
return a + a
double(constant_op.constant(1))
double(constant_op.constant(2))
self.assertAllEqual(double.experimental_get_tracing_count(), 1)
double(constant_op.constant('a'))
self.assertAllEqual(double.experimental_get_tracing_count(), 2)
def test_experimental_get_tracing_count_method(self):
class TestClass():
@def_function.function
def testDouble(self, a):
return a + a
obj1 = TestClass()
obj1.testDouble(constant_op.constant(1))
obj1.testDouble(constant_op.constant(2))
obj1.testDouble(constant_op.constant(1.1))
self.assertAllEqual(obj1.testDouble.experimental_get_tracing_count(), 2)
obj2 = TestClass()
obj2.testDouble(constant_op.constant(1))
obj2.testDouble(constant_op.constant(1.1))
obj2.testDouble(constant_op.constant('a'))
self.assertAllEqual(obj2.testDouble.experimental_get_tracing_count(), 3)
self.assertAllEqual(obj1.testDouble.experimental_get_tracing_count(), 2)
if __name__ == '__main__': if __name__ == '__main__':
ops.enable_eager_execution() ops.enable_eager_execution()

View File

@ -2047,6 +2047,10 @@ class TensorFlowTestCase(googletest.TestCase):
self._tempdir = None self._tempdir = None
self._cached_session = None self._cached_session = None
self._test_start_time = None self._test_start_time = None
# This flag provides the ability to control whether the graph mode gets
# initialized for TF1 or not. Initializing for TF1, which is what was
# happening earlier, was preventing enablement of 'eager mode' in the test.
self._set_default_seed = True
def setUp(self): def setUp(self):
super(TensorFlowTestCase, self).setUp() super(TensorFlowTestCase, self).setUp()
@ -2061,7 +2065,8 @@ class TensorFlowTestCase(googletest.TestCase):
# cleared first. # cleared first.
ops._default_graph_stack.reset() # pylint: disable=protected-access ops._default_graph_stack.reset() # pylint: disable=protected-access
ops.reset_default_graph() ops.reset_default_graph()
random_seed.set_random_seed(random_seed.DEFAULT_GRAPH_SEED) if self._set_default_seed:
random_seed.set_random_seed(random_seed.DEFAULT_GRAPH_SEED)
# Reset summary writer in case another test used set_as_default() with their # Reset summary writer in case another test used set_as_default() with their
# summary writer. # summary writer.
summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access

View File

@ -18,68 +18,73 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import os
from absl.testing import parameterized from absl.testing import parameterized
from tensorflow.python import tf2 from tensorflow.python import tf2
from tensorflow.python.compat import v2_compat
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.framework import combinations from tensorflow.python.framework import combinations
from tensorflow.python.platform import _pywrap_tf2
from tensorflow.python.platform import test from tensorflow.python.platform import test
def set_environ():
os.environ['TF2_BEHAVIOR'] = '1'
def unset_environ():
os.environ['TF2_BEHAVIOR'] = '0'
class EnablingTF2Behavior(test.TestCase, parameterized.TestCase): class EnablingTF2Behavior(test.TestCase, parameterized.TestCase):
def setUp(self): def __init__(self, methodName):
super(EnablingTF2Behavior, self).setUp() super().__init__(methodName)
tf2._force_enable = None self._set_default_seed = False
if 'TF2_BEHAVIOR' in os.environ:
del os.environ['TF2_BEHAVIOR']
actions = [tf2.enable, tf2.disable, set_environ, unset_environ] @combinations.generate(test_base.v1_only_combinations())
def test_tf1_enable_tf2_behaviour(self):
self.assertFalse(tf2.enabled())
self.assertFalse(_pywrap_tf2.is_enabled())
@combinations.generate( v2_compat.enable_v2_behavior()
combinations.combine( self.assertTrue(tf2.enabled())
action_0=actions, action_1=actions, self.assertTrue(_pywrap_tf2.is_enabled())
action_2=actions, action_3=actions))
def test_scenarios(self, action_0, action_1, action_2, action_3):
def state(action, enabled, disabled): v2_compat.disable_v2_behavior()
"""Returns bool tuple (tf2_enabled, force_enabled, force_disabled).""" self.assertFalse(tf2.enabled())
if action is tf2.enable: self.assertFalse(_pywrap_tf2.is_enabled())
return True, True, False
elif action is tf2.disable:
return False, False, True
elif action is set_environ:
return not disabled, enabled, disabled
elif action is unset_environ:
return enabled, enabled, disabled
else:
raise ValueError('Unexpected action {}. {} are supported'.format(
action, EnablingTF2Behavior.actions))
action_0() @combinations.generate(test_base.v1_only_combinations())
expected, enabled, disabled = state(action_0, False, False) def test_tf1_disable_tf2_behaviour(self):
self.assertEqual(tf2.enabled(), expected) self.assertFalse(tf2.enabled())
self.assertFalse(_pywrap_tf2.is_enabled())
action_1() v2_compat.disable_v2_behavior()
expected, enabled, disabled = state(action_1, enabled, disabled) self.assertFalse(tf2.enabled())
self.assertEqual(tf2.enabled(), expected) self.assertFalse(_pywrap_tf2.is_enabled())
action_2() v2_compat.enable_v2_behavior()
expected, enabled, disabled = state(action_2, enabled, disabled) self.assertTrue(tf2.enabled())
self.assertEqual(tf2.enabled(), expected) self.assertTrue(_pywrap_tf2.is_enabled())
action_3() @combinations.generate(test_base.v2_only_combinations())
expected, enabled, disabled = state(action_3, enabled, disabled) def test_tf2_enable_tf2_behaviour(self):
self.assertEqual(tf2.enabled(), expected) self.assertTrue(tf2.enabled())
self.assertTrue(_pywrap_tf2.is_enabled())
v2_compat.enable_v2_behavior()
self.assertTrue(tf2.enabled())
self.assertTrue(_pywrap_tf2.is_enabled())
v2_compat.disable_v2_behavior()
self.assertFalse(tf2.enabled())
self.assertFalse(_pywrap_tf2.is_enabled())
@combinations.generate(test_base.v2_only_combinations())
def test_tf2_disable_tf2_behaviour(self):
self.assertTrue(tf2.enabled())
self.assertTrue(_pywrap_tf2.is_enabled())
v2_compat.disable_v2_behavior()
self.assertFalse(tf2.enabled())
self.assertFalse(_pywrap_tf2.is_enabled())
v2_compat.enable_v2_behavior()
self.assertTrue(tf2.enabled())
self.assertTrue(_pywrap_tf2.is_enabled())
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -41,6 +41,7 @@ def index_directory(directory,
directory: The target directory (string). directory: The target directory (string).
labels: Either "inferred" labels: Either "inferred"
(labels are generated from the directory structure), (labels are generated from the directory structure),
None (no labels),
or a list/tuple of integer labels of the same size as the number of or a list/tuple of integer labels of the same size as the number of
valid files found in the directory. Labels should be sorted according valid files found in the directory. Labels should be sorted according
to the alphanumeric order of the image file paths to the alphanumeric order of the image file paths
@ -61,19 +62,24 @@ def index_directory(directory,
labels: list of matching integer labels (same length as file_paths) labels: list of matching integer labels (same length as file_paths)
class_names: names of the classes corresponding to these labels, in order. class_names: names of the classes corresponding to these labels, in order.
""" """
inferred_class_names = [] if labels is None:
for subdir in sorted(os.listdir(directory)): # in the no-label case, index from the parent directory down.
if os.path.isdir(os.path.join(directory, subdir)): subdirs = ['']
inferred_class_names.append(subdir) class_names = subdirs
if not class_names:
class_names = inferred_class_names
else: else:
if set(class_names) != set(inferred_class_names): subdirs = []
raise ValueError( for subdir in sorted(os.listdir(directory)):
'The `class_names` passed did not match the ' if os.path.isdir(os.path.join(directory, subdir)):
'names of the subdirectories of the target directory. ' subdirs.append(subdir)
'Expected: %s, but received: %s' % if not class_names:
(inferred_class_names, class_names)) class_names = subdirs
else:
if set(class_names) != set(subdirs):
raise ValueError(
'The `class_names` passed did not match the '
'names of the subdirectories of the target directory. '
'Expected: %s, but received: %s' %
(subdirs, class_names))
class_indices = dict(zip(class_names, range(len(class_names)))) class_indices = dict(zip(class_names, range(len(class_names))))
# Build an index of the files # Build an index of the files
@ -81,7 +87,8 @@ def index_directory(directory,
pool = multiprocessing.pool.ThreadPool() pool = multiprocessing.pool.ThreadPool()
results = [] results = []
filenames = [] filenames = []
for dirpath in (os.path.join(directory, subdir) for subdir in class_names):
for dirpath in (os.path.join(directory, subdir) for subdir in subdirs):
results.append( results.append(
pool.apply_async(index_subdirectory, pool.apply_async(index_subdirectory,
(dirpath, class_indices, follow_links, formats))) (dirpath, class_indices, follow_links, formats)))
@ -90,7 +97,7 @@ def index_directory(directory,
partial_filenames, partial_labels = res.get() partial_filenames, partial_labels = res.get()
labels_list.append(partial_labels) labels_list.append(partial_labels)
filenames += partial_filenames filenames += partial_filenames
if labels != 'inferred': if labels not in ('inferred', None):
if len(labels) != len(filenames): if len(labels) != len(filenames):
raise ValueError('Expected the lengths of `labels` to match the number ' raise ValueError('Expected the lengths of `labels` to match the number '
'of files in the target directory. len(labels) is %s ' 'of files in the target directory. len(labels) is %s '
@ -103,8 +110,11 @@ def index_directory(directory,
labels[i:i + len(partial_labels)] = partial_labels labels[i:i + len(partial_labels)] = partial_labels
i += len(partial_labels) i += len(partial_labels)
print('Found %d files belonging to %d classes.' % if labels is None:
(len(filenames), len(class_names))) print('Found %d files.' % (len(filenames),))
else:
print('Found %d files belonging to %d classes.' %
(len(filenames), len(class_names)))
pool.close() pool.close()
pool.join() pool.join()
file_paths = [os.path.join(directory, fname) for fname in filenames] file_paths = [os.path.join(directory, fname) for fname in filenames]

View File

@ -74,6 +74,7 @@ def image_dataset_from_directory(directory,
Otherwise, the directory structure is ignored. Otherwise, the directory structure is ignored.
labels: Either "inferred" labels: Either "inferred"
(labels are generated from the directory structure), (labels are generated from the directory structure),
None (no labels),
or a list/tuple of integer labels of the same size as the number of or a list/tuple of integer labels of the same size as the number of
image files found in the directory. Labels should be sorted according image files found in the directory. Labels should be sorted according
to the alphanumeric order of the image file paths to the alphanumeric order of the image file paths
@ -139,7 +140,7 @@ def image_dataset_from_directory(directory,
- if `color_mode` is `rgba`, - if `color_mode` is `rgba`,
there are 4 channel in the image tensors. there are 4 channel in the image tensors.
""" """
if labels != 'inferred': if labels not in ('inferred', None):
if not isinstance(labels, (list, tuple)): if not isinstance(labels, (list, tuple)):
raise ValueError( raise ValueError(
'`labels` argument should be a list/tuple of integer labels, of ' '`labels` argument should be a list/tuple of integer labels, of '
@ -156,6 +157,9 @@ def image_dataset_from_directory(directory,
raise ValueError( raise ValueError(
'`label_mode` argument must be one of "int", "categorical", "binary", ' '`label_mode` argument must be one of "int", "categorical", "binary", '
'or None. Received: %s' % (label_mode,)) 'or None. Received: %s' % (label_mode,))
if labels is None or label_mode is None:
labels = None
label_mode = None
if color_mode == 'rgb': if color_mode == 'rgb':
num_channels = 3 num_channels = 3
elif color_mode == 'rgba': elif color_mode == 'rgba':
@ -188,6 +192,8 @@ def image_dataset_from_directory(directory,
image_paths, labels = dataset_utils.get_training_or_validation_split( image_paths, labels = dataset_utils.get_training_or_validation_split(
image_paths, labels, validation_split, subset) image_paths, labels, validation_split, subset)
if not image_paths:
raise ValueError('No images found.')
dataset = paths_and_labels_to_dataset( dataset = paths_and_labels_to_dataset(
image_paths=image_paths, image_paths=image_paths,

View File

@ -82,7 +82,7 @@ class ImageDatasetFromDirectoryTest(keras_parameterized.TestCase):
# Save images to the paths # Save images to the paths
i = 0 i = 0
for img in self._get_images(color_mode=color_mode, count=count): for img in self._get_images(color_mode=color_mode, count=count):
path = paths[count % len(paths)] path = paths[i % len(paths)]
if color_mode == 'rgb': if color_mode == 'rgb':
ext = 'jpg' ext = 'jpg'
else: else:
@ -92,6 +92,32 @@ class ImageDatasetFromDirectoryTest(keras_parameterized.TestCase):
i += 1 i += 1
return temp_dir return temp_dir
def test_image_dataset_from_directory_standalone(self):
# Test retrieving images without labels from a directory and its subdirs.
if PIL is None:
return # Skip test if PIL is not available.
# Save a few extra images in the parent directory.
directory = self._prepare_directory(count=7, num_classes=2)
for i, img in enumerate(self._get_images(3)):
filename = 'image_%s.jpg' % (i,)
img.save(os.path.join(directory, filename))
dataset = image_dataset.image_dataset_from_directory(
directory, batch_size=5, image_size=(18, 18), labels=None)
batch = next(iter(dataset))
# We return plain images
self.assertEqual(batch.shape, (5, 18, 18, 3))
self.assertEqual(batch.dtype.name, 'float32')
# Count samples
batch_count = 0
sample_count = 0
for batch in dataset:
batch_count += 1
sample_count += batch.shape[0]
self.assertEqual(batch_count, 2)
self.assertEqual(sample_count, 10)
def test_image_dataset_from_directory_binary(self): def test_image_dataset_from_directory_binary(self):
if PIL is None: if PIL is None:
return # Skip test if PIL is not available. return # Skip test if PIL is not available.
@ -253,6 +279,11 @@ class ImageDatasetFromDirectoryTest(keras_parameterized.TestCase):
sample_count += batch.shape[0] sample_count += batch.shape[0]
self.assertEqual(sample_count, 25) self.assertEqual(sample_count, 25)
def test_image_dataset_from_directory_no_images(self):
directory = self._prepare_directory(num_classes=2, count=0)
with self.assertRaisesRegex(ValueError, 'No images found.'):
_ = image_dataset.image_dataset_from_directory(directory)
def test_image_dataset_from_directory_errors(self): def test_image_dataset_from_directory_errors(self):
if PIL is None: if PIL is None:
return # Skip test if PIL is not available. return # Skip test if PIL is not available.
@ -261,7 +292,7 @@ class ImageDatasetFromDirectoryTest(keras_parameterized.TestCase):
with self.assertRaisesRegex(ValueError, '`labels` argument should be'): with self.assertRaisesRegex(ValueError, '`labels` argument should be'):
_ = image_dataset.image_dataset_from_directory( _ = image_dataset.image_dataset_from_directory(
directory, labels=None) directory, labels='other')
with self.assertRaisesRegex(ValueError, '`label_mode` argument must be'): with self.assertRaisesRegex(ValueError, '`label_mode` argument must be'):
_ = image_dataset.image_dataset_from_directory( _ = image_dataset.image_dataset_from_directory(

View File

@ -66,6 +66,7 @@ def text_dataset_from_directory(directory,
Otherwise, the directory structure is ignored. Otherwise, the directory structure is ignored.
labels: Either "inferred" labels: Either "inferred"
(labels are generated from the directory structure), (labels are generated from the directory structure),
None (no labels),
or a list/tuple of integer labels of the same size as the number of or a list/tuple of integer labels of the same size as the number of
text files found in the directory. Labels should be sorted according text files found in the directory. Labels should be sorted according
to the alphanumeric order of the text file paths to the alphanumeric order of the text file paths
@ -114,7 +115,7 @@ def text_dataset_from_directory(directory,
of shape `(batch_size, num_classes)`, representing a one-hot of shape `(batch_size, num_classes)`, representing a one-hot
encoding of the class index. encoding of the class index.
""" """
if labels != 'inferred': if labels not in ('inferred', None):
if not isinstance(labels, (list, tuple)): if not isinstance(labels, (list, tuple)):
raise ValueError( raise ValueError(
'`labels` argument should be a list/tuple of integer labels, of ' '`labels` argument should be a list/tuple of integer labels, of '
@ -131,6 +132,9 @@ def text_dataset_from_directory(directory,
raise ValueError( raise ValueError(
'`label_mode` argument must be one of "int", "categorical", "binary", ' '`label_mode` argument must be one of "int", "categorical", "binary", '
'or None. Received: %s' % (label_mode,)) 'or None. Received: %s' % (label_mode,))
if labels is None or label_mode is None:
labels = None
label_mode = None
dataset_utils.check_validation_split_arg( dataset_utils.check_validation_split_arg(
validation_split, subset, shuffle, seed) validation_split, subset, shuffle, seed)
@ -152,6 +156,8 @@ def text_dataset_from_directory(directory,
file_paths, labels = dataset_utils.get_training_or_validation_split( file_paths, labels = dataset_utils.get_training_or_validation_split(
file_paths, labels, validation_split, subset) file_paths, labels, validation_split, subset)
if not file_paths:
raise ValueError('No text files found.')
dataset = paths_and_labels_to_dataset( dataset = paths_and_labels_to_dataset(
file_paths=file_paths, file_paths=file_paths,

View File

@ -58,7 +58,7 @@ class TextDatasetFromDirectoryTest(keras_parameterized.TestCase):
paths += class_paths paths += class_paths
for i in range(count): for i in range(count):
path = paths[count % len(paths)] path = paths[i % len(paths)]
filename = os.path.join(path, 'text_%s.txt' % (i,)) filename = os.path.join(path, 'text_%s.txt' % (i,))
f = open(os.path.join(temp_dir, filename), 'w') f = open(os.path.join(temp_dir, filename), 'w')
text = ''.join([random.choice(string.printable) for _ in range(length)]) text = ''.join([random.choice(string.printable) for _ in range(length)])
@ -66,6 +66,32 @@ class TextDatasetFromDirectoryTest(keras_parameterized.TestCase):
f.close() f.close()
return temp_dir return temp_dir
def test_text_dataset_from_directory_standalone(self):
# Test retrieving txt files without labels from a directory and its subdirs.
# Save a few extra files in the parent directory.
directory = self._prepare_directory(count=7, num_classes=2)
for i in range(3):
filename = 'text_%s.txt' % (i,)
f = open(os.path.join(directory, filename), 'w')
text = ''.join([random.choice(string.printable) for _ in range(20)])
f.write(text)
f.close()
dataset = text_dataset.text_dataset_from_directory(
directory, batch_size=5, label_mode=None, max_length=10)
batch = next(iter(dataset))
# We just return the texts, no labels
self.assertEqual(batch.shape, (5,))
self.assertEqual(batch.dtype.name, 'string')
# Count samples
batch_count = 0
sample_count = 0
for batch in dataset:
batch_count += 1
sample_count += batch.shape[0]
self.assertEqual(batch_count, 2)
self.assertEqual(sample_count, 10)
def test_text_dataset_from_directory_binary(self): def test_text_dataset_from_directory_binary(self):
directory = self._prepare_directory(num_classes=2) directory = self._prepare_directory(num_classes=2)
dataset = text_dataset.text_dataset_from_directory( dataset = text_dataset.text_dataset_from_directory(
@ -172,12 +198,17 @@ class TextDatasetFromDirectoryTest(keras_parameterized.TestCase):
sample_count += batch.shape[0] sample_count += batch.shape[0]
self.assertEqual(sample_count, 25) self.assertEqual(sample_count, 25)
def test_text_dataset_from_directory_no_files(self):
directory = self._prepare_directory(num_classes=2, count=0)
with self.assertRaisesRegex(ValueError, 'No text files found.'):
_ = text_dataset.text_dataset_from_directory(directory)
def test_text_dataset_from_directory_errors(self): def test_text_dataset_from_directory_errors(self):
directory = self._prepare_directory(num_classes=3, count=5) directory = self._prepare_directory(num_classes=3, count=5)
with self.assertRaisesRegex(ValueError, '`labels` argument should be'): with self.assertRaisesRegex(ValueError, '`labels` argument should be'):
_ = text_dataset.text_dataset_from_directory( _ = text_dataset.text_dataset_from_directory(
directory, labels=None) directory, labels='other')
with self.assertRaisesRegex(ValueError, '`label_mode` argument must be'): with self.assertRaisesRegex(ValueError, '`label_mode` argument must be'):
_ = text_dataset.text_dataset_from_directory( _ = text_dataset.text_dataset_from_directory(

View File

@ -3246,10 +3246,12 @@ cuda_py_test(
srcs = ["extract_volume_patches_grad_test.py"], srcs = ["extract_volume_patches_grad_test.py"],
shard_count = 50, shard_count = 50,
tags = [ tags = [
"no_gpu", # b/171837334
"no_oss", # Test times out on oss-nightly cpu builds
"no_pip", "no_pip",
"nogpu", # http://b/171837334 "nogpu", # b/171837334
"nomac", # http://b/139946976 "nomac", # b/139946976
"notap", # http://b/31080670 "notap", # b/31080670
], ],
deps = [ deps = [
"//tensorflow/python:array_ops", "//tensorflow/python:array_ops",

View File

@ -1557,6 +1557,14 @@ class AssertTypeTest(test.TestCase):
with self.assertRaisesRegexp(TypeError, "must be of type.*float32"): with self.assertRaisesRegexp(TypeError, "must be of type.*float32"):
check_ops.assert_type(sparse_float16, dtypes.float32) check_ops.assert_type(sparse_float16, dtypes.float32)
def test_raise_when_tf_type_is_not_dtype(self):
# Test case for GitHub issue:
# https://github.com/tensorflow/tensorflow/issues/45975
value = constant_op.constant(0.0)
with self.assertRaisesRegexp(TypeError,
"Cannot convert.*to a TensorFlow DType"):
check_ops.assert_type(value, (dtypes.float32,))
class AssertShapesTest(test.TestCase): class AssertShapesTest(test.TestCase):

Some files were not shown because too many files have changed in this diff Show More