Merge branch 'master'
Merge branch 'master' of https://github.com/tensorflow/tensorflow into feature-micro-add-op-depth-to-space-pr1
This commit is contained in:
commit
c20ac67cb1
139
RELEASE.md
139
RELEASE.md
@ -114,6 +114,143 @@ This release contains contributions from many people at Google, as well as:
|
|||||||
|
|
||||||
<INSERT>, <NAME>, <HERE>, <USING>, <GITHUB>, <HANDLE>
|
<INSERT>, <NAME>, <HERE>, <USING>, <GITHUB>, <HANDLE>
|
||||||
|
|
||||||
|
# Release 2.3.2
|
||||||
|
|
||||||
|
## Bug Fixes and Other Changes
|
||||||
|
* Fixes an access to unitialized memory in Eigen code
|
||||||
|
([CVE-2020-26266](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26266))
|
||||||
|
* Fixes a security vulnerability caused by lack of validation in
|
||||||
|
`tf.raw_ops.DataFormatVecPermute` and `tf.raw_ops.DataFormatDimMap`
|
||||||
|
([CVE-2020-26267](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26267))
|
||||||
|
* Fixes a vulnerability caused by attempting to write to immutable memory region in
|
||||||
|
`tf.raw_ops.ImmutableConst`
|
||||||
|
([CVE-2020-26268](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26268)
|
||||||
|
* Fixes a `CHECK`-fail in LSTM with zero-length input
|
||||||
|
([CVE-2020-26270](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26270))
|
||||||
|
* Fixes a security vulnerability caused by accessing heap data outside of bounds
|
||||||
|
when loading a specially crafted `SavedModel`
|
||||||
|
([CVE-2020-26271](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26271))
|
||||||
|
* Solves an OOM issue on TPUs when XLA contexts use fused average updates
|
||||||
|
* Updates `libjpeg-turbo` to `2.0.5` to handle
|
||||||
|
[CVE-2020-13790](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-13790).
|
||||||
|
* Updates `junit` to `4.13.1` to handle
|
||||||
|
[CVE-2020-15250](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15250).
|
||||||
|
* Updates `PCRE` to `8.44` to handle
|
||||||
|
[CVE-2019-20838](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-20838)
|
||||||
|
and
|
||||||
|
[CVE-2020-14155](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-14155).
|
||||||
|
* Updates `sqlite3` to `3.44.0` to keep in sync with master branch.
|
||||||
|
|
||||||
|
# Release 2.2.2
|
||||||
|
|
||||||
|
## Bug Fixes and Other Changes
|
||||||
|
* Fixes an access to unitialized memory in Eigen code
|
||||||
|
([CVE-2020-26266](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26266))
|
||||||
|
* Fixes a security vulnerability caused by lack of validation in
|
||||||
|
`tf.raw_ops.DataFormatVecPermute` and `tf.raw_ops.DataFormatDimMap`
|
||||||
|
([CVE-2020-26267](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26267))
|
||||||
|
* Fixes a vulnerability caused by attempting to write to immutable memory region in
|
||||||
|
`tf.raw_ops.ImmutableConst`
|
||||||
|
([CVE-2020-26268](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26268)
|
||||||
|
* Fixes a `CHECK`-fail in LSTM with zero-length input
|
||||||
|
([CVE-2020-26270](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26270))
|
||||||
|
* Fixes a security vulnerability caused by accessing heap data outside of bounds
|
||||||
|
when loading a specially crafted `SavedModel`
|
||||||
|
([CVE-2020-26271](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26271))
|
||||||
|
* Prevents memory leaks in loading `SavedModel`s that import functions
|
||||||
|
* Updates `libjpeg-turbo` to `2.0.5` to handle
|
||||||
|
[CVE-2020-13790](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-13790).
|
||||||
|
* Updates `junit` to `4.13.1` to handle
|
||||||
|
[CVE-2020-15250](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15250).
|
||||||
|
* Updates `PCRE` to `8.44` to handle
|
||||||
|
[CVE-2019-20838](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-20838)
|
||||||
|
and
|
||||||
|
[CVE-2020-14155](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-14155).
|
||||||
|
* Updates `sqlite3` to `3.44.0` to keep in sync with master branch.
|
||||||
|
|
||||||
|
# Release 2.1.3
|
||||||
|
|
||||||
|
## Bug Fixes and Other Changes
|
||||||
|
* Fixes an access to unitialized memory in Eigen code
|
||||||
|
([CVE-2020-26266](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26266))
|
||||||
|
* Fixes a security vulnerability caused by lack of validation in
|
||||||
|
`tf.raw_ops.DataFormatVecPermute` and `tf.raw_ops.DataFormatDimMap`
|
||||||
|
([CVE-2020-26267](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26267))
|
||||||
|
* Fixes a vulnerability caused by attempting to write to immutable memory region in
|
||||||
|
`tf.raw_ops.ImmutableConst`
|
||||||
|
([CVE-2020-26268](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26268)
|
||||||
|
* Fixes a `CHECK`-fail in LSTM with zero-length input
|
||||||
|
([CVE-2020-26270](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26270))
|
||||||
|
* Fixes a security vulnerability caused by accessing heap data outside of bounds
|
||||||
|
when loading a specially crafted `SavedModel`
|
||||||
|
([CVE-2020-26271](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26271))
|
||||||
|
* Updates `libjpeg-turbo` to `2.0.5` to handle
|
||||||
|
[CVE-2020-13790](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-13790).
|
||||||
|
* Updates `junit` to `4.13.1` to handle
|
||||||
|
[CVE-2020-15250](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15250).
|
||||||
|
* Updates `PCRE` to `8.44` to handle
|
||||||
|
[CVE-2019-20838](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-20838)
|
||||||
|
and
|
||||||
|
[CVE-2020-14155](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-14155).
|
||||||
|
* Updates `sqlite3` to `3.44.0` to keep in sync with master branch.
|
||||||
|
* Newer ROCm versions are supported on the 2.1 branch.
|
||||||
|
|
||||||
|
# Release 2.0.4
|
||||||
|
|
||||||
|
Note that this is the last patch release for the TensorFlow 2.0.x series.
|
||||||
|
|
||||||
|
## Bug Fixes and Other Changes
|
||||||
|
* Fixes an access to unitialized memory in Eigen code
|
||||||
|
([CVE-2020-26266](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26266))
|
||||||
|
* Fixes a security vulnerability caused by lack of validation in
|
||||||
|
`tf.raw_ops.DataFormatVecPermute` and `tf.raw_ops.DataFormatDimMap`
|
||||||
|
([CVE-2020-26267](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26267))
|
||||||
|
* Fixes a vulnerability caused by attempting to write to immutable memory region in
|
||||||
|
`tf.raw_ops.ImmutableConst`
|
||||||
|
([CVE-2020-26268](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26268)
|
||||||
|
* Fixes a `CHECK`-fail in LSTM with zero-length input
|
||||||
|
([CVE-2020-26270](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26270))
|
||||||
|
* Fixes a security vulnerability caused by accessing heap data outside of bounds
|
||||||
|
when loading a specially crafted `SavedModel`
|
||||||
|
([CVE-2020-26271](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26271))
|
||||||
|
* Updates `libjpeg-turbo` to `2.0.5` to handle
|
||||||
|
[CVE-2020-13790](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-13790).
|
||||||
|
* Updates `junit` to `4.13.1` to handle
|
||||||
|
[CVE-2020-15250](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15250).
|
||||||
|
* Updates `PCRE` to `8.44` to handle
|
||||||
|
[CVE-2019-20838](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-20838)
|
||||||
|
and
|
||||||
|
[CVE-2020-14155](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-14155).
|
||||||
|
* Updates `sqlite3` to `3.44.0` to keep in sync with master branch.
|
||||||
|
|
||||||
|
# Release 1.15.5
|
||||||
|
|
||||||
|
Note that this is the last patch release for the TensorFlow 1.x series.
|
||||||
|
|
||||||
|
## Bug Fixes and Other Changes
|
||||||
|
* Fixes an access to unitialized memory in Eigen code
|
||||||
|
([CVE-2020-26266](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26266))
|
||||||
|
* Fixes a security vulnerability caused by lack of validation in
|
||||||
|
`tf.raw_ops.DataFormatVecPermute` and `tf.raw_ops.DataFormatDimMap`
|
||||||
|
([CVE-2020-26267](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26267))
|
||||||
|
* Fixes a vulnerability caused by attempting to write to immutable memory region in
|
||||||
|
`tf.raw_ops.ImmutableConst`
|
||||||
|
([CVE-2020-26268](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26268)
|
||||||
|
* Fixes a `CHECK`-fail in LSTM with zero-length input
|
||||||
|
([CVE-2020-26270](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26270))
|
||||||
|
* Fixes a security vulnerability caused by accessing heap data outside of bounds
|
||||||
|
when loading a specially crafted `SavedModel`
|
||||||
|
([CVE-2020-26271](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26271))
|
||||||
|
* Updates `libjpeg-turbo` to `2.0.5` to handle
|
||||||
|
[CVE-2020-13790](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-13790).
|
||||||
|
* Updates `junit` to `4.13.1` to handle
|
||||||
|
[CVE-2020-15250](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15250).
|
||||||
|
* Updates `PCRE` to `8.44` to handle
|
||||||
|
[CVE-2019-20838](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-20838)
|
||||||
|
and
|
||||||
|
[CVE-2020-14155](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-14155).
|
||||||
|
* Updates `sqlite3` to `3.44.0` to keep in sync with master branch.
|
||||||
|
|
||||||
# Release 2.4.0
|
# Release 2.4.0
|
||||||
|
|
||||||
## Major Features and Improvements
|
## Major Features and Improvements
|
||||||
@ -163,7 +300,7 @@ This release contains contributions from many people at Google, as well as:
|
|||||||
## Breaking Changes
|
## Breaking Changes
|
||||||
|
|
||||||
* TF Core:
|
* TF Core:
|
||||||
* Certain float32 ops run in lower precsion on Ampere based GPUs, including
|
* Certain float32 ops run in lower precision on Ampere based GPUs, including
|
||||||
matmuls and convolutions, due to the use of [TensorFloat-32]
|
matmuls and convolutions, due to the use of [TensorFloat-32]
|
||||||
(https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/).
|
(https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/).
|
||||||
Specifically, inputs to such ops are rounded from 23 bits of precision to 10
|
Specifically, inputs to such ops are rounded from 23 bits of precision to 10
|
||||||
|
@ -1282,7 +1282,8 @@ class DynamicReshapeOpNotActuallyDynamic
|
|||||||
void DynamicReshapeOp::getCanonicalizationPatterns(
|
void DynamicReshapeOp::getCanonicalizationPatterns(
|
||||||
OwningRewritePatternList& results, MLIRContext* context) {
|
OwningRewritePatternList& results, MLIRContext* context) {
|
||||||
results.insert<DynamicReshapeOpNotActuallyDynamic,
|
results.insert<DynamicReshapeOpNotActuallyDynamic,
|
||||||
RemoveRedundantDynamicReshape, ShapeOfDynamicReshape>(context);
|
RemoveRedundantDynamicBroadcast, RemoveRedundantDynamicReshape,
|
||||||
|
ShapeOfDynamicReshape>(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -33,3 +33,10 @@ def UnaryEinsumToEinsum : Pat<
|
|||||||
def RemoveRedundantDynamicReshape : Pat<
|
def RemoveRedundantDynamicReshape : Pat<
|
||||||
(HLO_DynamicReshapeOp (HLO_DynamicReshapeOp $operand, $shape1), $shape2),
|
(HLO_DynamicReshapeOp (HLO_DynamicReshapeOp $operand, $shape1), $shape2),
|
||||||
(HLO_DynamicReshapeOp $operand, $shape2)>;
|
(HLO_DynamicReshapeOp $operand, $shape2)>;
|
||||||
|
|
||||||
|
// A dynamic broadcast of a dynamic reshape with the same shape operand
|
||||||
|
// is a dynamic reshape.
|
||||||
|
def RemoveRedundantDynamicBroadcast : Pat<
|
||||||
|
(HLO_DynamicBroadcastInDimOp
|
||||||
|
(HLO_DynamicReshapeOp $operand, $shape), $shape, $dims),
|
||||||
|
(HLO_DynamicReshapeOp $operand, $shape)>;
|
||||||
|
@ -1540,3 +1540,14 @@ func @identity_broadcast_in_dim_reshape(%arg0: tensor<128xf32>) -> tensor<128xf3
|
|||||||
return %1 : tensor<128xf32>
|
return %1 : tensor<128xf32>
|
||||||
// CHECK: return %arg0 : tensor<128xf32>
|
// CHECK: return %arg0 : tensor<128xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @broadcast_of_reshape
|
||||||
|
func @broadcast_of_reshape(%arg: tensor<?xf32>, %shape: tensor<2xindex>) -> tensor<?x?xf32> {
|
||||||
|
%0 = "mhlo.dynamic_reshape"(%arg, %shape) : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||||
|
%1 = "mhlo.dynamic_broadcast_in_dim"(%0, %shape) {
|
||||||
|
broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>
|
||||||
|
} : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||||
|
return %1 : tensor<?x?xf32>
|
||||||
|
}
|
||||||
|
// CHECK: [[RESHAPE:%.*]] = "mhlo.dynamic_reshape"
|
||||||
|
// CHECK: return [[RESHAPE]]
|
||||||
|
@ -133,7 +133,7 @@ class TFL_OperandsHaveSameShapesOrBroadcastableShape<
|
|||||||
TFL_RuntimePredOpTrait<"operands do not have the same shape or "
|
TFL_RuntimePredOpTrait<"operands do not have the same shape or "
|
||||||
"broadcastable shapes within the rank " # max_bcast_rank,
|
"broadcastable shapes within the rank " # max_bcast_rank,
|
||||||
CPred<"TFL::VerifyOperandsHaveSameShapesOrBroadcastableShape("
|
CPred<"TFL::VerifyOperandsHaveSameShapesOrBroadcastableShape("
|
||||||
"$_op, llvm::ArrayRef<unsigned>({" # StrJoinInt<indices>.result #
|
"$_op, llvm::ArrayRef<unsigned>({" # !interleave(indices, ", ") #
|
||||||
"}), " # max_bcast_rank # ")">>;
|
"}), " # max_bcast_rank # ")">>;
|
||||||
|
|
||||||
// These additional types/type constraints here are used to decouple the ops
|
// These additional types/type constraints here are used to decouple the ops
|
||||||
@ -3453,10 +3453,10 @@ def TFL_CastOp : TFL_Op<"cast", [
|
|||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
TFL_TensorOf<[F32, I1, I32, I64, TFL_Quint8, UI8, Complex<F<32>>]>:$input
|
TFL_TensorOf<[F32, I1, I16, I32, I64, TFL_Quint8, UI8, Complex<F<32>>]>:$input
|
||||||
);
|
);
|
||||||
|
|
||||||
let results = (outs TFL_TensorOf<[F32, I1, I32, I64, TFL_Quint8, UI8, Complex<F<32>>]>:$output);
|
let results = (outs TFL_TensorOf<[F32, I1, I16, I32, I64, TFL_Quint8, UI8, Complex<F<32>>]>:$output);
|
||||||
|
|
||||||
// TFLite's cast op does not utilize CastOptions, instead derives types
|
// TFLite's cast op does not utilize CastOptions, instead derives types
|
||||||
// from the TfLiteTensors.
|
// from the TfLiteTensors.
|
||||||
|
@ -34,7 +34,7 @@ class QuantizedType<string n, list<int> params, bit signed>
|
|||||||
"Q" # !if (signed, "I", "UI") # !head(params) # " type"> {
|
"Q" # !if (signed, "I", "UI") # !head(params) # " type"> {
|
||||||
string name = n;
|
string name = n;
|
||||||
string asTraitArgsStr =
|
string asTraitArgsStr =
|
||||||
StrJoinInt<params>.result # !if(signed, ", true", ", false");
|
!interleave(params, ", ") # !if(signed, ", true", ", false");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Uniform quantized types. Two integers "smantissa" and "sexp" are used to
|
// Uniform quantized types. Two integers "smantissa" and "sexp" are used to
|
||||||
@ -134,7 +134,7 @@ class FixedResultScale<QuantizedType qt> : NativeOpTrait<!strconcat(
|
|||||||
// needs a scale based on the scales of op1 and op2.
|
// needs a scale based on the scales of op1 and op2.
|
||||||
class AccumulatorUniformScale<int bias, int op1, int op2> : NativeOpTrait<
|
class AccumulatorUniformScale<int bias, int op1, int op2> : NativeOpTrait<
|
||||||
!strconcat("quant::AccumulatorUniformScale<",
|
!strconcat("quant::AccumulatorUniformScale<",
|
||||||
StrJoinInt<[bias, op1, op2]>.result,
|
!interleave([bias, op1, op2], ", "),
|
||||||
">::Impl")>;
|
">::Impl")>;
|
||||||
|
|
||||||
// Specify the operand index of the coefficient operand for an affine op
|
// Specify the operand index of the coefficient operand for an affine op
|
||||||
@ -142,7 +142,7 @@ class AccumulatorUniformScale<int bias, int op1, int op2> : NativeOpTrait<
|
|||||||
// If the quantization dimension is -1, per-axis quantization isn't supported.
|
// If the quantization dimension is -1, per-axis quantization isn't supported.
|
||||||
class AffineOpCoefficient<int dim, int index> : NativeOpTrait<
|
class AffineOpCoefficient<int dim, int index> : NativeOpTrait<
|
||||||
!strconcat("quant::AffineOpCoefficient<",
|
!strconcat("quant::AffineOpCoefficient<",
|
||||||
StrJoinInt<[dim, index]>.result,
|
!interleave([dim, index], ", "),
|
||||||
">::Impl")>;
|
">::Impl")>;
|
||||||
|
|
||||||
// Specify this trait if the op doesn't have quantizable output. We shouldn't
|
// Specify this trait if the op doesn't have quantizable output. We shouldn't
|
||||||
|
@ -1320,6 +1320,22 @@ func @cast(%arg0: tensor<1x2x2x5xi32>) -> tensor<1x2x2x5xf32> {
|
|||||||
// CHECK: "tfl.cast"(%arg0) : (tensor<1x2x2x5xi32>) -> tensor<1x2x2x5xf32>
|
// CHECK: "tfl.cast"(%arg0) : (tensor<1x2x2x5xi32>) -> tensor<1x2x2x5xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func @castFloat32ToI16(%arg0: tensor<1x2x2x5xf32>) -> tensor<1x2x2x5xi16> {
|
||||||
|
%0 = "tf.Cast"(%arg0) : (tensor<1x2x2x5xf32>) -> tensor<1x2x2x5xi16>
|
||||||
|
return %0 : tensor<1x2x2x5xi16>
|
||||||
|
|
||||||
|
// CHECK-LABEL: castFloat32ToI16
|
||||||
|
// CHECK: "tfl.cast"(%arg0) : (tensor<1x2x2x5xf32>) -> tensor<1x2x2x5xi16>
|
||||||
|
}
|
||||||
|
|
||||||
|
func @castI16ToFloat32(%arg0: tensor<1x2x2x5xi16>) -> tensor<1x2x2x5xf32> {
|
||||||
|
%0 = "tf.Cast"(%arg0) : (tensor<1x2x2x5xi16>) -> tensor<1x2x2x5xf32>
|
||||||
|
return %0 : tensor<1x2x2x5xf32>
|
||||||
|
|
||||||
|
// CHECK-LABEL: castI16ToFloat32
|
||||||
|
// CHECK: "tfl.cast"(%arg0) : (tensor<1x2x2x5xi16>) -> tensor<1x2x2x5xf32>
|
||||||
|
}
|
||||||
|
|
||||||
func @castComplex(%arg0: tensor<1x2x2x5xf32>) -> tensor<1x2x2x5xcomplex<f32>> {
|
func @castComplex(%arg0: tensor<1x2x2x5xf32>) -> tensor<1x2x2x5xcomplex<f32>> {
|
||||||
%0 = "tf.Cast"(%arg0) : (tensor<1x2x2x5xf32>) -> tensor<1x2x2x5xcomplex<f32>>
|
%0 = "tf.Cast"(%arg0) : (tensor<1x2x2x5xf32>) -> tensor<1x2x2x5xcomplex<f32>>
|
||||||
return %0 : tensor<1x2x2x5xcomplex<f32>>
|
return %0 : tensor<1x2x2x5xcomplex<f32>>
|
||||||
|
@ -98,11 +98,13 @@ class TF_OpIsBroadcastableToRes<int opId, int resId> : And<[
|
|||||||
|
|
||||||
|
|
||||||
class TF_AllTypesMatchPred<list<string> values> :
|
class TF_AllTypesMatchPred<list<string> values> :
|
||||||
CPred<"TF::AreCastCompatible(llvm::makeArrayRef({"# StrJoin<values>.result #"}))">;
|
CPred<"TF::AreCastCompatible(llvm::makeArrayRef({" #
|
||||||
|
!interleave(values, ", ") # "}))">;
|
||||||
|
|
||||||
class TF_AllTypesMatch<list<string> names> :
|
class TF_AllTypesMatch<list<string> names> :
|
||||||
PredOpTrait<
|
PredOpTrait<
|
||||||
"all of {" # StrJoin<names>.result # "} have dynamically equal types ",
|
"all of {" # !interleave(names, ", ") #
|
||||||
|
"} have dynamically equal types ",
|
||||||
TF_AllTypesMatchPred<
|
TF_AllTypesMatchPred<
|
||||||
!foreach(n, names, !subst("$_self", "$" # n, "$_self.getType()"))>>;
|
!foreach(n, names, !subst("$_self", "$" # n, "$_self.getType()"))>>;
|
||||||
|
|
||||||
|
@ -147,6 +147,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/mlir/hlo:lhlo",
|
"//tensorflow/compiler/mlir/hlo:lhlo",
|
||||||
"//tensorflow/compiler/mlir/hlo:lhlo_gpu",
|
"//tensorflow/compiler/mlir/hlo:lhlo_gpu",
|
||||||
"//tensorflow/compiler/xla:debug_options_flags",
|
"//tensorflow/compiler/xla:debug_options_flags",
|
||||||
|
"//tensorflow/compiler/xla:shape_util",
|
||||||
"//tensorflow/compiler/xla:statusor",
|
"//tensorflow/compiler/xla:statusor",
|
||||||
"//tensorflow/compiler/xla:util",
|
"//tensorflow/compiler/xla:util",
|
||||||
"//tensorflow/compiler/xla/service:backend",
|
"//tensorflow/compiler/xla/service:backend",
|
||||||
|
@ -384,7 +384,7 @@ ENTRY main {
|
|||||||
HloModule BatchNormForwardInference
|
HloModule BatchNormForwardInference
|
||||||
|
|
||||||
// CHECK: func @main
|
// CHECK: func @main
|
||||||
// CHECK: lmhlo_gpu.batch_norm_inference"
|
// CHECK: "lmhlo_gpu.batch_norm_inference"
|
||||||
// CHECK-SAME: epsilon = 1.000000e-03 : f32
|
// CHECK-SAME: epsilon = 1.000000e-03 : f32
|
||||||
// CHECK-SAME: feature_index = 0 : i64
|
// CHECK-SAME: feature_index = 0 : i64
|
||||||
// CHECK-SAME: (memref<2x2x2x2xf32>, memref<2xf32>, memref<2xf32>, memref<2xf32>, memref<2xf32>, memref<2x2x2x2xf32>) -> ()
|
// CHECK-SAME: (memref<2x2x2x2xf32>, memref<2xf32>, memref<2xf32>, memref<2xf32>, memref<2xf32>, memref<2x2x2x2xf32>) -> ()
|
||||||
@ -399,4 +399,16 @@ ENTRY main {
|
|||||||
ROOT %custom-call = f32[2,2,2,2]{3,2,1,0}
|
ROOT %custom-call = f32[2,2,2,2]{3,2,1,0}
|
||||||
custom-call(f32[2,2,2,2]{3,2,1,0} %input, f32[2]{0} %offset, f32[2]{0} %scale, f32[2]{0} %mean, f32[2]{0} %variance, f32[] %constant, s64[] %constant_1),
|
custom-call(f32[2,2,2,2]{3,2,1,0} %input, f32[2]{0} %offset, f32[2]{0} %scale, f32[2]{0} %mean, f32[2]{0} %variance, f32[] %constant, s64[] %constant_1),
|
||||||
custom_call_target="__cudnn$batchNormalizationForwardInference"
|
custom_call_target="__cudnn$batchNormalizationForwardInference"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
HloModule Infeed
|
||||||
|
|
||||||
|
// CHECK: func @main
|
||||||
|
// CHECK: "lmhlo.infeed"
|
||||||
|
// CHECK-SAME: (memref<3xf32>) -> ()
|
||||||
|
ENTRY main {
|
||||||
|
%tok = token[] parameter(0)
|
||||||
|
ROOT %infeed = (f32[3]{0}, token[]) infeed(token[] %tok)
|
||||||
|
}
|
||||||
|
@ -27,6 +27,7 @@ limitations under the License.
|
|||||||
#include "mlir/IR/AffineMap.h" // from @llvm-project
|
#include "mlir/IR/AffineMap.h" // from @llvm-project
|
||||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||||
|
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
|
||||||
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
|
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
|
||||||
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
|
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
|
||||||
#include "mlir/IR/Dialect.h" // from @llvm-project
|
#include "mlir/IR/Dialect.h" // from @llvm-project
|
||||||
@ -60,6 +61,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_parser.h"
|
#include "tensorflow/compiler/xla/service/hlo_parser.h"
|
||||||
#include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h"
|
#include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h"
|
||||||
|
#include "tensorflow/compiler/xla/shape_util.h"
|
||||||
#include "tensorflow/compiler/xla/statusor.h"
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
#include "tensorflow/compiler/xla/util.h"
|
#include "tensorflow/compiler/xla/util.h"
|
||||||
|
|
||||||
@ -67,6 +69,7 @@ using xla::BufferAllocation;
|
|||||||
using xla::BufferAssignment;
|
using xla::BufferAssignment;
|
||||||
using xla::HloComputation;
|
using xla::HloComputation;
|
||||||
using xla::HloCustomCallInstruction;
|
using xla::HloCustomCallInstruction;
|
||||||
|
using xla::HloInfeedInstruction;
|
||||||
using xla::HloInstruction;
|
using xla::HloInstruction;
|
||||||
using xla::HloModule;
|
using xla::HloModule;
|
||||||
using xla::HloModuleProto;
|
using xla::HloModuleProto;
|
||||||
@ -199,14 +202,16 @@ class XlaHloToLhloPass
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
// Creates MLIR operands corresponding to operands and results of the XLA HLO
|
// Creates MLIR operands corresponding to operands and results of the XLA HLO
|
||||||
// instruction. If `num_operands` is not -1, then only the first `num_operands`
|
// instruction. If `num_operands` is valid, then only the first `num_operands`
|
||||||
// operands of the HLO instruction will be considered.
|
// operands of the HLO instruction will be considered.
|
||||||
Status LhloDialectEmitter::CreateOperands(
|
Status LhloDialectEmitter::CreateOperands(
|
||||||
HloInstruction* instr, llvm::SmallVectorImpl<Value>& operands,
|
HloInstruction* instr, absl::optional<xla::int64> num_operands,
|
||||||
size_t& num_arguments, size_t& num_results,
|
llvm::SmallVectorImpl<Value>& operands, size_t& num_arguments,
|
||||||
absl::optional<xla::int64> num_operands) {
|
size_t& num_results) {
|
||||||
|
if (num_operands.value_or(0) > instr->operand_count())
|
||||||
|
return xla::InvalidArgument("num_operands must be <= operand count");
|
||||||
for (xla::int64 i = 0; i < num_operands.value_or(instr->operand_count());
|
for (xla::int64 i = 0; i < num_operands.value_or(instr->operand_count());
|
||||||
i++) {
|
++i) {
|
||||||
TF_RETURN_IF_ERROR(GetOrCreateView(instr->operand(i), &operands));
|
TF_RETURN_IF_ERROR(GetOrCreateView(instr->operand(i), &operands));
|
||||||
}
|
}
|
||||||
num_arguments = operands.size();
|
num_arguments = operands.size();
|
||||||
@ -215,19 +220,23 @@ Status LhloDialectEmitter::CreateOperands(
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename OpType>
|
||||||
|
OpType LhloDialectEmitter::CreateOpWithoutAttrs(HloInstruction* instr,
|
||||||
|
ValueRange operands) {
|
||||||
|
Location loc = getLocation(instr);
|
||||||
|
NamedAttribute attrs[] = {{Identifier::get("name", builder_.getContext()),
|
||||||
|
builder_.getStringAttr(instr->name())}};
|
||||||
|
return builder_.create<OpType>(loc, llvm::None, operands, attrs);
|
||||||
|
}
|
||||||
|
|
||||||
template <typename OpType>
|
template <typename OpType>
|
||||||
StatusOr<OpType> LhloDialectEmitter::CreateOpWithoutAttrs(
|
StatusOr<OpType> LhloDialectEmitter::CreateOpWithoutAttrs(
|
||||||
HloInstruction* instr, size_t& num_arguments, size_t& num_results,
|
HloInstruction* instr, size_t& num_arguments, size_t& num_results,
|
||||||
absl::optional<xla::int64> num_operands) {
|
absl::optional<xla::int64> num_operands) {
|
||||||
Location loc = getLocation(instr);
|
|
||||||
std::pair<Identifier, Attribute> attrs[] = {
|
|
||||||
{Identifier::get("name", builder_.getContext()),
|
|
||||||
builder_.getStringAttr(instr->name())},
|
|
||||||
};
|
|
||||||
llvm::SmallVector<Value, 4> operands;
|
llvm::SmallVector<Value, 4> operands;
|
||||||
TF_RETURN_IF_ERROR(CreateOperands(instr, operands, num_arguments, num_results,
|
TF_RETURN_IF_ERROR(CreateOperands(instr, num_operands, operands,
|
||||||
num_operands));
|
num_arguments, num_results));
|
||||||
return builder_.create<OpType>(loc, llvm::None, operands, attrs);
|
return CreateOpWithoutAttrs<OpType>(instr, operands);
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<mlir::Operation*> LhloDialectEmitter::EmitOp(HloInstruction* instr) {
|
StatusOr<mlir::Operation*> LhloDialectEmitter::EmitOp(HloInstruction* instr) {
|
||||||
@ -273,6 +282,8 @@ StatusOr<mlir::Operation*> LhloDialectEmitter::EmitOp(HloInstruction* instr) {
|
|||||||
return CreateOpWithoutAttrs<lmhlo::FloorOp>(instr);
|
return CreateOpWithoutAttrs<lmhlo::FloorOp>(instr);
|
||||||
case HloOpcode::kImag:
|
case HloOpcode::kImag:
|
||||||
return CreateOpWithoutAttrs<lmhlo::ImagOp>(instr);
|
return CreateOpWithoutAttrs<lmhlo::ImagOp>(instr);
|
||||||
|
case HloOpcode::kInfeed:
|
||||||
|
return EmitInfeedOp(instr);
|
||||||
case HloOpcode::kIsFinite:
|
case HloOpcode::kIsFinite:
|
||||||
return CreateOpWithoutAttrs<lmhlo::IsFiniteOp>(instr);
|
return CreateOpWithoutAttrs<lmhlo::IsFiniteOp>(instr);
|
||||||
case HloOpcode::kLog:
|
case HloOpcode::kLog:
|
||||||
@ -387,7 +398,7 @@ StatusOr<Value> LhloDialectEmitter::RewriteFusionOperand(
|
|||||||
::xla::ShapeIndex* shape_index, OpBuilder* b, Location loc) {
|
::xla::ShapeIndex* shape_index, OpBuilder* b, Location loc) {
|
||||||
if (shape.IsTuple()) {
|
if (shape.IsTuple()) {
|
||||||
llvm::SmallVector<Value, 4> values;
|
llvm::SmallVector<Value, 4> values;
|
||||||
for (int i = 0; i < shape.tuple_shapes_size(); i++) {
|
for (int i = 0; i < shape.tuple_shapes_size(); ++i) {
|
||||||
shape_index->push_back(i);
|
shape_index->push_back(i);
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
auto v, RewriteFusionOperand(root, shape.tuple_shapes(i), shape_index,
|
auto v, RewriteFusionOperand(root, shape.tuple_shapes(i), shape_index,
|
||||||
@ -423,7 +434,7 @@ StatusOr<lmhlo::FusionOp> LhloDialectEmitter::EmitFusionOp(
|
|||||||
auto region_builder = OpBuilder::atBlockBegin(&fusion.region().front());
|
auto region_builder = OpBuilder::atBlockBegin(&fusion.region().front());
|
||||||
|
|
||||||
llvm::SmallVector<Value, 8> arguments;
|
llvm::SmallVector<Value, 8> arguments;
|
||||||
for (int i = 0; i < instr->operands().size(); i++) {
|
for (int i = 0; i < instr->operands().size(); ++i) {
|
||||||
const HloInstruction* operand = instr->operand(i);
|
const HloInstruction* operand = instr->operand(i);
|
||||||
xla::ShapeIndex shape_index;
|
xla::ShapeIndex shape_index;
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
@ -982,6 +993,19 @@ StatusOr<lmhlo::AllReduceOp> LhloDialectEmitter::EmitAllReduceOp(
|
|||||||
return all_reduce_op;
|
return all_reduce_op;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
StatusOr<lmhlo::InfeedOp> LhloDialectEmitter::EmitInfeedOp(
|
||||||
|
HloInstruction* instr) {
|
||||||
|
HloInfeedInstruction* infeed = ::xla::Cast<HloInfeedInstruction>(instr);
|
||||||
|
// HLO Infeed instruction has a single operand of token type and a tuple
|
||||||
|
// with buffers and a token as its output. LMHLO Infeed operation does not
|
||||||
|
// need the token operand or result, so drop it.
|
||||||
|
SmallVector<Value, 2> operands;
|
||||||
|
TF_RETURN_IF_ERROR(GetOrCreateView(instr, &operands, /*result_subset=*/{0}));
|
||||||
|
auto infeed_op = CreateOpWithoutAttrs<lmhlo::InfeedOp>(instr, operands);
|
||||||
|
infeed_op.configAttr(builder_.getStringAttr(infeed->infeed_config()));
|
||||||
|
return infeed_op;
|
||||||
|
}
|
||||||
|
|
||||||
StatusOr<Value> LhloDialectEmitter::GetOrCreateArrayView(
|
StatusOr<Value> LhloDialectEmitter::GetOrCreateArrayView(
|
||||||
const ::xla::HloInstruction* instr, const ::xla::Shape& current_shape,
|
const ::xla::HloInstruction* instr, const ::xla::Shape& current_shape,
|
||||||
const ::xla::ShapeIndex& shape_index) {
|
const ::xla::ShapeIndex& shape_index) {
|
||||||
@ -1055,7 +1079,7 @@ Status LhloDialectEmitter::GetOrCreateViewImpl(
|
|||||||
const HloInstruction* instr, const Shape& current_shape,
|
const HloInstruction* instr, const Shape& current_shape,
|
||||||
::xla::ShapeIndex* current_shape_index, SmallVectorImpl<Value>* values) {
|
::xla::ShapeIndex* current_shape_index, SmallVectorImpl<Value>* values) {
|
||||||
if (current_shape.IsTuple()) {
|
if (current_shape.IsTuple()) {
|
||||||
for (int i = 0; i < current_shape.tuple_shapes().size(); i++) {
|
for (int i = 0; i < current_shape.tuple_shapes().size(); ++i) {
|
||||||
current_shape_index->push_back(i);
|
current_shape_index->push_back(i);
|
||||||
TF_RETURN_IF_ERROR(GetOrCreateViewImpl(
|
TF_RETURN_IF_ERROR(GetOrCreateViewImpl(
|
||||||
instr, current_shape.tuple_shapes(i), current_shape_index, values));
|
instr, current_shape.tuple_shapes(i), current_shape_index, values));
|
||||||
@ -1063,19 +1087,26 @@ Status LhloDialectEmitter::GetOrCreateViewImpl(
|
|||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
TF_ASSIGN_OR_RETURN(
|
if (current_shape.IsArray()) {
|
||||||
auto v, GetOrCreateArrayView(instr, current_shape, *current_shape_index));
|
TF_ASSIGN_OR_RETURN(auto v, GetOrCreateArrayView(instr, current_shape,
|
||||||
values->push_back(v);
|
*current_shape_index));
|
||||||
return Status::OK();
|
values->push_back(v);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
return xla::InternalError("Unexpected shape kind for %s and shape index %s",
|
||||||
|
instr->ToString(), current_shape_index->ToString());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns a view for the result of an instruction.
|
// Returns a view for the result of an instruction.
|
||||||
// We first get a view for the slice in the allocation, and then may need to
|
// We first get a view for the slice in the allocation, and then may need to
|
||||||
// create another view to adjust the slice for the shape of the instruction.
|
// create another view to adjust the slice for the shape of the instruction.
|
||||||
Status LhloDialectEmitter::GetOrCreateView(const HloInstruction* instr,
|
Status LhloDialectEmitter::GetOrCreateView(
|
||||||
SmallVectorImpl<Value>* values) {
|
const HloInstruction* instr, SmallVectorImpl<Value>* values,
|
||||||
::xla::ShapeIndex shape_index;
|
const xla::ShapeIndex& result_subset) {
|
||||||
return GetOrCreateViewImpl(instr, instr->shape(), &shape_index, values);
|
::xla::ShapeIndex shape_index = result_subset;
|
||||||
|
const Shape& sub_shape =
|
||||||
|
::xla::ShapeUtil::GetSubshape(instr->shape(), shape_index);
|
||||||
|
return GetOrCreateViewImpl(instr, sub_shape, &shape_index, values);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status LhloDialectEmitter::Initialize() {
|
Status LhloDialectEmitter::Initialize() {
|
||||||
|
@ -27,6 +27,8 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
|
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
|
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||||
|
#include "tensorflow/compiler/xla/shape_util.h"
|
||||||
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
|
|
||||||
@ -79,6 +81,7 @@ class LhloDialectEmitter : public ::xla::DfsHloVisitorWithDefault {
|
|||||||
|
|
||||||
::xla::StatusOr<lmhlo::CompareOp> EmitCompareOp(::xla::HloInstruction* instr);
|
::xla::StatusOr<lmhlo::CompareOp> EmitCompareOp(::xla::HloInstruction* instr);
|
||||||
|
|
||||||
|
::xla::StatusOr<lmhlo::InfeedOp> EmitInfeedOp(::xla::HloInstruction* instr);
|
||||||
::xla::StatusOr<lmhlo::MapOp> EmitMapOp(::xla::HloInstruction* instr);
|
::xla::StatusOr<lmhlo::MapOp> EmitMapOp(::xla::HloInstruction* instr);
|
||||||
|
|
||||||
::xla::StatusOr<lmhlo::ReducePrecisionOp> EmitReducePrecisionOp(
|
::xla::StatusOr<lmhlo::ReducePrecisionOp> EmitReducePrecisionOp(
|
||||||
@ -87,10 +90,16 @@ class LhloDialectEmitter : public ::xla::DfsHloVisitorWithDefault {
|
|||||||
::xla::StatusOr<lmhlo::AllReduceOp> EmitAllReduceOp(
|
::xla::StatusOr<lmhlo::AllReduceOp> EmitAllReduceOp(
|
||||||
::xla::HloInstruction* instr);
|
::xla::HloInstruction* instr);
|
||||||
|
|
||||||
::xla::Status CreateOperands(
|
// Create LHLO operation operands given an XLA HLO instruction. By default,
|
||||||
::xla::HloInstruction* instr, SmallVectorImpl<Value>& operands,
|
// all XLA HLO operands and results are converted to MLIR and appended to
|
||||||
size_t& num_arguments, size_t& num_results,
|
// `operands`. If `num_operands` is specified, only the first `num_operand`
|
||||||
absl::optional<xla::int64> num_operands = absl::nullopt);
|
// operands of the instruction are converted to MLIR. The function returns the
|
||||||
|
// actual number of operands and results generated for MLIR in `num_arguments`
|
||||||
|
// and `num_results`.
|
||||||
|
::xla::Status CreateOperands(::xla::HloInstruction* instr,
|
||||||
|
absl::optional<xla::int64> num_operands,
|
||||||
|
SmallVectorImpl<Value>& operands,
|
||||||
|
size_t& num_arguments, size_t& num_results);
|
||||||
|
|
||||||
template <typename OpType>
|
template <typename OpType>
|
||||||
::xla::StatusOr<OpType> CreateOpWithoutAttrs(
|
::xla::StatusOr<OpType> CreateOpWithoutAttrs(
|
||||||
@ -105,6 +114,10 @@ class LhloDialectEmitter : public ::xla::DfsHloVisitorWithDefault {
|
|||||||
::xla::HloInstruction* instr, size_t& num_arguments, size_t& num_results,
|
::xla::HloInstruction* instr, size_t& num_arguments, size_t& num_results,
|
||||||
absl::optional<xla::int64> num_operands = absl::nullopt);
|
absl::optional<xla::int64> num_operands = absl::nullopt);
|
||||||
|
|
||||||
|
template <typename OpType>
|
||||||
|
OpType CreateOpWithoutAttrs(::xla::HloInstruction* instr,
|
||||||
|
ValueRange operands);
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
DenseIntElementsAttr GetI64DenseElementsAttr(const T& container) {
|
DenseIntElementsAttr GetI64DenseElementsAttr(const T& container) {
|
||||||
return builder_.getI64TensorAttr(
|
return builder_.getI64TensorAttr(
|
||||||
@ -140,9 +153,14 @@ class LhloDialectEmitter : public ::xla::DfsHloVisitorWithDefault {
|
|||||||
SmallVectorImpl<Value>* values);
|
SmallVectorImpl<Value>* values);
|
||||||
|
|
||||||
// Helper function to create view/tuple of views to a buffer for a given
|
// Helper function to create view/tuple of views to a buffer for a given
|
||||||
// instruction result.
|
// instruction result. `result_subset` can be used to for instructions that
|
||||||
|
// have a tuple result and MLIR conversion needs to convert only one of the
|
||||||
|
// tuple elements. Note that if needed, this can be extended to take a list of
|
||||||
|
// ShapeIndex values in case we need finer control on what elements of the
|
||||||
|
// output tuple to be converted to MLIR.
|
||||||
tensorflow::Status GetOrCreateView(const ::xla::HloInstruction* instr,
|
tensorflow::Status GetOrCreateView(const ::xla::HloInstruction* instr,
|
||||||
SmallVectorImpl<Value>* values);
|
SmallVectorImpl<Value>* values,
|
||||||
|
const xla::ShapeIndex& result_subset = {});
|
||||||
|
|
||||||
::xla::StatusOr<Value> GetOrCreateArrayView(
|
::xla::StatusOr<Value> GetOrCreateArrayView(
|
||||||
const ::xla::HloInstruction* instr, const ::xla::Shape& current_shape,
|
const ::xla::HloInstruction* instr, const ::xla::Shape& current_shape,
|
||||||
|
@ -51,10 +51,10 @@
|
|||||||
" \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/xla/tutorials/compile\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n",
|
" \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/xla/tutorials/compile\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n",
|
||||||
" \u003c/td\u003e\n",
|
" \u003c/td\u003e\n",
|
||||||
" \u003ctd\u003e\n",
|
" \u003ctd\u003e\n",
|
||||||
" \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/g3doc/tutorials/compile.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
|
" \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/g3doc/tutorials/jit_compile.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
|
||||||
" \u003c/td\u003e\n",
|
" \u003c/td\u003e\n",
|
||||||
" \u003ctd\u003e\n",
|
" \u003ctd\u003e\n",
|
||||||
" \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/g3doc/tutorials/compile.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n",
|
" \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/g3doc/tutorials/jit_compile.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n",
|
||||||
" \u003c/td\u003e\n",
|
" \u003c/td\u003e\n",
|
||||||
"\u003c/table\u003e"
|
"\u003c/table\u003e"
|
||||||
]
|
]
|
||||||
|
@ -89,7 +89,7 @@ PyClient::GetDefaultDeviceAssignment1D(int num_replicas) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<std::unique_ptr<PjRtBuffer>> PyClient::PjRtBufferFromPyval(
|
StatusOr<std::unique_ptr<PjRtBuffer>> PyClient::PjRtBufferFromPyval(
|
||||||
const pybind11::object& argument, PjRtDevice* device, bool force_copy,
|
pybind11::handle argument, PjRtDevice* device, bool force_copy,
|
||||||
PjRtClient::HostBufferSemantics host_buffer_semantics) {
|
PjRtClient::HostBufferSemantics host_buffer_semantics) {
|
||||||
if (device == nullptr) {
|
if (device == nullptr) {
|
||||||
TF_RET_CHECK(!pjrt_client_->local_devices().empty());
|
TF_RET_CHECK(!pjrt_client_->local_devices().empty());
|
||||||
@ -123,7 +123,7 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PyClient::PjRtBufferFromPyval(
|
|||||||
return buffer;
|
return buffer;
|
||||||
}
|
}
|
||||||
StatusOr<std::unique_ptr<PyBuffer>> PyClient::BufferFromPyval(
|
StatusOr<std::unique_ptr<PyBuffer>> PyClient::BufferFromPyval(
|
||||||
const pybind11::object& argument, PjRtDevice* device, bool force_copy,
|
pybind11::handle argument, PjRtDevice* device, bool force_copy,
|
||||||
PjRtClient::HostBufferSemantics host_buffer_semantics) {
|
PjRtClient::HostBufferSemantics host_buffer_semantics) {
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
std::unique_ptr<PjRtBuffer> buffer,
|
std::unique_ptr<PjRtBuffer> buffer,
|
||||||
|
@ -124,10 +124,10 @@ class PyClient : public std::enable_shared_from_this<PyClient> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBufferFromPyval(
|
StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBufferFromPyval(
|
||||||
const pybind11::object& argument, PjRtDevice* device, bool force_copy,
|
pybind11::handle argument, PjRtDevice* device, bool force_copy,
|
||||||
PjRtClient::HostBufferSemantics host_buffer_semantics);
|
PjRtClient::HostBufferSemantics host_buffer_semantics);
|
||||||
StatusOr<std::unique_ptr<PyBuffer>> BufferFromPyval(
|
StatusOr<std::unique_ptr<PyBuffer>> BufferFromPyval(
|
||||||
const pybind11::object& argument, PjRtDevice* device, bool force_copy,
|
pybind11::handle argument, PjRtDevice* device, bool force_copy,
|
||||||
PjRtClient::HostBufferSemantics host_buffer_semantics);
|
PjRtClient::HostBufferSemantics host_buffer_semantics);
|
||||||
|
|
||||||
StatusOr<std::shared_ptr<PyExecutable>> Compile(
|
StatusOr<std::shared_ptr<PyExecutable>> Compile(
|
||||||
|
@ -71,6 +71,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/service/gpu/gpu_constants.h"
|
#include "tensorflow/compiler/xla/service/gpu/gpu_constants.h"
|
||||||
#include "tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h"
|
#include "tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h"
|
||||||
#include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h"
|
#include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/gpu/infeed_thunk.h"
|
||||||
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
|
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
|
||||||
#include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h"
|
#include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h"
|
||||||
#include "tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h"
|
#include "tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h"
|
||||||
@ -3168,7 +3169,22 @@ Status IrEmitterUnnested::HandleAllToAll(HloInstruction* hlo) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Status IrEmitterUnnested::HandleInfeed(HloInstruction* xla_infeed) {
|
Status IrEmitterUnnested::HandleInfeed(HloInstruction* xla_infeed) {
|
||||||
return ThunkEmitter(this).HandleInfeed(xla_infeed);
|
TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(xla_infeed));
|
||||||
|
|
||||||
|
auto infeed_op = mlir::dyn_cast<mlir::lmhlo::InfeedOp>(input.op);
|
||||||
|
|
||||||
|
std::vector<InfeedThunk::ShapedSlice> dest_slices;
|
||||||
|
dest_slices.reserve(infeed_op.outputs().size());
|
||||||
|
|
||||||
|
for (mlir::Value output : infeed_op.outputs()) {
|
||||||
|
TF_ASSIGN_OR_RETURN(auto slice, GetAllocationSliceForMlir(output));
|
||||||
|
const Shape& shape = TypeToShape(output.getType());
|
||||||
|
dest_slices.push_back(InfeedThunk::ShapedSlice{slice, shape});
|
||||||
|
}
|
||||||
|
|
||||||
|
AddThunkToThunkSequence(
|
||||||
|
absl::make_unique<InfeedThunk>(input.thunk_info, std::move(dest_slices)));
|
||||||
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status IrEmitterUnnested::HandleOutfeed(HloInstruction* outfeed) {
|
Status IrEmitterUnnested::HandleOutfeed(HloInstruction* outfeed) {
|
||||||
|
@ -800,8 +800,16 @@ Status AMDGPUTargetModuleLinker(llvm::Module* module, GpuVersion gpu_version,
|
|||||||
std::unique_ptr<llvm::TargetMachine> AMDGPUGetTargetMachine(
|
std::unique_ptr<llvm::TargetMachine> AMDGPUGetTargetMachine(
|
||||||
llvm::Triple target_triple, int amdgpu_version,
|
llvm::Triple target_triple, int amdgpu_version,
|
||||||
const HloModuleConfig& hlo_module_config) {
|
const HloModuleConfig& hlo_module_config) {
|
||||||
|
string feature_str = "+code-object-v3";
|
||||||
|
#if TF_ROCM_VERSION >= 30900
|
||||||
|
// code-object-v3 is default, so no need to expliticitly specify it
|
||||||
|
// in the feature string. Also, starting with ROCm 4.0, this feature string
|
||||||
|
// is deprecated, and we get a warning to that effect. So removing that
|
||||||
|
// feature string
|
||||||
|
feature_str = "";
|
||||||
|
#endif
|
||||||
return GetTargetMachine(target_triple, absl::StrCat("gfx", amdgpu_version),
|
return GetTargetMachine(target_triple, absl::StrCat("gfx", amdgpu_version),
|
||||||
hlo_module_config, "+code-object-v3");
|
hlo_module_config, feature_str);
|
||||||
}
|
}
|
||||||
|
|
||||||
void AMDGPUBackendInit(const HloModuleConfig& hlo_module_config) {
|
void AMDGPUBackendInit(const HloModuleConfig& hlo_module_config) {
|
||||||
|
@ -115,34 +115,6 @@ std::unique_ptr<Thunk> ThunkEmitter::BuildGemmThunk(
|
|||||||
/*implements_whole_instruction=*/true);
|
/*implements_whole_instruction=*/true);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<Thunk> ThunkEmitter::BuildInfeedThunk(
|
|
||||||
const HloInstruction* inst) {
|
|
||||||
CHECK_EQ(HloOpcode::kInfeed, inst->opcode());
|
|
||||||
|
|
||||||
std::vector<ShapeUtil::IndexedShape> leaf_shapes =
|
|
||||||
ShapeUtil::GetLeafShapes(inst->shape());
|
|
||||||
|
|
||||||
// For an infeed HLO, the output is a 2 element tuple where the first element
|
|
||||||
// of the tuple is all the infeed buffers and the second element is a token.
|
|
||||||
// The infeed thunk does not need to handle this token output, so just drop
|
|
||||||
// it.
|
|
||||||
leaf_shapes.pop_back();
|
|
||||||
|
|
||||||
std::vector<InfeedThunk::ShapedSlice> dest_slices;
|
|
||||||
dest_slices.reserve(leaf_shapes.size());
|
|
||||||
|
|
||||||
for (ShapeUtil::IndexedShape& indexed_shape : leaf_shapes) {
|
|
||||||
BufferAllocation::Slice slice =
|
|
||||||
GetAllocationSlice(*inst, indexed_shape.index);
|
|
||||||
const Shape& shape =
|
|
||||||
ShapeUtil::GetSubshape(inst->shape(), indexed_shape.index);
|
|
||||||
dest_slices.emplace_back(InfeedThunk::ShapedSlice{slice, shape});
|
|
||||||
}
|
|
||||||
|
|
||||||
return absl::make_unique<InfeedThunk>(context_->GetThunkInfo(inst),
|
|
||||||
std::move(dest_slices));
|
|
||||||
}
|
|
||||||
|
|
||||||
std::unique_ptr<Thunk> ThunkEmitter::BuildOutfeedThunk(
|
std::unique_ptr<Thunk> ThunkEmitter::BuildOutfeedThunk(
|
||||||
const HloInstruction* inst) {
|
const HloInstruction* inst) {
|
||||||
CHECK_EQ(HloOpcode::kOutfeed, inst->opcode());
|
CHECK_EQ(HloOpcode::kOutfeed, inst->opcode());
|
||||||
@ -258,11 +230,6 @@ Status ThunkEmitter::HandleTriangularSolve(HloInstruction* hlo) {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ThunkEmitter::HandleInfeed(HloInstruction* infeed) {
|
|
||||||
AddThunkToThunkSequence(BuildInfeedThunk(infeed));
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
Status ThunkEmitter::HandleOutfeed(HloInstruction* outfeed) {
|
Status ThunkEmitter::HandleOutfeed(HloInstruction* outfeed) {
|
||||||
AddThunkToThunkSequence(BuildOutfeedThunk(outfeed));
|
AddThunkToThunkSequence(BuildOutfeedThunk(outfeed));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
@ -46,7 +46,6 @@ class ThunkEmitter {
|
|||||||
Status HandleCustomCall(HloInstruction* custom_call);
|
Status HandleCustomCall(HloInstruction* custom_call);
|
||||||
Status HandleFft(HloInstruction* fft);
|
Status HandleFft(HloInstruction* fft);
|
||||||
Status HandleTriangularSolve(HloInstruction* hlo);
|
Status HandleTriangularSolve(HloInstruction* hlo);
|
||||||
Status HandleInfeed(HloInstruction* xla_infeed);
|
|
||||||
Status HandleOutfeed(HloInstruction* outfeed);
|
Status HandleOutfeed(HloInstruction* outfeed);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -367,7 +367,6 @@ xla_test(
|
|||||||
"conv_depthwise_test.cc",
|
"conv_depthwise_test.cc",
|
||||||
],
|
],
|
||||||
shard_count = 50,
|
shard_count = 50,
|
||||||
tags = ["no_rocm"], # ROCm 3.9 regression
|
|
||||||
deps = [
|
deps = [
|
||||||
":conv_depthwise_common",
|
":conv_depthwise_common",
|
||||||
":test_macros_header",
|
":test_macros_header",
|
||||||
@ -389,7 +388,6 @@ xla_test(
|
|||||||
timeout = "long",
|
timeout = "long",
|
||||||
srcs = ["conv_depthwise_backprop_filter_test.cc"],
|
srcs = ["conv_depthwise_backprop_filter_test.cc"],
|
||||||
shard_count = 40,
|
shard_count = 40,
|
||||||
tags = ["no_rocm"], # ROCm 3.9 regression
|
|
||||||
deps = [
|
deps = [
|
||||||
":test_macros_header",
|
":test_macros_header",
|
||||||
"//tensorflow/compiler/xla:execution_options_util",
|
"//tensorflow/compiler/xla:execution_options_util",
|
||||||
@ -414,7 +412,6 @@ xla_test(
|
|||||||
"cpu",
|
"cpu",
|
||||||
],
|
],
|
||||||
shard_count = 50,
|
shard_count = 50,
|
||||||
tags = ["no_rocm"], # ROCm 3.9 regression
|
|
||||||
deps = [
|
deps = [
|
||||||
":client_library_test_base",
|
":client_library_test_base",
|
||||||
":hlo_test_base",
|
":hlo_test_base",
|
||||||
@ -924,7 +921,6 @@ xla_test(
|
|||||||
srcs = ["dot_operation_test.cc"],
|
srcs = ["dot_operation_test.cc"],
|
||||||
shard_count = 20,
|
shard_count = 20,
|
||||||
tags = [
|
tags = [
|
||||||
"no_rocm", # ROCm 3.9 regression
|
|
||||||
"optonly",
|
"optonly",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
@ -958,7 +954,6 @@ xla_test(
|
|||||||
backends = ["gpu"],
|
backends = ["gpu"],
|
||||||
shard_count = 20,
|
shard_count = 20,
|
||||||
tags = [
|
tags = [
|
||||||
"no_rocm", # ROCm 3.9 regression
|
|
||||||
"optonly",
|
"optonly",
|
||||||
# TODO(b/151340488): Timed out on 2020-03-12.
|
# TODO(b/151340488): Timed out on 2020-03-12.
|
||||||
"nozapfhahn",
|
"nozapfhahn",
|
||||||
@ -1025,7 +1020,6 @@ xla_test(
|
|||||||
},
|
},
|
||||||
shard_count = 20,
|
shard_count = 20,
|
||||||
tags = [
|
tags = [
|
||||||
"no_rocm", # ROCm 3.9 regression
|
|
||||||
"optonly",
|
"optonly",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
@ -1253,7 +1247,6 @@ xla_test(
|
|||||||
"cpu": ["nomsan"],
|
"cpu": ["nomsan"],
|
||||||
},
|
},
|
||||||
shard_count = 30,
|
shard_count = 30,
|
||||||
tags = ["no_rocm"], # ROCm 3.9 regression
|
|
||||||
deps = [
|
deps = [
|
||||||
":test_macros_header",
|
":test_macros_header",
|
||||||
"//tensorflow/compiler/xla:array3d",
|
"//tensorflow/compiler/xla:array3d",
|
||||||
@ -1278,7 +1271,6 @@ xla_test(
|
|||||||
timeout = "long",
|
timeout = "long",
|
||||||
srcs = ["convolution_dimension_numbers_test.cc"],
|
srcs = ["convolution_dimension_numbers_test.cc"],
|
||||||
shard_count = 20,
|
shard_count = 20,
|
||||||
tags = ["no_rocm"], # ROCm 3.9 regression
|
|
||||||
deps = [
|
deps = [
|
||||||
":test_macros_header",
|
":test_macros_header",
|
||||||
"//tensorflow/compiler/xla:array4d",
|
"//tensorflow/compiler/xla:array4d",
|
||||||
@ -2322,7 +2314,6 @@ xla_test(
|
|||||||
name = "multioutput_fusion_test",
|
name = "multioutput_fusion_test",
|
||||||
srcs = ["multioutput_fusion_test.cc"],
|
srcs = ["multioutput_fusion_test.cc"],
|
||||||
backends = ["gpu"],
|
backends = ["gpu"],
|
||||||
tags = ["no_rocm"], # ROCm 3.9 regression
|
|
||||||
deps = [
|
deps = [
|
||||||
":test_macros_header",
|
":test_macros_header",
|
||||||
"//tensorflow/compiler/xla:literal",
|
"//tensorflow/compiler/xla:literal",
|
||||||
|
@ -480,13 +480,11 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
|||||||
{csinfo_.fused_batch_norm_grad_v3,
|
{csinfo_.fused_batch_norm_grad_v3,
|
||||||
mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad_v3),
|
mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad_v3),
|
||||||
CopyAttrsAll, FusedBatchNormV3Rewrite, GetRewriteCause()});
|
CopyAttrsAll, FusedBatchNormV3Rewrite, GetRewriteCause()});
|
||||||
#ifdef ENABLE_MKLDNN_V1
|
|
||||||
rinfo_.push_back({csinfo_.fused_batch_norm_ex,
|
rinfo_.push_back({csinfo_.fused_batch_norm_ex,
|
||||||
native_fmt ? csinfo_.mkl_native_fused_batch_norm_ex
|
native_fmt ? csinfo_.mkl_native_fused_batch_norm_ex
|
||||||
: csinfo_.mkl_fused_batch_norm_ex,
|
: csinfo_.mkl_fused_batch_norm_ex,
|
||||||
CopyAttrsAll, FusedBatchNormExRewrite,
|
CopyAttrsAll, FusedBatchNormExRewrite,
|
||||||
GetRewriteCause()});
|
GetRewriteCause()});
|
||||||
#endif
|
|
||||||
rinfo_.push_back({csinfo_.fused_conv2d,
|
rinfo_.push_back({csinfo_.fused_conv2d,
|
||||||
native_fmt ? csinfo_.mkl_native_fused_conv2d
|
native_fmt ? csinfo_.mkl_native_fused_conv2d
|
||||||
: csinfo_.mkl_fused_conv2d,
|
: csinfo_.mkl_fused_conv2d,
|
||||||
@ -672,14 +670,12 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
|||||||
rinfo_.push_back({csinfo_.requantize,
|
rinfo_.push_back({csinfo_.requantize,
|
||||||
mkl_op_registry::GetMklOpName(csinfo_.requantize),
|
mkl_op_registry::GetMklOpName(csinfo_.requantize),
|
||||||
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
|
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
|
||||||
#ifdef ENABLE_MKLDNN_V1
|
|
||||||
// Optimized TanhGrad support exists only in DNNL 1.x.
|
// Optimized TanhGrad support exists only in DNNL 1.x.
|
||||||
rinfo_.push_back({csinfo_.tanh, mkl_op_registry::GetMklOpName(csinfo_.tanh),
|
rinfo_.push_back({csinfo_.tanh, mkl_op_registry::GetMklOpName(csinfo_.tanh),
|
||||||
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
|
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
|
||||||
rinfo_.push_back({csinfo_.tanh_grad,
|
rinfo_.push_back({csinfo_.tanh_grad,
|
||||||
mkl_op_registry::GetMklOpName(csinfo_.tanh_grad),
|
mkl_op_registry::GetMklOpName(csinfo_.tanh_grad),
|
||||||
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
|
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
|
||||||
#endif // ENABLE_MKLDNN_V1
|
|
||||||
rinfo_.push_back({csinfo_.reshape,
|
rinfo_.push_back({csinfo_.reshape,
|
||||||
mkl_op_registry::GetMklOpName(csinfo_.reshape),
|
mkl_op_registry::GetMklOpName(csinfo_.reshape),
|
||||||
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
|
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
|
||||||
|
@ -53,7 +53,6 @@ static void InitGraph(const string& s, Graph* graph,
|
|||||||
GraphDef graph_def;
|
GraphDef graph_def;
|
||||||
|
|
||||||
auto parser = protobuf::TextFormat::Parser();
|
auto parser = protobuf::TextFormat::Parser();
|
||||||
// parser.AllowRelaxedWhitespace(true);
|
|
||||||
CHECK(parser.MergeFromString(s, &graph_def)) << s;
|
CHECK(parser.MergeFromString(s, &graph_def)) << s;
|
||||||
GraphConstructorOptions opts;
|
GraphConstructorOptions opts;
|
||||||
TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, graph));
|
TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, graph));
|
||||||
@ -66,7 +65,6 @@ static void InitGraph(const string& s, Graph* graph,
|
|||||||
class MklLayoutPassTest : public ::testing::Test {
|
class MklLayoutPassTest : public ::testing::Test {
|
||||||
public:
|
public:
|
||||||
MklLayoutPassTest() : graph_(OpRegistry::Global()) {}
|
MklLayoutPassTest() : graph_(OpRegistry::Global()) {}
|
||||||
// Ashraf added
|
|
||||||
Node* FindNode(const string& name) {
|
Node* FindNode(const string& name) {
|
||||||
for (Node* node : graph_.nodes()) {
|
for (Node* node : graph_.nodes()) {
|
||||||
if (node->name() == name) return node;
|
if (node->name() == name) return node;
|
||||||
@ -3087,8 +3085,6 @@ REGISTER_TEST_ALL_TYPES(NodeRewrite_LeakyReluGrad_Negative);
|
|||||||
REGISTER_TEST_ALL_TYPES(NodeRewrite_LeakyReluLeakyReluGrad_Positive);
|
REGISTER_TEST_ALL_TYPES(NodeRewrite_LeakyReluLeakyReluGrad_Positive);
|
||||||
#undef REGISTER_TEST
|
#undef REGISTER_TEST
|
||||||
|
|
||||||
#ifdef ENABLE_MKLDNN_V1
|
|
||||||
|
|
||||||
#define REGISTER_TEST(NAME, T, INPUT) \
|
#define REGISTER_TEST(NAME, T, INPUT) \
|
||||||
TEST_F(MklLayoutPassTest, NAME##_##T) { \
|
TEST_F(MklLayoutPassTest, NAME##_##T) { \
|
||||||
DCHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); \
|
DCHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); \
|
||||||
@ -3146,7 +3142,6 @@ REGISTER_TEST_ALL_TYPES(NodeRewrite_TanhGrad_Positive);
|
|||||||
}
|
}
|
||||||
REGISTER_TEST_ALL_TYPES(NodeRewrite_TanhTanhGrad_Positive);
|
REGISTER_TEST_ALL_TYPES(NodeRewrite_TanhTanhGrad_Positive);
|
||||||
#undef REGISTER_TEST
|
#undef REGISTER_TEST
|
||||||
#endif // ENABLE_MKLDNN_V1
|
|
||||||
|
|
||||||
#define REGISTER_TEST(NAME, T, INPUT) \
|
#define REGISTER_TEST(NAME, T, INPUT) \
|
||||||
TEST_F(MklLayoutPassTest, NAME##_##T) { \
|
TEST_F(MklLayoutPassTest, NAME##_##T) { \
|
||||||
@ -3513,7 +3508,6 @@ REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedBatchNormGradV3_5D_Negative_2);
|
|||||||
#undef DATA_FORMAT
|
#undef DATA_FORMAT
|
||||||
#undef REGISTER_TEST
|
#undef REGISTER_TEST
|
||||||
|
|
||||||
#ifdef ENABLE_MKLDNN_V1
|
|
||||||
#define REGISTER_TEST(NAME, T, INPUT) \
|
#define REGISTER_TEST(NAME, T, INPUT) \
|
||||||
TEST_F(MklLayoutPassTest, NAME##_##T) { \
|
TEST_F(MklLayoutPassTest, NAME##_##T) { \
|
||||||
InitGraph("node { name: 'A' op: '" #INPUT "'}" \
|
InitGraph("node { name: 'A' op: '" #INPUT "'}" \
|
||||||
@ -3603,7 +3597,6 @@ REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedBatchNormEx_Negative1);
|
|||||||
}
|
}
|
||||||
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedBatchNormEx_Negative2);
|
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedBatchNormEx_Negative2);
|
||||||
#undef REGISTER_TEST
|
#undef REGISTER_TEST
|
||||||
#endif // ENABLE_MKLDNN_V1
|
|
||||||
|
|
||||||
TEST_F(MklLayoutPassTest, NodeRewrite_QuantizedDepthwiseConv2D_Positive) {
|
TEST_F(MklLayoutPassTest, NodeRewrite_QuantizedDepthwiseConv2D_Positive) {
|
||||||
InitGraph(
|
InitGraph(
|
||||||
@ -5184,8 +5177,8 @@ static void BM_MklLayoutRewritePass(int iters, int op_nodes) {
|
|||||||
|
|
||||||
bool first = true;
|
bool first = true;
|
||||||
while (iters > 0) {
|
while (iters > 0) {
|
||||||
Graph* graph = new Graph(OpRegistry::Global());
|
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||||
InitGraph(s, graph);
|
InitGraph(s, graph.get());
|
||||||
int N = graph->num_node_ids();
|
int N = graph->num_node_ids();
|
||||||
if (first) {
|
if (first) {
|
||||||
testing::SetLabel(strings::StrCat("Per graph node. Nodes: ", N));
|
testing::SetLabel(strings::StrCat("Per graph node. Nodes: ", N));
|
||||||
@ -5193,13 +5186,12 @@ static void BM_MklLayoutRewritePass(int iters, int op_nodes) {
|
|||||||
}
|
}
|
||||||
{
|
{
|
||||||
testing::StartTiming();
|
testing::StartTiming();
|
||||||
std::unique_ptr<Graph> ug(graph);
|
std::unique_ptr<Graph> ug(graph.get());
|
||||||
RunMklLayoutRewritePass(&ug);
|
RunMklLayoutRewritePass(&ug);
|
||||||
testing::StopTiming();
|
testing::StopTiming();
|
||||||
}
|
}
|
||||||
iters -= N; // Our benchmark units are individual graph nodes,
|
iters -= N; // Our benchmark units are individual graph nodes,
|
||||||
// not whole graphs
|
// not whole graphs
|
||||||
// delete graph;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
BENCHMARK(BM_MklLayoutRewritePass)->Arg(1000)->Arg(10000);
|
BENCHMARK(BM_MklLayoutRewritePass)->Arg(1000)->Arg(10000);
|
||||||
|
@ -37,6 +37,7 @@ load(
|
|||||||
package(
|
package(
|
||||||
default_visibility = [
|
default_visibility = [
|
||||||
"//tensorflow/core:__subpackages__",
|
"//tensorflow/core:__subpackages__",
|
||||||
|
"//tensorflow/security/fuzzing:__subpackages__",
|
||||||
],
|
],
|
||||||
licenses = ["notice"], # Apache 2.0
|
licenses = ["notice"], # Apache 2.0
|
||||||
)
|
)
|
||||||
@ -622,7 +623,10 @@ cc_library(
|
|||||||
name = "bfloat16",
|
name = "bfloat16",
|
||||||
srcs = ["bfloat16.cc"],
|
srcs = ["bfloat16.cc"],
|
||||||
hdrs = ["bfloat16.h"],
|
hdrs = ["bfloat16.h"],
|
||||||
visibility = ["//tensorflow/core:__subpackages__"],
|
visibility = [
|
||||||
|
"//tensorflow/core:__subpackages__",
|
||||||
|
"//tensorflow/security/fuzzing:__subpackages__",
|
||||||
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":numeric_types",
|
":numeric_types",
|
||||||
"//tensorflow/core/platform:byte_order",
|
"//tensorflow/core/platform:byte_order",
|
||||||
|
@ -269,6 +269,9 @@ Status MetaOptimizer::InitializeOptimizers(
|
|||||||
if (cfg_.pin_to_host_optimization() == RewriterConfig::ON) {
|
if (cfg_.pin_to_host_optimization() == RewriterConfig::ON) {
|
||||||
optimizers->push_back(MakeUnique<PinToHostOptimizer>());
|
optimizers->push_back(MakeUnique<PinToHostOptimizer>());
|
||||||
}
|
}
|
||||||
|
if (cfg_.remapping() != RewriterConfig::OFF) {
|
||||||
|
optimizers->push_back(MakeUnique<Remapper>(cfg_.remapping()));
|
||||||
|
}
|
||||||
if (cfg_.arithmetic_optimization() != RewriterConfig::OFF) {
|
if (cfg_.arithmetic_optimization() != RewriterConfig::OFF) {
|
||||||
optimizers->push_back(
|
optimizers->push_back(
|
||||||
MakeUnique<ArithmeticOptimizer>(cfg_.arithmetic_optimization()));
|
MakeUnique<ArithmeticOptimizer>(cfg_.arithmetic_optimization()));
|
||||||
@ -278,9 +281,6 @@ Status MetaOptimizer::InitializeOptimizers(
|
|||||||
/*optimization level*/ cfg_.layout_optimizer(),
|
/*optimization level*/ cfg_.layout_optimizer(),
|
||||||
/*CPU layout conversion*/ cfg_.cpu_layout_conversion()));
|
/*CPU layout conversion*/ cfg_.cpu_layout_conversion()));
|
||||||
}
|
}
|
||||||
if (cfg_.remapping() != RewriterConfig::OFF) {
|
|
||||||
optimizers->push_back(MakeUnique<Remapper>(cfg_.remapping()));
|
|
||||||
}
|
|
||||||
if (cfg_.loop_optimization() != RewriterConfig::OFF) {
|
if (cfg_.loop_optimization() != RewriterConfig::OFF) {
|
||||||
optimizers->push_back(
|
optimizers->push_back(
|
||||||
MakeUnique<LoopOptimizer>(cfg_.loop_optimization(), cpu_device_));
|
MakeUnique<LoopOptimizer>(cfg_.loop_optimization(), cpu_device_));
|
||||||
|
@ -29,8 +29,14 @@ REGISTER5(BinaryOp, CPU, "DivNoNan", functor::div_no_nan, Eigen::half, float,
|
|||||||
|
|
||||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
// ROCM TODO: re-enable complex64 / complex128 after compiler fix
|
// ROCM TODO: re-enable complex64 / complex128 after compiler fix
|
||||||
|
#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
|
||||||
|
!defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED)
|
||||||
REGISTER9(BinaryOp, GPU, "Div", functor::div, float, Eigen::half, double, uint8,
|
REGISTER9(BinaryOp, GPU, "Div", functor::div, float, Eigen::half, double, uint8,
|
||||||
uint16, int16, int64, complex64, complex128);
|
uint16, int16, int64, complex64, complex128);
|
||||||
|
#else
|
||||||
|
REGISTER4(BinaryOp, GPU, "Div", functor::div, uint8, uint16, complex64,
|
||||||
|
complex128);
|
||||||
|
#endif
|
||||||
REGISTER4(BinaryOp, GPU, "TruncateDiv", functor::div, uint8, uint16, int16,
|
REGISTER4(BinaryOp, GPU, "TruncateDiv", functor::div, uint8, uint16, int16,
|
||||||
int64);
|
int64);
|
||||||
REGISTER5(BinaryOp, GPU, "RealDiv", functor::div, float, Eigen::half, double,
|
REGISTER5(BinaryOp, GPU, "RealDiv", functor::div, float, Eigen::half, double,
|
||||||
|
@ -30,8 +30,13 @@ REGISTER(BinaryOp, CPU, "Sub", functor::sub, int32);
|
|||||||
#endif // __ANDROID_TYPES_SLIM__
|
#endif // __ANDROID_TYPES_SLIM__
|
||||||
|
|
||||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
|
||||||
|
!defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED)
|
||||||
REGISTER7(BinaryOp, GPU, "Sub", functor::sub, float, Eigen::half, double, int64,
|
REGISTER7(BinaryOp, GPU, "Sub", functor::sub, float, Eigen::half, double, int64,
|
||||||
complex64, complex128, uint32);
|
complex64, complex128, uint32);
|
||||||
|
#else
|
||||||
|
REGISTER3(BinaryOp, GPU, "Sub", functor::sub, complex64, complex128, uint32);
|
||||||
|
#endif
|
||||||
|
|
||||||
// A special GPU kernel for int32.
|
// A special GPU kernel for int32.
|
||||||
// TODO(b/25387198): Also enable int32 in device memory. This kernel
|
// TODO(b/25387198): Also enable int32 in device memory. This kernel
|
||||||
|
@ -85,7 +85,7 @@ void OptimizeDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
|
|||||||
// clang-format off
|
// clang-format off
|
||||||
absl::flat_hash_map<string, uint64> live_experiments = {
|
absl::flat_hash_map<string, uint64> live_experiments = {
|
||||||
{"enable_gradient_descent", 0},
|
{"enable_gradient_descent", 0},
|
||||||
{"map_parallelization", 0}
|
{"map_parallelization", 1}
|
||||||
};
|
};
|
||||||
// clang-format on
|
// clang-format on
|
||||||
auto hash_func = [](const string& str) { return Hash64(str); };
|
auto hash_func = [](const string& str) { return Hash64(str); };
|
||||||
|
@ -43,7 +43,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/lib/strings/strcat.h"
|
#include "tensorflow/core/lib/strings/strcat.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
#include "tensorflow/core/platform/macros.h"
|
#include "tensorflow/core/platform/macros.h"
|
||||||
#include "tensorflow/core/util/mkl_types.h"
|
|
||||||
#include "tensorflow/core/util/mkl_util.h"
|
#include "tensorflow/core/util/mkl_util.h"
|
||||||
#include "tensorflow/core/util/padding.h"
|
#include "tensorflow/core/util/padding.h"
|
||||||
#include "tensorflow/core/util/tensor_format.h"
|
#include "tensorflow/core/util/tensor_format.h"
|
||||||
@ -65,7 +64,7 @@ struct MklConvFwdParams {
|
|||||||
memory::dims dilations;
|
memory::dims dilations;
|
||||||
memory::dims padding_left;
|
memory::dims padding_left;
|
||||||
memory::dims padding_right;
|
memory::dims padding_right;
|
||||||
MKL_TENSOR_FORMAT tf_fmt;
|
MklTensorFormat tf_fmt;
|
||||||
bool native_format;
|
bool native_format;
|
||||||
string dtypes = string("");
|
string dtypes = string("");
|
||||||
struct PostOpParam {
|
struct PostOpParam {
|
||||||
@ -80,7 +79,7 @@ struct MklConvFwdParams {
|
|||||||
memory::dims bias_dims, memory::dims dst_dims,
|
memory::dims bias_dims, memory::dims dst_dims,
|
||||||
memory::dims strides, memory::dims dilations,
|
memory::dims strides, memory::dims dilations,
|
||||||
memory::dims padding_left, memory::dims padding_right,
|
memory::dims padding_left, memory::dims padding_right,
|
||||||
MKL_TENSOR_FORMAT tf_fmt, bool native_format)
|
MklTensorFormat tf_fmt, bool native_format)
|
||||||
: src_dims(src_dims),
|
: src_dims(src_dims),
|
||||||
filter_dims(filter_dims),
|
filter_dims(filter_dims),
|
||||||
bias_dims(bias_dims),
|
bias_dims(bias_dims),
|
||||||
@ -99,7 +98,7 @@ template <typename Tinput, typename Tfilter, typename Tbias, typename Toutput>
|
|||||||
class MklConvFwdPrimitive : public MklPrimitive {
|
class MklConvFwdPrimitive : public MklPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit MklConvFwdPrimitive(const MklConvFwdParams& convFwdDims)
|
explicit MklConvFwdPrimitive(const MklConvFwdParams& convFwdDims)
|
||||||
: MklPrimitive(engine(ENGINE_CPU, 0)) {
|
: MklPrimitive(engine(engine::kind::cpu, 0)) {
|
||||||
// Create convolution primitive
|
// Create convolution primitive
|
||||||
if (context_.conv_fwd == nullptr) {
|
if (context_.conv_fwd == nullptr) {
|
||||||
Setup(convFwdDims);
|
Setup(convFwdDims);
|
||||||
@ -115,8 +114,8 @@ class MklConvFwdPrimitive : public MklPrimitive {
|
|||||||
void Execute(const Tinput* src_data, const Tfilter* filter_data,
|
void Execute(const Tinput* src_data, const Tfilter* filter_data,
|
||||||
const Tbias* bias_data, const Toutput* dst_data,
|
const Tbias* bias_data, const Toutput* dst_data,
|
||||||
std::shared_ptr<stream> fwd_stream) {
|
std::shared_ptr<stream> fwd_stream) {
|
||||||
// TODO: Create a common function and avoid the duplicate code
|
|
||||||
#ifdef ENABLE_MKLDNN_THREADPOOL
|
#ifdef ENABLE_MKLDNN_THREADPOOL
|
||||||
|
// TODO: Create a common function and avoid the duplicate code
|
||||||
context_.src_mem->set_data_handle(
|
context_.src_mem->set_data_handle(
|
||||||
static_cast<void*>(const_cast<Tinput*>(src_data)), *fwd_stream);
|
static_cast<void*>(const_cast<Tinput*>(src_data)), *fwd_stream);
|
||||||
context_.filter_mem->set_data_handle(
|
context_.filter_mem->set_data_handle(
|
||||||
@ -139,16 +138,13 @@ class MklConvFwdPrimitive : public MklPrimitive {
|
|||||||
context_.dst_mem->set_data_handle(
|
context_.dst_mem->set_data_handle(
|
||||||
static_cast<void*>(const_cast<Toutput*>(dst_data)));
|
static_cast<void*>(const_cast<Toutput*>(dst_data)));
|
||||||
#endif // ENABLE_MKLDNN_THREADPOOL
|
#endif // ENABLE_MKLDNN_THREADPOOL
|
||||||
#ifdef ENABLE_MKLDNN_V1
|
|
||||||
DCHECK_EQ(context_.fwd_primitives.size(),
|
DCHECK_EQ(context_.fwd_primitives.size(),
|
||||||
context_.fwd_primitives_args.size());
|
context_.fwd_primitives_args.size());
|
||||||
for (size_t i = 0; i < context_.fwd_primitives.size(); ++i) {
|
for (size_t i = 0; i < context_.fwd_primitives.size(); ++i) {
|
||||||
context_.fwd_primitives.at(i).execute(*fwd_stream,
|
context_.fwd_primitives.at(i).execute(*fwd_stream,
|
||||||
context_.fwd_primitives_args.at(i));
|
context_.fwd_primitives_args.at(i));
|
||||||
}
|
}
|
||||||
#else
|
|
||||||
fwd_stream->submit(context_.fwd_primitives);
|
|
||||||
#endif // ENABLE_MKLDNN_V1
|
|
||||||
|
|
||||||
// After execution, set data handle back
|
// After execution, set data handle back
|
||||||
context_.src_mem->set_data_handle(DummyData);
|
context_.src_mem->set_data_handle(DummyData);
|
||||||
@ -168,13 +164,6 @@ class MklConvFwdPrimitive : public MklPrimitive {
|
|||||||
Execute(src_data, filter_data, nullptr, dst_data, fwd_stream);
|
Execute(src_data, filter_data, nullptr, dst_data, fwd_stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifndef ENABLE_MKLDNN_V1
|
|
||||||
// In MKL-DNN v1.x, memory format tags only provide a partial description
|
|
||||||
// of the memory layout. Hence, these functions are disabled for v1.x.
|
|
||||||
memory::format GetSrcMemoryFormat() const { return context_.src_fmt; }
|
|
||||||
memory::format GetFilterMemoryFormat() const { return context_.filter_fmt; }
|
|
||||||
#endif // !ENABLE_MKLDNN_V1
|
|
||||||
|
|
||||||
std::shared_ptr<ConvFwdPd> GetPrimitiveDesc() const {
|
std::shared_ptr<ConvFwdPd> GetPrimitiveDesc() const {
|
||||||
return context_.fwd_pd;
|
return context_.fwd_pd;
|
||||||
}
|
}
|
||||||
@ -182,12 +171,6 @@ class MklConvFwdPrimitive : public MklPrimitive {
|
|||||||
private:
|
private:
|
||||||
// Primitive reuse context for Conv2D Fwd op
|
// Primitive reuse context for Conv2D Fwd op
|
||||||
struct ConvFwdContext {
|
struct ConvFwdContext {
|
||||||
#ifndef ENABLE_MKLDNN_V1
|
|
||||||
// Expected memory format for this primitive instance
|
|
||||||
memory::format src_fmt;
|
|
||||||
memory::format filter_fmt;
|
|
||||||
#endif // !ENABLE_MKLDNN_V1
|
|
||||||
|
|
||||||
// MKL-DNN memory
|
// MKL-DNN memory
|
||||||
std::shared_ptr<mkldnn::memory> src_mem;
|
std::shared_ptr<mkldnn::memory> src_mem;
|
||||||
std::shared_ptr<mkldnn::memory> filter_mem;
|
std::shared_ptr<mkldnn::memory> filter_mem;
|
||||||
@ -208,18 +191,10 @@ class MklConvFwdPrimitive : public MklPrimitive {
|
|||||||
std::shared_ptr<mkldnn::primitive> conv_fwd;
|
std::shared_ptr<mkldnn::primitive> conv_fwd;
|
||||||
|
|
||||||
std::vector<mkldnn::primitive> fwd_primitives;
|
std::vector<mkldnn::primitive> fwd_primitives;
|
||||||
|
|
||||||
#ifdef ENABLE_MKLDNN_V1
|
|
||||||
std::vector<std::unordered_map<int, memory>> fwd_primitives_args;
|
std::vector<std::unordered_map<int, memory>> fwd_primitives_args;
|
||||||
#endif // ENABLE_MKLDNN_V1
|
|
||||||
|
|
||||||
ConvFwdContext()
|
ConvFwdContext()
|
||||||
:
|
: src_mem(nullptr),
|
||||||
#ifndef ENABLE_MKLDNN_V1
|
|
||||||
src_fmt(memory::format::any),
|
|
||||||
filter_fmt(memory::format::any),
|
|
||||||
#endif // !ENABLE_MKLDNN_V1
|
|
||||||
src_mem(nullptr),
|
|
||||||
filter_mem(nullptr),
|
filter_mem(nullptr),
|
||||||
bias_mem(nullptr),
|
bias_mem(nullptr),
|
||||||
dst_mem(nullptr),
|
dst_mem(nullptr),
|
||||||
@ -228,52 +203,45 @@ class MklConvFwdPrimitive : public MklPrimitive {
|
|||||||
filter_md(nullptr),
|
filter_md(nullptr),
|
||||||
bias_md(nullptr),
|
bias_md(nullptr),
|
||||||
fwd_pd(nullptr),
|
fwd_pd(nullptr),
|
||||||
conv_fwd(nullptr) {
|
conv_fwd(nullptr) {}
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
void Setup(const MklConvFwdParams& convFwdDims) {
|
void Setup(const MklConvFwdParams& convFwdDims) {
|
||||||
MEMORY_FORMAT user_data_fmt;
|
memory::format_tag user_data_fmt;
|
||||||
if (convFwdDims.native_format) {
|
if (convFwdDims.native_format) {
|
||||||
user_data_fmt = MklTensorFormatToMklDnnDataFormat(convFwdDims.tf_fmt);
|
user_data_fmt = MklTensorFormatToMklDnnDataFormat(convFwdDims.tf_fmt);
|
||||||
} else {
|
} else {
|
||||||
// Create memory descriptors for convolution data w/ no specified format
|
// Create memory descriptors for convolution data w/ no specified format
|
||||||
user_data_fmt = MEMORY_FORMAT::any;
|
user_data_fmt = memory::format_tag::any;
|
||||||
}
|
}
|
||||||
context_.src_md.reset(new memory::desc(
|
context_.src_md.reset(new memory::desc(
|
||||||
{convFwdDims.src_dims}, MklDnnType<Tinput>(), user_data_fmt));
|
{convFwdDims.src_dims}, MklDnnType<Tinput>(), user_data_fmt));
|
||||||
|
|
||||||
context_.filter_md.reset(new memory::desc(
|
context_.filter_md.reset(new memory::desc({convFwdDims.filter_dims},
|
||||||
{convFwdDims.filter_dims}, MklDnnType<Tfilter>(), MEMORY_FORMAT::any));
|
MklDnnType<Tfilter>(),
|
||||||
|
memory::format_tag::any));
|
||||||
|
|
||||||
context_.dst_md.reset(new memory::desc(
|
context_.dst_md.reset(new memory::desc(
|
||||||
{convFwdDims.dst_dims}, MklDnnType<Toutput>(), user_data_fmt));
|
{convFwdDims.dst_dims}, MklDnnType<Toutput>(), user_data_fmt));
|
||||||
|
|
||||||
if (!convFwdDims.bias_dims.empty())
|
if (!convFwdDims.bias_dims.empty())
|
||||||
context_.bias_md.reset(new memory::desc(
|
context_.bias_md.reset(new memory::desc({convFwdDims.bias_dims},
|
||||||
{convFwdDims.bias_dims}, MklDnnType<Tbias>(), MEMORY_FORMAT::any));
|
MklDnnType<Tbias>(),
|
||||||
|
memory::format_tag::any));
|
||||||
|
|
||||||
// Create a convolution descriptor
|
// Create a convolution descriptor
|
||||||
if (!convFwdDims.bias_dims.empty()) {
|
if (!convFwdDims.bias_dims.empty()) {
|
||||||
context_.fwd_desc.reset(new convolution_forward::desc(
|
context_.fwd_desc.reset(new convolution_forward::desc(
|
||||||
prop_kind::forward, ALGORITHM::convolution_direct, *context_.src_md,
|
prop_kind::forward, mkldnn::algorithm::convolution_direct,
|
||||||
*context_.filter_md, *context_.bias_md, *context_.dst_md,
|
*context_.src_md, *context_.filter_md, *context_.bias_md,
|
||||||
convFwdDims.strides, convFwdDims.dilations, convFwdDims.padding_left,
|
*context_.dst_md, convFwdDims.strides, convFwdDims.dilations,
|
||||||
#ifndef ENABLE_MKLDNN_V1
|
convFwdDims.padding_left, convFwdDims.padding_right));
|
||||||
convFwdDims.padding_right, padding_kind::zero));
|
|
||||||
#else
|
|
||||||
convFwdDims.padding_right));
|
|
||||||
#endif // !ENABLE_MKLDNN_V1
|
|
||||||
} else {
|
} else {
|
||||||
context_.fwd_desc.reset(new convolution_forward::desc(
|
context_.fwd_desc.reset(new convolution_forward::desc(
|
||||||
prop_kind::forward, ALGORITHM::convolution_direct, *context_.src_md,
|
prop_kind::forward, mkldnn::algorithm::convolution_direct,
|
||||||
*context_.filter_md, *context_.dst_md, convFwdDims.strides,
|
*context_.src_md, *context_.filter_md, *context_.dst_md,
|
||||||
convFwdDims.dilations, convFwdDims.padding_left,
|
convFwdDims.strides, convFwdDims.dilations, convFwdDims.padding_left,
|
||||||
#ifndef ENABLE_MKLDNN_V1
|
|
||||||
convFwdDims.padding_right, padding_kind::zero));
|
|
||||||
#else
|
|
||||||
convFwdDims.padding_right));
|
convFwdDims.padding_right));
|
||||||
#endif // !ENABLE_MKLDNN_V1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
context_.fwd_pd.reset(new ConvFwdPd(*context_.fwd_desc, cpu_engine_));
|
context_.fwd_pd.reset(new ConvFwdPd(*context_.fwd_desc, cpu_engine_));
|
||||||
@ -314,54 +282,32 @@ class MklConvFwdPrimitive : public MklPrimitive {
|
|||||||
context_.fwd_pd.reset(new ConvFwdPd(*context_.fwd_desc, cpu_engine_));
|
context_.fwd_pd.reset(new ConvFwdPd(*context_.fwd_desc, cpu_engine_));
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifndef ENABLE_MKLDNN_V1
|
|
||||||
// Store the expected memory format
|
|
||||||
context_.src_fmt = static_cast<mkldnn::memory::format>(
|
|
||||||
context_.fwd_pd.get()->src_primitive_desc().desc().data.format);
|
|
||||||
|
|
||||||
context_.filter_fmt = static_cast<mkldnn::memory::format>(
|
|
||||||
context_.fwd_pd.get()->weights_primitive_desc().desc().data.format);
|
|
||||||
#endif // !ENABLE_MKLDNN_V1
|
|
||||||
|
|
||||||
// Create memory primitive based on dummy data
|
// Create memory primitive based on dummy data
|
||||||
context_.src_mem.reset(new MEMORY_CONSTRUCTOR(
|
context_.src_mem.reset(
|
||||||
context_.fwd_pd.get()->PRIMITIVE_DESC_SRC, cpu_engine_, DummyData));
|
new memory(context_.fwd_pd.get()->src_desc(), cpu_engine_, DummyData));
|
||||||
context_.filter_mem.reset(new MEMORY_CONSTRUCTOR(
|
context_.filter_mem.reset(new memory(context_.fwd_pd.get()->weights_desc(),
|
||||||
context_.fwd_pd.get()->PRIMITIVE_DESC_WEIGHTS, cpu_engine_, DummyData));
|
cpu_engine_, DummyData));
|
||||||
context_.dst_mem.reset(new MEMORY_CONSTRUCTOR(
|
context_.dst_mem.reset(
|
||||||
context_.fwd_pd.get()->PRIMITIVE_DESC_DST, cpu_engine_, DummyData));
|
new memory(context_.fwd_pd.get()->dst_desc(), cpu_engine_, DummyData));
|
||||||
|
|
||||||
// Create convolution primitive and add it to net
|
// Create convolution primitive and add it to net
|
||||||
if (!convFwdDims.bias_dims.empty()) {
|
if (!convFwdDims.bias_dims.empty()) {
|
||||||
context_.bias_mem.reset(new MEMORY_CONSTRUCTOR_USING_MEM_PD(
|
context_.bias_mem.reset(new memory(
|
||||||
convFwdDims.bias_dims, Tbias, MEMORY_FORMAT::x, cpu_engine_,
|
{{convFwdDims.bias_dims}, MklDnnType<Tbias>(), memory::format_tag::x},
|
||||||
DummyData));
|
cpu_engine_, DummyData));
|
||||||
#ifdef ENABLE_MKLDNN_V1
|
|
||||||
context_.conv_fwd.reset(new convolution_forward(*context_.fwd_pd));
|
context_.conv_fwd.reset(new convolution_forward(*context_.fwd_pd));
|
||||||
context_.fwd_primitives_args.push_back(
|
context_.fwd_primitives_args.push_back(
|
||||||
{{MKLDNN_ARG_SRC, *context_.src_mem},
|
{{MKLDNN_ARG_SRC, *context_.src_mem},
|
||||||
{MKLDNN_ARG_WEIGHTS, *context_.filter_mem},
|
{MKLDNN_ARG_WEIGHTS, *context_.filter_mem},
|
||||||
{MKLDNN_ARG_BIAS, *context_.bias_mem},
|
{MKLDNN_ARG_BIAS, *context_.bias_mem},
|
||||||
{ MKLDNN_ARG_DST,
|
{MKLDNN_ARG_DST, *context_.dst_mem}});
|
||||||
*context_.dst_mem }});
|
|
||||||
} else {
|
} else {
|
||||||
context_.conv_fwd.reset(new convolution_forward(*context_.fwd_pd));
|
context_.conv_fwd.reset(new convolution_forward(*context_.fwd_pd));
|
||||||
context_.fwd_primitives_args.push_back(
|
context_.fwd_primitives_args.push_back(
|
||||||
{{MKLDNN_ARG_SRC, *context_.src_mem},
|
{{MKLDNN_ARG_SRC, *context_.src_mem},
|
||||||
{MKLDNN_ARG_WEIGHTS, *context_.filter_mem},
|
{MKLDNN_ARG_WEIGHTS, *context_.filter_mem},
|
||||||
{ MKLDNN_ARG_DST,
|
{MKLDNN_ARG_DST, *context_.dst_mem}});
|
||||||
*context_.dst_mem }});
|
|
||||||
}
|
}
|
||||||
#else
|
|
||||||
context_.conv_fwd.reset(new convolution_forward(
|
|
||||||
*context_.fwd_pd, *context_.src_mem, *context_.filter_mem,
|
|
||||||
*context_.bias_mem, *context_.dst_mem));
|
|
||||||
} else {
|
|
||||||
context_.conv_fwd.reset(
|
|
||||||
new convolution_forward(*context_.fwd_pd, *context_.src_mem,
|
|
||||||
*context_.filter_mem, *context_.dst_mem));
|
|
||||||
}
|
|
||||||
#endif // ENABLE_MKLDNN_V1
|
|
||||||
context_.fwd_primitives.push_back(*context_.conv_fwd);
|
context_.fwd_primitives.push_back(*context_.conv_fwd);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -650,12 +596,10 @@ class MklConvOp : public OpKernel {
|
|||||||
auto tf_fmt = is_conv2d ? TFDataFormatToMklDnnDataFormat(data_format_)
|
auto tf_fmt = is_conv2d ? TFDataFormatToMklDnnDataFormat(data_format_)
|
||||||
: TFDataFormatToMklDnn3DDataFormat(data_format_);
|
: TFDataFormatToMklDnn3DDataFormat(data_format_);
|
||||||
|
|
||||||
#ifdef ENABLE_MKLDNN_V1
|
|
||||||
auto mkl_fmt_tag = MklTensorFormatToMklDnnDataFormat(tf_fmt);
|
auto mkl_fmt_tag = MklTensorFormatToMklDnnDataFormat(tf_fmt);
|
||||||
// NOTE: `mkl_fmt_tag` will be `format_tag::undef` for ReLU
|
// NOTE: `mkl_fmt_tag` will be `format_tag::undef` for ReLU
|
||||||
OP_REQUIRES(context, mkl_fmt_tag != memory::format_tag::undef,
|
OP_REQUIRES(context, mkl_fmt_tag != memory::format_tag::undef,
|
||||||
errors::InvalidArgument("Invalid data format"));
|
errors::InvalidArgument("Invalid data format"));
|
||||||
#endif // ENABLE_MKLDNN_V1
|
|
||||||
|
|
||||||
// If input is in MKL layout, then simply grab the layout; otherwise,
|
// If input is in MKL layout, then simply grab the layout; otherwise,
|
||||||
// construct TF layout for input.
|
// construct TF layout for input.
|
||||||
@ -667,19 +611,15 @@ class MklConvOp : public OpKernel {
|
|||||||
auto src_md =
|
auto src_md =
|
||||||
src_mkl_shape.IsMklTensor()
|
src_mkl_shape.IsMklTensor()
|
||||||
? src_mkl_shape.GetMklLayout()
|
? src_mkl_shape.GetMklLayout()
|
||||||
#ifdef ENABLE_MKLDNN_V1
|
|
||||||
: memory::desc(src_dims, MklDnnType<Tinput>(), mkl_fmt_tag);
|
: memory::desc(src_dims, MklDnnType<Tinput>(), mkl_fmt_tag);
|
||||||
#else
|
|
||||||
: memory::desc(src_dims, MklDnnType<Tinput>(), tf_fmt);
|
|
||||||
#endif // ENABLE_MKLDNN_V1
|
|
||||||
src.SetUsrMem(src_md, &src_tensor);
|
src.SetUsrMem(src_md, &src_tensor);
|
||||||
|
|
||||||
// Although filter shape (filter_dims) required is in MKL-DNN order,
|
// Although filter shape (filter_dims) required is in MKL-DNN order,
|
||||||
// the layout is Tensorflow's layout (HWIO) and (HWIGO) for
|
// the layout is Tensorflow's layout (HWIO) and (HWIGO) for
|
||||||
// depthwise/group convolutions.
|
// depthwise/group convolutions.
|
||||||
auto filter_format = is_conv2d ? (is_depthwise ? MEMORY_FORMAT::hwigo
|
auto filter_format = is_conv2d ? (is_depthwise ? memory::format_tag::hwigo
|
||||||
: MEMORY_FORMAT::hwio)
|
: memory::format_tag::hwio)
|
||||||
: MEMORY_FORMAT::dhwio;
|
: memory::format_tag::dhwio;
|
||||||
|
|
||||||
DCHECK(!filter_mkl_shape.IsMklTensor());
|
DCHECK(!filter_mkl_shape.IsMklTensor());
|
||||||
auto filter_md =
|
auto filter_md =
|
||||||
@ -738,12 +678,9 @@ class MklConvOp : public OpKernel {
|
|||||||
|
|
||||||
// Check whether src and filter need to be reordered.
|
// Check whether src and filter need to be reordered.
|
||||||
Tinput* src_data = nullptr;
|
Tinput* src_data = nullptr;
|
||||||
if (IS_SRC_REORDER_NEEDED(src_md, conv_fwd_pd, conv_fwd)) {
|
if (src_md != conv_fwd_pd->src_desc()) {
|
||||||
src.SetUsrMem(src_md, &src_tensor);
|
src.SetUsrMem(src_md, &src_tensor);
|
||||||
src.CheckReorderToOpMem(
|
src.CheckReorderToOpMem(conv_fwd_pd->src_desc(), cpu_engine_, context);
|
||||||
MEMORY_PD_WITHOUT_DATA(GET_SRC_DESC_FROM_OP_PD(conv_fwd_pd),
|
|
||||||
cpu_engine_),
|
|
||||||
context);
|
|
||||||
src_data = static_cast<Tinput*>(src.GetOpMem().get_data_handle());
|
src_data = static_cast<Tinput*>(src.GetOpMem().get_data_handle());
|
||||||
} else {
|
} else {
|
||||||
src_data = static_cast<Tinput*>(
|
src_data = static_cast<Tinput*>(
|
||||||
@ -751,7 +688,7 @@ class MklConvOp : public OpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Tfilter* filter_data = nullptr;
|
Tfilter* filter_data = nullptr;
|
||||||
if (IS_FILTER_REORDER_NEEDED(filter_md, conv_fwd_pd, conv_fwd)) {
|
if (filter_md != conv_fwd_pd->weights_desc()) {
|
||||||
bool is_filter_cached = false;
|
bool is_filter_cached = false;
|
||||||
// If filter is a constant, we can avoid the conversion of filter from
|
// If filter is a constant, we can avoid the conversion of filter from
|
||||||
// Tensorflow format to MKL format by caching the filter when it is
|
// Tensorflow format to MKL format by caching the filter when it is
|
||||||
@ -761,28 +698,20 @@ class MklConvOp : public OpKernel {
|
|||||||
if (IsFilterCacheEmpty(context)) {
|
if (IsFilterCacheEmpty(context)) {
|
||||||
// Cache filter if it is not already cached.
|
// Cache filter if it is not already cached.
|
||||||
CacheFilter(context, conv_fwd_pd, filter_data, filter_tensor,
|
CacheFilter(context, conv_fwd_pd, filter_data, filter_tensor,
|
||||||
#ifdef ENABLE_MKLDNN_V1
|
|
||||||
filter, filter_md, filter_mkl_shape);
|
filter, filter_md, filter_mkl_shape);
|
||||||
#else
|
|
||||||
filter, filter_md);
|
|
||||||
#endif // ENABLE_MKLDNN_V1
|
|
||||||
}
|
}
|
||||||
filter_data = GetCachedFilter(
|
filter_data = GetCachedFilter(context, conv_fwd_pd->weights_desc());
|
||||||
context, GET_WEIGHTS_FORMAT_FROM_OP_PD(conv_fwd_pd, conv_fwd));
|
|
||||||
is_filter_cached = (filter_data != nullptr);
|
is_filter_cached = (filter_data != nullptr);
|
||||||
}
|
}
|
||||||
if (!is_filter_cached) {
|
if (!is_filter_cached) {
|
||||||
filter.SetUsrMem(filter_md, &filter_tensor);
|
filter.SetUsrMem(filter_md, &filter_tensor);
|
||||||
if (filter_out_tensor == nullptr) {
|
if (filter_out_tensor == nullptr) {
|
||||||
filter.CheckReorderToOpMem(
|
filter.CheckReorderToOpMem(conv_fwd_pd->weights_desc(), cpu_engine_,
|
||||||
MEMORY_PD_WITHOUT_DATA(GET_WEIGHTS_DESC_FROM_OP_PD(conv_fwd_pd),
|
context);
|
||||||
cpu_engine_),
|
|
||||||
context);
|
|
||||||
} else {
|
} else {
|
||||||
filter.CheckReorderToOpMem(
|
filter.CheckReorderToOpMem(
|
||||||
GET_WEIGHTS_DESC_FROM_OP_PD(conv_fwd_pd),
|
conv_fwd_pd->weights_desc(),
|
||||||
DATA_WITH_ENGINE(filter.GetTensorBuffer(filter_out_tensor),
|
filter.GetTensorBuffer(filter_out_tensor), cpu_engine_,
|
||||||
cpu_engine_),
|
|
||||||
context);
|
context);
|
||||||
}
|
}
|
||||||
filter_data =
|
filter_data =
|
||||||
@ -897,7 +826,8 @@ class MklConvOp : public OpKernel {
|
|||||||
// NOTE: Fusion of BiasAdd is handled directly inside MklConvOp by
|
// NOTE: Fusion of BiasAdd is handled directly inside MklConvOp by
|
||||||
// checking `fuse_biasadd_` flag.
|
// checking `fuse_biasadd_` flag.
|
||||||
if (fuse_add_) {
|
if (fuse_add_) {
|
||||||
params.post_op_params.push_back({"sum", ALGORITHM_UNDEF, {1.0}, ""});
|
params.post_op_params.push_back(
|
||||||
|
{"sum", mkldnn::algorithm::undef, {1.0}, ""});
|
||||||
}
|
}
|
||||||
if (fuse_activation_) {
|
if (fuse_activation_) {
|
||||||
params.post_op_params.push_back(
|
params.post_op_params.push_back(
|
||||||
@ -918,35 +848,27 @@ class MklConvOp : public OpKernel {
|
|||||||
virtual void AllocateOutputTensor(OpKernelContext* context,
|
virtual void AllocateOutputTensor(OpKernelContext* context,
|
||||||
const ConvFwdPd& conv_prim_desc,
|
const ConvFwdPd& conv_prim_desc,
|
||||||
const memory::dims& output_dims_mkl_order,
|
const memory::dims& output_dims_mkl_order,
|
||||||
MKL_TENSOR_FORMAT output_tf_format,
|
MklTensorFormat output_tf_format,
|
||||||
MklDnnShape* output_mkl_shape,
|
MklDnnShape* output_mkl_shape,
|
||||||
Tensor** output_tensor) {
|
Tensor** output_tensor) {
|
||||||
DCHECK(output_tensor);
|
DCHECK(output_tensor);
|
||||||
#ifdef ENABLE_MKLDNN_V1
|
|
||||||
auto dst_md = conv_prim_desc.dst_desc();
|
auto dst_md = conv_prim_desc.dst_desc();
|
||||||
#else
|
|
||||||
auto dst_pd = conv_prim_desc.dst_primitive_desc();
|
|
||||||
auto dst_md = dst_pd.desc();
|
|
||||||
#endif // ENABLE_MKLDNN_V1
|
|
||||||
|
|
||||||
if (!std::is_same<Ttemp_output, Toutput>::value) {
|
if (!std::is_same<Ttemp_output, Toutput>::value) {
|
||||||
dst_md.data.data_type =
|
dst_md.data.data_type =
|
||||||
static_cast<mkldnn_data_type_t>(MklDnnType<Toutput>());
|
static_cast<mkldnn_data_type_t>(MklDnnType<Toutput>());
|
||||||
#ifndef ENABLE_MKLDNN_V1
|
|
||||||
dst_pd = memory::primitive_desc(dst_md, cpu_engine_);
|
|
||||||
#endif // !ENABLE_MKLDNN_V1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Allocate shape of MKL tensor
|
// Allocate shape of MKL tensor
|
||||||
output_mkl_shape->SetMklTensor(true);
|
output_mkl_shape->SetMklTensor(true);
|
||||||
output_mkl_shape->SetMklLayout(&DST_MD);
|
output_mkl_shape->SetMklLayout(&dst_md);
|
||||||
output_mkl_shape->SetElemType(MklDnnType<Toutput>());
|
output_mkl_shape->SetElemType(MklDnnType<Toutput>());
|
||||||
output_mkl_shape->SetTfLayout(output_dims_mkl_order.size(),
|
output_mkl_shape->SetTfLayout(output_dims_mkl_order.size(),
|
||||||
output_dims_mkl_order, output_tf_format);
|
output_dims_mkl_order, output_tf_format);
|
||||||
|
|
||||||
// Allocate shape of TF tensor
|
// Allocate shape of TF tensor
|
||||||
TensorShape output_tf_shape;
|
TensorShape output_tf_shape;
|
||||||
output_tf_shape.AddDim((DST_MD.get_size() / sizeof(Toutput)));
|
output_tf_shape.AddDim((dst_md.get_size() / sizeof(Toutput)));
|
||||||
if (native_format) {
|
if (native_format) {
|
||||||
output_tf_shape = output_mkl_shape->GetTfShape();
|
output_tf_shape = output_mkl_shape->GetTfShape();
|
||||||
}
|
}
|
||||||
@ -972,23 +894,16 @@ class MklConvOp : public OpKernel {
|
|||||||
AllocateOutputSetMklShape(context, kOutputIndex_Dst, output_tensor,
|
AllocateOutputSetMklShape(context, kOutputIndex_Dst, output_tensor,
|
||||||
output_tf_shape, *output_mkl_shape,
|
output_tf_shape, *output_mkl_shape,
|
||||||
native_format);
|
native_format);
|
||||||
#ifdef ENABLE_MKLDNN_V1
|
|
||||||
auto output_format_tag = MklTensorFormatToMklDnnDataFormat(
|
auto output_format_tag = MklTensorFormatToMklDnnDataFormat(
|
||||||
output_mkl_shape->GetTfDataFormat());
|
output_mkl_shape->GetTfDataFormat());
|
||||||
OP_REQUIRES(context, output_format_tag != memory::format_tag::undef,
|
OP_REQUIRES(context, output_format_tag != memory::format_tag::undef,
|
||||||
errors::InvalidArgument(
|
errors::InvalidArgument(
|
||||||
"MklConvOp: AddN fusion: Invalid data format"));
|
"MklConvOp: AddN fusion: Invalid data format"));
|
||||||
#endif // ENABLE_MKLDNN_V1
|
|
||||||
auto add_md =
|
auto add_md =
|
||||||
add_mkl_shape.IsMklTensor()
|
add_mkl_shape.IsMklTensor()
|
||||||
? add_mkl_shape.GetMklLayout()
|
? add_mkl_shape.GetMklLayout()
|
||||||
: memory::desc(output_dims_mkl_order, MklDnnType<Toutput>(),
|
: memory::desc(output_dims_mkl_order, MklDnnType<Toutput>(),
|
||||||
#ifdef ENABLE_MKLDNN_V1
|
|
||||||
output_format_tag);
|
output_format_tag);
|
||||||
#else
|
|
||||||
output_mkl_shape->GetTfDataFormat());
|
|
||||||
auto add_pd = memory::primitive_desc(add_md, this->cpu_engine_);
|
|
||||||
#endif // ENABLE_MKLDNN_V1
|
|
||||||
void* add_buf = static_cast<void*>(
|
void* add_buf = static_cast<void*>(
|
||||||
const_cast<Toutput*>(add_tensor.flat<Toutput>().data()));
|
const_cast<Toutput*>(add_tensor.flat<Toutput>().data()));
|
||||||
void* dst_buf =
|
void* dst_buf =
|
||||||
@ -996,16 +911,14 @@ class MklConvOp : public OpKernel {
|
|||||||
if (native_format) {
|
if (native_format) {
|
||||||
// We are simply deep copying the add_tensor to output_tensor without
|
// We are simply deep copying the add_tensor to output_tensor without
|
||||||
// changing memory layout, hence using same memory descriptor.
|
// changing memory layout, hence using same memory descriptor.
|
||||||
ADD_MD = DST_MD =
|
add_md = dst_md =
|
||||||
memory::desc({add_tensor.NumElements()}, MklDnnType<Toutput>(),
|
memory::desc({add_tensor.NumElements()}, MklDnnType<Toutput>(),
|
||||||
mkldnn::memory::format_tag::x);
|
mkldnn::memory::format_tag::x);
|
||||||
}
|
}
|
||||||
fuse_add_src_.reset(
|
fuse_add_src_.reset(new memory(add_md, this->cpu_engine_, add_buf));
|
||||||
new MEMORY_CONSTRUCTOR(ADD_MD, this->cpu_engine_, add_buf));
|
fuse_add_dst_.reset(new memory(dst_md, this->cpu_engine_, dst_buf));
|
||||||
fuse_add_dst_.reset(
|
|
||||||
new MEMORY_CONSTRUCTOR(DST_MD, this->cpu_engine_, dst_buf));
|
|
||||||
auto reorder_desc =
|
auto reorder_desc =
|
||||||
REORDER_PD_CONSTRUCTOR(ADD_MD, DST_MD, this->cpu_engine_);
|
ReorderPd(this->cpu_engine_, add_md, this->cpu_engine_, dst_md);
|
||||||
|
|
||||||
CreateAndExecuteReorder(reorder_desc, *fuse_add_src_, *fuse_add_dst_,
|
CreateAndExecuteReorder(reorder_desc, *fuse_add_src_, *fuse_add_dst_,
|
||||||
this->cpu_engine_, context);
|
this->cpu_engine_, context);
|
||||||
@ -1017,7 +930,7 @@ class MklConvOp : public OpKernel {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
engine cpu_engine_ = engine(ENGINE_CPU, 0);
|
engine cpu_engine_ = engine(engine::kind::cpu, 0);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::shared_ptr<mkldnn::memory> fuse_add_src_;
|
std::shared_ptr<mkldnn::memory> fuse_add_src_;
|
||||||
@ -1041,7 +954,7 @@ class MklConvOp : public OpKernel {
|
|||||||
// This variable is used for alpha in leakyrelu or upper bound in relu6
|
// This variable is used for alpha in leakyrelu or upper bound in relu6
|
||||||
// depending on the context
|
// depending on the context
|
||||||
float alpha_or_upbound_ = 0.0;
|
float alpha_or_upbound_ = 0.0;
|
||||||
mkldnn::algorithm activation_alg_ = ALGORITHM_UNDEF;
|
mkldnn::algorithm activation_alg_ = mkldnn::algorithm::undef;
|
||||||
|
|
||||||
int input_index_pad_ = 2;
|
int input_index_pad_ = 2;
|
||||||
|
|
||||||
@ -1050,15 +963,10 @@ class MklConvOp : public OpKernel {
|
|||||||
const int kOutputIndex_Dst = 0, kOutputIndex_Filter = 1;
|
const int kOutputIndex_Dst = 0, kOutputIndex_Filter = 1;
|
||||||
const int kDilationH = 0, kDilationW = 1;
|
const int kDilationH = 0, kDilationW = 1;
|
||||||
|
|
||||||
MKL_TENSOR_FORMAT_IN_C GetFilterTfDataFormat(
|
MklTensorFormat GetFilterTfDataFormat(const MklDnnShape* filter_mkl_shape,
|
||||||
const MklDnnShape* filter_mkl_shape,
|
const ConvFwdPd& conv_prim_desc) const {
|
||||||
const ConvFwdPd& conv_prim_desc) const {
|
|
||||||
#ifdef ENABLE_MKLDNN_V1
|
|
||||||
DCHECK(filter_mkl_shape);
|
DCHECK(filter_mkl_shape);
|
||||||
return filter_mkl_shape->GetTfDataFormat();
|
return filter_mkl_shape->GetTfDataFormat();
|
||||||
#else
|
|
||||||
return conv_prim_desc.weights_primitive_desc().desc().data.format;
|
|
||||||
#endif // ENABLE_MKLDNN_V1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Allocate persistent tensors for cached filter data and
|
// Allocate persistent tensors for cached filter data and
|
||||||
@ -1070,23 +978,13 @@ class MklConvOp : public OpKernel {
|
|||||||
DCHECK(filter_tensor);
|
DCHECK(filter_tensor);
|
||||||
TensorShape filter_tf_shape;
|
TensorShape filter_tf_shape;
|
||||||
filter_tf_shape.AddDim(
|
filter_tf_shape.AddDim(
|
||||||
(conv_prim_desc.PRIMITIVE_DESC_WEIGHTS.get_size() / sizeof(Tfilter)));
|
(conv_prim_desc.weights_desc().get_size() / sizeof(Tfilter)));
|
||||||
OP_REQUIRES_OK(context, context->allocate_persistent(
|
OP_REQUIRES_OK(context, context->allocate_persistent(
|
||||||
DataTypeToEnum<Tfilter>::value, filter_tf_shape,
|
DataTypeToEnum<Tfilter>::value, filter_tf_shape,
|
||||||
&cached_filter_data_ptensor_, filter_tensor));
|
&cached_filter_data_ptensor_, filter_tensor));
|
||||||
|
|
||||||
Tensor* second_tensor = nullptr;
|
Tensor* second_tensor = nullptr;
|
||||||
#ifndef ENABLE_MKLDNN_V1
|
|
||||||
TensorShape filter_mkl_format;
|
|
||||||
filter_mkl_format.AddDim(
|
|
||||||
sizeof(GetFilterTfDataFormat(filter_mkl_shape, conv_prim_desc)) /
|
|
||||||
sizeof(DT_INT32));
|
|
||||||
OP_REQUIRES_OK(context, context->allocate_persistent(
|
|
||||||
DT_INT32, filter_mkl_format,
|
|
||||||
&cached_filter_md_ptensor_, &second_tensor));
|
|
||||||
second_tensor->scalar<int32>()() = static_cast<int32>(
|
|
||||||
GetFilterTfDataFormat(filter_mkl_shape, conv_prim_desc));
|
|
||||||
#else
|
|
||||||
// There is no tensor format in DNNL 1.x. So we cache the complete filter
|
// There is no tensor format in DNNL 1.x. So we cache the complete filter
|
||||||
// descriptor as flat byte array.
|
// descriptor as flat byte array.
|
||||||
TensorShape cached_filter_md_shape;
|
TensorShape cached_filter_md_shape;
|
||||||
@ -1100,7 +998,6 @@ class MklConvOp : public OpKernel {
|
|||||||
&cached_filter_md_ptensor_, &second_tensor));
|
&cached_filter_md_ptensor_, &second_tensor));
|
||||||
*reinterpret_cast<memory::desc*>(second_tensor->flat<uint8>().data()) =
|
*reinterpret_cast<memory::desc*>(second_tensor->flat<uint8>().data()) =
|
||||||
weights_desc;
|
weights_desc;
|
||||||
#endif // !ENABLE_MKLDNN_V1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void AllocatePersistentTensor(OpKernelContext* context,
|
void AllocatePersistentTensor(OpKernelContext* context,
|
||||||
@ -1114,7 +1011,7 @@ class MklConvOp : public OpKernel {
|
|||||||
const memory::dims& filter_dims_tf_order,
|
const memory::dims& filter_dims_tf_order,
|
||||||
Tensor** filter_tensor) {
|
Tensor** filter_tensor) {
|
||||||
DCHECK(filter_tensor);
|
DCHECK(filter_tensor);
|
||||||
auto filter_md = conv_prim_desc.PRIMITIVE_DESC_WEIGHTS;
|
auto filter_md = conv_prim_desc.weights_desc();
|
||||||
|
|
||||||
// Allocate shape of MKL tensor
|
// Allocate shape of MKL tensor
|
||||||
MklDnnShape filter_mkl_shape;
|
MklDnnShape filter_mkl_shape;
|
||||||
@ -1127,7 +1024,7 @@ class MklConvOp : public OpKernel {
|
|||||||
// is stored in the MKL data.
|
// is stored in the MKL data.
|
||||||
filter_mkl_shape.SetTfLayout(filter_dims_tf_order.size(),
|
filter_mkl_shape.SetTfLayout(filter_dims_tf_order.size(),
|
||||||
filter_dims_tf_order,
|
filter_dims_tf_order,
|
||||||
MKL_TENSOR_FORMAT_BLOCKED);
|
MklTensorFormat::FORMAT_BLOCKED);
|
||||||
|
|
||||||
// Allocate the data space for the filter to propagate as TF tensor.
|
// Allocate the data space for the filter to propagate as TF tensor.
|
||||||
TensorShape filter_tf_shape;
|
TensorShape filter_tf_shape;
|
||||||
@ -1150,17 +1047,15 @@ class MklConvOp : public OpKernel {
|
|||||||
// Create reorders between user layout and MKL layout if it is needed and
|
// Create reorders between user layout and MKL layout if it is needed and
|
||||||
// add it to the net before convolution. No need to check for output
|
// add it to the net before convolution. No need to check for output
|
||||||
// reorder as we propagate output layout to the next layer.
|
// reorder as we propagate output layout to the next layer.
|
||||||
src->CheckReorderToOpMem(
|
src->CheckReorderToOpMem(conv_prim_desc.src_desc(), cpu_engine_);
|
||||||
MEMORY_PD_WITHOUT_DATA(conv_prim_desc.PRIMITIVE_DESC_SRC, cpu_engine_));
|
|
||||||
|
|
||||||
// Rather than re-ordering to a temp buffer, reorder directly to the
|
// Rather than re-ordering to a temp buffer, reorder directly to the
|
||||||
// filter output tensor
|
// filter output tensor
|
||||||
filter->CheckReorderToOpMem(conv_prim_desc.PRIMITIVE_DESC_WEIGHTS,
|
filter->CheckReorderToOpMem(conv_prim_desc.weights_desc(),
|
||||||
filter->GetTensorBuffer(filter_out_tensor));
|
filter->GetTensorBuffer(filter_out_tensor));
|
||||||
|
|
||||||
// Create convolution primitive and add it to net.
|
// Create convolution primitive and add it to net.
|
||||||
std::vector<primitive> net;
|
std::vector<primitive> net;
|
||||||
#ifdef ENABLE_MKLDNN_V1
|
|
||||||
std::vector<std::unordered_map<int, memory>> net_args;
|
std::vector<std::unordered_map<int, memory>> net_args;
|
||||||
if (bias) {
|
if (bias) {
|
||||||
DCHECK(fuse_biasadd_);
|
DCHECK(fuse_biasadd_);
|
||||||
@ -1168,31 +1063,15 @@ class MklConvOp : public OpKernel {
|
|||||||
net_args.push_back({{MKLDNN_ARG_SRC, src->GetOpMem()},
|
net_args.push_back({{MKLDNN_ARG_SRC, src->GetOpMem()},
|
||||||
{MKLDNN_ARG_WEIGHTS, filter->GetOpMem()},
|
{MKLDNN_ARG_WEIGHTS, filter->GetOpMem()},
|
||||||
{MKLDNN_ARG_BIAS, bias->GetOpMem()},
|
{MKLDNN_ARG_BIAS, bias->GetOpMem()},
|
||||||
{ MKLDNN_ARG_DST,
|
{MKLDNN_ARG_DST, output->GetOpMem()}});
|
||||||
output->GetOpMem() }});
|
|
||||||
} else {
|
} else {
|
||||||
DCHECK(!fuse_biasadd_);
|
DCHECK(!fuse_biasadd_);
|
||||||
net.push_back(convolution_forward(conv_prim_desc));
|
net.push_back(convolution_forward(conv_prim_desc));
|
||||||
net_args.push_back({{MKLDNN_ARG_SRC, src->GetOpMem()},
|
net_args.push_back({{MKLDNN_ARG_SRC, src->GetOpMem()},
|
||||||
{MKLDNN_ARG_WEIGHTS, filter->GetOpMem()},
|
{MKLDNN_ARG_WEIGHTS, filter->GetOpMem()},
|
||||||
{ MKLDNN_ARG_DST,
|
{MKLDNN_ARG_DST, output->GetOpMem()}});
|
||||||
output->GetOpMem() }});
|
|
||||||
}
|
}
|
||||||
ExecutePrimitive(net, &net_args, cpu_engine_);
|
ExecutePrimitive(net, &net_args, cpu_engine_);
|
||||||
#else
|
|
||||||
if (bias) {
|
|
||||||
DCHECK(fuse_biasadd_);
|
|
||||||
net.push_back(convolution_forward(conv_prim_desc, src->GetOpMem(),
|
|
||||||
filter->GetOpMem(), bias->GetOpMem(),
|
|
||||||
output->GetOpMem()));
|
|
||||||
} else {
|
|
||||||
DCHECK(!fuse_biasadd_);
|
|
||||||
net.push_back(convolution_forward(conv_prim_desc, src->GetOpMem(),
|
|
||||||
filter->GetOpMem(),
|
|
||||||
output->GetOpMem()));
|
|
||||||
}
|
|
||||||
ExecutePrimitive(net, nullptr, cpu_engine_);
|
|
||||||
#endif // ENABLE_MKLDNN_V1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// TF_LOCKS_EXCLUDED annotation ensures that the lock (mu_) cannot
|
// TF_LOCKS_EXCLUDED annotation ensures that the lock (mu_) cannot
|
||||||
@ -1206,9 +1085,8 @@ class MklConvOp : public OpKernel {
|
|||||||
return (cached_filter_data_tensor.NumElements() == 0);
|
return (cached_filter_data_tensor.NumElements() == 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Cache the converted filter in a persistent tensor.
|
// Cache the converted filter in a persistent tensor.
|
||||||
// Only one thread can execute this method at any given time.
|
// Only one thread can execute this method at any given time.
|
||||||
#ifdef ENABLE_MKLDNN_V1
|
|
||||||
void CacheFilter(OpKernelContext* context,
|
void CacheFilter(OpKernelContext* context,
|
||||||
const std::shared_ptr<ConvFwdPd>& conv_fwd_pd,
|
const std::shared_ptr<ConvFwdPd>& conv_fwd_pd,
|
||||||
Tfilter* filter_data, const Tensor& filter_tensor,
|
Tfilter* filter_data, const Tensor& filter_tensor,
|
||||||
@ -1254,37 +1132,8 @@ class MklConvOp : public OpKernel {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
#else
|
|
||||||
void CacheFilter(OpKernelContext* context,
|
|
||||||
const std::shared_ptr<ConvFwdPd>& conv_fwd_pd,
|
|
||||||
Tfilter* filter_data, const Tensor& filter_tensor,
|
|
||||||
MklDnnData<Tfilter>& filter, const memory::desc& filter_md)
|
|
||||||
TF_LOCKS_EXCLUDED(mu_) {
|
|
||||||
mutex_lock lock(mu_);
|
|
||||||
const Tensor& cached_filter_data_tensor =
|
|
||||||
*cached_filter_data_ptensor_.AccessTensor(context);
|
|
||||||
|
|
||||||
// If filter is already cached, there's nothing to do.
|
|
||||||
if (cached_filter_data_tensor.NumElements() > 0) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Otherwise, cache filter
|
|
||||||
filter.SetUsrMem(filter_md, &filter_tensor);
|
|
||||||
filter.CheckReorderToOpMem(conv_fwd_pd.get()->weights_primitive_desc());
|
|
||||||
filter_data = static_cast<Tfilter*>(filter.GetOpMem().get_data_handle());
|
|
||||||
|
|
||||||
Tensor* filter_tensor_ptr = nullptr;
|
|
||||||
AllocatePersistentTensor(context, *conv_fwd_pd, &filter_tensor_ptr);
|
|
||||||
void* cached_filter_data = filter.GetTensorBuffer(filter_tensor_ptr);
|
|
||||||
size_t cached_filter_data_size =
|
|
||||||
filter.GetOpMem().get_primitive_desc().get_size();
|
|
||||||
memcpy(cached_filter_data, filter_data, cached_filter_data_size);
|
|
||||||
}
|
|
||||||
#endif // ENABLE_MKLDNN_V1
|
|
||||||
|
|
||||||
Tfilter* GetCachedFilter(OpKernelContext* context,
|
Tfilter* GetCachedFilter(OpKernelContext* context,
|
||||||
const MEMORY_DESC& filter_md)
|
const memory::desc& filter_md)
|
||||||
TF_LOCKS_EXCLUDED(mu_) {
|
TF_LOCKS_EXCLUDED(mu_) {
|
||||||
tf_shared_lock lock(mu_);
|
tf_shared_lock lock(mu_);
|
||||||
const Tensor& cached_filter_data =
|
const Tensor& cached_filter_data =
|
||||||
@ -1292,15 +1141,10 @@ class MklConvOp : public OpKernel {
|
|||||||
const Tensor& cached_filter_md =
|
const Tensor& cached_filter_md =
|
||||||
*cached_filter_md_ptensor_.AccessTensor(context);
|
*cached_filter_md_ptensor_.AccessTensor(context);
|
||||||
|
|
||||||
// Check if the memory descriptor of the cached weights is same as
|
// Check if the memory descriptor of the cached weights is the same as
|
||||||
// filter_md. If so, we can use the cached weights; otherwise
|
// filter_md. If so, we can use the cached weights; otherwise
|
||||||
// return nullptr.
|
// return nullptr.
|
||||||
#ifdef ENABLE_MKLDNN_V1
|
|
||||||
if (filter_md == *static_cast<memory::desc*>(cached_filter_md.data())) {
|
if (filter_md == *static_cast<memory::desc*>(cached_filter_md.data())) {
|
||||||
#else
|
|
||||||
if (cached_filter_md.scalar<int32>().size() &&
|
|
||||||
cached_filter_md.scalar<int32>()() == filter_md) {
|
|
||||||
#endif // ENABLE_MKLDNN_V1
|
|
||||||
return static_cast<Tfilter*>(
|
return static_cast<Tfilter*>(
|
||||||
const_cast<Tfilter*>(cached_filter_data.flat<Tfilter>().data()));
|
const_cast<Tfilter*>(cached_filter_data.flat<Tfilter>().data()));
|
||||||
}
|
}
|
||||||
@ -1336,31 +1180,34 @@ class MklFusedConvOp
|
|||||||
errors::InvalidArgument(
|
errors::InvalidArgument(
|
||||||
"Fused Conv2D must have one extra argument: bias."));
|
"Fused Conv2D must have one extra argument: bias."));
|
||||||
} else if (fused_ops == std::vector<string>{"Relu"}) {
|
} else if (fused_ops == std::vector<string>{"Relu"}) {
|
||||||
this->set_fuse_activation(true, ALGORITHM::eltwise_relu);
|
this->set_fuse_activation(true, mkldnn::algorithm::eltwise_relu);
|
||||||
} else if (fused_ops == std::vector<string>{"Relu6"}) {
|
} else if (fused_ops == std::vector<string>{"Relu6"}) {
|
||||||
this->set_fuse_activation(true, ALGORITHM::eltwise_bounded_relu, 6.0);
|
this->set_fuse_activation(true, mkldnn::algorithm::eltwise_bounded_relu,
|
||||||
|
6.0);
|
||||||
} else if (fused_ops == std::vector<string>{"Elu"}) {
|
} else if (fused_ops == std::vector<string>{"Elu"}) {
|
||||||
this->set_fuse_activation(true, ALGORITHM::eltwise_elu, 1.0);
|
this->set_fuse_activation(true, mkldnn::algorithm::eltwise_elu, 1.0);
|
||||||
} else if (fused_ops == std::vector<string>{"LeakyRelu"}) {
|
} else if (fused_ops == std::vector<string>{"LeakyRelu"}) {
|
||||||
float leakyrelu_alpha;
|
float leakyrelu_alpha;
|
||||||
OP_REQUIRES_OK(context,
|
OP_REQUIRES_OK(context,
|
||||||
context->GetAttr("leakyrelu_alpha", &leakyrelu_alpha));
|
context->GetAttr("leakyrelu_alpha", &leakyrelu_alpha));
|
||||||
this->set_fuse_activation(true, ALGORITHM::eltwise_relu, leakyrelu_alpha);
|
this->set_fuse_activation(true, mkldnn::algorithm::eltwise_relu,
|
||||||
|
leakyrelu_alpha);
|
||||||
} else if (fused_ops == std::vector<string>{"BiasAdd", "Relu"}) {
|
} else if (fused_ops == std::vector<string>{"BiasAdd", "Relu"}) {
|
||||||
this->set_fuse_biasadd(true);
|
this->set_fuse_biasadd(true);
|
||||||
this->set_fuse_activation(true, ALGORITHM::eltwise_relu);
|
this->set_fuse_activation(true, mkldnn::algorithm::eltwise_relu);
|
||||||
OP_REQUIRES(context, num_args == 1,
|
OP_REQUIRES(context, num_args == 1,
|
||||||
errors::InvalidArgument(
|
errors::InvalidArgument(
|
||||||
"Fused Conv2D must have one extra argument: bias."));
|
"Fused Conv2D must have one extra argument: bias."));
|
||||||
} else if (fused_ops == std::vector<string>{"BiasAdd", "Relu6"}) {
|
} else if (fused_ops == std::vector<string>{"BiasAdd", "Relu6"}) {
|
||||||
this->set_fuse_biasadd(true);
|
this->set_fuse_biasadd(true);
|
||||||
this->set_fuse_activation(true, ALGORITHM::eltwise_bounded_relu, 6.0);
|
this->set_fuse_activation(true, mkldnn::algorithm::eltwise_bounded_relu,
|
||||||
|
6.0);
|
||||||
OP_REQUIRES(context, num_args == 1,
|
OP_REQUIRES(context, num_args == 1,
|
||||||
errors::InvalidArgument(
|
errors::InvalidArgument(
|
||||||
"Fused Conv2D must have one extra argument: bias."));
|
"Fused Conv2D must have one extra argument: bias."));
|
||||||
} else if (fused_ops == std::vector<string>{"BiasAdd", "Elu"}) {
|
} else if (fused_ops == std::vector<string>{"BiasAdd", "Elu"}) {
|
||||||
this->set_fuse_biasadd(true);
|
this->set_fuse_biasadd(true);
|
||||||
this->set_fuse_activation(true, ALGORITHM::eltwise_elu, 1.0);
|
this->set_fuse_activation(true, mkldnn::algorithm::eltwise_elu, 1.0);
|
||||||
OP_REQUIRES(context, num_args == 1,
|
OP_REQUIRES(context, num_args == 1,
|
||||||
errors::InvalidArgument(
|
errors::InvalidArgument(
|
||||||
"Fused Conv2D must have one extra argument: bias."));
|
"Fused Conv2D must have one extra argument: bias."));
|
||||||
@ -1369,7 +1216,8 @@ class MklFusedConvOp
|
|||||||
float leakyrelu_alpha;
|
float leakyrelu_alpha;
|
||||||
OP_REQUIRES_OK(context,
|
OP_REQUIRES_OK(context,
|
||||||
context->GetAttr("leakyrelu_alpha", &leakyrelu_alpha));
|
context->GetAttr("leakyrelu_alpha", &leakyrelu_alpha));
|
||||||
this->set_fuse_activation(true, ALGORITHM::eltwise_relu, leakyrelu_alpha);
|
this->set_fuse_activation(true, mkldnn::algorithm::eltwise_relu,
|
||||||
|
leakyrelu_alpha);
|
||||||
OP_REQUIRES(context, num_args == 1,
|
OP_REQUIRES(context, num_args == 1,
|
||||||
errors::InvalidArgument(
|
errors::InvalidArgument(
|
||||||
"Fused Conv2D must have one extra argument: bias."));
|
"Fused Conv2D must have one extra argument: bias."));
|
||||||
@ -1383,7 +1231,7 @@ class MklFusedConvOp
|
|||||||
} else if (fused_ops == std::vector<string>{"BiasAdd", "Add", "Relu"}) {
|
} else if (fused_ops == std::vector<string>{"BiasAdd", "Add", "Relu"}) {
|
||||||
this->set_fuse_biasadd(true);
|
this->set_fuse_biasadd(true);
|
||||||
this->set_fuse_add(true);
|
this->set_fuse_add(true);
|
||||||
this->set_fuse_activation(true, ALGORITHM::eltwise_relu);
|
this->set_fuse_activation(true, mkldnn::algorithm::eltwise_relu);
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
context, num_args == 2,
|
context, num_args == 2,
|
||||||
errors::InvalidArgument(
|
errors::InvalidArgument(
|
||||||
@ -1391,7 +1239,8 @@ class MklFusedConvOp
|
|||||||
} else if (fused_ops == std::vector<string>{"BiasAdd", "Add", "Relu6"}) {
|
} else if (fused_ops == std::vector<string>{"BiasAdd", "Add", "Relu6"}) {
|
||||||
this->set_fuse_biasadd(true);
|
this->set_fuse_biasadd(true);
|
||||||
this->set_fuse_add(true);
|
this->set_fuse_add(true);
|
||||||
this->set_fuse_activation(true, ALGORITHM::eltwise_bounded_relu, 6.0);
|
this->set_fuse_activation(true, mkldnn::algorithm::eltwise_bounded_relu,
|
||||||
|
6.0);
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
context, num_args == 2,
|
context, num_args == 2,
|
||||||
errors::InvalidArgument(
|
errors::InvalidArgument(
|
||||||
@ -1399,7 +1248,7 @@ class MklFusedConvOp
|
|||||||
} else if (fused_ops == std::vector<string>{"BiasAdd", "Add", "Elu"}) {
|
} else if (fused_ops == std::vector<string>{"BiasAdd", "Add", "Elu"}) {
|
||||||
this->set_fuse_biasadd(true);
|
this->set_fuse_biasadd(true);
|
||||||
this->set_fuse_add(true);
|
this->set_fuse_add(true);
|
||||||
this->set_fuse_activation(true, ALGORITHM::eltwise_elu, 1.0);
|
this->set_fuse_activation(true, mkldnn::algorithm::eltwise_elu, 1.0);
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
context, num_args == 2,
|
context, num_args == 2,
|
||||||
errors::InvalidArgument(
|
errors::InvalidArgument(
|
||||||
@ -1411,7 +1260,8 @@ class MklFusedConvOp
|
|||||||
float leakyrelu_alpha;
|
float leakyrelu_alpha;
|
||||||
OP_REQUIRES_OK(context,
|
OP_REQUIRES_OK(context,
|
||||||
context->GetAttr("leakyrelu_alpha", &leakyrelu_alpha));
|
context->GetAttr("leakyrelu_alpha", &leakyrelu_alpha));
|
||||||
this->set_fuse_activation(true, ALGORITHM::eltwise_relu, leakyrelu_alpha);
|
this->set_fuse_activation(true, mkldnn::algorithm::eltwise_relu,
|
||||||
|
leakyrelu_alpha);
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
context, num_args == 2,
|
context, num_args == 2,
|
||||||
errors::InvalidArgument(
|
errors::InvalidArgument(
|
||||||
@ -1459,13 +1309,14 @@ class MklFusedDepthwiseConvOp
|
|||||||
this->set_fuse_biasadd(true);
|
this->set_fuse_biasadd(true);
|
||||||
} else if (fused_ops == std::vector<string>{"BiasAdd", "Relu"}) {
|
} else if (fused_ops == std::vector<string>{"BiasAdd", "Relu"}) {
|
||||||
this->set_fuse_biasadd(true);
|
this->set_fuse_biasadd(true);
|
||||||
this->set_fuse_activation(true, ALGORITHM::eltwise_relu);
|
this->set_fuse_activation(true, mkldnn::algorithm::eltwise_relu);
|
||||||
} else if (fused_ops == std::vector<string>{"BiasAdd", "Relu6"}) {
|
} else if (fused_ops == std::vector<string>{"BiasAdd", "Relu6"}) {
|
||||||
this->set_fuse_biasadd(true);
|
this->set_fuse_biasadd(true);
|
||||||
this->set_fuse_activation(true, ALGORITHM::eltwise_bounded_relu, 6.0);
|
this->set_fuse_activation(true, mkldnn::algorithm::eltwise_bounded_relu,
|
||||||
|
6.0);
|
||||||
} else if (fused_ops == std::vector<string>{"BiasAdd", "Elu"}) {
|
} else if (fused_ops == std::vector<string>{"BiasAdd", "Elu"}) {
|
||||||
this->set_fuse_biasadd(true);
|
this->set_fuse_biasadd(true);
|
||||||
this->set_fuse_activation(true, ALGORITHM::eltwise_elu, 1.0);
|
this->set_fuse_activation(true, mkldnn::algorithm::eltwise_elu, 1.0);
|
||||||
} else {
|
} else {
|
||||||
OP_REQUIRES(context, false,
|
OP_REQUIRES(context, false,
|
||||||
errors::Unimplemented("Fusion is not implemented: [",
|
errors::Unimplemented("Fusion is not implemented: [",
|
||||||
@ -1642,8 +1493,8 @@ class MklQuantizedConv2DOp
|
|||||||
param_key.AddAsKey<float>(max_freezed_output);
|
param_key.AddAsKey<float>(max_freezed_output);
|
||||||
param_key.AddAsKey<const float*>(min_filter);
|
param_key.AddAsKey<const float*>(min_filter);
|
||||||
param_key.AddAsKey<const float*>(max_filter);
|
param_key.AddAsKey<const float*>(max_filter);
|
||||||
params.post_op_params.push_back(
|
params.post_op_params.push_back({"output_scale", mkldnn::algorithm::undef,
|
||||||
{"output_scale", ALGORITHM_UNDEF, scales, param_key.GetKey()});
|
scales, param_key.GetKey()});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1696,31 +1547,27 @@ class MklQuantizedConv2DOp
|
|||||||
bias_attr.set_output_scales(1, scales_);
|
bias_attr.set_output_scales(1, scales_);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto bias_md =
|
auto bias_md = memory::desc({static_cast<int>(bias_tensor.NumElements())},
|
||||||
MEMORY_PD_CONSTRUCTOR(static_cast<int>(bias_tensor.NumElements()),
|
MklDnnType<Tbias>(), memory::format_tag::x);
|
||||||
Tbias, MEMORY_FORMAT::x, this->cpu_engine_);
|
|
||||||
void* bias_buf = static_cast<void*>(
|
void* bias_buf = static_cast<void*>(
|
||||||
const_cast<Tbias*>(bias_tensor.flat<Tbias>().data()));
|
const_cast<Tbias*>(bias_tensor.flat<Tbias>().data()));
|
||||||
if (!input_bias_) {
|
if (!input_bias_) {
|
||||||
input_bias_ =
|
input_bias_ = new memory(bias_md, this->cpu_engine_, bias_buf);
|
||||||
new MEMORY_CONSTRUCTOR(bias_md, this->cpu_engine_, bias_buf);
|
|
||||||
} else {
|
} else {
|
||||||
input_bias_->set_data_handle(bias_buf);
|
input_bias_->set_data_handle(bias_buf);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!scaled_bias_buf_)
|
if (!scaled_bias_buf_)
|
||||||
AllocTmpBuffer<Tbias>(context, &scaled_bias_tensor_,
|
AllocTmpBuffer<Tbias>(context, &scaled_bias_tensor_,
|
||||||
GET_BIAS_DESC_FROM_OP_PD(conv_fwd_pd),
|
conv_fwd_pd->bias_desc(), &scaled_bias_buf_);
|
||||||
&scaled_bias_buf_);
|
|
||||||
if (!scaled_bias_) {
|
if (!scaled_bias_) {
|
||||||
scaled_bias_ = new MEMORY_CONSTRUCTOR(bias_md, this->cpu_engine_,
|
scaled_bias_ = new memory(bias_md, this->cpu_engine_, scaled_bias_buf_);
|
||||||
scaled_bias_buf_);
|
|
||||||
} else {
|
} else {
|
||||||
scaled_bias_->set_data_handle(scaled_bias_buf_);
|
scaled_bias_->set_data_handle(scaled_bias_buf_);
|
||||||
}
|
}
|
||||||
auto reorder_desc = REORDER_PD_CONSTRUCTOR_WITH_ATTR(
|
auto reorder_desc =
|
||||||
input_bias_->GET_DESC, scaled_bias_->GET_DESC, this->cpu_engine_,
|
ReorderPd(this->cpu_engine_, input_bias_->get_desc(),
|
||||||
bias_attr);
|
this->cpu_engine_, scaled_bias_->get_desc(), bias_attr);
|
||||||
CreateAndExecuteReorder(reorder_desc, *input_bias_, *scaled_bias_,
|
CreateAndExecuteReorder(reorder_desc, *input_bias_, *scaled_bias_,
|
||||||
this->cpu_engine_, context);
|
this->cpu_engine_, context);
|
||||||
|
|
||||||
@ -1754,7 +1601,7 @@ class MklQuantizedConv2DOp
|
|||||||
DCHECK(bias_tensor);
|
DCHECK(bias_tensor);
|
||||||
TensorShape bias_tf_shape;
|
TensorShape bias_tf_shape;
|
||||||
bias_tf_shape.AddDim(
|
bias_tf_shape.AddDim(
|
||||||
(conv_prim_desc.PRIMITIVE_DESC_BIAS.get_size() / sizeof(Tbias)));
|
(conv_prim_desc.bias_desc().get_size() / sizeof(Tbias)));
|
||||||
OP_REQUIRES_OK(context, context->allocate_persistent(
|
OP_REQUIRES_OK(context, context->allocate_persistent(
|
||||||
DataTypeToEnum<Tbias>::value, bias_tf_shape,
|
DataTypeToEnum<Tbias>::value, bias_tf_shape,
|
||||||
&cached_bias_data_ptensor_, bias_tensor));
|
&cached_bias_data_ptensor_, bias_tensor));
|
||||||
@ -1787,7 +1634,7 @@ class MklQuantizedConv2DOp
|
|||||||
AllocatePersistentTensor(context, *conv_fwd_pd, &bias_tensor_ptr);
|
AllocatePersistentTensor(context, *conv_fwd_pd, &bias_tensor_ptr);
|
||||||
void* cached_bias_data = const_cast<void*>(
|
void* cached_bias_data = const_cast<void*>(
|
||||||
static_cast<const void*>(bias_tensor_ptr->flat<Tbias>().data()));
|
static_cast<const void*>(bias_tensor_ptr->flat<Tbias>().data()));
|
||||||
size_t cached_bias_data_size = scaled_bias->GET_DESC.get_size();
|
size_t cached_bias_data_size = scaled_bias->get_desc().get_size();
|
||||||
memcpy(cached_bias_data, bias_data, cached_bias_data_size);
|
memcpy(cached_bias_data, bias_data, cached_bias_data_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1822,7 +1669,7 @@ class MklQuantizedConv2DReluOp
|
|||||||
is_depthwise>::ExtendConvFwdParams(context, params);
|
is_depthwise>::ExtendConvFwdParams(context, params);
|
||||||
|
|
||||||
params.post_op_params.push_back(
|
params.post_op_params.push_back(
|
||||||
{"activation", ALGORITHM::eltwise_relu, {1.0, 0.0, 0.0}, ""});
|
{"activation", mkldnn::algorithm::eltwise_relu, {1.0, 0.0, 0.0}, ""});
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -1868,26 +1715,30 @@ class MklQuantizedConv2DSumReluOp
|
|||||||
// if summand_type is also DT_QUINT8 as the scale_output,
|
// if summand_type is also DT_QUINT8 as the scale_output,
|
||||||
// the scaling factor of 255.0f cancels each other and thus is avoided.
|
// the scaling factor of 255.0f cancels each other and thus is avoided.
|
||||||
// If it is not then it is DT_INT8 and is scaled appropriately.
|
// If it is not then it is DT_INT8 and is scaled appropriately.
|
||||||
if (summand_type == DT_QUINT8)
|
if (summand_type == DT_QUINT8) {
|
||||||
params.post_op_params.push_back(
|
params.post_op_params.push_back({"sum",
|
||||||
{"sum", ALGORITHM_UNDEF, {scale_summand / scale_output}, ""});
|
mkldnn::algorithm::undef,
|
||||||
else
|
{scale_summand / scale_output},
|
||||||
|
""});
|
||||||
|
} else {
|
||||||
params.post_op_params.push_back(
|
params.post_op_params.push_back(
|
||||||
{"sum",
|
{"sum",
|
||||||
ALGORITHM_UNDEF,
|
mkldnn::algorithm::undef,
|
||||||
{255.0f * scale_summand / (scale_output * 127.0f)},
|
{255.0f * scale_summand / (scale_output * 127.0f)},
|
||||||
""});
|
""});
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
params.post_op_params.push_back({"sum", ALGORITHM_UNDEF, {1.0}, ""});
|
params.post_op_params.push_back(
|
||||||
|
{"sum", mkldnn::algorithm::undef, {1.0}, ""});
|
||||||
}
|
}
|
||||||
params.post_op_params.push_back(
|
params.post_op_params.push_back(
|
||||||
{"activation", ALGORITHM::eltwise_relu, {1.0, 0.0, 0.0}, ""});
|
{"activation", mkldnn::algorithm::eltwise_relu, {1.0, 0.0, 0.0}, ""});
|
||||||
}
|
}
|
||||||
|
|
||||||
void AllocateOutputTensor(OpKernelContext* context,
|
void AllocateOutputTensor(OpKernelContext* context,
|
||||||
const ConvFwdPd& conv_prim_desc,
|
const ConvFwdPd& conv_prim_desc,
|
||||||
const memory::dims& output_dims_mkl_order,
|
const memory::dims& output_dims_mkl_order,
|
||||||
MKL_TENSOR_FORMAT output_tf_format,
|
MklTensorFormat output_tf_format,
|
||||||
MklDnnShape* output_mkl_shape,
|
MklDnnShape* output_mkl_shape,
|
||||||
Tensor** output_tensor) override {
|
Tensor** output_tensor) override {
|
||||||
int summand_idx = context->num_inputs() / 2 - 1;
|
int summand_idx = context->num_inputs() / 2 - 1;
|
||||||
@ -1966,21 +1817,17 @@ class MklQuantizedConv2DSumReluOp
|
|||||||
summand_mkl_shape.IsMklTensor()
|
summand_mkl_shape.IsMklTensor()
|
||||||
? summand_mkl_shape.GetMklLayout()
|
? summand_mkl_shape.GetMklLayout()
|
||||||
: memory::desc(output_dims_mkl_order, MklDnnType<Tbias>(),
|
: memory::desc(output_dims_mkl_order, MklDnnType<Tbias>(),
|
||||||
MEMORY_FORMAT::nhwc);
|
memory::format_tag::nhwc);
|
||||||
#ifndef ENABLE_MKLDNN_V1
|
|
||||||
auto summand_pd = memory::primitive_desc(summand_md, this->cpu_engine_);
|
|
||||||
#endif // !ENABLE_MKLDNN_V1
|
|
||||||
void* summand_buf =
|
void* summand_buf =
|
||||||
static_cast<void*>(const_cast<Tbias*>(summand.flat<Tbias>().data()));
|
static_cast<void*>(const_cast<Tbias*>(summand.flat<Tbias>().data()));
|
||||||
void* dst_buf =
|
void* dst_buf =
|
||||||
static_cast<void*>((*output_tensor)->flat<Ttemp_output>().data());
|
static_cast<void*>((*output_tensor)->flat<Ttemp_output>().data());
|
||||||
summand_.reset(
|
summand_.reset(new memory(summand_md, this->cpu_engine_, summand_buf));
|
||||||
new MEMORY_CONSTRUCTOR(SUMMAND_MD, this->cpu_engine_, summand_buf));
|
dst_.reset(
|
||||||
dst_.reset(new MEMORY_CONSTRUCTOR(conv_prim_desc.PRIMITIVE_DESC_DST,
|
new memory(conv_prim_desc.dst_desc(), this->cpu_engine_, dst_buf));
|
||||||
this->cpu_engine_, dst_buf));
|
auto reorder_desc =
|
||||||
auto reorder_desc = REORDER_PD_CONSTRUCTOR_WITH_ATTR(
|
ReorderPd(this->cpu_engine_, summand_md, this->cpu_engine_,
|
||||||
SUMMAND_MD, conv_prim_desc.PRIMITIVE_DESC_DST, this->cpu_engine_,
|
conv_prim_desc.dst_desc(), reorder_attr);
|
||||||
reorder_attr);
|
|
||||||
CreateAndExecuteReorder(reorder_desc, *summand_, *dst_, this->cpu_engine_,
|
CreateAndExecuteReorder(reorder_desc, *summand_, *dst_, this->cpu_engine_,
|
||||||
context);
|
context);
|
||||||
}
|
}
|
||||||
|
@ -42,20 +42,13 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/util/padding.h"
|
#include "tensorflow/core/util/padding.h"
|
||||||
#include "tensorflow/core/util/tensor_format.h"
|
#include "tensorflow/core/util/tensor_format.h"
|
||||||
|
|
||||||
#ifndef ENABLE_MKLDNN_V1
|
|
||||||
using mkldnn::convolution_direct;
|
|
||||||
#endif // !ENABLE_MKLDNN_V1
|
|
||||||
using mkldnn::convolution_forward;
|
using mkldnn::convolution_forward;
|
||||||
using mkldnn::prop_kind;
|
using mkldnn::prop_kind;
|
||||||
using mkldnn::stream;
|
using mkldnn::stream;
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
#ifdef ENABLE_MKLDNN_V1
|
|
||||||
#define MKLDNN_SIZE_DTYPE memory::dim
|
#define MKLDNN_SIZE_DTYPE memory::dim
|
||||||
#else
|
|
||||||
#define MKLDNN_SIZE_DTYPE int
|
|
||||||
#endif // ENABLE_MKLDNN_V1
|
|
||||||
|
|
||||||
using ConvFwdDesc = mkldnn::convolution_forward::desc;
|
using ConvFwdDesc = mkldnn::convolution_forward::desc;
|
||||||
using ConvFwdPd = mkldnn::convolution_forward::primitive_desc;
|
using ConvFwdPd = mkldnn::convolution_forward::primitive_desc;
|
||||||
|
@ -134,6 +134,7 @@ tf_kernel_library(
|
|||||||
"gpu_op_bitwise_and.cc",
|
"gpu_op_bitwise_and.cc",
|
||||||
"gpu_op_bitwise_or.cc",
|
"gpu_op_bitwise_or.cc",
|
||||||
"gpu_op_bitwise_xor.cc",
|
"gpu_op_bitwise_xor.cc",
|
||||||
|
"gpu_op_div.cc",
|
||||||
"gpu_op_equal.cc",
|
"gpu_op_equal.cc",
|
||||||
"gpu_op_floor_div.cc",
|
"gpu_op_floor_div.cc",
|
||||||
"gpu_op_greater.cc",
|
"gpu_op_greater.cc",
|
||||||
@ -146,6 +147,7 @@ tf_kernel_library(
|
|||||||
"gpu_op_mul.cc",
|
"gpu_op_mul.cc",
|
||||||
"gpu_op_not_equal.cc",
|
"gpu_op_not_equal.cc",
|
||||||
"gpu_op_right_shift.cc",
|
"gpu_op_right_shift.cc",
|
||||||
|
"gpu_op_sub.cc",
|
||||||
],
|
],
|
||||||
tags = [
|
tags = [
|
||||||
"manual",
|
"manual",
|
||||||
@ -155,6 +157,7 @@ tf_kernel_library(
|
|||||||
":bitwise_and_kernels",
|
":bitwise_and_kernels",
|
||||||
":bitwise_or_kernels",
|
":bitwise_or_kernels",
|
||||||
":bitwise_xor_kernels",
|
":bitwise_xor_kernels",
|
||||||
|
":div_kernels",
|
||||||
":equal_kernels",
|
":equal_kernels",
|
||||||
":floor_div_kernels",
|
":floor_div_kernels",
|
||||||
":gpu_ops_base",
|
":gpu_ops_base",
|
||||||
@ -170,6 +173,7 @@ tf_kernel_library(
|
|||||||
":mul_kernels",
|
":mul_kernels",
|
||||||
":not_equal_kernels",
|
":not_equal_kernels",
|
||||||
":right_shift_kernels",
|
":right_shift_kernels",
|
||||||
|
":sub_kernels",
|
||||||
"//third_party/eigen3",
|
"//third_party/eigen3",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -366,18 +370,24 @@ gen_kernel_library(
|
|||||||
unroll_factors = "4",
|
unroll_factors = "4",
|
||||||
)
|
)
|
||||||
|
|
||||||
gen_kernel_library(
|
[
|
||||||
name = "add_v2",
|
gen_kernel_library(
|
||||||
tile_size = "256,1,1",
|
name = name,
|
||||||
types = [
|
tile_size = "256,1,1",
|
||||||
"f16",
|
types = [
|
||||||
"f32",
|
"f16",
|
||||||
"f64",
|
"f32",
|
||||||
"i64",
|
"f64",
|
||||||
],
|
"i64",
|
||||||
# TODO(b/174543802): Enable once fusion heuristics is better.
|
],
|
||||||
# unroll_factors = "4",
|
# TODO(b/174543802): Enable once fusion heuristics is better.
|
||||||
)
|
# unroll_factors = "4",
|
||||||
|
)
|
||||||
|
for name in [
|
||||||
|
"add_v2",
|
||||||
|
"sub",
|
||||||
|
]
|
||||||
|
]
|
||||||
|
|
||||||
gen_kernel_library(
|
gen_kernel_library(
|
||||||
name = "complex",
|
name = "complex",
|
||||||
@ -390,6 +400,20 @@ gen_kernel_library(
|
|||||||
# unroll_factors = "2",
|
# unroll_factors = "2",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
gen_kernel_library(
|
||||||
|
name = "div",
|
||||||
|
tile_size = "256,1,1",
|
||||||
|
types = [
|
||||||
|
"f16",
|
||||||
|
"f32",
|
||||||
|
"f64",
|
||||||
|
"i16",
|
||||||
|
"i64",
|
||||||
|
],
|
||||||
|
# TODO(b/174543802): Enable once fusion heuristics is better.
|
||||||
|
# unroll_factors = "4",
|
||||||
|
)
|
||||||
|
|
||||||
gen_kernel_library(
|
gen_kernel_library(
|
||||||
name = "mul",
|
name = "mul",
|
||||||
tile_size = "256,1,1",
|
tile_size = "256,1,1",
|
||||||
|
@ -48,14 +48,17 @@ class GpuBinaryOpTest : public OpsTestBase {
|
|||||||
void SetOpKernel(const std::string& op_name, const TensorShape& lhs_shape,
|
void SetOpKernel(const std::string& op_name, const TensorShape& lhs_shape,
|
||||||
const absl::InlinedVector<T, 10>& lhs_input,
|
const absl::InlinedVector<T, 10>& lhs_input,
|
||||||
const TensorShape& rhs_shape,
|
const TensorShape& rhs_shape,
|
||||||
const absl::InlinedVector<T, 10>& rhs_input,
|
const absl::InlinedVector<T, 10>& rhs_input, bool add_t,
|
||||||
bool use_constraint) {
|
bool add_tout) {
|
||||||
auto builder = NodeDefBuilder("some_name", op_name)
|
auto builder = NodeDefBuilder("some_name", op_name)
|
||||||
.Input(FakeInput(DataTypeToEnum<T>::v()))
|
.Input(FakeInput(DataTypeToEnum<T>::v()))
|
||||||
.Input(FakeInput(DataTypeToEnum<T>::v()));
|
.Input(FakeInput(DataTypeToEnum<T>::v()));
|
||||||
if (use_constraint) {
|
if (add_t) {
|
||||||
builder.Attr("T", DataTypeToEnum<T>::v());
|
builder.Attr("T", DataTypeToEnum<T>::v());
|
||||||
}
|
}
|
||||||
|
if (add_tout) {
|
||||||
|
builder.Attr("Tout", DataTypeToEnum<OutT>::v());
|
||||||
|
}
|
||||||
TF_ASSERT_OK(builder.Finalize(node_def()));
|
TF_ASSERT_OK(builder.Finalize(node_def()));
|
||||||
|
|
||||||
TF_ASSERT_OK(InitOp());
|
TF_ASSERT_OK(InitOp());
|
||||||
@ -73,16 +76,20 @@ class GpuBinaryOpTest : public OpsTestBase {
|
|||||||
const absl::InlinedVector<T, 10>& rhs_input,
|
const absl::InlinedVector<T, 10>& rhs_input,
|
||||||
const TensorShape& expected_shape,
|
const TensorShape& expected_shape,
|
||||||
const absl::InlinedVector<OutT, 10>& expected_output,
|
const absl::InlinedVector<OutT, 10>& expected_output,
|
||||||
bool use_constraint = true) {
|
const test::GpuOpsTestConfig& config) {
|
||||||
SetOpKernel<T, OutT>(op_name, lhs_shape, lhs_input, rhs_shape, rhs_input,
|
SetOpKernel<T, OutT>(op_name, lhs_shape, lhs_input, rhs_shape, rhs_input,
|
||||||
use_constraint);
|
config.add_t, config.add_tout);
|
||||||
TF_ASSERT_OK(RunOpKernel());
|
TF_ASSERT_OK(RunOpKernel());
|
||||||
|
|
||||||
// Compare output to expectation.
|
// Compare output to expectation.
|
||||||
Tensor expected_tensor(allocator(), DataTypeToEnum<OutT>::value,
|
Tensor expected_tensor(allocator(), DataTypeToEnum<OutT>::value,
|
||||||
expected_shape);
|
expected_shape);
|
||||||
test::FillValues<OutT>(&expected_tensor, expected_output);
|
test::FillValues<OutT>(&expected_tensor, expected_output);
|
||||||
test::ExpectEqual(expected_tensor, *GetOutput(0));
|
if (config.expect_strictly_equal) {
|
||||||
|
test::ExpectEqual(expected_tensor, *GetOutput(0));
|
||||||
|
} else {
|
||||||
|
test::ExpectClose(expected_tensor, *GetOutput(0));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename OutT>
|
template <typename T, typename OutT>
|
||||||
@ -91,9 +98,9 @@ class GpuBinaryOpTest : public OpsTestBase {
|
|||||||
const absl::InlinedVector<T, 10>& lhs_input,
|
const absl::InlinedVector<T, 10>& lhs_input,
|
||||||
const TensorShape& rhs_shape,
|
const TensorShape& rhs_shape,
|
||||||
const absl::InlinedVector<T, 10>& rhs_input,
|
const absl::InlinedVector<T, 10>& rhs_input,
|
||||||
bool use_constraint = true) {
|
const test::GpuOpsTestConfig& config) {
|
||||||
SetOpKernel<T, OutT>(op_name, lhs_shape, lhs_input, rhs_shape, rhs_input,
|
SetOpKernel<T, OutT>(op_name, lhs_shape, lhs_input, rhs_shape, rhs_input,
|
||||||
use_constraint);
|
config.add_t, config.add_tout);
|
||||||
auto status = RunOpKernel();
|
auto status = RunOpKernel();
|
||||||
EXPECT_FALSE(status.ok());
|
EXPECT_FALSE(status.ok());
|
||||||
EXPECT_EQ(status.code(), error::INVALID_ARGUMENT);
|
EXPECT_EQ(status.code(), error::INVALID_ARGUMENT);
|
||||||
@ -105,7 +112,7 @@ class GpuBinaryOpTest : public OpsTestBase {
|
|||||||
void TestIncompatibleShapes(const std::string& op_name,
|
void TestIncompatibleShapes(const std::string& op_name,
|
||||||
const absl::InlinedVector<T, 10>& lhs_input,
|
const absl::InlinedVector<T, 10>& lhs_input,
|
||||||
const absl::InlinedVector<T, 10>& rhs_input,
|
const absl::InlinedVector<T, 10>& rhs_input,
|
||||||
bool use_constraint = true) {
|
const test::GpuOpsTestConfig& config) {
|
||||||
// Prepare incompatibly shaped inputs.
|
// Prepare incompatibly shaped inputs.
|
||||||
TensorShape lhs_shape{3};
|
TensorShape lhs_shape{3};
|
||||||
TensorShape rhs_shape{2};
|
TensorShape rhs_shape{2};
|
||||||
@ -115,8 +122,7 @@ class GpuBinaryOpTest : public OpsTestBase {
|
|||||||
test::RepeatInputToMatchShape(rhs_input, rhs_shape.num_elements());
|
test::RepeatInputToMatchShape(rhs_input, rhs_shape.num_elements());
|
||||||
|
|
||||||
RunAndExpectInvalidArgument<T, OutT>(op_name, lhs_shape, repeated_lhs_input,
|
RunAndExpectInvalidArgument<T, OutT>(op_name, lhs_shape, repeated_lhs_input,
|
||||||
rhs_shape, repeated_rhs_input,
|
rhs_shape, repeated_rhs_input, config);
|
||||||
use_constraint);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename BaselineT, typename OutT,
|
template <typename T, typename BaselineT, typename OutT,
|
||||||
@ -125,7 +131,7 @@ class GpuBinaryOpTest : public OpsTestBase {
|
|||||||
const absl::InlinedVector<T, 10>& lhs_input,
|
const absl::InlinedVector<T, 10>& lhs_input,
|
||||||
const absl::InlinedVector<T, 10>& rhs_input,
|
const absl::InlinedVector<T, 10>& rhs_input,
|
||||||
BaselineOutT (*baseline_callback)(BaselineT, BaselineT),
|
BaselineOutT (*baseline_callback)(BaselineT, BaselineT),
|
||||||
bool use_constraint = true) {
|
const test::GpuOpsTestConfig& config) {
|
||||||
// Prepare inputs.
|
// Prepare inputs.
|
||||||
int input_size = shape.num_elements();
|
int input_size = shape.num_elements();
|
||||||
auto repeated_lhs_input =
|
auto repeated_lhs_input =
|
||||||
@ -147,7 +153,7 @@ class GpuBinaryOpTest : public OpsTestBase {
|
|||||||
|
|
||||||
RunAndExpectResult<T, OutT>(op_name, shape, repeated_lhs_input, shape,
|
RunAndExpectResult<T, OutT>(op_name, shape, repeated_lhs_input, shape,
|
||||||
repeated_rhs_input, shape, expected_output,
|
repeated_rhs_input, shape, expected_output,
|
||||||
use_constraint);
|
config);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename BaselineT, typename OutT,
|
template <typename T, typename BaselineT, typename OutT,
|
||||||
@ -156,7 +162,7 @@ class GpuBinaryOpTest : public OpsTestBase {
|
|||||||
const TensorShape& other_shape,
|
const TensorShape& other_shape,
|
||||||
const absl::InlinedVector<T, 10>& other_input,
|
const absl::InlinedVector<T, 10>& other_input,
|
||||||
BaselineOutT (*baseline_callback)(BaselineT, BaselineT),
|
BaselineOutT (*baseline_callback)(BaselineT, BaselineT),
|
||||||
bool use_constraint = true) {
|
const test::GpuOpsTestConfig& config) {
|
||||||
// Prepare inputs.
|
// Prepare inputs.
|
||||||
TensorShape scalar_shape{};
|
TensorShape scalar_shape{};
|
||||||
auto repeated_other_input =
|
auto repeated_other_input =
|
||||||
@ -177,7 +183,7 @@ class GpuBinaryOpTest : public OpsTestBase {
|
|||||||
RunAndExpectResult<T, OutT>(op_name, scalar_shape, scalar_input_vector,
|
RunAndExpectResult<T, OutT>(op_name, scalar_shape, scalar_input_vector,
|
||||||
other_shape, repeated_other_input,
|
other_shape, repeated_other_input,
|
||||||
/*expected_shape=*/other_shape, expected_output,
|
/*expected_shape=*/other_shape, expected_output,
|
||||||
use_constraint);
|
config);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename BaselineT, typename OutT,
|
template <typename T, typename BaselineT, typename OutT,
|
||||||
@ -187,7 +193,7 @@ class GpuBinaryOpTest : public OpsTestBase {
|
|||||||
const absl::InlinedVector<T, 10>& rhs_input,
|
const absl::InlinedVector<T, 10>& rhs_input,
|
||||||
BaselineOutT (*baseline_callback)(BaselineT,
|
BaselineOutT (*baseline_callback)(BaselineT,
|
||||||
BaselineT),
|
BaselineT),
|
||||||
bool use_constraint = true) {
|
const test::GpuOpsTestConfig& config) {
|
||||||
// Prepare inputs.
|
// Prepare inputs.
|
||||||
TensorShape lhs_shape{1};
|
TensorShape lhs_shape{1};
|
||||||
TensorShape rhs_shape{6};
|
TensorShape rhs_shape{6};
|
||||||
@ -206,7 +212,7 @@ class GpuBinaryOpTest : public OpsTestBase {
|
|||||||
|
|
||||||
RunAndExpectResult<T, OutT>(
|
RunAndExpectResult<T, OutT>(
|
||||||
op_name, lhs_shape, repeated_lhs_input, rhs_shape, repeated_rhs_input,
|
op_name, lhs_shape, repeated_lhs_input, rhs_shape, repeated_rhs_input,
|
||||||
/*expected_shape=*/rhs_shape, expected_output, use_constraint);
|
/*expected_shape=*/rhs_shape, expected_output, config);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename BaselineT, typename OutT,
|
template <typename T, typename BaselineT, typename OutT,
|
||||||
@ -216,7 +222,7 @@ class GpuBinaryOpTest : public OpsTestBase {
|
|||||||
const absl::InlinedVector<T, 10>& rhs_input,
|
const absl::InlinedVector<T, 10>& rhs_input,
|
||||||
BaselineOutT (*baseline_callback)(BaselineT,
|
BaselineOutT (*baseline_callback)(BaselineT,
|
||||||
BaselineT),
|
BaselineT),
|
||||||
bool use_constraint = true) {
|
const test::GpuOpsTestConfig& config) {
|
||||||
// Prepare inputs.
|
// Prepare inputs.
|
||||||
TensorShape lhs_shape{3};
|
TensorShape lhs_shape{3};
|
||||||
TensorShape rhs_shape{2, 3};
|
TensorShape rhs_shape{2, 3};
|
||||||
@ -235,7 +241,7 @@ class GpuBinaryOpTest : public OpsTestBase {
|
|||||||
|
|
||||||
RunAndExpectResult<T, OutT>(
|
RunAndExpectResult<T, OutT>(
|
||||||
op_name, lhs_shape, repeated_lhs_input, rhs_shape, repeated_rhs_input,
|
op_name, lhs_shape, repeated_lhs_input, rhs_shape, repeated_rhs_input,
|
||||||
/*expected_shape=*/rhs_shape, expected_output, use_constraint);
|
/*expected_shape=*/rhs_shape, expected_output, config);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename BaselineT, typename OutT,
|
template <typename T, typename BaselineT, typename OutT,
|
||||||
@ -244,7 +250,7 @@ class GpuBinaryOpTest : public OpsTestBase {
|
|||||||
const absl::InlinedVector<T, 10>& lhs_input,
|
const absl::InlinedVector<T, 10>& lhs_input,
|
||||||
const absl::InlinedVector<T, 10>& rhs_input,
|
const absl::InlinedVector<T, 10>& rhs_input,
|
||||||
BaselineOutT (*baseline_callback)(BaselineT, BaselineT),
|
BaselineOutT (*baseline_callback)(BaselineT, BaselineT),
|
||||||
bool use_constraint = true) {
|
const test::GpuOpsTestConfig& config) {
|
||||||
// Prepare inputs.
|
// Prepare inputs.
|
||||||
TensorShape lhs_shape{2, 1};
|
TensorShape lhs_shape{2, 1};
|
||||||
TensorShape rhs_shape{3};
|
TensorShape rhs_shape{3};
|
||||||
@ -264,7 +270,7 @@ class GpuBinaryOpTest : public OpsTestBase {
|
|||||||
|
|
||||||
RunAndExpectResult<T, OutT>(op_name, lhs_shape, repeated_lhs_input,
|
RunAndExpectResult<T, OutT>(op_name, lhs_shape, repeated_lhs_input,
|
||||||
rhs_shape, repeated_rhs_input, expected_shape,
|
rhs_shape, repeated_rhs_input, expected_shape,
|
||||||
expected_output, use_constraint);
|
expected_output, config);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename BaselineT, typename OutT,
|
template <typename T, typename BaselineT, typename OutT,
|
||||||
@ -272,7 +278,7 @@ class GpuBinaryOpTest : public OpsTestBase {
|
|||||||
void TestEmptyShapeBroadcasting(const std::string& op_name,
|
void TestEmptyShapeBroadcasting(const std::string& op_name,
|
||||||
const absl::InlinedVector<T, 10>& lhs_input,
|
const absl::InlinedVector<T, 10>& lhs_input,
|
||||||
const absl::InlinedVector<T, 10>& rhs_input,
|
const absl::InlinedVector<T, 10>& rhs_input,
|
||||||
bool use_constraint = true) {
|
const test::GpuOpsTestConfig& config) {
|
||||||
// Prepare inputs.
|
// Prepare inputs.
|
||||||
TensorShape lhs_shape{2, 0, 1};
|
TensorShape lhs_shape{2, 0, 1};
|
||||||
TensorShape rhs_shape{2, 0, 5};
|
TensorShape rhs_shape{2, 0, 5};
|
||||||
@ -284,7 +290,7 @@ class GpuBinaryOpTest : public OpsTestBase {
|
|||||||
|
|
||||||
RunAndExpectResult<T, OutT>(op_name, lhs_shape, empty_input, rhs_shape,
|
RunAndExpectResult<T, OutT>(op_name, lhs_shape, empty_input, rhs_shape,
|
||||||
empty_input, expected_shape, expected_output,
|
empty_input, expected_shape, expected_output,
|
||||||
use_constraint);
|
config);
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -309,60 +315,60 @@ class GpuBinaryOpTest : public OpsTestBase {
|
|||||||
// define your own test fixtures.
|
// define your own test fixtures.
|
||||||
|
|
||||||
#define GENERATE_DEFAULT_TESTS_2(op_name, test_name, T, BaselineT, OutT, \
|
#define GENERATE_DEFAULT_TESTS_2(op_name, test_name, T, BaselineT, OutT, \
|
||||||
BaselineOutT, baseline_callback, \
|
BaselineOutT, baseline_callback, config) \
|
||||||
use_constraint) \
|
|
||||||
TEST_F(GpuBinaryOpTest, op_name##EqShapes##test_name) { \
|
TEST_F(GpuBinaryOpTest, op_name##EqShapes##test_name) { \
|
||||||
TestEqualShapes<T, BaselineT, OutT, BaselineOutT>( \
|
TestEqualShapes<T, BaselineT, OutT, BaselineOutT>( \
|
||||||
#op_name, /*shape=*/test::DefaultInputShape(), \
|
#op_name, /*shape=*/test::DefaultInputShape(), \
|
||||||
/*lhs_input=*/test::DefaultInput<T>(#op_name), \
|
/*lhs_input=*/test::DefaultInput<T>(#op_name), \
|
||||||
/*rhs_input=*/test::DefaultInput<T>(#op_name), baseline_callback, \
|
/*rhs_input=*/test::DefaultInput<T>(#op_name), baseline_callback, \
|
||||||
use_constraint); \
|
config); \
|
||||||
} \
|
} \
|
||||||
\
|
\
|
||||||
TEST_F(GpuBinaryOpTest, op_name##OneScalar##test_name) { \
|
TEST_F(GpuBinaryOpTest, op_name##OneScalar##test_name) { \
|
||||||
TestOneScalar<T, BaselineT, OutT, BaselineOutT>( \
|
TestOneScalar<T, BaselineT, OutT, BaselineOutT>( \
|
||||||
#op_name, /*scalar_input=*/test::DefaultScalarInput<T>(), \
|
#op_name, /*scalar_input=*/test::DefaultInput<T>(#op_name).front(), \
|
||||||
/*other_shape=*/test::DefaultInputShape(), \
|
/*other_shape=*/test::DefaultInputShape(), \
|
||||||
/*other_input=*/test::DefaultInput<T>(#op_name), baseline_callback, \
|
/*other_input=*/test::DefaultInput<T>(#op_name), baseline_callback, \
|
||||||
use_constraint); \
|
config); \
|
||||||
} \
|
} \
|
||||||
\
|
\
|
||||||
TEST_F(GpuBinaryOpTest, op_name##IncompatibleShapes##test_name) { \
|
TEST_F(GpuBinaryOpTest, op_name##IncompatibleShapes##test_name) { \
|
||||||
TestIncompatibleShapes<T, OutT>( \
|
TestIncompatibleShapes<T, OutT>( \
|
||||||
#op_name, /*lhs_input=*/test::DefaultInput<T>(#op_name), \
|
#op_name, /*lhs_input=*/test::DefaultInput<T>(#op_name), \
|
||||||
/*rhs_input=*/test::DefaultInput<T>(#op_name), use_constraint); \
|
/*rhs_input=*/test::DefaultInput<T>(#op_name), config); \
|
||||||
} \
|
} \
|
||||||
\
|
\
|
||||||
TEST_F(GpuBinaryOpTest, op_name##BroadcastingExpand##test_name) { \
|
TEST_F(GpuBinaryOpTest, op_name##BroadcastingExpand##test_name) { \
|
||||||
TestBroadcastingExpand<T, BaselineT, OutT, BaselineOutT>( \
|
TestBroadcastingExpand<T, BaselineT, OutT, BaselineOutT>( \
|
||||||
#op_name, /*lhs_input=*/test::DefaultInput<T>(#op_name), \
|
#op_name, /*lhs_input=*/test::DefaultInput<T>(#op_name), \
|
||||||
/*rhs_input=*/test::DefaultInput<T>(#op_name), baseline_callback, \
|
/*rhs_input=*/test::DefaultInput<T>(#op_name), baseline_callback, \
|
||||||
use_constraint); \
|
config); \
|
||||||
} \
|
} \
|
||||||
\
|
\
|
||||||
TEST_F(GpuBinaryOpTest, op_name##BroadcastingInDim##test_name) { \
|
TEST_F(GpuBinaryOpTest, op_name##BroadcastingInDim##test_name) { \
|
||||||
TestBroadcastingInDim<T, BaselineT, OutT, BaselineOutT>( \
|
TestBroadcastingInDim<T, BaselineT, OutT, BaselineOutT>( \
|
||||||
#op_name, /*lhs_input=*/test::DefaultInput<T>(#op_name), \
|
#op_name, /*lhs_input=*/test::DefaultInput<T>(#op_name), \
|
||||||
/*rhs_input=*/test::DefaultInput<T>(#op_name), baseline_callback, \
|
/*rhs_input=*/test::DefaultInput<T>(#op_name), baseline_callback, \
|
||||||
use_constraint); \
|
config); \
|
||||||
} \
|
} \
|
||||||
\
|
\
|
||||||
TEST_F(GpuBinaryOpTest, op_name##Broadcasting##test_name) { \
|
TEST_F(GpuBinaryOpTest, op_name##Broadcasting##test_name) { \
|
||||||
TestBroadcasting<T, BaselineT, OutT, BaselineOutT>( \
|
TestBroadcasting<T, BaselineT, OutT, BaselineOutT>( \
|
||||||
#op_name, /*lhs_input=*/test::DefaultInput<T>(#op_name), \
|
#op_name, /*lhs_input=*/test::DefaultInput<T>(#op_name), \
|
||||||
/*rhs_input=*/test::DefaultInput<T>(#op_name), baseline_callback, \
|
/*rhs_input=*/test::DefaultInput<T>(#op_name), baseline_callback, \
|
||||||
use_constraint); \
|
config); \
|
||||||
} \
|
} \
|
||||||
\
|
\
|
||||||
TEST_F(GpuBinaryOpTest, op_name##EmptyShapeBroadcasting##test_name) { \
|
TEST_F(GpuBinaryOpTest, op_name##EmptyShapeBroadcasting##test_name) { \
|
||||||
TestEmptyShapeBroadcasting<T, BaselineT, OutT, BaselineOutT>( \
|
TestEmptyShapeBroadcasting<T, BaselineT, OutT, BaselineOutT>( \
|
||||||
#op_name, /*lhs_input=*/test::DefaultInput<T>(#op_name), \
|
#op_name, /*lhs_input=*/test::DefaultInput<T>(#op_name), \
|
||||||
/*rhs_input=*/test::DefaultInput<T>(#op_name), use_constraint); \
|
/*rhs_input=*/test::DefaultInput<T>(#op_name), config); \
|
||||||
}
|
}
|
||||||
|
|
||||||
#define GENERATE_DEFAULT_TESTS(op_name, test_name, T, OutT, baseline_callback) \
|
#define GENERATE_DEFAULT_TESTS(op_name, test_name, T, OutT, baseline_callback) \
|
||||||
GENERATE_DEFAULT_TESTS_2(op_name, test_name, T, T, OutT, OutT, \
|
GENERATE_DEFAULT_TESTS_2(op_name, test_name, T, T, OutT, OutT, \
|
||||||
baseline_callback, /*use_constraint=*/false)
|
baseline_callback, \
|
||||||
|
test::GpuOpsTestConfig().ExpectStrictlyEqual())
|
||||||
|
|
||||||
#define GENERATE_DEFAULT_TESTS_SAME_INPUT_AND_OUTPUT_TYPE( \
|
#define GENERATE_DEFAULT_TESTS_SAME_INPUT_AND_OUTPUT_TYPE( \
|
||||||
op_name, test_name, T, baseline_callback) \
|
op_name, test_name, T, baseline_callback) \
|
||||||
@ -433,37 +439,23 @@ GENERATE_DEFAULT_TESTS(BitwiseXor,
|
|||||||
GENERATE_DEFAULT_TESTS(BitwiseXor,
|
GENERATE_DEFAULT_TESTS(BitwiseXor,
|
||||||
/*test_name=*/Int64, int64, int64, baseline_bitwise_xor)
|
/*test_name=*/Int64, int64, int64, baseline_bitwise_xor)
|
||||||
|
|
||||||
/// Test `tf.LeftShift`.
|
/// Test `tf.Div`.
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
T baseline_left_shift(T lhs, T rhs) {
|
T baseline_div(T lhs, T rhs) {
|
||||||
return lhs << rhs;
|
return lhs / rhs;
|
||||||
}
|
}
|
||||||
|
|
||||||
GENERATE_DEFAULT_TESTS(LeftShift, /*test_name=*/Int8, int8, int8,
|
GENERATE_DEFAULT_TESTS(Div,
|
||||||
baseline_left_shift)
|
/*test_name=*/Half, Eigen::half, Eigen::half,
|
||||||
GENERATE_DEFAULT_TESTS(LeftShift, /*test_name=*/Int16, int16, int16,
|
baseline_div);
|
||||||
baseline_left_shift)
|
GENERATE_DEFAULT_TESTS(Div,
|
||||||
GENERATE_DEFAULT_TESTS(LeftShift, /*test_name=*/Int32, int32, int32,
|
/*test_name=*/Float, float, float, baseline_div);
|
||||||
baseline_left_shift)
|
GENERATE_DEFAULT_TESTS(Div,
|
||||||
GENERATE_DEFAULT_TESTS(LeftShift, /*test_name=*/Int64, int64, int64,
|
/*test_name=*/Double, double, double, baseline_div);
|
||||||
baseline_left_shift)
|
GENERATE_DEFAULT_TESTS(Div,
|
||||||
|
/*test_name=*/Int16, int16, int16, baseline_div);
|
||||||
/// Test `tf.RightShift`.
|
GENERATE_DEFAULT_TESTS(Div,
|
||||||
|
/*test_name=*/Int64, int64, int64, baseline_div);
|
||||||
template <typename T>
|
|
||||||
T baseline_right_shift(T lhs, T rhs) {
|
|
||||||
return lhs >> rhs;
|
|
||||||
}
|
|
||||||
|
|
||||||
GENERATE_DEFAULT_TESTS(RightShift,
|
|
||||||
/*test_name=*/Int8, int8, int8, baseline_right_shift)
|
|
||||||
GENERATE_DEFAULT_TESTS(RightShift,
|
|
||||||
/*test_name=*/Int16, int16, int16, baseline_right_shift)
|
|
||||||
GENERATE_DEFAULT_TESTS(RightShift,
|
|
||||||
/*test_name=*/Int32, int32, int32, baseline_right_shift)
|
|
||||||
GENERATE_DEFAULT_TESTS(RightShift,
|
|
||||||
/*test_name=*/Int64, int64, int64, baseline_right_shift)
|
|
||||||
|
|
||||||
/// Test `tf.Equal`.
|
/// Test `tf.Equal`.
|
||||||
|
|
||||||
@ -482,27 +474,25 @@ GENERATE_DEFAULT_TESTS(Equal, /*test_name=*/Int8, int8, bool, baseline_equal)
|
|||||||
GENERATE_DEFAULT_TESTS(Equal, /*test_name=*/Int16, int16, bool, baseline_equal)
|
GENERATE_DEFAULT_TESTS(Equal, /*test_name=*/Int16, int16, bool, baseline_equal)
|
||||||
GENERATE_DEFAULT_TESTS(Equal, /*test_name=*/Int64, int64, bool, baseline_equal)
|
GENERATE_DEFAULT_TESTS(Equal, /*test_name=*/Int64, int64, bool, baseline_equal)
|
||||||
|
|
||||||
/// Test `tf.NotEqual`.
|
/// Test `tf.FloorDiv`.
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
bool baseline_not_equal(T lhs, T rhs) {
|
T baseline_floor_div(T lhs, T rhs) {
|
||||||
return lhs != rhs;
|
return std::floor(lhs / rhs);
|
||||||
}
|
}
|
||||||
|
|
||||||
GENERATE_DEFAULT_TESTS(NotEqual, /*test_name=*/Half, Eigen::half, bool,
|
template <>
|
||||||
baseline_not_equal)
|
Eigen::half baseline_floor_div(Eigen::half lhs, Eigen::half rhs) {
|
||||||
GENERATE_DEFAULT_TESTS(NotEqual, /*test_name=*/Float, float, bool,
|
return static_cast<Eigen::half>(std::floor(static_cast<float>(lhs / rhs)));
|
||||||
baseline_not_equal)
|
}
|
||||||
GENERATE_DEFAULT_TESTS(NotEqual, /*test_name=*/Double, double, bool,
|
|
||||||
baseline_not_equal)
|
GENERATE_DEFAULT_TESTS(FloorDiv,
|
||||||
GENERATE_DEFAULT_TESTS(NotEqual, /*test_name=*/Bool, bool, bool,
|
/*test_name=*/Half, Eigen::half, Eigen::half,
|
||||||
baseline_not_equal)
|
baseline_floor_div)
|
||||||
GENERATE_DEFAULT_TESTS(NotEqual, /*test_name=*/Int8, int8, bool,
|
GENERATE_DEFAULT_TESTS(FloorDiv,
|
||||||
baseline_not_equal)
|
/*test_name=*/Float, float, float, baseline_floor_div)
|
||||||
GENERATE_DEFAULT_TESTS(NotEqual, /*test_name=*/Int16, int16, bool,
|
GENERATE_DEFAULT_TESTS(FloorDiv,
|
||||||
baseline_not_equal)
|
/*test_name=*/Double, double, double, baseline_floor_div)
|
||||||
GENERATE_DEFAULT_TESTS(NotEqual, /*test_name=*/Int64, int64, bool,
|
|
||||||
baseline_not_equal)
|
|
||||||
|
|
||||||
/// Test `tf.Greater`.
|
/// Test `tf.Greater`.
|
||||||
|
|
||||||
@ -544,6 +534,22 @@ GENERATE_DEFAULT_TESTS(GreaterEqual, /*test_name=*/Int16, int16, bool,
|
|||||||
GENERATE_DEFAULT_TESTS(GreaterEqual, /*test_name=*/Int64, int64, bool,
|
GENERATE_DEFAULT_TESTS(GreaterEqual, /*test_name=*/Int64, int64, bool,
|
||||||
baseline_greater_equal)
|
baseline_greater_equal)
|
||||||
|
|
||||||
|
/// Test `tf.LeftShift`.
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
T baseline_left_shift(T lhs, T rhs) {
|
||||||
|
return lhs << rhs;
|
||||||
|
}
|
||||||
|
|
||||||
|
GENERATE_DEFAULT_TESTS(LeftShift, /*test_name=*/Int8, int8, int8,
|
||||||
|
baseline_left_shift)
|
||||||
|
GENERATE_DEFAULT_TESTS(LeftShift, /*test_name=*/Int16, int16, int16,
|
||||||
|
baseline_left_shift)
|
||||||
|
GENERATE_DEFAULT_TESTS(LeftShift, /*test_name=*/Int32, int32, int32,
|
||||||
|
baseline_left_shift)
|
||||||
|
GENERATE_DEFAULT_TESTS(LeftShift, /*test_name=*/Int64, int64, int64,
|
||||||
|
baseline_left_shift)
|
||||||
|
|
||||||
/// Test `tf.Less`.
|
/// Test `tf.Less`.
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
@ -586,7 +592,7 @@ bool baseline_logical_and(bool lhs, bool rhs) { return lhs && rhs; }
|
|||||||
GENERATE_DEFAULT_TESTS_2(LogicalAnd, /*test_name=*/Bool, /*T=*/bool,
|
GENERATE_DEFAULT_TESTS_2(LogicalAnd, /*test_name=*/Bool, /*T=*/bool,
|
||||||
/*BaselineT=*/bool, /*OutT=*/bool,
|
/*BaselineT=*/bool, /*OutT=*/bool,
|
||||||
/*BaselineOutT=*/bool, baseline_logical_and,
|
/*BaselineOutT=*/bool, baseline_logical_and,
|
||||||
/*use_constraint=*/false)
|
test::GpuOpsTestConfig().ExpectStrictlyEqual().NoT())
|
||||||
|
|
||||||
/// Test `tf.LogicalOr`.
|
/// Test `tf.LogicalOr`.
|
||||||
|
|
||||||
@ -595,27 +601,77 @@ bool baseline_logical_or(bool lhs, bool rhs) { return lhs || rhs; }
|
|||||||
GENERATE_DEFAULT_TESTS_2(LogicalOr, /*test_name=*/Bool, /*T=*/bool,
|
GENERATE_DEFAULT_TESTS_2(LogicalOr, /*test_name=*/Bool, /*T=*/bool,
|
||||||
/*BaselineT=*/bool, /*OutT=*/bool,
|
/*BaselineT=*/bool, /*OutT=*/bool,
|
||||||
/*BaselineOutT=*/bool, baseline_logical_or,
|
/*BaselineOutT=*/bool, baseline_logical_or,
|
||||||
/*use_constraint=*/false)
|
test::GpuOpsTestConfig().ExpectStrictlyEqual().NoT())
|
||||||
|
|
||||||
|
/// Test `tf.Mul`.
|
||||||
|
|
||||||
/// Test `tf.FloorDiv`.
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
T baseline_floor_div(T lhs, T rhs) {
|
T baseline_mul(T lhs, T rhs) {
|
||||||
return std::floor(lhs / rhs);
|
return lhs * rhs;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
GENERATE_DEFAULT_TESTS(Mul, /*test_name=*/Half, Eigen::half, Eigen::half,
|
||||||
Eigen::half baseline_floor_div(Eigen::half lhs, Eigen::half rhs) {
|
baseline_mul)
|
||||||
return static_cast<Eigen::half>(std::floor(static_cast<float>(lhs / rhs)));
|
GENERATE_DEFAULT_TESTS(Mul, /*test_name=*/Float, float, float, baseline_mul)
|
||||||
|
GENERATE_DEFAULT_TESTS(Mul, /*test_name=*/Double, double, double, baseline_mul)
|
||||||
|
GENERATE_DEFAULT_TESTS(Mul, /*test_name=*/Int8, int8, int8, baseline_mul)
|
||||||
|
GENERATE_DEFAULT_TESTS(Mul, /*test_name=*/Int16, int16, int16, baseline_mul)
|
||||||
|
GENERATE_DEFAULT_TESTS(Mul, /*test_name=*/Int64, int64, int64, baseline_mul)
|
||||||
|
|
||||||
|
/// Test `tf.NotEqual`.
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
bool baseline_not_equal(T lhs, T rhs) {
|
||||||
|
return lhs != rhs;
|
||||||
}
|
}
|
||||||
|
|
||||||
GENERATE_DEFAULT_TESTS(FloorDiv,
|
GENERATE_DEFAULT_TESTS(NotEqual, /*test_name=*/Half, Eigen::half, bool,
|
||||||
|
baseline_not_equal)
|
||||||
|
GENERATE_DEFAULT_TESTS(NotEqual, /*test_name=*/Float, float, bool,
|
||||||
|
baseline_not_equal)
|
||||||
|
GENERATE_DEFAULT_TESTS(NotEqual, /*test_name=*/Double, double, bool,
|
||||||
|
baseline_not_equal)
|
||||||
|
GENERATE_DEFAULT_TESTS(NotEqual, /*test_name=*/Bool, bool, bool,
|
||||||
|
baseline_not_equal)
|
||||||
|
GENERATE_DEFAULT_TESTS(NotEqual, /*test_name=*/Int8, int8, bool,
|
||||||
|
baseline_not_equal)
|
||||||
|
GENERATE_DEFAULT_TESTS(NotEqual, /*test_name=*/Int16, int16, bool,
|
||||||
|
baseline_not_equal)
|
||||||
|
GENERATE_DEFAULT_TESTS(NotEqual, /*test_name=*/Int64, int64, bool,
|
||||||
|
baseline_not_equal)
|
||||||
|
|
||||||
|
/// Test `tf.RightShift`.
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
T baseline_right_shift(T lhs, T rhs) {
|
||||||
|
return lhs >> rhs;
|
||||||
|
}
|
||||||
|
|
||||||
|
GENERATE_DEFAULT_TESTS(RightShift,
|
||||||
|
/*test_name=*/Int8, int8, int8, baseline_right_shift)
|
||||||
|
GENERATE_DEFAULT_TESTS(RightShift,
|
||||||
|
/*test_name=*/Int16, int16, int16, baseline_right_shift)
|
||||||
|
GENERATE_DEFAULT_TESTS(RightShift,
|
||||||
|
/*test_name=*/Int32, int32, int32, baseline_right_shift)
|
||||||
|
GENERATE_DEFAULT_TESTS(RightShift,
|
||||||
|
/*test_name=*/Int64, int64, int64, baseline_right_shift)
|
||||||
|
|
||||||
|
/// Test `tf.Sub`.
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
T baseline_sub(T lhs, T rhs) {
|
||||||
|
return lhs - rhs;
|
||||||
|
}
|
||||||
|
|
||||||
|
GENERATE_DEFAULT_TESTS(Sub,
|
||||||
/*test_name=*/Half, Eigen::half, Eigen::half,
|
/*test_name=*/Half, Eigen::half, Eigen::half,
|
||||||
baseline_floor_div);
|
baseline_sub)
|
||||||
GENERATE_DEFAULT_TESTS(FloorDiv,
|
GENERATE_DEFAULT_TESTS(Sub,
|
||||||
/*test_name=*/Float, float, float, baseline_floor_div);
|
/*test_name=*/Float, float, float, baseline_sub)
|
||||||
GENERATE_DEFAULT_TESTS(FloorDiv,
|
GENERATE_DEFAULT_TESTS(Sub,
|
||||||
/*test_name=*/Double, double, double,
|
/*test_name=*/Double, double, double, baseline_sub)
|
||||||
baseline_floor_div);
|
GENERATE_DEFAULT_TESTS(Sub,
|
||||||
|
/*test_name=*/Int64, int64, int64, baseline_sub)
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // end namespace tensorflow
|
} // end namespace tensorflow
|
||||||
|
27
tensorflow/core/kernels/mlir_generated/gpu_op_div.cc
Normal file
27
tensorflow/core/kernels/mlir_generated/gpu_op_div.cc
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
|
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
GENERATE_AND_REGISTER_BINARY_KERNEL(Div, f16, DT_HALF, Eigen::half);
|
||||||
|
GENERATE_AND_REGISTER_BINARY_KERNEL(Div, f32, DT_FLOAT, float);
|
||||||
|
GENERATE_AND_REGISTER_BINARY_KERNEL(Div, f64, DT_DOUBLE, double);
|
||||||
|
GENERATE_AND_REGISTER_BINARY_KERNEL(Div, i16, DT_INT16, int16);
|
||||||
|
GENERATE_AND_REGISTER_BINARY_KERNEL(Div, i64, DT_INT64, int64);
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
26
tensorflow/core/kernels/mlir_generated/gpu_op_sub.cc
Normal file
26
tensorflow/core/kernels/mlir_generated/gpu_op_sub.cc
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
|
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
GENERATE_AND_REGISTER_BINARY_KERNEL(Sub, f16, DT_HALF, Eigen::half);
|
||||||
|
GENERATE_AND_REGISTER_BINARY_KERNEL(Sub, f32, DT_FLOAT, float);
|
||||||
|
GENERATE_AND_REGISTER_BINARY_KERNEL(Sub, f64, DT_DOUBLE, double);
|
||||||
|
GENERATE_AND_REGISTER_BINARY_KERNEL(Sub, i64, DT_INT64, int64);
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
@ -57,6 +57,7 @@ TensorShape DefaultInputShape();
|
|||||||
struct GpuOpsTestConfig {
|
struct GpuOpsTestConfig {
|
||||||
bool add_t = true;
|
bool add_t = true;
|
||||||
bool add_tout = false;
|
bool add_tout = false;
|
||||||
|
// Only used for gpu_unary_ops_test.
|
||||||
bool expect_buffer_reuse = true;
|
bool expect_buffer_reuse = true;
|
||||||
bool expect_strictly_equal = false;
|
bool expect_strictly_equal = false;
|
||||||
GpuOpsTestConfig ExpectStrictlyEqual() {
|
GpuOpsTestConfig ExpectStrictlyEqual() {
|
||||||
@ -119,33 +120,10 @@ absl::InlinedVector<T, 10> DefaultInputGreaterOrEqualToZero() {
|
|||||||
|
|
||||||
/// Helper functions to get default input data.
|
/// Helper functions to get default input data.
|
||||||
|
|
||||||
template <typename T,
|
|
||||||
std::enable_if_t<llvm::is_one_of<T, int8, int16, int32, int64>::value,
|
|
||||||
bool> = true>
|
|
||||||
T DefaultScalarInput() {
|
|
||||||
return static_cast<T>(3);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, std::enable_if_t<
|
|
||||||
llvm::is_one_of<T, Eigen::half, float, double>::value,
|
|
||||||
bool> = true>
|
|
||||||
T DefaultScalarInput() {
|
|
||||||
return static_cast<T>(2.0);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T,
|
|
||||||
std::enable_if_t<llvm::is_one_of<T, bool>::value, bool> = true>
|
|
||||||
T DefaultScalarInput() {
|
|
||||||
return static_cast<T>(true);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T,
|
template <typename T,
|
||||||
std::enable_if_t<llvm::is_one_of<T, int8, int16, int32, int64>::value,
|
std::enable_if_t<llvm::is_one_of<T, int8, int16, int32, int64>::value,
|
||||||
bool> = true>
|
bool> = true>
|
||||||
absl::InlinedVector<T, 10> DefaultInput(absl::string_view op_name) {
|
absl::InlinedVector<T, 10> DefaultInput(absl::string_view op_name) {
|
||||||
if (op_name == "Abs") {
|
|
||||||
return NearZeroAndExtremeInput<T>();
|
|
||||||
}
|
|
||||||
// Only generate values less than the bitwidth of the data type.
|
// Only generate values less than the bitwidth of the data type.
|
||||||
if (op_name == "LeftShift" || op_name == "RightShift") {
|
if (op_name == "LeftShift" || op_name == "RightShift") {
|
||||||
auto max_shift = sizeof(T) * 8 - 1;
|
auto max_shift = sizeof(T) * 8 - 1;
|
||||||
@ -153,6 +131,9 @@ absl::InlinedVector<T, 10> DefaultInput(absl::string_view op_name) {
|
|||||||
for (auto i = 0; i < max_shift; ++i) v.push_back(i);
|
for (auto i = 0; i < max_shift; ++i) v.push_back(i);
|
||||||
return v;
|
return v;
|
||||||
}
|
}
|
||||||
|
if (op_name == "Div") {
|
||||||
|
return InputAsVector<T, int>({-18, -9, 9, 18});
|
||||||
|
}
|
||||||
return InputAsVector<T, int>({-18, -9, -1, 0, 0, 1, 1, 2, 3, 5, 7, 9, 9, 18});
|
return InputAsVector<T, int>({-18, -9, -1, 0, 0, 1, 1, 2, 3, 5, 7, 9, 9, 18});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -160,16 +141,7 @@ template <typename T, std::enable_if_t<
|
|||||||
llvm::is_one_of<T, Eigen::half, float, double>::value,
|
llvm::is_one_of<T, Eigen::half, float, double>::value,
|
||||||
bool> = true>
|
bool> = true>
|
||||||
absl::InlinedVector<T, 10> DefaultInput(absl::string_view op_name) {
|
absl::InlinedVector<T, 10> DefaultInput(absl::string_view op_name) {
|
||||||
if (op_name == "Abs") {
|
if (op_name == "Div" || op_name == "FloorDiv") {
|
||||||
return NearZeroAndExtremeInput<T>();
|
|
||||||
}
|
|
||||||
if (op_name == "Log" || op_name == "Rsqrt") {
|
|
||||||
return DefaultInputGreaterThanZero<T>();
|
|
||||||
}
|
|
||||||
if (op_name == "Sqrt") {
|
|
||||||
return DefaultInputGreaterOrEqualToZero<T>();
|
|
||||||
}
|
|
||||||
if (op_name == "FloorDiv") {
|
|
||||||
return InputAsVector<T, double>({-18.0, -9.0, -1e-6, -0.1, 0.1, 1e-6, 0.1,
|
return InputAsVector<T, double>({-18.0, -9.0, -1e-6, -0.1, 0.1, 1e-6, 0.1,
|
||||||
0.2, 0.3, 0.5, 0.7, 0.9, 9.0, 18.0});
|
0.2, 0.3, 0.5, 0.7, 0.9, 9.0, 18.0});
|
||||||
}
|
}
|
||||||
|
@ -125,37 +125,55 @@ class GpuUnaryOpTest : public OpsTestBase {
|
|||||||
// define your own test fixtures.
|
// define your own test fixtures.
|
||||||
|
|
||||||
#define GENERATE_DEFAULT_TEST(op_name, InT, OutT, baseline_callback, config) \
|
#define GENERATE_DEFAULT_TEST(op_name, InT, OutT, baseline_callback, config) \
|
||||||
GENERATE_DEFAULT_TEST2(op_name, InT, InT, OutT, OutT, baseline_callback, \
|
GENERATE_DEFAULT_TEST_2(op_name, InT, InT, OutT, OutT, baseline_callback, \
|
||||||
config)
|
config)
|
||||||
|
|
||||||
#define GENERATE_DEFAULT_TEST2(op_name, InT, BaselineT, OutT, BaselineOutT, \
|
#define GENERATE_DEFAULT_TEST_2(op_name, InT, BaselineT, OutT, BaselineOutT, \
|
||||||
baseline_callback, config) \
|
baseline_callback, config) \
|
||||||
TEST_F(GpuUnaryOpTest, op_name##InT) { \
|
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES_2( \
|
||||||
using NativeT = EnumToDataType<InT>::Type; \
|
op_name, InT, BaselineT, OutT, BaselineOutT, \
|
||||||
using NativeBaselineT = EnumToDataType<BaselineT>::Type; \
|
test::DefaultInput<NativeT>(#op_name), baseline_callback, config)
|
||||||
using NativeOutT = EnumToDataType<OutT>::Type; \
|
|
||||||
using NativeBaselineOutT = EnumToDataType<BaselineOutT>::Type; \
|
#define GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES( \
|
||||||
Test<NativeT, NativeBaselineT, NativeOutT, NativeBaselineOutT>( \
|
op_name, InT, OutT, input_values, baseline_callback, config) \
|
||||||
#op_name, test::DefaultInputShape(), \
|
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES_2( \
|
||||||
test::DefaultInput<NativeT>(#op_name), baseline_callback, config); \
|
op_name, InT, InT, OutT, OutT, input_values, baseline_callback, config)
|
||||||
|
|
||||||
|
#define GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES_2( \
|
||||||
|
op_name, InT, BaselineT, OutT, BaselineOutT, input_values, \
|
||||||
|
baseline_callback, config) \
|
||||||
|
TEST_F(GpuUnaryOpTest, op_name##InT) { \
|
||||||
|
using NativeT = EnumToDataType<InT>::Type; \
|
||||||
|
using NativeBaselineT = EnumToDataType<BaselineT>::Type; \
|
||||||
|
using NativeOutT = EnumToDataType<OutT>::Type; \
|
||||||
|
using NativeBaselineOutT = EnumToDataType<BaselineOutT>::Type; \
|
||||||
|
Test<NativeT, NativeBaselineT, NativeOutT, NativeBaselineOutT>( \
|
||||||
|
#op_name, test::DefaultInputShape(), input_values, baseline_callback, \
|
||||||
|
config); \
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Test `tf.Abs`.
|
/// Test `tf.Abs`.
|
||||||
|
|
||||||
GENERATE_DEFAULT_TEST(Abs, DT_FLOAT, DT_FLOAT, std::abs,
|
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES(
|
||||||
test::GpuOpsTestConfig().ExpectStrictlyEqual())
|
Abs, DT_FLOAT, DT_FLOAT, test::NearZeroAndExtremeInput<float>(), std::abs,
|
||||||
|
test::GpuOpsTestConfig().ExpectStrictlyEqual())
|
||||||
|
|
||||||
GENERATE_DEFAULT_TEST(Abs, DT_DOUBLE, DT_DOUBLE, std::abs,
|
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES(
|
||||||
test::GpuOpsTestConfig().ExpectStrictlyEqual())
|
Abs, DT_DOUBLE, DT_DOUBLE, test::NearZeroAndExtremeInput<double>(),
|
||||||
|
std::abs, test::GpuOpsTestConfig().ExpectStrictlyEqual())
|
||||||
|
|
||||||
GENERATE_DEFAULT_TEST2(Abs, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT, std::abs,
|
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES_2(
|
||||||
test::GpuOpsTestConfig().ExpectStrictlyEqual())
|
Abs, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT,
|
||||||
|
test::NearZeroAndExtremeInput<Eigen::half>(), std::abs,
|
||||||
|
test::GpuOpsTestConfig().ExpectStrictlyEqual())
|
||||||
|
|
||||||
GENERATE_DEFAULT_TEST(Abs, DT_INT32, DT_INT32, std::abs,
|
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES(
|
||||||
test::GpuOpsTestConfig().ExpectStrictlyEqual())
|
Abs, DT_INT32, DT_INT32, test::NearZeroAndExtremeInput<int32>(), std::abs,
|
||||||
|
test::GpuOpsTestConfig().ExpectStrictlyEqual())
|
||||||
|
|
||||||
GENERATE_DEFAULT_TEST(Abs, DT_INT64, DT_INT64, std::abs,
|
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES(
|
||||||
test::GpuOpsTestConfig().ExpectStrictlyEqual())
|
Abs, DT_INT64, DT_INT64, test::NearZeroAndExtremeInput<int64>(), std::abs,
|
||||||
|
test::GpuOpsTestConfig().ExpectStrictlyEqual())
|
||||||
|
|
||||||
/// Test `tf.Ceil`.
|
/// Test `tf.Ceil`.
|
||||||
|
|
||||||
@ -165,8 +183,8 @@ GENERATE_DEFAULT_TEST(Ceil, DT_FLOAT, DT_FLOAT, std::ceil,
|
|||||||
GENERATE_DEFAULT_TEST(Ceil, DT_DOUBLE, DT_DOUBLE, std::ceil,
|
GENERATE_DEFAULT_TEST(Ceil, DT_DOUBLE, DT_DOUBLE, std::ceil,
|
||||||
test::GpuOpsTestConfig().ExpectStrictlyEqual())
|
test::GpuOpsTestConfig().ExpectStrictlyEqual())
|
||||||
|
|
||||||
GENERATE_DEFAULT_TEST2(Ceil, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT, std::ceil,
|
GENERATE_DEFAULT_TEST_2(Ceil, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT, std::ceil,
|
||||||
test::GpuOpsTestConfig().ExpectStrictlyEqual())
|
test::GpuOpsTestConfig().ExpectStrictlyEqual())
|
||||||
|
|
||||||
/// Test `tf.Conj`.
|
/// Test `tf.Conj`.
|
||||||
|
|
||||||
@ -189,8 +207,8 @@ GENERATE_DEFAULT_TEST(Cos, DT_FLOAT, DT_FLOAT, std::cos,
|
|||||||
GENERATE_DEFAULT_TEST(Cos, DT_DOUBLE, DT_DOUBLE, std::cos,
|
GENERATE_DEFAULT_TEST(Cos, DT_DOUBLE, DT_DOUBLE, std::cos,
|
||||||
test::GpuOpsTestConfig())
|
test::GpuOpsTestConfig())
|
||||||
|
|
||||||
GENERATE_DEFAULT_TEST2(Cos, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT, std::cos,
|
GENERATE_DEFAULT_TEST_2(Cos, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT, std::cos,
|
||||||
test::GpuOpsTestConfig())
|
test::GpuOpsTestConfig())
|
||||||
|
|
||||||
/// Test `tf.Exp`.
|
/// Test `tf.Exp`.
|
||||||
|
|
||||||
@ -200,8 +218,8 @@ GENERATE_DEFAULT_TEST(Exp, DT_FLOAT, DT_FLOAT, std::exp,
|
|||||||
GENERATE_DEFAULT_TEST(Exp, DT_DOUBLE, DT_DOUBLE, std::exp,
|
GENERATE_DEFAULT_TEST(Exp, DT_DOUBLE, DT_DOUBLE, std::exp,
|
||||||
test::GpuOpsTestConfig())
|
test::GpuOpsTestConfig())
|
||||||
|
|
||||||
GENERATE_DEFAULT_TEST2(Exp, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT, std::exp,
|
GENERATE_DEFAULT_TEST_2(Exp, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT, std::exp,
|
||||||
test::GpuOpsTestConfig())
|
test::GpuOpsTestConfig())
|
||||||
|
|
||||||
/// Test `tf.Floor`.
|
/// Test `tf.Floor`.
|
||||||
|
|
||||||
@ -211,8 +229,8 @@ GENERATE_DEFAULT_TEST(Floor, DT_FLOAT, DT_FLOAT, std::floor,
|
|||||||
GENERATE_DEFAULT_TEST(Floor, DT_DOUBLE, DT_DOUBLE, std::floor,
|
GENERATE_DEFAULT_TEST(Floor, DT_DOUBLE, DT_DOUBLE, std::floor,
|
||||||
test::GpuOpsTestConfig().ExpectStrictlyEqual())
|
test::GpuOpsTestConfig().ExpectStrictlyEqual())
|
||||||
|
|
||||||
GENERATE_DEFAULT_TEST2(Floor, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT, std::floor,
|
GENERATE_DEFAULT_TEST_2(Floor, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT, std::floor,
|
||||||
test::GpuOpsTestConfig().ExpectStrictlyEqual())
|
test::GpuOpsTestConfig().ExpectStrictlyEqual())
|
||||||
|
|
||||||
/// Test `tf.Imag`.
|
/// Test `tf.Imag`.
|
||||||
|
|
||||||
@ -260,14 +278,18 @@ TEST_F(GpuUnaryOpTest, DISABLED_IsInfHalf) {
|
|||||||
|
|
||||||
/// Test `tf.Log`.
|
/// Test `tf.Log`.
|
||||||
|
|
||||||
GENERATE_DEFAULT_TEST(Log, DT_FLOAT, DT_FLOAT, std::log,
|
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES(
|
||||||
test::GpuOpsTestConfig())
|
Log, DT_FLOAT, DT_FLOAT, test::DefaultInputGreaterThanZero<float>(),
|
||||||
|
std::log, test::GpuOpsTestConfig())
|
||||||
|
|
||||||
GENERATE_DEFAULT_TEST(Log, DT_DOUBLE, DT_DOUBLE, std::log,
|
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES(
|
||||||
test::GpuOpsTestConfig())
|
Log, DT_DOUBLE, DT_DOUBLE, test::DefaultInputGreaterThanZero<double>(),
|
||||||
|
std::log, test::GpuOpsTestConfig())
|
||||||
|
|
||||||
GENERATE_DEFAULT_TEST2(Log, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT, std::log,
|
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES_2(
|
||||||
test::GpuOpsTestConfig())
|
Log, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT,
|
||||||
|
test::DefaultInputGreaterThanZero<Eigen::half>(), std::log,
|
||||||
|
test::GpuOpsTestConfig())
|
||||||
|
|
||||||
/// Test `tf.LogicalNot`
|
/// Test `tf.LogicalNot`
|
||||||
|
|
||||||
@ -290,8 +312,8 @@ GENERATE_DEFAULT_TEST(Neg, DT_FLOAT, DT_FLOAT, baseline_neg,
|
|||||||
GENERATE_DEFAULT_TEST(Neg, DT_DOUBLE, DT_DOUBLE, baseline_neg,
|
GENERATE_DEFAULT_TEST(Neg, DT_DOUBLE, DT_DOUBLE, baseline_neg,
|
||||||
test::GpuOpsTestConfig().ExpectStrictlyEqual())
|
test::GpuOpsTestConfig().ExpectStrictlyEqual())
|
||||||
|
|
||||||
GENERATE_DEFAULT_TEST2(Neg, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT, baseline_neg,
|
GENERATE_DEFAULT_TEST_2(Neg, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT, baseline_neg,
|
||||||
test::GpuOpsTestConfig())
|
test::GpuOpsTestConfig())
|
||||||
|
|
||||||
GENERATE_DEFAULT_TEST(Neg, DT_INT8, DT_INT8, baseline_neg,
|
GENERATE_DEFAULT_TEST(Neg, DT_INT8, DT_INT8, baseline_neg,
|
||||||
test::GpuOpsTestConfig().ExpectStrictlyEqual())
|
test::GpuOpsTestConfig().ExpectStrictlyEqual())
|
||||||
@ -323,14 +345,18 @@ T baseline_rsqrt(T x) {
|
|||||||
return 1.0 / std::sqrt(x);
|
return 1.0 / std::sqrt(x);
|
||||||
}
|
}
|
||||||
|
|
||||||
GENERATE_DEFAULT_TEST(Rsqrt, DT_FLOAT, DT_FLOAT, baseline_rsqrt,
|
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES(
|
||||||
test::GpuOpsTestConfig())
|
Rsqrt, DT_FLOAT, DT_FLOAT, test::DefaultInputGreaterThanZero<float>(),
|
||||||
|
baseline_rsqrt, test::GpuOpsTestConfig())
|
||||||
|
|
||||||
GENERATE_DEFAULT_TEST(Rsqrt, DT_DOUBLE, DT_DOUBLE, baseline_rsqrt,
|
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES(
|
||||||
test::GpuOpsTestConfig())
|
Rsqrt, DT_DOUBLE, DT_DOUBLE, test::DefaultInputGreaterThanZero<double>(),
|
||||||
|
baseline_rsqrt, test::GpuOpsTestConfig())
|
||||||
|
|
||||||
GENERATE_DEFAULT_TEST2(Rsqrt, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT,
|
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES_2(
|
||||||
baseline_rsqrt, test::GpuOpsTestConfig())
|
Rsqrt, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT,
|
||||||
|
test::DefaultInputGreaterThanZero<Eigen::half>(), baseline_rsqrt,
|
||||||
|
test::GpuOpsTestConfig())
|
||||||
|
|
||||||
/// Test `tf.Sign`.
|
/// Test `tf.Sign`.
|
||||||
|
|
||||||
@ -350,8 +376,8 @@ GENERATE_DEFAULT_TEST(Sign, DT_DOUBLE, DT_DOUBLE, baseline_sign,
|
|||||||
|
|
||||||
// TODO(b/162577610): We should actually use ExpectStrictlyEqual()
|
// TODO(b/162577610): We should actually use ExpectStrictlyEqual()
|
||||||
// here. This requires returning 0.0 for input -0.0.
|
// here. This requires returning 0.0 for input -0.0.
|
||||||
GENERATE_DEFAULT_TEST2(Sign, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT,
|
GENERATE_DEFAULT_TEST_2(Sign, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT,
|
||||||
baseline_sign, test::GpuOpsTestConfig())
|
baseline_sign, test::GpuOpsTestConfig())
|
||||||
|
|
||||||
GENERATE_DEFAULT_TEST(Sign, DT_INT64, DT_INT64, baseline_sign,
|
GENERATE_DEFAULT_TEST(Sign, DT_INT64, DT_INT64, baseline_sign,
|
||||||
test::GpuOpsTestConfig().ExpectStrictlyEqual())
|
test::GpuOpsTestConfig().ExpectStrictlyEqual())
|
||||||
@ -364,19 +390,24 @@ GENERATE_DEFAULT_TEST(Sin, DT_FLOAT, DT_FLOAT, std::sin,
|
|||||||
GENERATE_DEFAULT_TEST(Sin, DT_DOUBLE, DT_DOUBLE, std::sin,
|
GENERATE_DEFAULT_TEST(Sin, DT_DOUBLE, DT_DOUBLE, std::sin,
|
||||||
test::GpuOpsTestConfig())
|
test::GpuOpsTestConfig())
|
||||||
|
|
||||||
GENERATE_DEFAULT_TEST2(Sin, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT, std::sin,
|
GENERATE_DEFAULT_TEST_2(Sin, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT, std::sin,
|
||||||
test::GpuOpsTestConfig())
|
test::GpuOpsTestConfig())
|
||||||
|
|
||||||
/// Test `tf.Sqrt`.
|
/// Test `tf.Sqrt`.
|
||||||
|
|
||||||
GENERATE_DEFAULT_TEST(Sqrt, DT_FLOAT, DT_FLOAT, std::sqrt,
|
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES(
|
||||||
test::GpuOpsTestConfig())
|
Sqrt, DT_FLOAT, DT_FLOAT, test::DefaultInputGreaterOrEqualToZero<float>(),
|
||||||
|
std::sqrt, test::GpuOpsTestConfig())
|
||||||
|
|
||||||
GENERATE_DEFAULT_TEST(Sqrt, DT_DOUBLE, DT_DOUBLE, std::sqrt,
|
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES(
|
||||||
test::GpuOpsTestConfig())
|
Sqrt, DT_DOUBLE, DT_DOUBLE,
|
||||||
|
test::DefaultInputGreaterOrEqualToZero<double>(), std::sqrt,
|
||||||
|
test::GpuOpsTestConfig())
|
||||||
|
|
||||||
GENERATE_DEFAULT_TEST2(Sqrt, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT, std::sqrt,
|
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES_2(
|
||||||
test::GpuOpsTestConfig())
|
Sqrt, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT,
|
||||||
|
test::DefaultInputGreaterOrEqualToZero<Eigen::half>(), std::sqrt,
|
||||||
|
test::GpuOpsTestConfig())
|
||||||
|
|
||||||
/// Test `tf.Tanh`.
|
/// Test `tf.Tanh`.
|
||||||
|
|
||||||
@ -386,8 +417,8 @@ GENERATE_DEFAULT_TEST(Tanh, DT_FLOAT, DT_FLOAT, std::tanh,
|
|||||||
GENERATE_DEFAULT_TEST(Tanh, DT_DOUBLE, DT_DOUBLE, std::tanh,
|
GENERATE_DEFAULT_TEST(Tanh, DT_DOUBLE, DT_DOUBLE, std::tanh,
|
||||||
test::GpuOpsTestConfig())
|
test::GpuOpsTestConfig())
|
||||||
|
|
||||||
GENERATE_DEFAULT_TEST2(Tanh, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT, std::tanh,
|
GENERATE_DEFAULT_TEST_2(Tanh, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT, std::tanh,
|
||||||
test::GpuOpsTestConfig())
|
test::GpuOpsTestConfig())
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // end namespace tensorflow
|
} // end namespace tensorflow
|
||||||
|
@ -0,0 +1,6 @@
|
|||||||
|
func @Div_elem_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>)
|
||||||
|
-> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} {
|
||||||
|
%0 = "tf.Div"(%arg0, %arg1) {T = elem_type, device = ""}
|
||||||
|
: (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xelem_type>
|
||||||
|
return %0 : tensor<*xelem_type>
|
||||||
|
}
|
@ -0,0 +1,6 @@
|
|||||||
|
func @Sub_elem_type(%arg0: tensor<*xelem_type>, %arg1: tensor<*xelem_type>)
|
||||||
|
-> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} {
|
||||||
|
%0 = "tf.Sub"(%arg0, %arg1) {T = elem_type, device = ""}
|
||||||
|
: (tensor<*xelem_type>, tensor<*xelem_type>) -> tensor<*xelem_type>
|
||||||
|
return %0 : tensor<*xelem_type>
|
||||||
|
}
|
@ -137,4 +137,14 @@ REGISTER_KERNEL_BUILDER(
|
|||||||
Name("QuantizedMaxPool").Device(DEVICE_CPU).TypeConstraint<quint8>("T"),
|
Name("QuantizedMaxPool").Device(DEVICE_CPU).TypeConstraint<quint8>("T"),
|
||||||
QuantizedMaxPoolingOp<CPUDevice, quint8>);
|
QuantizedMaxPoolingOp<CPUDevice, quint8>);
|
||||||
|
|
||||||
|
#ifdef INTEL_MKL
|
||||||
|
REGISTER_KERNEL_BUILDER(
|
||||||
|
Name("QuantizedAvgPool").Device(DEVICE_CPU).TypeConstraint<qint8>("T"),
|
||||||
|
QuantizedAvgPoolingOp<CPUDevice, qint8>);
|
||||||
|
|
||||||
|
REGISTER_KERNEL_BUILDER(
|
||||||
|
Name("QuantizedMaxPool").Device(DEVICE_CPU).TypeConstraint<qint8>("T"),
|
||||||
|
QuantizedMaxPoolingOp<CPUDevice, qint8>);
|
||||||
|
#endif
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -232,11 +232,7 @@ REGISTER_OP("_FusedBatchNormEx")
|
|||||||
.Output("reserve_space_1: U")
|
.Output("reserve_space_1: U")
|
||||||
.Output("reserve_space_2: U")
|
.Output("reserve_space_2: U")
|
||||||
.Output("reserve_space_3: U")
|
.Output("reserve_space_3: U")
|
||||||
#ifdef ENABLE_MKLDNN_V1
|
|
||||||
.Attr("T: {half, float, bfloat16}")
|
.Attr("T: {half, float, bfloat16}")
|
||||||
#else
|
|
||||||
.Attr("T: {half, float}")
|
|
||||||
#endif
|
|
||||||
.Attr("U: {float}")
|
.Attr("U: {float}")
|
||||||
.Attr("epsilon: float = 0.0001")
|
.Attr("epsilon: float = 0.0001")
|
||||||
.Attr("exponential_avg_factor: float = 1.0")
|
.Attr("exponential_avg_factor: float = 1.0")
|
||||||
|
@ -981,6 +981,17 @@ cc_library(
|
|||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "enable_tf2_utils",
|
||||||
|
srcs = ["enable_tf2_utils.cc"],
|
||||||
|
hdrs = ["enable_tf2_utils.h"],
|
||||||
|
copts = tf_copts(),
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/core/util:env_var",
|
||||||
|
],
|
||||||
|
alwayslink = 1,
|
||||||
|
)
|
||||||
|
|
||||||
alias(
|
alias(
|
||||||
name = "profile_utils_cpu_utils",
|
name = "profile_utils_cpu_utils",
|
||||||
actual = "//tensorflow/core/platform/profile_utils:profile_utils_cpu_utils",
|
actual = "//tensorflow/core/platform/profile_utils:profile_utils_cpu_utils",
|
||||||
@ -992,6 +1003,12 @@ filegroup(
|
|||||||
compatible_with = get_compatible_with_portable(),
|
compatible_with = get_compatible_with_portable(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "enable_tf2_hdr",
|
||||||
|
srcs = ["enable_tf2_utils.h"],
|
||||||
|
compatible_with = get_compatible_with_portable(),
|
||||||
|
)
|
||||||
|
|
||||||
tf_cc_tests(
|
tf_cc_tests(
|
||||||
name = "low_level_library_tests",
|
name = "low_level_library_tests",
|
||||||
size = "small",
|
size = "small",
|
||||||
@ -1047,6 +1064,20 @@ tf_cc_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_cc_test(
|
||||||
|
name = "enable_tf2_utils_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = [
|
||||||
|
"enable_tf2_utils_test.cc",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":enable_tf2_utils",
|
||||||
|
"//tensorflow/core:test",
|
||||||
|
"//tensorflow/core:test_main",
|
||||||
|
"//tensorflow/core/util:env_var",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
tf_cc_tests(
|
tf_cc_tests(
|
||||||
name = "stacktrace_handler_test",
|
name = "stacktrace_handler_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
|
@ -149,9 +149,7 @@ class PosixEnv : public Env {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#if defined(__ANDROID__) || defined(__EMSCRIPTEN__)
|
#if defined(__GLIBC__) || defined(__FreeBSD__)
|
||||||
return false;
|
|
||||||
#else
|
|
||||||
char buf[100];
|
char buf[100];
|
||||||
#ifdef __FreeBSD__
|
#ifdef __FreeBSD__
|
||||||
int res = 0;
|
int res = 0;
|
||||||
@ -164,6 +162,8 @@ class PosixEnv : public Env {
|
|||||||
}
|
}
|
||||||
*name = buf;
|
*name = buf;
|
||||||
return true;
|
return true;
|
||||||
|
#else
|
||||||
|
return false;
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
49
tensorflow/core/platform/enable_tf2_utils.cc
Normal file
49
tensorflow/core/platform/enable_tf2_utils.cc
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/platform/enable_tf2_utils.h"
|
||||||
|
|
||||||
|
#include <atomic>
|
||||||
|
|
||||||
|
#include "tensorflow/core/util/env_var.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
enum Enablement : uint8 { kFalse = 0, kTrue = 1, undefined = 2 };
|
||||||
|
|
||||||
|
// If this flag is set, we will use it as a signal to decide on whether to
|
||||||
|
// use the MLIR based TF-XLA bridge.
|
||||||
|
static std::atomic<Enablement> tf2_enabled{undefined};
|
||||||
|
|
||||||
|
// Determine whether or not the user has explicitly asked for tf2 execution.
|
||||||
|
// Will be used to determine whether to use the MLIR based bridge.
|
||||||
|
void set_tf2_execution(bool enabled) {
|
||||||
|
tf2_enabled = (enabled) ? Enablement::kTrue : Enablement::kFalse;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool tf2_execution_enabled() {
|
||||||
|
if (tf2_enabled == Enablement::undefined) {
|
||||||
|
static bool tf2_behavior_env_enabled = [] {
|
||||||
|
string tf2_env;
|
||||||
|
TF_CHECK_OK(ReadStringFromEnvVar("TF2_BEHAVIOR", "0", &tf2_env));
|
||||||
|
return tf2_env != "0";
|
||||||
|
}();
|
||||||
|
tf2_enabled =
|
||||||
|
(tf2_behavior_env_enabled) ? Enablement::kTrue : Enablement::kFalse;
|
||||||
|
}
|
||||||
|
return tf2_enabled;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
31
tensorflow/core/platform/enable_tf2_utils.h
Normal file
31
tensorflow/core/platform/enable_tf2_utils.h
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TF_CORE_PLATFORM_TF2_UTILS_H_
|
||||||
|
#define TF_CORE_PLATFORM_TF2_UTILS_H_
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
// Sets the tf2 execution state. This can be used to indicate whether the user
|
||||||
|
// has explicitly asked for tf2 execution.
|
||||||
|
void set_tf2_execution(bool enabled);
|
||||||
|
|
||||||
|
// Returns true or false depending on whether the user flag for tf2 execution
|
||||||
|
// has been set. The default is false.
|
||||||
|
bool tf2_execution_enabled();
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TF_CORE_PLATFORM_TF2_UTILS_H_
|
35
tensorflow/core/platform/enable_tf2_utils_test.cc
Normal file
35
tensorflow/core/platform/enable_tf2_utils_test.cc
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
// Testing TF2 enablement.
|
||||||
|
|
||||||
|
#include "tensorflow/core/platform/enable_tf2_utils.h"
|
||||||
|
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
#include "tensorflow/core/util/env_var.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
TEST(TF2EnabledTest, enabled_behavior) {
|
||||||
|
string tf2_env;
|
||||||
|
TF_CHECK_OK(ReadStringFromEnvVar("TF2_BEHAVIOR", "0", &tf2_env));
|
||||||
|
bool expected = (tf2_env != "0");
|
||||||
|
EXPECT_EQ(tensorflow::tf2_execution_enabled(), expected);
|
||||||
|
tensorflow::set_tf2_execution(true);
|
||||||
|
EXPECT_TRUE(tensorflow::tf2_execution_enabled());
|
||||||
|
tensorflow::set_tf2_execution(false);
|
||||||
|
EXPECT_FALSE(tensorflow::tf2_execution_enabled());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
@ -108,7 +108,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0
|
#define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0
|
||||||
#define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0
|
#define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0
|
||||||
#define TF_GRAPH_DEF_VERSION 634 // Updated: 2021/1/2
|
#define TF_GRAPH_DEF_VERSION 637 // Updated: 2021/1/5
|
||||||
|
|
||||||
// Checkpoint compatibility versions (the versions field in SavedSliceMeta).
|
// Checkpoint compatibility versions (the versions field in SavedSliceMeta).
|
||||||
//
|
//
|
||||||
|
@ -386,6 +386,10 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type,
|
|||||||
return ParseSoftmax(op, error_reporter, allocator, builtin_data);
|
return ParseSoftmax(op, error_reporter, allocator, builtin_data);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
case BuiltinOperator_SPACE_TO_DEPTH: {
|
||||||
|
return ParseSpaceToDepth(op, error_reporter, allocator, builtin_data);
|
||||||
|
}
|
||||||
|
|
||||||
case BuiltinOperator_SPLIT: {
|
case BuiltinOperator_SPLIT: {
|
||||||
return ParseSplit(op, error_reporter, allocator, builtin_data);
|
return ParseSplit(op, error_reporter, allocator, builtin_data);
|
||||||
}
|
}
|
||||||
@ -596,16 +600,6 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type,
|
|||||||
*builtin_data = params.release();
|
*builtin_data = params.release();
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
case BuiltinOperator_SPACE_TO_DEPTH: {
|
|
||||||
auto params = safe_allocator.Allocate<TfLiteSpaceToDepthParams>();
|
|
||||||
TF_LITE_ENSURE(error_reporter, params != nullptr);
|
|
||||||
if (const auto* schema_params =
|
|
||||||
op->builtin_options_as_SpaceToDepthOptions()) {
|
|
||||||
params->block_size = schema_params->block_size();
|
|
||||||
}
|
|
||||||
*builtin_data = params.release();
|
|
||||||
return kTfLiteOk;
|
|
||||||
}
|
|
||||||
|
|
||||||
case BuiltinOperator_GATHER: {
|
case BuiltinOperator_GATHER: {
|
||||||
auto params = safe_allocator.Allocate<TfLiteGatherParams>();
|
auto params = safe_allocator.Allocate<TfLiteGatherParams>();
|
||||||
@ -1684,6 +1678,31 @@ TfLiteStatus ParseSoftmax(const Operator* op, ErrorReporter* error_reporter,
|
|||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TfLiteStatus ParseSpaceToDepth(const Operator* op,
|
||||||
|
ErrorReporter* error_reporter,
|
||||||
|
BuiltinDataAllocator* allocator,
|
||||||
|
void** builtin_data) {
|
||||||
|
CheckParsePointerParams(op, error_reporter, allocator, builtin_data);
|
||||||
|
|
||||||
|
SafeBuiltinDataAllocator safe_allocator(allocator);
|
||||||
|
std::unique_ptr<TfLiteSpaceToDepthParams,
|
||||||
|
SafeBuiltinDataAllocator::BuiltinDataDeleter>
|
||||||
|
params = safe_allocator.Allocate<TfLiteSpaceToDepthParams>();
|
||||||
|
TF_LITE_ENSURE(error_reporter, params != nullptr);
|
||||||
|
|
||||||
|
const auto* schema_params = op->builtin_options_as_SpaceToDepthOptions();
|
||||||
|
if (schema_params != nullptr) {
|
||||||
|
params->block_size = schema_params->block_size();
|
||||||
|
} else {
|
||||||
|
// TODO(b/157480169): We should either return kTfLiteError or fill in some
|
||||||
|
// reasonable defaults in the params struct. We are not doing so until we
|
||||||
|
// better undertand the ramifications of changing the legacy behavior.
|
||||||
|
}
|
||||||
|
|
||||||
|
*builtin_data = params.release();
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
TfLiteStatus ParseSplit(const Operator* op, ErrorReporter* error_reporter,
|
TfLiteStatus ParseSplit(const Operator* op, ErrorReporter* error_reporter,
|
||||||
BuiltinDataAllocator* allocator, void** builtin_data) {
|
BuiltinDataAllocator* allocator, void** builtin_data) {
|
||||||
CheckParsePointerParams(op, error_reporter, allocator, builtin_data);
|
CheckParsePointerParams(op, error_reporter, allocator, builtin_data);
|
||||||
|
@ -254,6 +254,11 @@ TfLiteStatus ParseSin(const Operator* op, ErrorReporter* error_reporter,
|
|||||||
TfLiteStatus ParseSoftmax(const Operator* op, ErrorReporter* error_reporter,
|
TfLiteStatus ParseSoftmax(const Operator* op, ErrorReporter* error_reporter,
|
||||||
BuiltinDataAllocator* allocator, void** builtin_data);
|
BuiltinDataAllocator* allocator, void** builtin_data);
|
||||||
|
|
||||||
|
TfLiteStatus ParseSpaceToDepth(const Operator* op,
|
||||||
|
ErrorReporter* error_reporter,
|
||||||
|
BuiltinDataAllocator* allocator,
|
||||||
|
void** builtin_data);
|
||||||
|
|
||||||
TfLiteStatus ParseSplit(const Operator* op, ErrorReporter* error_reporter,
|
TfLiteStatus ParseSplit(const Operator* op, ErrorReporter* error_reporter,
|
||||||
BuiltinDataAllocator* allocator, void** builtin_data);
|
BuiltinDataAllocator* allocator, void** builtin_data);
|
||||||
|
|
||||||
|
@ -111,11 +111,15 @@ class Subgraph {
|
|||||||
inline TfLiteStatus SetTensorParametersReadWrite(
|
inline TfLiteStatus SetTensorParametersReadWrite(
|
||||||
int tensor_index, TfLiteType type, const char* name,
|
int tensor_index, TfLiteType type, const char* name,
|
||||||
const std::vector<int>& dims, TfLiteQuantization quantization,
|
const std::vector<int>& dims, TfLiteQuantization quantization,
|
||||||
bool is_variable = false, const size_t rank_dims_signature = 0,
|
bool is_variable = false, const std::vector<int>& dims_signature = {}) {
|
||||||
const int* dims_signature = nullptr) {
|
if (dims_signature.empty()) {
|
||||||
return SetTensorParametersReadWrite(tensor_index, type, name, dims.size(),
|
return SetTensorParametersReadWrite(tensor_index, type, name, dims.size(),
|
||||||
dims.data(), quantization, is_variable,
|
dims.data(), quantization,
|
||||||
rank_dims_signature, dims_signature);
|
is_variable);
|
||||||
|
}
|
||||||
|
return SetTensorParametersReadWrite(
|
||||||
|
tensor_index, type, name, dims.size(), dims.data(), quantization,
|
||||||
|
is_variable, dims_signature.size(), dims_signature.data());
|
||||||
}
|
}
|
||||||
TfLiteStatus SetTensorParametersReadWrite(
|
TfLiteStatus SetTensorParametersReadWrite(
|
||||||
int tensor_index, TfLiteType type, const char* name, const size_t rank,
|
int tensor_index, TfLiteType type, const char* name, const size_t rank,
|
||||||
|
@ -45,6 +45,56 @@ int3 GetWorkGroupsCount(int grid_dimension, const int3& grid_size,
|
|||||||
}
|
}
|
||||||
return work_groups_count;
|
return work_groups_count;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string GetCommonOpenCLDefines(CalculationsPrecision precision) {
|
||||||
|
std::string result;
|
||||||
|
|
||||||
|
result += "#define GLOBAL_ID_0 get_global_id(0)\n";
|
||||||
|
result += "#define GLOBAL_ID_1 get_global_id(1)\n";
|
||||||
|
result += "#define GLOBAL_ID_2 get_global_id(2)\n";
|
||||||
|
result += "#define MAIN_FUNCTION __kernel void main_function\n";
|
||||||
|
switch (precision) {
|
||||||
|
case CalculationsPrecision::F32:
|
||||||
|
result += "#pragma OPENCL EXTENSION cl_khr_3d_image_writes : enable\n";
|
||||||
|
result += "#define ACCUM_FLT4 float4\n";
|
||||||
|
result += "#define FLT float\n";
|
||||||
|
result += "#define FLT2 float2\n";
|
||||||
|
result += "#define FLT3 float3\n";
|
||||||
|
result += "#define FLT4 float4\n";
|
||||||
|
result += "#define TO_FLT4 convert_float4\n";
|
||||||
|
result += "#define TO_ACCUM_TYPE convert_float4\n";
|
||||||
|
result += "#define TO_ACCUM_FLT convert_float\n";
|
||||||
|
result += "#define INIT_FLT4(value) (float4)(value)\n";
|
||||||
|
break;
|
||||||
|
case CalculationsPrecision::F16:
|
||||||
|
result += "#pragma OPENCL EXTENSION cl_khr_3d_image_writes : enable\n";
|
||||||
|
result += "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n";
|
||||||
|
result += "#define ACCUM_FLT4 half4\n";
|
||||||
|
result += "#define FLT half\n";
|
||||||
|
result += "#define FLT2 half2\n";
|
||||||
|
result += "#define FLT3 half3\n";
|
||||||
|
result += "#define FLT4 half4\n";
|
||||||
|
result += "#define TO_FLT4 convert_half4\n";
|
||||||
|
result += "#define TO_ACCUM_TYPE convert_half4\n";
|
||||||
|
result += "#define TO_ACCUM_FLT convert_half\n";
|
||||||
|
result += "#define INIT_FLT4(value) (half4)(value)\n";
|
||||||
|
break;
|
||||||
|
case CalculationsPrecision::F32_F16:
|
||||||
|
result += "#pragma OPENCL EXTENSION cl_khr_3d_image_writes : enable\n";
|
||||||
|
result += "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n";
|
||||||
|
result += "#define ACCUM_FLT4 float4\n";
|
||||||
|
result += "#define FLT half\n";
|
||||||
|
result += "#define FLT2 half2\n";
|
||||||
|
result += "#define FLT3 half3\n";
|
||||||
|
result += "#define FLT4 half4\n";
|
||||||
|
result += "#define TO_FLT4 convert_half4\n";
|
||||||
|
result += "#define TO_ACCUM_TYPE convert_float4\n";
|
||||||
|
result += "#define TO_ACCUM_FLT convert_float\n";
|
||||||
|
result += "#define INIT_FLT4(value) (half4)(value)\n";
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
ClOperation::ClOperation(ClOperation&& operation)
|
ClOperation::ClOperation(ClOperation&& operation)
|
||||||
@ -94,6 +144,9 @@ absl::Status ClOperation::UpdateParams() {
|
|||||||
|
|
||||||
absl::Status ClOperation::Compile(const CreationContext& creation_context) {
|
absl::Status ClOperation::Compile(const CreationContext& creation_context) {
|
||||||
operation_->AssembleCode(creation_context.GetGpuInfo());
|
operation_->AssembleCode(creation_context.GetGpuInfo());
|
||||||
|
operation_->code_ =
|
||||||
|
GetCommonOpenCLDefines(operation_->definition_.precision) +
|
||||||
|
operation_->code_;
|
||||||
RETURN_IF_ERROR(cl_args_.Init(
|
RETURN_IF_ERROR(cl_args_.Init(
|
||||||
creation_context.GetGpuInfo(),
|
creation_context.GetGpuInfo(),
|
||||||
{{operation_->dst_tensors_names_[0], operation_->elementwise_code_}},
|
{{operation_->dst_tensors_names_[0], operation_->elementwise_code_}},
|
||||||
|
@ -336,6 +336,8 @@ absl::Status Tensor::GetGPUResources(const GPUObjectDescriptor* obj_ptr,
|
|||||||
if (!tensor_desc) {
|
if (!tensor_desc) {
|
||||||
return absl::InvalidArgumentError("Expected TensorDescriptor on input.");
|
return absl::InvalidArgumentError("Expected TensorDescriptor on input.");
|
||||||
}
|
}
|
||||||
|
resources->ints.push_back(
|
||||||
|
{"slice_stride", tensor_desc->GetSliceStrideSize(shape_)});
|
||||||
if (descriptor_.HasAxis(Axis::WIDTH)) {
|
if (descriptor_.HasAxis(Axis::WIDTH)) {
|
||||||
resources->ints.push_back({"width", Width()});
|
resources->ints.push_back({"width", Width()});
|
||||||
resources->ints.push_back({"width_div2", Width() / 2});
|
resources->ints.push_back({"width_div2", Width() / 2});
|
||||||
|
@ -22,61 +22,18 @@ limitations under the License.
|
|||||||
namespace tflite {
|
namespace tflite {
|
||||||
namespace gpu {
|
namespace gpu {
|
||||||
namespace {
|
namespace {
|
||||||
std::string GetCommonDefines(CalculationsPrecision precision) {
|
|
||||||
std::string result;
|
|
||||||
|
|
||||||
switch (precision) {
|
|
||||||
case CalculationsPrecision::F32:
|
|
||||||
result += "#pragma OPENCL EXTENSION cl_khr_3d_image_writes : enable\n";
|
|
||||||
result += "#define ACCUM_FLT4 float4\n";
|
|
||||||
result += "#define FLT float\n";
|
|
||||||
result += "#define FLT2 float2\n";
|
|
||||||
result += "#define FLT3 float3\n";
|
|
||||||
result += "#define FLT4 float4\n";
|
|
||||||
result += "#define TO_FLT4 convert_float4\n";
|
|
||||||
result += "#define TO_ACCUM_TYPE convert_float4\n";
|
|
||||||
result += "#define TO_ACCUM_FLT convert_float\n";
|
|
||||||
break;
|
|
||||||
case CalculationsPrecision::F16:
|
|
||||||
result += "#pragma OPENCL EXTENSION cl_khr_3d_image_writes : enable\n";
|
|
||||||
result += "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n";
|
|
||||||
result += "#define ACCUM_FLT4 half4\n";
|
|
||||||
result += "#define FLT half\n";
|
|
||||||
result += "#define FLT2 half2\n";
|
|
||||||
result += "#define FLT3 half3\n";
|
|
||||||
result += "#define FLT4 half4\n";
|
|
||||||
result += "#define TO_FLT4 convert_half4\n";
|
|
||||||
result += "#define TO_ACCUM_TYPE convert_half4\n";
|
|
||||||
result += "#define TO_ACCUM_FLT convert_half\n";
|
|
||||||
break;
|
|
||||||
case CalculationsPrecision::F32_F16:
|
|
||||||
result += "#pragma OPENCL EXTENSION cl_khr_3d_image_writes : enable\n";
|
|
||||||
result += "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n";
|
|
||||||
result += "#define ACCUM_FLT4 float4\n";
|
|
||||||
result += "#define FLT half\n";
|
|
||||||
result += "#define FLT2 half2\n";
|
|
||||||
result += "#define FLT3 half3\n";
|
|
||||||
result += "#define FLT4 half4\n";
|
|
||||||
result += "#define TO_FLT4 convert_half4\n";
|
|
||||||
result += "#define TO_ACCUM_TYPE convert_float4\n";
|
|
||||||
result += "#define TO_ACCUM_FLT convert_float\n";
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string GetElementWiseCode(const OperationDef& op_def,
|
std::string GetElementWiseCode(const OperationDef& op_def,
|
||||||
bool check_src_slices) {
|
bool check_src_slices) {
|
||||||
std::string c;
|
std::string c;
|
||||||
c += "__kernel void main_function(\n";
|
c += "MAIN_FUNCTION(\n";
|
||||||
c += "$0) {\n";
|
c += "$0) {\n";
|
||||||
c += " int X = get_global_id(0);\n";
|
c += " int X = GLOBAL_ID_0;\n";
|
||||||
c += " int Y = get_global_id(1);\n";
|
c += " int Y = GLOBAL_ID_1;\n";
|
||||||
c += " int Z = get_global_id(2);\n";
|
c += " int Z = GLOBAL_ID_2;\n";
|
||||||
c += " if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() || "
|
c += " if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() || "
|
||||||
"Z >= args.dst_tensor.Slices()) return; \n";
|
"Z >= args.dst_tensor.Slices()) return; \n";
|
||||||
if (check_src_slices) {
|
if (check_src_slices) {
|
||||||
c += " FLT4 src = (FLT4)(0.0f);\n";
|
c += " FLT4 src = INIT_FLT4(0.0f);\n";
|
||||||
c += " if (Z < args.src_tensor.Slices()) {\n";
|
c += " if (Z < args.src_tensor.Slices()) {\n";
|
||||||
c += " src = args.src_tensor.Read(X, Y, Z);\n";
|
c += " src = args.src_tensor.Read(X, Y, Z);\n";
|
||||||
c += " }\n";
|
c += " }\n";
|
||||||
@ -240,7 +197,6 @@ void GPUOperation::AssembleCode(const GpuInfo& gpu_info) {
|
|||||||
elementwise_code_ = "{\n" + code_ + "\n}\n" + elementwise_code_;
|
elementwise_code_ = "{\n" + code_ + "\n}\n" + elementwise_code_;
|
||||||
code_ = GetElementWiseCode(definition_, check_src_channels_size_);
|
code_ = GetElementWiseCode(definition_, check_src_channels_size_);
|
||||||
}
|
}
|
||||||
code_ = GetCommonDefines(definition_.precision) + code_;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void GPUOperation::GetPossibleKernelWorkGroups(
|
void GPUOperation::GetPossibleKernelWorkGroups(
|
||||||
|
@ -94,6 +94,7 @@ TensorDescriptor& TensorDescriptor::operator=(TensorDescriptor&& desc) {
|
|||||||
|
|
||||||
GPUResources TensorDescriptor::GetGPUResources() const {
|
GPUResources TensorDescriptor::GetGPUResources() const {
|
||||||
GPUResources resources;
|
GPUResources resources;
|
||||||
|
resources.ints.push_back("slice_stride");
|
||||||
if (HasAxis(Axis::WIDTH)) {
|
if (HasAxis(Axis::WIDTH)) {
|
||||||
resources.ints.push_back("width");
|
resources.ints.push_back("width");
|
||||||
resources.ints.push_back("width_div2");
|
resources.ints.push_back("width_div2");
|
||||||
@ -175,7 +176,7 @@ absl::Status TensorDescriptor::PerformSelector(
|
|||||||
*result = "slices";
|
*result = "slices";
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
} else if (selector == "SliceStride") {
|
} else if (selector == "SliceStride") {
|
||||||
*result = GetSliceStride();
|
*result = "slice_stride";
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
} else if (selector == "Channels") {
|
} else if (selector == "Channels") {
|
||||||
*result = "channels";
|
*result = "channels";
|
||||||
@ -402,7 +403,7 @@ absl::Status TensorDescriptor::PerformGetPtrWithSliceOffsetSelector(
|
|||||||
"GetPtrWithSliceOffset require one argument(slice coordinate), but ",
|
"GetPtrWithSliceOffset require one argument(slice coordinate), but ",
|
||||||
args.size(), " was passed"));
|
args.size(), " was passed"));
|
||||||
}
|
}
|
||||||
*result = absl::StrCat("buffer + ", args[0], " * ", GetSliceStride());
|
*result = absl::StrCat("buffer + ", args[0], " * slice_stride");
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -644,6 +645,35 @@ bool TensorDescriptor::HasAxis(Axis axis) const {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int TensorDescriptor::GetWidthSize(BHWDC shape) const {
|
||||||
|
int width = shape.w;
|
||||||
|
auto it1 = state_vars_.find("ElementsX2");
|
||||||
|
if (it1 != state_vars_.end() && it1->second == "true") {
|
||||||
|
width /= 2;
|
||||||
|
}
|
||||||
|
auto it2 = state_vars_.find("ElementsX4");
|
||||||
|
if (it2 != state_vars_.end() && it2->second == "true") {
|
||||||
|
width /= 4;
|
||||||
|
}
|
||||||
|
auto it = state_vars_.find("BatchedWidth");
|
||||||
|
if (it != state_vars_.end() && it->second == "true") {
|
||||||
|
width *= shape.b;
|
||||||
|
}
|
||||||
|
return width;
|
||||||
|
}
|
||||||
|
|
||||||
|
int TensorDescriptor::GetSliceStrideSize(BHWDC shape) const {
|
||||||
|
if (IsBatchedWidth()) {
|
||||||
|
return GetWidthSize(shape) * shape.h;
|
||||||
|
} else {
|
||||||
|
if (HasAxis(Axis::BATCH)) {
|
||||||
|
return GetWidthSize(shape) * shape.h * shape.b;
|
||||||
|
} else {
|
||||||
|
return GetWidthSize(shape) * shape.h;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void TensorDescriptor::SetAddressMode(AddressMode mode) {
|
void TensorDescriptor::SetAddressMode(AddressMode mode) {
|
||||||
if (mode == AddressMode::kZero) {
|
if (mode == AddressMode::kZero) {
|
||||||
state_vars_["TextureMode"] = "ZERO";
|
state_vars_["TextureMode"] = "ZERO";
|
||||||
@ -719,18 +749,6 @@ std::string TensorDescriptor::GetWidth() const {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string TensorDescriptor::GetSliceStride() const {
|
|
||||||
if (IsBatchedWidth()) {
|
|
||||||
return GetWidth() + " * height";
|
|
||||||
} else {
|
|
||||||
if (HasAxis(Axis::BATCH)) {
|
|
||||||
return GetWidth() + " * height * batch";
|
|
||||||
} else {
|
|
||||||
return GetWidth() + " * height";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
AddressMode TensorDescriptor::AddressModeFromState() const {
|
AddressMode TensorDescriptor::AddressModeFromState() const {
|
||||||
auto it = state_vars_.find("TextureMode");
|
auto it = state_vars_.find("TextureMode");
|
||||||
if (it != state_vars_.end()) {
|
if (it != state_vars_.end()) {
|
||||||
|
@ -70,6 +70,8 @@ struct TensorDescriptor : public GPUObjectDescriptor {
|
|||||||
|
|
||||||
bool HasAxis(Axis axis) const;
|
bool HasAxis(Axis axis) const;
|
||||||
void SetAddressMode(AddressMode mode);
|
void SetAddressMode(AddressMode mode);
|
||||||
|
int GetWidthSize(BHWDC shape) const;
|
||||||
|
int GetSliceStrideSize(BHWDC shape) const;
|
||||||
|
|
||||||
absl::Status GetLinkingContextFromWriteSelector(
|
absl::Status GetLinkingContextFromWriteSelector(
|
||||||
const std::vector<std::string>& args, std::string* value_name,
|
const std::vector<std::string>& args, std::string* value_name,
|
||||||
@ -136,7 +138,6 @@ struct TensorDescriptor : public GPUObjectDescriptor {
|
|||||||
bool IsBatchedWidth() const;
|
bool IsBatchedWidth() const;
|
||||||
|
|
||||||
std::string GetWidth() const;
|
std::string GetWidth() const;
|
||||||
std::string GetSliceStride() const;
|
|
||||||
|
|
||||||
AddressMode AddressModeFromState() const;
|
AddressMode AddressModeFromState() const;
|
||||||
|
|
||||||
|
@ -677,6 +677,7 @@ cc_library(
|
|||||||
"//tensorflow/lite/delegates/gpu/common:shape",
|
"//tensorflow/lite/delegates/gpu/common:shape",
|
||||||
"//tensorflow/lite/delegates/gpu/common:types",
|
"//tensorflow/lite/delegates/gpu/common:types",
|
||||||
"//tensorflow/lite/delegates/gpu/common:util",
|
"//tensorflow/lite/delegates/gpu/common:util",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common/task:util",
|
||||||
"//tensorflow/lite/delegates/gpu/metal:compute_task_descriptor",
|
"//tensorflow/lite/delegates/gpu/metal:compute_task_descriptor",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -32,59 +32,56 @@ namespace gpu {
|
|||||||
namespace metal {
|
namespace metal {
|
||||||
|
|
||||||
std::string GetResizeBilinearCode(const Resize2DAttributes& attr) {
|
std::string GetResizeBilinearCode(const Resize2DAttributes& attr) {
|
||||||
std::string code = R"(
|
std::string c = R"(
|
||||||
#include <metal_stdlib>
|
#include <metal_stdlib>
|
||||||
using namespace metal;
|
using namespace metal;
|
||||||
$0
|
$0
|
||||||
kernel void ComputeFunction(
|
kernel void ComputeFunction(
|
||||||
$1
|
$1
|
||||||
uint3 gid[[thread_position_in_grid]]) {
|
uint3 gid[[thread_position_in_grid]]) {
|
||||||
if (int(gid.x) >= size.z || int(gid.y) >= size.w) {
|
if (int(gid.x) >= args.dst_tensor.Width() || int(gid.y) >= args.dst_tensor.Height()) {
|
||||||
return;
|
return;
|
||||||
})";
|
|
||||||
if (attr.half_pixel_centers) {
|
|
||||||
code += "const float2 tex_coord = (float2(gid.xy) + 0.5f) * scale - 0.5f;";
|
|
||||||
} else {
|
|
||||||
code += "const float2 tex_coord = float2(gid.xy) * scale;";
|
|
||||||
}
|
}
|
||||||
code += R"(
|
)";
|
||||||
const float2 tex_coord_floor = floor(tex_coord);
|
if (attr.half_pixel_centers) {
|
||||||
const int2 itex_coord_floor = int2(tex_coord_floor);
|
c += " float2 tex_coord = (float2(gid.xy) + 0.5f) * scale - 0.5f;";
|
||||||
const int2 borders = size.xy - int2(1, 1);
|
} else {
|
||||||
int4 st;
|
c += " float2 tex_coord = float2(gid.xy) * scale;";
|
||||||
st.xy = max(itex_coord_floor, int2(0, 0));
|
}
|
||||||
st.zw = min(itex_coord_floor + int2(1, 1), borders);
|
c += R"(
|
||||||
const float2 t = tex_coord - tex_coord_floor; // interpolating factors
|
float2 tex_coord_floor = floor(tex_coord);
|
||||||
const int src_index0 = (gid.z * size.y + st.y) * size.x + st.x;
|
int2 itex_coord_floor = int2(tex_coord_floor);
|
||||||
const int src_index1 = (gid.z * size.y + st.y) * size.x + st.z;
|
int2 borders = int2(args.src_tensor.Width() - 1, args.src_tensor.Height() - 1);
|
||||||
const int src_index2 = (gid.z * size.y + st.w) * size.x + st.x;
|
int4 st;
|
||||||
const int src_index3 = (gid.z * size.y + st.w) * size.x + st.z;
|
st.xy = max(itex_coord_floor, int2(0, 0));
|
||||||
FLT4 tex11 = src_tensor[src_index0];
|
st.zw = min(itex_coord_floor + int2(1, 1), borders);
|
||||||
FLT4 tex21 = src_tensor[src_index1];
|
float2 t = tex_coord - tex_coord_floor; // interpolating factors
|
||||||
FLT4 tex12 = src_tensor[src_index2];
|
FLT4 tex11 = args.src_tensor.Read(st.x, st.y, gid.z);
|
||||||
FLT4 tex22 = src_tensor[src_index3];
|
FLT4 tex21 = args.src_tensor.Read(st.z, st.y, gid.z);
|
||||||
// bilinear interpolation
|
FLT4 tex12 = args.src_tensor.Read(st.x, st.w, gid.z);
|
||||||
FLT4 value = mix(mix(tex11, tex21, static_cast<FLT>(t.x)),
|
FLT4 tex22 = args.src_tensor.Read(st.z, st.w, gid.z);
|
||||||
mix(tex12, tex22, static_cast<FLT>(t.x)), static_cast<FLT>(t.y));
|
// bilinear interpolation
|
||||||
const int linear_index = (gid.z * size.w + gid.y) * size.z + gid.x;
|
FLT4 value = mix(mix(tex11, tex21, static_cast<FLT>(t.x)),
|
||||||
$2
|
mix(tex12, tex22, static_cast<FLT>(t.x)), static_cast<FLT>(t.y));
|
||||||
dst_tensor[linear_index] = value;
|
args.dst_tensor.GetAddress(linear_index, gid.x, gid.y, gid.z);
|
||||||
}
|
$2
|
||||||
)";
|
args.dst_tensor.Write(value, gid.x, gid.y, gid.z);
|
||||||
return code;
|
}
|
||||||
|
)";
|
||||||
|
return c;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string GetResizeNearestCode(const Resize2DAttributes& attr) {
|
std::string GetResizeNearestCode(const Resize2DAttributes& attr) {
|
||||||
std::string code = R"(
|
std::string c = R"(
|
||||||
#include <metal_stdlib>
|
#include <metal_stdlib>
|
||||||
using namespace metal;
|
using namespace metal;
|
||||||
$0
|
$0
|
||||||
kernel void ComputeFunction(
|
kernel void ComputeFunction(
|
||||||
$1
|
$1
|
||||||
uint3 gid[[thread_position_in_grid]]) {
|
uint3 gid[[thread_position_in_grid]]) {
|
||||||
if (int(gid.x) >= size.z || int(gid.y) >= size.w) {
|
if (int(gid.x) >= args.dst_tensor.Width() || int(gid.y) >= args.dst_tensor.Height()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
std::string fxc;
|
std::string fxc;
|
||||||
std::string fyc;
|
std::string fyc;
|
||||||
@ -99,27 +96,27 @@ std::string GetResizeNearestCode(const Resize2DAttributes& attr) {
|
|||||||
fxc += " + 0.5f";
|
fxc += " + 0.5f";
|
||||||
fyc += " + 0.5f";
|
fyc += " + 0.5f";
|
||||||
}
|
}
|
||||||
code += " int2 coord;\n";
|
c += " int2 coord;\n";
|
||||||
code += " coord.x = static_cast<int>(" + fxc + ");\n";
|
c += " coord.x = static_cast<int>(" + fxc + ");\n";
|
||||||
code += " coord.y = static_cast<int>(" + fyc + ");\n";
|
c += " coord.y = static_cast<int>(" + fyc + ");\n";
|
||||||
code += " coord.x = max(0, coord.x);\n";
|
c += " coord.x = max(0, coord.x);\n";
|
||||||
code += " coord.y = max(0, coord.y);\n";
|
c += " coord.y = max(0, coord.y);\n";
|
||||||
code += " coord.x = min(coord.x, size.x - 1);\n";
|
c += " coord.x = min(coord.x, args.src_tensor.Width() - 1);\n";
|
||||||
code += " coord.y = min(coord.y, size.y - 1);\n";
|
c += " coord.y = min(coord.y, args.src_tensor.Height() - 1);\n";
|
||||||
code += R"(
|
c += R"(
|
||||||
const int src_index = (gid.z * size.y + coord.y) * size.x + coord.x;
|
FLT4 value = args.src_tensor.Read(coord.x, coord.y, gid.z);
|
||||||
FLT4 value = src_tensor[src_index];
|
args.dst_tensor.GetAddress(linear_index, gid.x, gid.y, gid.z);
|
||||||
const int linear_index = (gid.z * size.w + gid.y) * size.z + gid.x;
|
$2
|
||||||
$2
|
args.dst_tensor.Write(value, gid.x, gid.y, gid.z);
|
||||||
dst_tensor[linear_index] = value;
|
}
|
||||||
}
|
)";
|
||||||
)";
|
return c;
|
||||||
return code;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ComputeTaskDescriptor Resize(const OperationDef& definition,
|
ComputeTaskDescriptor Resize(const OperationDef& definition,
|
||||||
const Resize2DAttributes& attr) {
|
const Resize2DAttributes& attr) {
|
||||||
ComputeTaskDescriptor desc(definition);
|
ComputeTaskDescriptor desc(definition);
|
||||||
|
desc.tensors_as_args = true;
|
||||||
switch (attr.type) {
|
switch (attr.type) {
|
||||||
case SamplingType::BILINEAR:
|
case SamplingType::BILINEAR:
|
||||||
desc.shader_source = GetResizeBilinearCode(attr);
|
desc.shader_source = GetResizeBilinearCode(attr);
|
||||||
@ -136,17 +133,6 @@ ComputeTaskDescriptor Resize(const OperationDef& definition,
|
|||||||
desc.AddDstTensor("dst_tensor", definition.dst_tensors[0]);
|
desc.AddDstTensor("dst_tensor", definition.dst_tensors[0]);
|
||||||
|
|
||||||
desc.uniform_buffers = {
|
desc.uniform_buffers = {
|
||||||
{"constant int4& size",
|
|
||||||
[](const std::vector<BHWC>& src_shapes,
|
|
||||||
const std::vector<BHWC>& dst_shapes) {
|
|
||||||
std::vector<int> sizes = {
|
|
||||||
src_shapes[0].w,
|
|
||||||
src_shapes[0].h,
|
|
||||||
dst_shapes[0].w,
|
|
||||||
dst_shapes[0].h,
|
|
||||||
};
|
|
||||||
return GetByteBuffer(sizes);
|
|
||||||
}},
|
|
||||||
{"constant float2& scale",
|
{"constant float2& scale",
|
||||||
[attr](const std::vector<BHWC>& src_shapes,
|
[attr](const std::vector<BHWC>& src_shapes,
|
||||||
const std::vector<BHWC>& dst_shapes) {
|
const std::vector<BHWC>& dst_shapes) {
|
||||||
@ -160,10 +146,10 @@ ComputeTaskDescriptor Resize(const OperationDef& definition,
|
|||||||
|
|
||||||
desc.resize_function = [](const std::vector<BHWC>& src_shapes,
|
desc.resize_function = [](const std::vector<BHWC>& src_shapes,
|
||||||
const std::vector<BHWC>& dst_shapes) {
|
const std::vector<BHWC>& dst_shapes) {
|
||||||
const uint3 groups_size{16, 16, 1};
|
const uint3 groups_size{8, 8, 1};
|
||||||
|
const int dst_layers = DivideRoundUp(dst_shapes[0].c, 4);
|
||||||
int groups_x = DivideRoundUp(dst_shapes[0].w, groups_size.x);
|
int groups_x = DivideRoundUp(dst_shapes[0].w, groups_size.x);
|
||||||
int groups_y = DivideRoundUp(dst_shapes[0].h, groups_size.y);
|
int groups_y = DivideRoundUp(dst_shapes[0].h, groups_size.y);
|
||||||
const int dst_layers = DivideRoundUp(dst_shapes[0].c, 4);
|
|
||||||
int groups_z = DivideRoundUp(dst_layers, groups_size.z);
|
int groups_z = DivideRoundUp(dst_layers, groups_size.z);
|
||||||
return std::make_pair(groups_size, uint3{groups_x, groups_y, groups_z});
|
return std::make_pair(groups_size, uint3{groups_x, groups_y, groups_z});
|
||||||
};
|
};
|
||||||
|
@ -35,123 +35,146 @@ namespace gpu {
|
|||||||
namespace metal {
|
namespace metal {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
std::string GetSliceCode(const SliceAttributes& attr) {
|
namespace {
|
||||||
std::stringstream code;
|
bool Is4Aligned(const SliceAttributes& attr) {
|
||||||
|
return attr.strides.c == 1 && attr.starts.c % 4 == 0;
|
||||||
|
}
|
||||||
|
|
||||||
code << R"(
|
int4 GetOffset(const SliceAttributes& attr, int src_width, int src_height,
|
||||||
#include <metal_stdlib>
|
int src_channels, int src_batch) {
|
||||||
using namespace metal;
|
int4 offset;
|
||||||
|
|
||||||
struct uniforms {
|
|
||||||
int4 src_size;
|
|
||||||
int4 dst_size;
|
|
||||||
};
|
|
||||||
|
|
||||||
constant int4 width = int4($0, $1, $2, 0);
|
|
||||||
constant int4 height = int4($3, $4, $5, 0);
|
|
||||||
constant int4 channels = int4($6, $7, $8, 0);
|
|
||||||
constant FLT4 null_vec = FLT4(0.0f, 0.0f, 0.0f, 0.0f);
|
|
||||||
|
|
||||||
$$0
|
|
||||||
kernel void ComputeFunction(
|
|
||||||
$$1
|
|
||||||
uint3 gid[[thread_position_in_grid]]) {
|
|
||||||
if (static_cast<int>(gid.x) >= params.dst_size.x ||
|
|
||||||
static_cast<int>(gid.y) >= params.dst_size.y) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
FLT4 value;
|
|
||||||
short2 offset;
|
|
||||||
)";
|
|
||||||
if (attr.strides.w > 0) {
|
if (attr.strides.w > 0) {
|
||||||
code << " offset.x = width.x;" << std::endl;
|
offset.x = attr.starts.w;
|
||||||
} else {
|
} else {
|
||||||
if (attr.ends.w > 0) {
|
if (attr.ends.w > 0) {
|
||||||
code << " offset.x = width.z;" << std::endl;
|
offset.x = attr.ends.w;
|
||||||
} else {
|
} else {
|
||||||
code << " offset.x = params.src_size.x + width.z;" << std::endl;
|
offset.x = src_width + attr.ends.w;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (attr.strides.h > 0) {
|
if (attr.strides.h > 0) {
|
||||||
code << " offset.y = height.x;" << std::endl;
|
offset.y = attr.starts.h;
|
||||||
} else {
|
} else {
|
||||||
if (attr.ends.h > 0) {
|
if (attr.ends.h > 0) {
|
||||||
code << " offset.y = height.z;" << std::endl;
|
offset.y = attr.ends.h;
|
||||||
} else {
|
} else {
|
||||||
code << " offset.y = params.src_size.y + height.z;" << std::endl;
|
offset.y = src_height + attr.ends.h;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
code << std::endl;
|
if (attr.strides.c > 0) {
|
||||||
code << " short2 stride = short2(width.y, height.y);" << std::endl;
|
offset.z = attr.starts.c;
|
||||||
|
} else {
|
||||||
|
if (attr.ends.c > 0) {
|
||||||
|
offset.z = attr.ends.c;
|
||||||
|
} else {
|
||||||
|
offset.z = src_channels + attr.ends.c;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (Is4Aligned(attr)) {
|
||||||
|
offset.z /= 4;
|
||||||
|
}
|
||||||
|
if (attr.strides.b > 0) {
|
||||||
|
offset.w = attr.starts.b;
|
||||||
|
} else {
|
||||||
|
if (attr.ends.b > 0) {
|
||||||
|
offset.w = attr.ends.b;
|
||||||
|
} else {
|
||||||
|
offset.w = src_batch + attr.ends.b;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return offset;
|
||||||
|
}
|
||||||
|
|
||||||
code << " const short2 s_c = offset + short2(gid.xy) * stride;"
|
} // namespace
|
||||||
<< std::endl;
|
|
||||||
code << " bool outside = false;" << std::endl;
|
std::string GetSliceCode(const OperationDef& op_def, bool alignedx4) {
|
||||||
code << " int step = gid.z * 4;" << std::endl;
|
const std::string batch_id =
|
||||||
code << " FLT4 tmp;" << std::endl;
|
op_def.dst_tensors[0].HasAxis(Axis::BATCH) ? "B" : "0";
|
||||||
code << " int buffer_index = 0;" << std::endl;
|
std::string c = R"(
|
||||||
code << " int addr = 0;" << std::endl;
|
#include <metal_stdlib>
|
||||||
code << std::endl;
|
using namespace metal;
|
||||||
for (int i = 0; i < 4; i++) {
|
|
||||||
code << " addr = step * channels.y;" << std::endl;
|
struct uniforms {
|
||||||
if (attr.strides.c > 0) {
|
int4 offset;
|
||||||
code << " addr += channels.x;" << std::endl;
|
int4 stride;
|
||||||
} else {
|
};
|
||||||
if (attr.ends.c > 0) {
|
|
||||||
code << " addr += channels.z;" << std::endl;
|
$0
|
||||||
} else {
|
kernel void ComputeFunction($1
|
||||||
code << " addr += params.src_size.z + channels.z;" << std::endl;
|
uint3 gid[[thread_position_in_grid]]) {
|
||||||
}
|
)";
|
||||||
}
|
if (op_def.dst_tensors[0].HasAxis(Axis::BATCH)) {
|
||||||
code << " buffer_index = ((addr / 4) * params.src_size.y + s_c.y) * "
|
c += " int linear_id = static_cast<int>(gid.x);\n";
|
||||||
"params.src_size.x + "
|
c += " int X = linear_id / args.dst_tensor.Batch();\n";
|
||||||
"s_c.x;"
|
c += " int B = linear_id % args.dst_tensor.Batch();\n";
|
||||||
<< std::endl;
|
c += " args.dst_tensor.SetBatchRef(B);\n";
|
||||||
code << " outside = step >= params.dst_size.z;" << std::endl;
|
} else {
|
||||||
code << " tmp = outside ? null_vec : src_tensor[buffer_index];"
|
c += " int X = static_cast<int>(gid.x);\n";
|
||||||
<< std::endl;
|
}
|
||||||
code << " value[" << i << "] = tmp[addr % 4];" << std::endl;
|
c += " int Y = static_cast<int>(gid.y);\n";
|
||||||
if (i != 3) {
|
c += " int Z = static_cast<int>(gid.z);\n";
|
||||||
code << " step++;" << std::endl;
|
c += " if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() || "
|
||||||
code << std::endl;
|
"Z >= args.dst_tensor.Slices()) { \n";
|
||||||
|
c += " return; \n";
|
||||||
|
c += " } \n";
|
||||||
|
c += " int s_x = X * params.stride.x + params.offset.x;\n";
|
||||||
|
c += " int s_y = Y * params.stride.y + params.offset.y;\n";
|
||||||
|
if (op_def.src_tensors[0].HasAxis(Axis::BATCH)) {
|
||||||
|
c += " int s_b = " + batch_id + " * params.stride.w + params.offset.w;\n";
|
||||||
|
c += " args.src_tensor.SetBatchRef(s_b);\n";
|
||||||
|
}
|
||||||
|
if (alignedx4) {
|
||||||
|
c += " int s_z = Z + params.offset.z;\n";
|
||||||
|
c += " FLT4 result = args.src_tensor.Read(s_x, s_y, s_z);\n";
|
||||||
|
} else {
|
||||||
|
c += " FLT4 result;\n";
|
||||||
|
const std::string postfixes[] = {"x", "y", "z", "w"};
|
||||||
|
for (int i = 0; i < 4; ++i) {
|
||||||
|
c += " {\n";
|
||||||
|
const std::string ch = "(Z * 4 + " + std::to_string(i) + ")";
|
||||||
|
c += " int s_ch = " + ch + " * params.stride.z + params.offset.z;\n";
|
||||||
|
c += " int s_z = min(s_ch >> 2, args.src_tensor.Slices() - 1);\n";
|
||||||
|
c += " int s_z_rem = s_ch & 3;\n";
|
||||||
|
c += " FLT4 t = args.src_tensor.Read(s_x, s_y, s_z);\n";
|
||||||
|
c += " result." + postfixes[i] + " = t[s_ch & 3];\n";
|
||||||
|
c += " }\n";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
code << R"(
|
c += " FLT4 value = result;\n";
|
||||||
int linear_index = (gid.z * params.dst_size.y + int(gid.y)) *
|
c += " args.dst_tensor.GetAddress(linear_index, X, Y, Z);\n";
|
||||||
params.dst_size.x + int(gid.x);
|
c += " $2\n";
|
||||||
$$2
|
c += " args.dst_tensor.Write(value, X, Y, Z);\n";
|
||||||
dst_tensor[linear_index] = value;
|
c += "}\n";
|
||||||
})";
|
return c;
|
||||||
return absl::Substitute(
|
|
||||||
code.str(), attr.starts.w, attr.strides.w, attr.ends.w, attr.starts.h,
|
|
||||||
attr.strides.h, attr.ends.h, attr.starts.c, attr.strides.c, attr.ends.c);
|
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
ComputeTaskDescriptor Slice(const OperationDef& definition,
|
ComputeTaskDescriptor Slice(const OperationDef& definition,
|
||||||
const SliceAttributes& attr) {
|
const SliceAttributes& attr) {
|
||||||
ComputeTaskDescriptor desc(definition);
|
ComputeTaskDescriptor desc(definition);
|
||||||
desc.shader_source = GetSliceCode(attr);
|
desc.tensors_as_args = true;
|
||||||
|
desc.shader_source = GetSliceCode(definition, Is4Aligned(attr));
|
||||||
|
|
||||||
desc.AddSrcTensor("src_tensor", definition.src_tensors[0]);
|
desc.AddSrcTensor("src_tensor", definition.src_tensors[0]);
|
||||||
desc.AddDstTensor("dst_tensor", definition.dst_tensors[0]);
|
desc.AddDstTensor("dst_tensor", definition.dst_tensors[0]);
|
||||||
|
|
||||||
desc.uniform_buffers = {
|
desc.uniform_buffers = {
|
||||||
{"constant uniforms& params",
|
{"constant uniforms& params",
|
||||||
[](const std::vector<BHWC>& src_shapes,
|
[attr](const std::vector<BHWC>& src_shapes,
|
||||||
const std::vector<BHWC>& dst_shapes) {
|
const std::vector<BHWC>& dst_shapes) {
|
||||||
|
int4 offset = GetOffset(attr, src_shapes[0].w, src_shapes[0].h,
|
||||||
|
src_shapes[0].c, src_shapes[0].b);
|
||||||
std::vector<int> uniform_params{
|
std::vector<int> uniform_params{
|
||||||
// int4 src_size
|
// int4 offset
|
||||||
src_shapes[0].w,
|
offset.x,
|
||||||
src_shapes[0].h,
|
offset.y,
|
||||||
src_shapes[0].c,
|
offset.z,
|
||||||
DivideRoundUp(src_shapes[0].c, 4),
|
offset.w,
|
||||||
// int4 dst_size
|
// int4 stride
|
||||||
dst_shapes[0].w,
|
attr.strides.w,
|
||||||
dst_shapes[0].h,
|
attr.strides.h,
|
||||||
dst_shapes[0].c,
|
attr.strides.c,
|
||||||
DivideRoundUp(dst_shapes[0].c, 4),
|
attr.strides.b,
|
||||||
};
|
};
|
||||||
return GetByteBuffer(uniform_params);
|
return GetByteBuffer(uniform_params);
|
||||||
}},
|
}},
|
||||||
@ -159,10 +182,10 @@ ComputeTaskDescriptor Slice(const OperationDef& definition,
|
|||||||
|
|
||||||
desc.resize_function = [attr](const std::vector<BHWC>& src_shapes,
|
desc.resize_function = [attr](const std::vector<BHWC>& src_shapes,
|
||||||
const std::vector<BHWC>& dst_shapes) {
|
const std::vector<BHWC>& dst_shapes) {
|
||||||
const uint3 groups_size{16, 16, 1};
|
const uint3 groups_size{8, 4, 1};
|
||||||
|
const int dst_layers = DivideRoundUp(dst_shapes[0].c, 4);
|
||||||
int groups_x = DivideRoundUp(dst_shapes[0].w, groups_size.x);
|
int groups_x = DivideRoundUp(dst_shapes[0].w, groups_size.x);
|
||||||
int groups_y = DivideRoundUp(dst_shapes[0].h, groups_size.y);
|
int groups_y = DivideRoundUp(dst_shapes[0].h, groups_size.y);
|
||||||
const int dst_layers = DivideRoundUp(dst_shapes[0].c, 4);
|
|
||||||
int groups_z = DivideRoundUp(dst_layers, groups_size.z);
|
int groups_z = DivideRoundUp(dst_layers, groups_size.z);
|
||||||
return std::make_pair(groups_size, uint3{groups_x, groups_y, groups_z});
|
return std::make_pair(groups_size, uint3{groups_x, groups_y, groups_z});
|
||||||
};
|
};
|
||||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/lite/delegates/gpu/common/gpu_info.h"
|
#include "tensorflow/lite/delegates/gpu/common/gpu_info.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/common/model.h"
|
#include "tensorflow/lite/delegates/gpu/common/model.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/common/shape.h"
|
#include "tensorflow/lite/delegates/gpu/common/shape.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/task/util.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/common/types.h"
|
#include "tensorflow/lite/delegates/gpu/common/types.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/common/util.h"
|
#include "tensorflow/lite/delegates/gpu/common/util.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
|
#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
|
||||||
@ -40,7 +41,6 @@ std::string GetSoftmax1x1Code(const GpuInfo& gpu_info) {
|
|||||||
using namespace metal;
|
using namespace metal;
|
||||||
|
|
||||||
struct uniforms {
|
struct uniforms {
|
||||||
int4 size;
|
|
||||||
float4 mask;
|
float4 mask;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -51,11 +51,11 @@ kernel void ComputeFunction($1
|
|||||||
uint3 ugid[[thread_position_in_grid]])
|
uint3 ugid[[thread_position_in_grid]])
|
||||||
{
|
{
|
||||||
|
|
||||||
float4 maxx4 = float4(src_tensor[0].x);
|
float4 maxx4 = float4(args.src_tensor.Read(0, 0, 0).x);
|
||||||
for (int s = int(tid); s < params.size.x; s += 32) {
|
for (int s = int(tid); s < args.src_tensor.Slices(); s += 32) {
|
||||||
float4 mask_a = s == params.size.x - 1 ? params.mask : float4(1.0f);
|
float4 mask_a = s == args.src_tensor.Slices() - 1 ? params.mask : float4(1.0f);
|
||||||
float4 mask_b = float4(1.0f) - mask_a;
|
float4 mask_b = float4(1.0f) - mask_a;
|
||||||
float4 src = float4(src_tensor[s]);
|
float4 src = float4(args.src_tensor.Read(0, 0, s));
|
||||||
src = src * mask_a + mask_b * src.x;
|
src = src * mask_a + mask_b * src.x;
|
||||||
maxx4 = max(maxx4, src);
|
maxx4 = max(maxx4, src);
|
||||||
}
|
}
|
||||||
@ -89,9 +89,9 @@ kernel void ComputeFunction($1
|
|||||||
maximum = tmpx1[0];
|
maximum = tmpx1[0];
|
||||||
|
|
||||||
float sum = 0.0f;
|
float sum = 0.0f;
|
||||||
for (int s = int(tid); s < params.size.x; s += 32) {
|
for (int s = int(tid); s < args.src_tensor.Slices(); s += 32) {
|
||||||
float4 mask_temp = s == params.size.x - 1 ? params.mask : float4(1.0f);
|
float4 mask_temp = s == args.src_tensor.Slices() - 1 ? params.mask : float4(1.0f);
|
||||||
float4 src = float4(src_tensor[s]) - float4(maximum);
|
float4 src = float4(args.src_tensor.Read(0, 0, s)) - float4(maximum);
|
||||||
sum += dot(mask_temp, exp(src));
|
sum += dot(mask_temp, exp(src));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -120,13 +120,13 @@ kernel void ComputeFunction($1
|
|||||||
sum = tmpx1[0];
|
sum = tmpx1[0];
|
||||||
|
|
||||||
int dst_s = int(ugid.x);
|
int dst_s = int(ugid.x);
|
||||||
if (dst_s < params.size.x) {
|
if (dst_s < args.src_tensor.Slices()) {
|
||||||
int linear_index = dst_s;
|
float4 src = float4(args.src_tensor.Read(0, 0, dst_s)) - float4(maximum);
|
||||||
float4 src = float4(src_tensor[linear_index]) - float4(maximum);
|
|
||||||
FLT4 value = FLT4(exp(src) * sum);
|
FLT4 value = FLT4(exp(src) * sum);
|
||||||
uint3 gid = uint3(0, 0, linear_index);
|
uint3 gid = uint3(0, 0, dst_s);
|
||||||
|
args.dst_tensor.GetAddress(linear_index, 0, 0, dst_s);
|
||||||
$2
|
$2
|
||||||
dst_tensor[linear_index] = value;
|
args.dst_tensor.Write(value, 0, 0, dst_s);
|
||||||
}
|
}
|
||||||
})";
|
})";
|
||||||
return code;
|
return code;
|
||||||
@ -135,28 +135,27 @@ kernel void ComputeFunction($1
|
|||||||
|
|
||||||
ComputeTaskDescriptor Softmax(const OperationDef& definition) {
|
ComputeTaskDescriptor Softmax(const OperationDef& definition) {
|
||||||
ComputeTaskDescriptor desc(definition);
|
ComputeTaskDescriptor desc(definition);
|
||||||
|
desc.tensors_as_args = true;
|
||||||
desc.shader_source = R"(
|
desc.shader_source = R"(
|
||||||
#include <metal_stdlib>
|
#include <metal_stdlib>
|
||||||
using namespace metal;
|
using namespace metal;
|
||||||
|
|
||||||
struct uniforms {
|
struct uniforms {
|
||||||
int4 size;
|
|
||||||
float4 mask;
|
float4 mask;
|
||||||
};
|
};
|
||||||
$0
|
$0
|
||||||
kernel void ComputeFunction(
|
kernel void ComputeFunction(
|
||||||
$1
|
$1
|
||||||
uint3 gid[[thread_position_in_grid]]) {
|
uint3 gid[[thread_position_in_grid]]) {
|
||||||
if (int(gid.x) >= params.size.x || int(gid.y) >= params.size.y) {
|
if (int(gid.x) >= args.dst_tensor.Width() || int(gid.y) >= args.dst_tensor.Height()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
float maximum = src_tensor[gid.y * params.size.x + gid.x].x;
|
float maximum = args.src_tensor.Read(gid.x, gid.y, 0).x;
|
||||||
for (int d = 0; d < params.size.z; ++d) {
|
for (int d = 0; d < args.dst_tensor.Slices(); ++d) {
|
||||||
int buffer_index = (d * params.size.y + gid.y) * params.size.x + gid.x;
|
float4 mask_a = d == args.dst_tensor.Slices() - 1 ? params.mask : float4(1.0f);
|
||||||
float4 mask_a = d == params.size.z - 1 ? params.mask : float4(1.0f);
|
|
||||||
float4 mask_b = float4(1.0f) - mask_a;
|
float4 mask_b = float4(1.0f) - mask_a;
|
||||||
float4 src = float4(src_tensor[buffer_index]);
|
float4 src = float4(args.src_tensor.Read(gid.x, gid.y, d));
|
||||||
src = src * mask_a + mask_b * src.x;
|
src = src * mask_a + mask_b * src.x;
|
||||||
maximum = max(maximum, src.x);
|
maximum = max(maximum, src.x);
|
||||||
maximum = max(maximum, src.y);
|
maximum = max(maximum, src.y);
|
||||||
@ -165,19 +164,18 @@ kernel void ComputeFunction(
|
|||||||
}
|
}
|
||||||
|
|
||||||
float sum = 0.0f;
|
float sum = 0.0f;
|
||||||
for (int d = 0; d < params.size.z; ++d) {
|
for (int d = 0; d < args.dst_tensor.Slices(); ++d) {
|
||||||
int buffer_index = (d * params.size.y + gid.y) * params.size.x + gid.x;
|
float4 mask_temp = d == args.dst_tensor.Slices() - 1 ? params.mask : float4(1.0f);
|
||||||
float4 mask_temp = d == params.size.z - 1 ? params.mask : float4(1.0f);
|
float4 src = float4(args.src_tensor.Read(gid.x, gid.y, d)) - float4(maximum);
|
||||||
float4 src = float4(src_tensor[buffer_index]) - float4(maximum);
|
|
||||||
sum += dot(mask_temp, exp(src));
|
sum += dot(mask_temp, exp(src));
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int d = 0; d < params.size.z; ++d) {
|
for (int d = 0; d < args.dst_tensor.Slices(); ++d) {
|
||||||
const int linear_index = (d * params.size.y + gid.y) * params.size.x + gid.x;
|
float4 src = float4(args.src_tensor.Read(gid.x, gid.y, d)) - float4(maximum);
|
||||||
float4 src = float4(src_tensor[linear_index]) - float4(maximum);
|
|
||||||
FLT4 value = FLT4(exp(src) / sum);
|
FLT4 value = FLT4(exp(src) / sum);
|
||||||
|
args.dst_tensor.GetAddress(linear_index, gid.x, gid.y, d);
|
||||||
$2
|
$2
|
||||||
dst_tensor[linear_index] = value;
|
args.dst_tensor.Write(value, gid.x, gid.y, d);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
@ -189,20 +187,9 @@ kernel void ComputeFunction(
|
|||||||
{"constant uniforms& params",
|
{"constant uniforms& params",
|
||||||
[](const std::vector<BHWC>& src_shapes,
|
[](const std::vector<BHWC>& src_shapes,
|
||||||
const std::vector<BHWC>& dst_shapes) {
|
const std::vector<BHWC>& dst_shapes) {
|
||||||
const int dst_depth = DivideRoundUp(dst_shapes[0].c, 4);
|
float4 mask = GetMaskForLastPlane(dst_shapes[0].c);
|
||||||
struct uniforms {
|
const uint8_t* ptr = reinterpret_cast<const uint8_t*>(&mask);
|
||||||
int4 size;
|
return std::vector<uint8_t>(ptr, ptr + sizeof(float4));
|
||||||
float4 mask;
|
|
||||||
};
|
|
||||||
uniforms params;
|
|
||||||
params.size = {dst_shapes[0].w, dst_shapes[0].h, dst_depth, 1};
|
|
||||||
params.mask = {0.0f, 0.0f, 0.0f, 0.0f};
|
|
||||||
int reminder = dst_shapes[0].c % 4 == 0 ? 4 : dst_shapes[0].c % 4;
|
|
||||||
for (int i = 0; i < reminder; ++i) {
|
|
||||||
params.mask[i] = 1.0f;
|
|
||||||
}
|
|
||||||
const uint8_t* ptr = reinterpret_cast<const uint8_t*>(¶ms);
|
|
||||||
return std::vector<uint8_t>(ptr, ptr + sizeof(uniforms));
|
|
||||||
}},
|
}},
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -220,6 +207,7 @@ kernel void ComputeFunction(
|
|||||||
ComputeTaskDescriptor Softmax1x1(const OperationDef& definition,
|
ComputeTaskDescriptor Softmax1x1(const OperationDef& definition,
|
||||||
const GpuInfo& gpu_info) {
|
const GpuInfo& gpu_info) {
|
||||||
ComputeTaskDescriptor desc(definition);
|
ComputeTaskDescriptor desc(definition);
|
||||||
|
desc.tensors_as_args = true;
|
||||||
desc.shader_source = GetSoftmax1x1Code(gpu_info);
|
desc.shader_source = GetSoftmax1x1Code(gpu_info);
|
||||||
|
|
||||||
desc.AddSrcTensor("src_tensor", definition.src_tensors[0]);
|
desc.AddSrcTensor("src_tensor", definition.src_tensors[0]);
|
||||||
@ -229,20 +217,9 @@ ComputeTaskDescriptor Softmax1x1(const OperationDef& definition,
|
|||||||
{"constant uniforms& params",
|
{"constant uniforms& params",
|
||||||
[](const std::vector<BHWC>& src_shapes,
|
[](const std::vector<BHWC>& src_shapes,
|
||||||
const std::vector<BHWC>& dst_shapes) {
|
const std::vector<BHWC>& dst_shapes) {
|
||||||
const int src_depth = DivideRoundUp(dst_shapes[0].c, 4);
|
float4 mask = GetMaskForLastPlane(dst_shapes[0].c);
|
||||||
struct uniforms {
|
const uint8_t* ptr = reinterpret_cast<const uint8_t*>(&mask);
|
||||||
int4 size;
|
return std::vector<uint8_t>(ptr, ptr + sizeof(float4));
|
||||||
float4 mask;
|
|
||||||
};
|
|
||||||
uniforms params;
|
|
||||||
params.size = {src_depth, DivideRoundUp(src_depth, 32), 1, 1};
|
|
||||||
params.mask = {0.0f, 0.0f, 0.0f, 0.0f};
|
|
||||||
int reminder = dst_shapes[0].c % 4 == 0 ? 4 : dst_shapes[0].c % 4;
|
|
||||||
for (int i = 0; i < reminder; ++i) {
|
|
||||||
params.mask[i] = 1.0f;
|
|
||||||
}
|
|
||||||
const uint8_t* ptr = reinterpret_cast<const uint8_t*>(¶ms);
|
|
||||||
return std::vector<uint8_t>(ptr, ptr + sizeof(uniforms));
|
|
||||||
}},
|
}},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -124,6 +124,8 @@ absl::Status MetalSpatialTensor::GetGPUResources(
|
|||||||
if (!tensor_desc) {
|
if (!tensor_desc) {
|
||||||
return absl::InvalidArgumentError("Expected TensorDescriptor on input.");
|
return absl::InvalidArgumentError("Expected TensorDescriptor on input.");
|
||||||
}
|
}
|
||||||
|
resources->ints.push_back(
|
||||||
|
{"slice_stride", tensor_desc->GetSliceStrideSize(shape_)});
|
||||||
if (descriptor_.HasAxis(Axis::WIDTH)) {
|
if (descriptor_.HasAxis(Axis::WIDTH)) {
|
||||||
resources->ints.push_back({"width", Width()});
|
resources->ints.push_back({"width", Width()});
|
||||||
resources->ints.push_back({"width_div2", Width() / 2});
|
resources->ints.push_back({"width_div2", Width() / 2});
|
||||||
|
@ -310,7 +310,7 @@
|
|||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"Downloading data from https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media\u0026token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8\n",
|
"Downloading data from https://dl.fbaipublicfiles.com/glue/data/SST-2.zip\n",
|
||||||
"7446528/7439277 [==============================] - 0s 0us/step\n"
|
"7446528/7439277 [==============================] - 0s 0us/step\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
@ -318,7 +318,7 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"data_dir = tf.keras.utils.get_file(\n",
|
"data_dir = tf.keras.utils.get_file(\n",
|
||||||
" fname='SST-2.zip',\n",
|
" fname='SST-2.zip',\n",
|
||||||
" origin='https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media\u0026token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8',\n",
|
" origin='https://dl.fbaipublicfiles.com/glue/data/SST-2.zip',\n",
|
||||||
" extract=True)\n",
|
" extract=True)\n",
|
||||||
"data_dir = os.path.join(os.path.dirname(data_dir), 'SST-2')"
|
"data_dir = os.path.join(os.path.dirname(data_dir), 'SST-2')"
|
||||||
]
|
]
|
||||||
|
@ -587,11 +587,9 @@ TfLiteStatus InterpreterBuilder::ParseTensors(
|
|||||||
status = kTfLiteError;
|
status = kTfLiteError;
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t dims_signature_rank = 0;
|
std::vector<int> dims_signature = {};
|
||||||
const int* dims_signature_data = nullptr;
|
|
||||||
if (tensor->shape_signature()) {
|
if (tensor->shape_signature()) {
|
||||||
dims_signature_rank = tensor->shape_signature()->size();
|
dims_signature = FlatBufferIntArrayToVector(tensor->shape_signature());
|
||||||
dims_signature_data = tensor->shape_signature()->data();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool is_variable = tensor->is_variable();
|
bool is_variable = tensor->is_variable();
|
||||||
@ -623,7 +621,7 @@ TfLiteStatus InterpreterBuilder::ParseTensors(
|
|||||||
} else {
|
} else {
|
||||||
if (subgraph->SetTensorParametersReadWrite(
|
if (subgraph->SetTensorParametersReadWrite(
|
||||||
i, type, get_name(tensor), dims, quantization, is_variable,
|
i, type, get_name(tensor), dims, quantization, is_variable,
|
||||||
dims_signature_rank, dims_signature_data) != kTfLiteOk) {
|
dims_signature) != kTfLiteOk) {
|
||||||
error_reporter_->Report("Tensor %d is invalidly specified in schema.\n",
|
error_reporter_->Report("Tensor %d is invalidly specified in schema.\n",
|
||||||
i);
|
i);
|
||||||
status = kTfLiteError;
|
status = kTfLiteError;
|
||||||
|
@ -80,6 +80,9 @@ TfLiteStatus copyToTensor(TfLiteContext* context, const FromT* in,
|
|||||||
case kTfLiteInt32:
|
case kTfLiteInt32:
|
||||||
copyCast(in, out->data.i32, num_elements);
|
copyCast(in, out->data.i32, num_elements);
|
||||||
break;
|
break;
|
||||||
|
case kTfLiteInt16:
|
||||||
|
copyCast(in, out->data.i16, num_elements);
|
||||||
|
break;
|
||||||
case kTfLiteUInt8:
|
case kTfLiteUInt8:
|
||||||
copyCast(in, out->data.uint8, num_elements);
|
copyCast(in, out->data.uint8, num_elements);
|
||||||
break;
|
break;
|
||||||
@ -113,6 +116,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
return copyToTensor(context, input->data.i64, output, num_elements);
|
return copyToTensor(context, input->data.i64, output, num_elements);
|
||||||
case kTfLiteInt32:
|
case kTfLiteInt32:
|
||||||
return copyToTensor(context, input->data.i32, output, num_elements);
|
return copyToTensor(context, input->data.i32, output, num_elements);
|
||||||
|
case kTfLiteInt16:
|
||||||
|
return copyToTensor(context, input->data.i16, output, num_elements);
|
||||||
case kTfLiteUInt8:
|
case kTfLiteUInt8:
|
||||||
return copyToTensor(context, input->data.uint8, output, num_elements);
|
return copyToTensor(context, input->data.uint8, output, num_elements);
|
||||||
case kTfLiteFloat32:
|
case kTfLiteFloat32:
|
||||||
|
@ -46,6 +46,22 @@ class CastOpModel : public SingleOpModel {
|
|||||||
int output_;
|
int output_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
TEST(CastOpModel, CastInt16ToFloat) {
|
||||||
|
CastOpModel m({TensorType_INT16, {2, 3}}, {TensorType_FLOAT32, {2, 3}});
|
||||||
|
m.PopulateTensor<int16_t>(m.input(), {100, 200, 300, 400, 500, 600});
|
||||||
|
m.Invoke();
|
||||||
|
EXPECT_THAT(m.ExtractVector<float>(m.output()),
|
||||||
|
ElementsAreArray({100.f, 200.f, 300.f, 400.f, 500.f, 600.f}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CastOpModel, CastInt16ToInt32) {
|
||||||
|
CastOpModel m({TensorType_INT16, {2, 3}}, {TensorType_INT32, {2, 3}});
|
||||||
|
m.PopulateTensor<int16_t>(m.input(), {100, 200, 300, 400, 500, 600});
|
||||||
|
m.Invoke();
|
||||||
|
EXPECT_THAT(m.ExtractVector<int32_t>(m.output()),
|
||||||
|
ElementsAreArray({100, 200, 300, 400, 500, 600}));
|
||||||
|
}
|
||||||
|
|
||||||
TEST(CastOpModel, CastInt32ToFloat) {
|
TEST(CastOpModel, CastInt32ToFloat) {
|
||||||
CastOpModel m({TensorType_INT32, {2, 3}}, {TensorType_FLOAT32, {2, 3}});
|
CastOpModel m({TensorType_INT32, {2, 3}}, {TensorType_FLOAT32, {2, 3}});
|
||||||
m.PopulateTensor<int32_t>(m.input(), {100, 200, 300, 400, 500, 600});
|
m.PopulateTensor<int32_t>(m.input(), {100, 200, 300, 400, 500, 600});
|
||||||
@ -62,6 +78,14 @@ TEST(CastOpModel, CastFloatToInt32) {
|
|||||||
ElementsAreArray({100, 20, 3, 0, 0, 1}));
|
ElementsAreArray({100, 20, 3, 0, 0, 1}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(CastOpModel, CastFloatToInt16) {
|
||||||
|
CastOpModel m({TensorType_FLOAT32, {3, 2}}, {TensorType_INT16, {3, 2}});
|
||||||
|
m.PopulateTensor<float>(m.input(), {100.f, 20.f, 3.f, 0.4f, 0.999f, 1.1f});
|
||||||
|
m.Invoke();
|
||||||
|
EXPECT_THAT(m.ExtractVector<int16_t>(m.output()),
|
||||||
|
ElementsAreArray({100, 20, 3, 0, 0, 1}));
|
||||||
|
}
|
||||||
|
|
||||||
TEST(CastOpModel, CastInt64ToFloat) {
|
TEST(CastOpModel, CastInt64ToFloat) {
|
||||||
CastOpModel m({TensorType_INT64, {2, 3}}, {TensorType_FLOAT32, {2, 3}});
|
CastOpModel m({TensorType_INT64, {2, 3}}, {TensorType_FLOAT32, {2, 3}});
|
||||||
m.PopulateTensor<int64_t>(m.input(), {100, 200, 300, 400, 500, 600});
|
m.PopulateTensor<int64_t>(m.input(), {100, 200, 300, 400, 500, 600});
|
||||||
|
@ -121,7 +121,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
|
|
||||||
// TODO(ahentz): Creating 'all_inputs' below is not very efficient. We should
|
// TODO(ahentz): Creating 'all_inputs' below is not very efficient. We should
|
||||||
// allocate and populate these during Prepare().
|
// allocate and populate these during Prepare().
|
||||||
// TODO(ycling): Activation function parameter is ignored. For now we dont have
|
// TODO(ycling): Activation function parameter is ignored. For now we don't have
|
||||||
// a model with a Concatenation with fused activation function.
|
// a model with a Concatenation with fused activation function.
|
||||||
#define TF_LITE_CONCATENATION(scalar) \
|
#define TF_LITE_CONCATENATION(scalar) \
|
||||||
{ \
|
{ \
|
||||||
|
@ -494,6 +494,7 @@ cc_library(
|
|||||||
"reference/resize_nearest_neighbor.h",
|
"reference/resize_nearest_neighbor.h",
|
||||||
"reference/round.h",
|
"reference/round.h",
|
||||||
"reference/softmax.h",
|
"reference/softmax.h",
|
||||||
|
"reference/space_to_depth.h",
|
||||||
"reference/strided_slice.h",
|
"reference/strided_slice.h",
|
||||||
"reference/sub.h",
|
"reference/sub.h",
|
||||||
"reference/svdf.h",
|
"reference/svdf.h",
|
||||||
@ -511,13 +512,14 @@ cc_library(
|
|||||||
}),
|
}),
|
||||||
compatible_with = get_compatible_with_portable(),
|
compatible_with = get_compatible_with_portable(),
|
||||||
copts = tflite_copts(),
|
copts = tflite_copts(),
|
||||||
# We are disabling parse_headers for the tf_lite_static_memory build to
|
# We are disabling parse_headers for this header-only target so that the
|
||||||
# allow it to be consistent with the OSS bazel build. See b/175817116
|
# external and internal builds are consistent. The primary issue here is
|
||||||
# for more details.
|
# that parse_headers is not supported with bazel and the TFLM team would
|
||||||
features = select({
|
# really like to have all build errors in shared Micro/Lite code be
|
||||||
":tf_lite_static_memory": ["-parse_headers"],
|
# reproducible from the OSS build as well.
|
||||||
"//conditions:default": [],
|
#
|
||||||
}),
|
# See b/175817116 for more details.
|
||||||
|
features = ["-parse_headers"],
|
||||||
deps = [
|
deps = [
|
||||||
":common",
|
":common",
|
||||||
":compatibility",
|
":compatibility",
|
||||||
@ -588,6 +590,7 @@ cc_library(
|
|||||||
"reference/resize_nearest_neighbor.h",
|
"reference/resize_nearest_neighbor.h",
|
||||||
"reference/round.h",
|
"reference/round.h",
|
||||||
"reference/softmax.h",
|
"reference/softmax.h",
|
||||||
|
"reference/space_to_depth.h",
|
||||||
"reference/strided_slice.h",
|
"reference/strided_slice.h",
|
||||||
"reference/string_comparisons.h",
|
"reference/string_comparisons.h",
|
||||||
"reference/sub.h",
|
"reference/sub.h",
|
||||||
|
@ -15,16 +15,13 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_CONV_H_
|
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_CONV_H_
|
||||||
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_CONV_H_
|
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_CONV_H_
|
||||||
|
|
||||||
#include "tensorflow/lite/kernels/internal/types.h"
|
|
||||||
#include "tensorflow/lite/kernels/internal/common.h"
|
#include "tensorflow/lite/kernels/internal/common.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/types.h"
|
||||||
|
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
|
|
||||||
namespace reference_ops {
|
namespace reference_ops {
|
||||||
|
|
||||||
|
|
||||||
inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
|
inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
|
||||||
const float* input_data, const RuntimeShape& filter_shape,
|
const float* input_data, const RuntimeShape& filter_shape,
|
||||||
const float* filter_data, const RuntimeShape& bias_shape,
|
const float* filter_data, const RuntimeShape& bias_shape,
|
||||||
@ -108,8 +105,8 @@ inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
|
|||||||
uint8_t* output_data, const RuntimeShape& im2col_shape,
|
uint8_t* output_data, const RuntimeShape& im2col_shape,
|
||||||
uint8_t* im2col_data, void* cpu_backend_context) {
|
uint8_t* im2col_data, void* cpu_backend_context) {
|
||||||
(void)cpu_backend_context; // only used in optimized code.
|
(void)cpu_backend_context; // only used in optimized code.
|
||||||
(void)im2col_data; // only used in optimized code.
|
(void)im2col_data; // only used in optimized code.
|
||||||
(void)im2col_shape; // only used in optimized code.
|
(void)im2col_shape; // only used in optimized code.
|
||||||
const int stride_width = params.stride_width;
|
const int stride_width = params.stride_width;
|
||||||
const int stride_height = params.stride_height;
|
const int stride_height = params.stride_height;
|
||||||
const int dilation_width_factor = params.dilation_width_factor;
|
const int dilation_width_factor = params.dilation_width_factor;
|
||||||
|
@ -61,6 +61,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/lite/kernels/internal/reference/resize_nearest_neighbor.h"
|
#include "tensorflow/lite/kernels/internal/reference/resize_nearest_neighbor.h"
|
||||||
#include "tensorflow/lite/kernels/internal/reference/round.h"
|
#include "tensorflow/lite/kernels/internal/reference/round.h"
|
||||||
#include "tensorflow/lite/kernels/internal/reference/softmax.h"
|
#include "tensorflow/lite/kernels/internal/reference/softmax.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/reference/space_to_depth.h"
|
||||||
#include "tensorflow/lite/kernels/internal/reference/strided_slice.h"
|
#include "tensorflow/lite/kernels/internal/reference/strided_slice.h"
|
||||||
#include "tensorflow/lite/kernels/internal/reference/string_comparisons.h"
|
#include "tensorflow/lite/kernels/internal/reference/string_comparisons.h"
|
||||||
#include "tensorflow/lite/kernels/internal/reference/sub.h"
|
#include "tensorflow/lite/kernels/internal/reference/sub.h"
|
||||||
@ -126,58 +127,6 @@ inline void DepthToSpace(const tflite::DepthToSpaceParams& op_params,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
inline void SpaceToDepth(const tflite::SpaceToDepthParams& op_params,
|
|
||||||
const RuntimeShape& unextended_input_shape,
|
|
||||||
const T* input_data,
|
|
||||||
const RuntimeShape& unextended_output_shape,
|
|
||||||
T* output_data) {
|
|
||||||
TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
|
|
||||||
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
|
|
||||||
const RuntimeShape input_shape =
|
|
||||||
RuntimeShape::ExtendedShape(4, unextended_input_shape);
|
|
||||||
const RuntimeShape output_shape =
|
|
||||||
RuntimeShape::ExtendedShape(4, unextended_output_shape);
|
|
||||||
|
|
||||||
const int input_depth = input_shape.Dims(3);
|
|
||||||
const int input_width = input_shape.Dims(2);
|
|
||||||
const int input_height = input_shape.Dims(1);
|
|
||||||
const int input_batch = input_shape.Dims(0);
|
|
||||||
|
|
||||||
const int output_depth = output_shape.Dims(3);
|
|
||||||
const int output_width = output_shape.Dims(2);
|
|
||||||
const int output_height = output_shape.Dims(1);
|
|
||||||
const int output_batch = output_shape.Dims(0);
|
|
||||||
|
|
||||||
const int32 block_size = op_params.block_size;
|
|
||||||
|
|
||||||
TFLITE_DCHECK_EQ(input_width, output_width * block_size);
|
|
||||||
TFLITE_DCHECK_EQ(input_height, output_height * block_size);
|
|
||||||
TFLITE_DCHECK_EQ(input_depth * block_size * block_size, output_depth);
|
|
||||||
TFLITE_DCHECK_EQ(input_batch, output_batch);
|
|
||||||
|
|
||||||
for (int in_b = 0; in_b < input_batch; ++in_b) {
|
|
||||||
for (int in_h = 0; in_h < input_height; ++in_h) {
|
|
||||||
for (int in_w = 0; in_w < input_width; ++in_w) {
|
|
||||||
for (int in_d = 0; in_d < input_depth; ++in_d) {
|
|
||||||
const int out_d =
|
|
||||||
in_d + ((in_h % block_size) * block_size + in_w % block_size) *
|
|
||||||
input_depth;
|
|
||||||
const int out_w = in_w / block_size;
|
|
||||||
const int out_h = in_h / block_size;
|
|
||||||
const int out_b = in_b;
|
|
||||||
|
|
||||||
const int input_index = Offset(input_shape, in_b, in_h, in_w, in_d);
|
|
||||||
const int output_index =
|
|
||||||
Offset(output_shape, out_b, out_h, out_w, out_d);
|
|
||||||
|
|
||||||
output_data[output_index] = input_data[input_index];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
inline void Elu(const RuntimeShape& input_shape, const float* input_data,
|
inline void Elu(const RuntimeShape& input_shape, const float* input_data,
|
||||||
const RuntimeShape& output_shape, float* output_data) {
|
const RuntimeShape& output_shape, float* output_data) {
|
||||||
const int flat_size = MatchingFlatSize(input_shape, output_shape);
|
const int flat_size = MatchingFlatSize(input_shape, output_shape);
|
||||||
|
78
tensorflow/lite/kernels/internal/reference/space_to_depth.h
Normal file
78
tensorflow/lite/kernels/internal/reference/space_to_depth.h
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_SPACE_TO_DEPTH_H_
|
||||||
|
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_SPACE_TO_DEPTH_H_
|
||||||
|
|
||||||
|
#include "tensorflow/lite/kernels/internal/types.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace reference_ops {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline void SpaceToDepth(const tflite::SpaceToDepthParams& op_params,
|
||||||
|
const RuntimeShape& unextended_input_shape,
|
||||||
|
const T* input_data,
|
||||||
|
const RuntimeShape& unextended_output_shape,
|
||||||
|
T* output_data) {
|
||||||
|
TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
|
||||||
|
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
|
||||||
|
const RuntimeShape input_shape =
|
||||||
|
RuntimeShape::ExtendedShape(4, unextended_input_shape);
|
||||||
|
const RuntimeShape output_shape =
|
||||||
|
RuntimeShape::ExtendedShape(4, unextended_output_shape);
|
||||||
|
|
||||||
|
const int input_depth = input_shape.Dims(3);
|
||||||
|
const int input_width = input_shape.Dims(2);
|
||||||
|
const int input_height = input_shape.Dims(1);
|
||||||
|
const int input_batch = input_shape.Dims(0);
|
||||||
|
|
||||||
|
const int output_depth = output_shape.Dims(3);
|
||||||
|
const int output_width = output_shape.Dims(2);
|
||||||
|
const int output_height = output_shape.Dims(1);
|
||||||
|
const int output_batch = output_shape.Dims(0);
|
||||||
|
|
||||||
|
const int32 block_size = op_params.block_size;
|
||||||
|
|
||||||
|
TFLITE_DCHECK_EQ(input_width, output_width * block_size);
|
||||||
|
TFLITE_DCHECK_EQ(input_height, output_height * block_size);
|
||||||
|
TFLITE_DCHECK_EQ(input_depth * block_size * block_size, output_depth);
|
||||||
|
TFLITE_DCHECK_EQ(input_batch, output_batch);
|
||||||
|
|
||||||
|
for (int in_b = 0; in_b < input_batch; ++in_b) {
|
||||||
|
for (int in_h = 0; in_h < input_height; ++in_h) {
|
||||||
|
for (int in_w = 0; in_w < input_width; ++in_w) {
|
||||||
|
for (int in_d = 0; in_d < input_depth; ++in_d) {
|
||||||
|
const int out_d =
|
||||||
|
in_d + ((in_h % block_size) * block_size + in_w % block_size) *
|
||||||
|
input_depth;
|
||||||
|
const int out_w = in_w / block_size;
|
||||||
|
const int out_h = in_h / block_size;
|
||||||
|
const int out_b = in_b;
|
||||||
|
|
||||||
|
const int input_index = Offset(input_shape, in_b, in_h, in_w, in_d);
|
||||||
|
const int output_index =
|
||||||
|
Offset(output_shape, out_b, out_h, out_w, out_d);
|
||||||
|
|
||||||
|
output_data[output_index] = input_data[input_index];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace reference_ops
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_SPACE_TO_DEPTH_H_
|
@ -170,7 +170,7 @@ TfLiteStatus PopulateQuantizedLstmParams8x8_16(
|
|||||||
intermediate_zp.push_back(0);
|
intermediate_zp.push_back(0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// In the absense of projection, hidden becomes otuput and this intermediate
|
// In the absence of projection, hidden becomes otuput and this intermediate
|
||||||
// is ignored.
|
// is ignored.
|
||||||
TfLiteTensor* hidden;
|
TfLiteTensor* hidden;
|
||||||
TF_LITE_ENSURE_OK(context, GetIntermediatesSafe(context, node, 4, &hidden));
|
TF_LITE_ENSURE_OK(context, GetIntermediatesSafe(context, node, 4, &hidden));
|
||||||
|
@ -204,7 +204,7 @@ to determine if the requested feature aligns with the TFLM roadmap.
|
|||||||
1. Run all the tests for x86, and any other platform that you are modifying.
|
1. Run all the tests for x86, and any other platform that you are modifying.
|
||||||
|
|
||||||
```
|
```
|
||||||
tensorflow/lite/micro/tools/make/tools/ci_build/test_x86.sh
|
tensorflow/lite/micro/tools/ci_build/test_x86.sh
|
||||||
```
|
```
|
||||||
|
|
||||||
Please check the READMEs in the optimized kernel directories for specific
|
Please check the READMEs in the optimized kernel directories for specific
|
||||||
|
@ -1,2 +1,2 @@
|
|||||||
numpy==1.16.2
|
numpy==1.16.2
|
||||||
tensorflow==2.0.0-beta1
|
tensorflow==2.4.0
|
||||||
|
@ -33,7 +33,7 @@ set +e
|
|||||||
# The pigweed scripts only work from a git repository and the Tensorflow CI
|
# The pigweed scripts only work from a git repository and the Tensorflow CI
|
||||||
# infrastructure does not always guarantee that. As an ugly workaround, we
|
# infrastructure does not always guarantee that. As an ugly workaround, we
|
||||||
# create our own git repo when running on the CI servers.
|
# create our own git repo when running on the CI servers.
|
||||||
pushd tensorflow/lite/micro/
|
pushd tensorflow/lite/
|
||||||
if [[ ${1} == "PRESUBMIT" ]]; then
|
if [[ ${1} == "PRESUBMIT" ]]; then
|
||||||
git init .
|
git init .
|
||||||
git config user.email "tflm@google.com"
|
git config user.email "tflm@google.com"
|
||||||
@ -43,9 +43,12 @@ if [[ ${1} == "PRESUBMIT" ]]; then
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
# Check for license with the necessary exclusions.
|
# Check for license with the necessary exclusions.
|
||||||
tools/make/downloads/pigweed/pw_presubmit/py/pw_presubmit/pigweed_presubmit.py \
|
micro/tools/make/downloads/pigweed/pw_presubmit/py/pw_presubmit/pigweed_presubmit.py \
|
||||||
. \
|
kernels/internal/reference/ \
|
||||||
|
micro/ \
|
||||||
-p copyright_notice \
|
-p copyright_notice \
|
||||||
|
-e kernels/internal/reference/integer_ops/ \
|
||||||
|
-e kernels/internal/reference/reference_ops.h \
|
||||||
-e tools/make/downloads \
|
-e tools/make/downloads \
|
||||||
-e tools/make/targets/ecm3531 \
|
-e tools/make/targets/ecm3531 \
|
||||||
-e BUILD\
|
-e BUILD\
|
||||||
@ -66,8 +69,11 @@ LICENSE_CHECK_RESULT=$?
|
|||||||
# Python files (with yapf as the formatter) because that needs additional setup.
|
# Python files (with yapf as the formatter) because that needs additional setup.
|
||||||
# We are also ignoring the markdown files to allow for a more gradual rollout of
|
# We are also ignoring the markdown files to allow for a more gradual rollout of
|
||||||
# this presubmit check.
|
# this presubmit check.
|
||||||
tools/make/downloads/pigweed/pw_presubmit/py/pw_presubmit/format_code.py \
|
micro/tools/make/downloads/pigweed/pw_presubmit/py/pw_presubmit/format_code.py \
|
||||||
. \
|
kernels/internal/reference/ \
|
||||||
|
micro/ \
|
||||||
|
-e kernels/internal/reference/integer_ops/ \
|
||||||
|
-e kernels/internal/reference/reference_ops.h \
|
||||||
-e "\.inc" \
|
-e "\.inc" \
|
||||||
-e "\.md" \
|
-e "\.md" \
|
||||||
-e "\.py"
|
-e "\.py"
|
||||||
@ -76,7 +82,7 @@ CLANG_FORMAT_RESULT=$?
|
|||||||
|
|
||||||
popd
|
popd
|
||||||
if [[ ${1} == "PRESUBMIT" ]]; then
|
if [[ ${1} == "PRESUBMIT" ]]; then
|
||||||
rm -rf tensorflow/lite/micro/.git
|
rm -rf tensorflow/lite/.git
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Re-enable exit on error now that we are done with the temporary git repo.
|
# Re-enable exit on error now that we are done with the temporary git repo.
|
||||||
|
@ -116,7 +116,7 @@ download_and_extract() {
|
|||||||
local tempdir=$(mktemp -d)
|
local tempdir=$(mktemp -d)
|
||||||
local tempdir2=$(mktemp -d)
|
local tempdir2=$(mktemp -d)
|
||||||
local tempfile=${tempdir}/temp_file
|
local tempfile=${tempdir}/temp_file
|
||||||
local curl_retries=3
|
local curl_retries=5
|
||||||
|
|
||||||
# Destionation already downloaded.
|
# Destionation already downloaded.
|
||||||
if [ -d ${dir} ]; then
|
if [ -d ${dir} ]; then
|
||||||
@ -131,24 +131,21 @@ download_and_extract() {
|
|||||||
mkdir -p "${dir}"
|
mkdir -p "${dir}"
|
||||||
# We've been seeing occasional 56 errors from valid URLs, so set up a retry
|
# We've been seeing occasional 56 errors from valid URLs, so set up a retry
|
||||||
# loop to attempt to recover from them.
|
# loop to attempt to recover from them.
|
||||||
for (( i=1; i<=$curl_retries; ++i ))
|
for (( i=1; i<=$curl_retries; ++i )); do
|
||||||
do
|
|
||||||
# We have to use this approach because we normally halt the script when
|
# We have to use this approach because we normally halt the script when
|
||||||
# there's an error, and instead we want to catch errors so we can retry.
|
# there's an error, and instead we want to catch errors so we can retry.
|
||||||
set +e
|
set +ex
|
||||||
curl -Ls --fail --retry 5 "${url}" > ${tempfile}
|
curl -LsS --fail --retry 5 "${url}" > ${tempfile}
|
||||||
CURL_RESULT=$?
|
CURL_RESULT=$?
|
||||||
set -e
|
set -ex
|
||||||
|
|
||||||
# Was the command successful? If so, continue.
|
# Was the command successful? If so, continue.
|
||||||
if [[ $CURL_RESULT -eq 0 ]]
|
if [[ $CURL_RESULT -eq 0 ]]; then
|
||||||
then
|
|
||||||
break
|
break
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Keep trying if we see the '56' error code.
|
# Keep trying if we see the '56' error code.
|
||||||
if [[ ( $CURL_RESULT -ne 56 ) || ( $i -eq $curl_retries ) ]]
|
if [[ ( $CURL_RESULT -ne 56 ) || ( $i -eq $curl_retries ) ]]; then
|
||||||
then
|
|
||||||
echo "Error $CURL_RESULT downloading '${url}'"
|
echo "Error $CURL_RESULT downloading '${url}'"
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
@ -110,7 +110,7 @@ class OpsSet(enum.Enum):
|
|||||||
"EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8"
|
"EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8"
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return self.value
|
return str(self.value)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_options():
|
def get_options():
|
||||||
@ -394,22 +394,22 @@ def build_toco_convert_protos(input_tensors,
|
|||||||
input_tensors: List of input tensors. Type and shape are computed using
|
input_tensors: List of input tensors. Type and shape are computed using
|
||||||
`foo.shape` and `foo.dtype`.
|
`foo.shape` and `foo.dtype`.
|
||||||
output_tensors: List of output tensors (only .name is used from this).
|
output_tensors: List of output tensors (only .name is used from this).
|
||||||
inference_type: Target data type of real-number arrays in the output file.
|
inference_type: Data type of numeric arrays, excluding the input layer.
|
||||||
Must be `{tf.float32, tf.uint8, tf.int8}`. (default tf.float32)
|
(default tf.float32, must be in {tf.float32, tf.int8, tf.uint8})
|
||||||
inference_input_type: Target data type of real-number input arrays. Allows
|
inference_input_type: Data type of the numeric arrays in the input layer. If
|
||||||
for a different type for input arrays in the case of quantization. Must be
|
`inference_input_type` is in {tf.int8, tf.uint8}, then
|
||||||
`{tf.float32, tf.uint8, tf.int8}`. (default `inference_type`)
|
`quantized_input_stats` must be provided. (default is the value assigned
|
||||||
input_format: Type of data to read Currently must be
|
to `inference_type`, must be in {tf.float32, tf.int8, tf.uint8})
|
||||||
`{TENSORFLOW_GRAPHDEF}`. (default TENSORFLOW_GRAPHDEF)
|
input_format: Type of data to read.
|
||||||
input_shapes: Input array shape. It needs to be a list of the same length as
|
(default TENSORFLOW_GRAPHDEF, must be in {TENSORFLOW_GRAPHDEF})
|
||||||
`input_tensors`, or None. (default None)
|
input_shapes: Input array shape. (default None, must be None or a list of
|
||||||
output_format: Output file format. Currently must be `{TFLITE,
|
the same length as `input_tensors`.)
|
||||||
GRAPHVIZ_DOT}`. (default TFLITE)
|
output_format: Output file format. (default TFLITE, must be in
|
||||||
quantized_input_stats: List of tuples of floats representing the mean and
|
{TFLITE, GRAPHVIZ_DOT})
|
||||||
standard deviation. Each tuple maps to the corresponding input tensor.
|
quantized_input_stats: Map of input tensor names to a tuple of floats
|
||||||
Only need if `inference_input_type` is `QUANTIZED_UINT8` or `INT8`.
|
representing the mean and standard deviation of the training data.
|
||||||
real_input_value = (quantized_input_value - mean_value) / std_dev_value.
|
(e.g., {"foo" : (0., 1.)}). Required if `inference_input_type` is tf.int8
|
||||||
(default None)
|
or tf.uint8. (default None)
|
||||||
default_ranges_stats: Tuple of integers representing (min, max) range values
|
default_ranges_stats: Tuple of integers representing (min, max) range values
|
||||||
for all arrays without a specified range. Intended for experimenting with
|
for all arrays without a specified range. Intended for experimenting with
|
||||||
quantization via "dummy quantization". (default None)
|
quantization via "dummy quantization". (default None)
|
||||||
@ -574,8 +574,10 @@ def toco_convert_graph_def(input_data, input_arrays_with_shape, output_arrays,
|
|||||||
if _requires_input_stats(toco_flags):
|
if _requires_input_stats(toco_flags):
|
||||||
if (("quantized_input_stats" not in kwargs) or
|
if (("quantized_input_stats" not in kwargs) or
|
||||||
(not kwargs["quantized_input_stats"])):
|
(not kwargs["quantized_input_stats"])):
|
||||||
raise ValueError("std_dev and mean must be defined when inference_type "
|
raise ValueError(
|
||||||
"or inference_input_type is QUANTIZED_UINT8 or INT8.")
|
"The `quantized_input_stats` flag must be defined when either "
|
||||||
|
"`inference_type` flag or `inference_input_type` flag is set to "
|
||||||
|
"tf.int8 or tf.uint8.")
|
||||||
input_array.mean_value, input_array.std_value = kwargs[
|
input_array.mean_value, input_array.std_value = kwargs[
|
||||||
"quantized_input_stats"][idx]
|
"quantized_input_stats"][idx]
|
||||||
input_array.name = name
|
input_array.name = name
|
||||||
@ -661,7 +663,7 @@ def toco_convert(input_data, input_tensors, output_tensors, *args, **kwargs):
|
|||||||
Typically this function is used to convert from TensorFlow GraphDef to TFLite.
|
Typically this function is used to convert from TensorFlow GraphDef to TFLite.
|
||||||
Conversion can be customized by providing arguments that are forwarded to
|
Conversion can be customized by providing arguments that are forwarded to
|
||||||
`build_toco_convert_protos` (see documentation for details). This function has
|
`build_toco_convert_protos` (see documentation for details). This function has
|
||||||
been deprecated. Please use `lite.TFLiteConverter` instead.
|
been deprecated. Please use `tf.lite.TFLiteConverter` instead.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_data: Input data (i.e. often `sess.graph_def`),
|
input_data: Input data (i.e. often `sess.graph_def`),
|
||||||
|
@ -137,7 +137,7 @@ class ConvertTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertEqual("output", output_details[0]["name"])
|
self.assertEqual("output", output_details[0]["name"])
|
||||||
self.assertEqual(np.uint8, output_details[0]["dtype"])
|
self.assertEqual(np.uint8, output_details[0]["dtype"])
|
||||||
self.assertTrue(([1, 16, 16, 3] == output_details[0]["shape"]).all())
|
self.assertTrue(([1, 16, 16, 3] == output_details[0]["shape"]).all())
|
||||||
self.assertTrue(output_details[0]["quantization"][0] > 0) # scale
|
self.assertGreater(output_details[0]["quantization"][0], 0) # scale
|
||||||
|
|
||||||
def testGraphDefQuantizationInvalid(self):
|
def testGraphDefQuantizationInvalid(self):
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
@ -159,9 +159,9 @@ class ConvertTest(test_util.TensorFlowTestCase):
|
|||||||
enable_mlir_converter=False,
|
enable_mlir_converter=False,
|
||||||
inference_type=dtypes.uint8)
|
inference_type=dtypes.uint8)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
"std_dev and mean must be defined when inference_type or "
|
"The `quantized_input_stats` flag must be defined when either "
|
||||||
"inference_input_type is QUANTIZED_UINT8 or INT8.",
|
"`inference_type` flag or `inference_input_type` flag is set to "
|
||||||
str(error.exception))
|
"tf.int8 or tf.uint8.", str(error.exception))
|
||||||
|
|
||||||
|
|
||||||
class ConvertTestOpHint(test_util.TensorFlowTestCase):
|
class ConvertTestOpHint(test_util.TensorFlowTestCase):
|
||||||
|
@ -61,6 +61,7 @@ from tensorflow.lite.python.util import get_debug_info as _get_debug_info
|
|||||||
from tensorflow.lite.python.util import get_grappler_config as _get_grappler_config
|
from tensorflow.lite.python.util import get_grappler_config as _get_grappler_config
|
||||||
from tensorflow.lite.python.util import get_tensor_name as _get_tensor_name
|
from tensorflow.lite.python.util import get_tensor_name as _get_tensor_name
|
||||||
from tensorflow.lite.python.util import get_tensors_from_tensor_names as _get_tensors_from_tensor_names
|
from tensorflow.lite.python.util import get_tensors_from_tensor_names as _get_tensors_from_tensor_names
|
||||||
|
from tensorflow.lite.python.util import get_tf_type_name as _get_tf_type_name
|
||||||
from tensorflow.lite.python.util import is_frozen_graph as _is_frozen_graph
|
from tensorflow.lite.python.util import is_frozen_graph as _is_frozen_graph
|
||||||
from tensorflow.lite.python.util import model_input_signature as _model_input_signature
|
from tensorflow.lite.python.util import model_input_signature as _model_input_signature
|
||||||
from tensorflow.lite.python.util import modify_model_io_type as _modify_model_io_type
|
from tensorflow.lite.python.util import modify_model_io_type as _modify_model_io_type
|
||||||
@ -89,19 +90,14 @@ from tensorflow.python.util.tf_export import tf_export as _tf_export
|
|||||||
|
|
||||||
@_tf_export("lite.Optimize")
|
@_tf_export("lite.Optimize")
|
||||||
class Optimize(enum.Enum):
|
class Optimize(enum.Enum):
|
||||||
"""Enum defining the optimizations to apply when generating tflite graphs.
|
"""Enum defining the optimizations to apply when generating a tflite model.
|
||||||
|
|
||||||
Some optimizations may come at the cost of accuracy.
|
|
||||||
|
|
||||||
DEFAULT
|
DEFAULT
|
||||||
Default optimization strategy.
|
Default optimization strategy that quantizes model weights. Enhanced
|
||||||
|
optimizations are gained by providing a representative dataset that
|
||||||
Converter will do its best to improve size and latency based on the
|
quantizes biases and activations as well.
|
||||||
information provided.
|
Converter will do its best to reduce size and latency, while minimizing
|
||||||
Enhanced optimizations are gained by providing a representative_dataset.
|
the loss in accuracy.
|
||||||
This is recommended, and is currently equivalent to the modes below.
|
|
||||||
Currently, weights will be quantized and if representative_dataset is
|
|
||||||
provided, activations for quantizable operations will also be quantized.
|
|
||||||
|
|
||||||
OPTIMIZE_FOR_SIZE
|
OPTIMIZE_FOR_SIZE
|
||||||
Deprecated. Does the same as DEFAULT.
|
Deprecated. Does the same as DEFAULT.
|
||||||
@ -110,14 +106,11 @@ class Optimize(enum.Enum):
|
|||||||
Deprecated. Does the same as DEFAULT.
|
Deprecated. Does the same as DEFAULT.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Default optimization strategy.
|
# Default optimization strategy that quantizes model weights. Enhanced
|
||||||
#
|
# optimizations are gained by providing a representative dataset that
|
||||||
# Converter will do its best to improve size and latency based on the
|
# quantizes biases and activations as well.
|
||||||
# information provided.
|
# Converter will do its best to reduce size and latency, while minimizing
|
||||||
# Enhanced optimizations can be gained by providing a representative_dataset.
|
# the loss in accuracy.
|
||||||
# This is recommended, and is currently equivalent to the modes below.
|
|
||||||
# Currently, weights will be quantized and if representative_dataset is
|
|
||||||
# provided, activations for quantizable operations will also be quantized.
|
|
||||||
DEFAULT = "DEFAULT"
|
DEFAULT = "DEFAULT"
|
||||||
|
|
||||||
# Deprecated. Does the same as DEFAULT.
|
# Deprecated. Does the same as DEFAULT.
|
||||||
@ -132,48 +125,47 @@ class Optimize(enum.Enum):
|
|||||||
|
|
||||||
@_tf_export("lite.RepresentativeDataset")
|
@_tf_export("lite.RepresentativeDataset")
|
||||||
class RepresentativeDataset(object):
|
class RepresentativeDataset(object):
|
||||||
"""Representative dataset to evaluate optimizations.
|
"""Representative dataset used to optimize the model.
|
||||||
|
|
||||||
A representative dataset that can be used to evaluate optimizations by the
|
This is a generator function that provides a small dataset to calibrate or
|
||||||
converter. E.g. converter can use these examples to estimate (min, max) ranges
|
estimate the range, i.e, (min, max) of all floating-point arrays in the model
|
||||||
by calibrating the model on inputs. This can allow converter to quantize a
|
(such as model input, activation outputs of intermediate layers, and model
|
||||||
converted floating point model.
|
output) for quantization. Usually, this is a small subset of a few hundred
|
||||||
|
samples randomly chosen, in no particular order, from the training or
|
||||||
|
evaluation dataset.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, input_gen):
|
def __init__(self, input_gen):
|
||||||
"""Creates a representative dataset.
|
"""Creates a representative dataset.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_gen: an input generator that can be used to generate input samples
|
input_gen: A generator function that generates input samples for the
|
||||||
for the model. This must be a callable object that returns an object
|
model and has the same order, type and shape as the inputs to the model.
|
||||||
that supports the `iter()` protocol (e.g. a generator function). The
|
Usually, this is a small subset of a few hundred samples randomly
|
||||||
elements generated must have same type and shape as inputs to the model.
|
chosen, in no particular order, from the training or evaluation dataset.
|
||||||
"""
|
"""
|
||||||
self.input_gen = input_gen
|
self.input_gen = input_gen
|
||||||
|
|
||||||
|
|
||||||
@_tf_export("lite.TargetSpec")
|
@_tf_export("lite.TargetSpec")
|
||||||
class TargetSpec(object):
|
class TargetSpec(object):
|
||||||
"""Specification of target device.
|
"""Specification of target device used to optimize the model.
|
||||||
|
|
||||||
Details about target device. Converter optimizes the generated model for
|
|
||||||
specific device.
|
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
supported_ops: Experimental flag, subject to change. Set of OpsSet options
|
supported_ops: Experimental flag, subject to change. Set of `tf.lite.OpsSet`
|
||||||
supported by the device. (default set([OpsSet.TFLITE_BUILTINS]))
|
options, where each option represents a set of operators supported by the
|
||||||
supported_types: List of types for constant values on the target device.
|
target device. (default {tf.lite.OpsSet.TFLITE_BUILTINS}))
|
||||||
Frequently, an optimization choice is driven by the most compact
|
supported_types: Set of `tf.dtypes.DType` data types supported on the target
|
||||||
(i.e. smallest) type in this list (default [tf.float32])
|
device. If initialized, optimization might be driven by the smallest type
|
||||||
|
in this set. (default set())
|
||||||
experimental_select_user_tf_ops: Experimental flag, subject to change. Set
|
experimental_select_user_tf_ops: Experimental flag, subject to change. Set
|
||||||
of user's TensorFlow operators' names that are required in the TensorFlow
|
of user's TensorFlow operators' names that are required in the TensorFlow
|
||||||
Lite runtime. These ops will be exported as select TensorFlow ops in the
|
Lite runtime. These ops will be exported as select TensorFlow ops in the
|
||||||
model (in conjunction with the OpsSet.SELECT_TF_OPS flag). This is an
|
model (in conjunction with the tf.lite.OpsSet.SELECT_TF_OPS flag). This is
|
||||||
advanced feature that should only be used if the client is using TF ops
|
an advanced feature that should only be used if the client is using TF ops
|
||||||
that may not be linked in by default with the TF ops that are provided
|
that may not be linked in by default with the TF ops that are provided
|
||||||
when using the SELECT_TF_OPS path. The client is responsible for linking
|
when using the SELECT_TF_OPS path. The client is responsible for linking
|
||||||
these ops into the target runtime.
|
these ops into the target runtime.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@ -181,17 +173,17 @@ class TargetSpec(object):
|
|||||||
supported_types=None,
|
supported_types=None,
|
||||||
experimental_select_user_tf_ops=None):
|
experimental_select_user_tf_ops=None):
|
||||||
if supported_ops is None:
|
if supported_ops is None:
|
||||||
supported_ops = set([OpsSet.TFLITE_BUILTINS])
|
supported_ops = {OpsSet.TFLITE_BUILTINS}
|
||||||
self.supported_ops = supported_ops
|
self.supported_ops = supported_ops
|
||||||
if supported_types is None:
|
if supported_types is None:
|
||||||
supported_types = []
|
supported_types = set()
|
||||||
self.supported_types = supported_types
|
self.supported_types = supported_types
|
||||||
if experimental_select_user_tf_ops is None:
|
if experimental_select_user_tf_ops is None:
|
||||||
self.experimental_select_user_tf_ops = []
|
self.experimental_select_user_tf_ops = set()
|
||||||
|
|
||||||
|
|
||||||
class QuantizationMode(object):
|
class QuantizationMode(object):
|
||||||
"""QuantizationMode determines the quantized conversion from user options."""
|
"""QuantizationMode determines the quantization type from user options."""
|
||||||
|
|
||||||
def __init__(self, optimizations, target_spec, representative_dataset,
|
def __init__(self, optimizations, target_spec, representative_dataset,
|
||||||
graph_def):
|
graph_def):
|
||||||
@ -205,7 +197,6 @@ class QuantizationMode(object):
|
|||||||
# TODO(b/162537905): Refactor the following quantization functions -
|
# TODO(b/162537905): Refactor the following quantization functions -
|
||||||
# re-organize and refactor for better readability.
|
# re-organize and refactor for better readability.
|
||||||
def post_training_int8_no_float(self):
|
def post_training_int8_no_float(self):
|
||||||
"""Post training int8 quantize, disallow float fallback."""
|
|
||||||
return (self._any_optimization_enabled() and
|
return (self._any_optimization_enabled() and
|
||||||
self._is_int8_target_required() and
|
self._is_int8_target_required() and
|
||||||
not self._is_int16x8_target_required() and
|
not self._is_int16x8_target_required() and
|
||||||
@ -213,19 +204,16 @@ class QuantizationMode(object):
|
|||||||
self._representative_dataset is not None)
|
self._representative_dataset is not None)
|
||||||
|
|
||||||
def post_training_int8_allow_float(self):
|
def post_training_int8_allow_float(self):
|
||||||
"""Post training int8 quantize, allow float fallback."""
|
|
||||||
return (self._any_optimization_enabled() and
|
return (self._any_optimization_enabled() and
|
||||||
not self._is_int16x8_target_required() and
|
not self._is_int16x8_target_required() and
|
||||||
self._representative_dataset is not None and
|
self._representative_dataset is not None and
|
||||||
self._smallest_supported_type() == _dtypes.int8)
|
self._smallest_supported_type() == _dtypes.int8)
|
||||||
|
|
||||||
def is_post_training_integer_quantize_8(self):
|
def is_post_training_integer_quantize_8(self):
|
||||||
"""Post training integer 8 quantization."""
|
|
||||||
return (self.post_training_int8_no_float() or
|
return (self.post_training_int8_no_float() or
|
||||||
self.post_training_int8_allow_float())
|
self.post_training_int8_allow_float())
|
||||||
|
|
||||||
def is_post_training_integer_quantize_16x8(self):
|
def is_post_training_integer_quantize_16x8(self):
|
||||||
"""Post training integer 16x8 quantization."""
|
|
||||||
return (self.post_training_int16x8_no_float() or
|
return (self.post_training_int16x8_no_float() or
|
||||||
self.post_training_int16x8_allow_float())
|
self.post_training_int16x8_allow_float())
|
||||||
|
|
||||||
@ -239,7 +227,6 @@ class QuantizationMode(object):
|
|||||||
self.contains_training_quant_op())
|
self.contains_training_quant_op())
|
||||||
|
|
||||||
def post_training_int16x8_no_float(self):
|
def post_training_int16x8_no_float(self):
|
||||||
"""Post training int16x8 quantize, disallow float fallback."""
|
|
||||||
return (self._any_optimization_enabled() and
|
return (self._any_optimization_enabled() and
|
||||||
not self._is_int8_target_required() and
|
not self._is_int8_target_required() and
|
||||||
self._is_int16x8_target_required() and
|
self._is_int16x8_target_required() and
|
||||||
@ -247,13 +234,11 @@ class QuantizationMode(object):
|
|||||||
self._representative_dataset is not None)
|
self._representative_dataset is not None)
|
||||||
|
|
||||||
def post_training_int16x8_allow_float(self):
|
def post_training_int16x8_allow_float(self):
|
||||||
"""Post training int16x8 quantize, allow float fallback."""
|
|
||||||
return (self._any_optimization_enabled() and
|
return (self._any_optimization_enabled() and
|
||||||
self._is_int16x8_target_required() and
|
self._is_int16x8_target_required() and
|
||||||
self._is_allow_float())
|
self._is_allow_float())
|
||||||
|
|
||||||
def post_training_dynamic_range_int8(self):
|
def post_training_dynamic_range_int8(self):
|
||||||
"""Post training int8 const, on-the-fly int8 quantize of dynamic tensors."""
|
|
||||||
# Post-training dynamic range quantization is only enabled if post-training
|
# Post-training dynamic range quantization is only enabled if post-training
|
||||||
# int8 quantization and training time quantization was not done.
|
# int8 quantization and training time quantization was not done.
|
||||||
return (self._any_optimization_enabled() and
|
return (self._any_optimization_enabled() and
|
||||||
@ -262,7 +247,6 @@ class QuantizationMode(object):
|
|||||||
self._smallest_supported_type() == _dtypes.int8)
|
self._smallest_supported_type() == _dtypes.int8)
|
||||||
|
|
||||||
def post_training_fp16(self):
|
def post_training_fp16(self):
|
||||||
"""Post training fp16 quantize."""
|
|
||||||
return (self._any_optimization_enabled() and
|
return (self._any_optimization_enabled() and
|
||||||
self._smallest_supported_type() == _dtypes.float16)
|
self._smallest_supported_type() == _dtypes.float16)
|
||||||
|
|
||||||
@ -416,21 +400,20 @@ class TFLiteConverterBase(object):
|
|||||||
"""Converter subclass to share functionality between V1 and V2 converters."""
|
"""Converter subclass to share functionality between V1 and V2 converters."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.allow_custom_ops = False
|
self.optimizations = set()
|
||||||
self.target_spec = TargetSpec()
|
|
||||||
self.optimizations = []
|
|
||||||
self.representative_dataset = None
|
self.representative_dataset = None
|
||||||
|
self.target_spec = TargetSpec()
|
||||||
|
self.allow_custom_ops = False
|
||||||
self.experimental_new_converter = True
|
self.experimental_new_converter = True
|
||||||
self._experimental_new_quantizer = False
|
self._experimental_new_quantizer = False
|
||||||
self._experimental_calibrate_only = False
|
self._experimental_calibrate_only = False
|
||||||
# The 'GraphDebugInfo' contains the stack traces of all the original nodes
|
self._experimental_sparsify_model = False
|
||||||
# in the `GraphDef` to the converter.
|
self._debug_info = None # contains the stack traces of all the original
|
||||||
self._debug_info = None
|
# nodes in the `GraphDef` to the converter.
|
||||||
self.saved_model_dir = None
|
self.saved_model_dir = None
|
||||||
self._saved_model_tags = None
|
self._saved_model_tags = None
|
||||||
self._saved_model_version = 0
|
self._saved_model_version = 0
|
||||||
self._saved_model_exported_names = []
|
self._saved_model_exported_names = []
|
||||||
self._experimental_sparsify_model = False
|
|
||||||
|
|
||||||
def _grappler_config(self, optimizers=None):
|
def _grappler_config(self, optimizers=None):
|
||||||
"""Creates a tf.compat.v1.ConfigProto for configuring Grappler.
|
"""Creates a tf.compat.v1.ConfigProto for configuring Grappler.
|
||||||
@ -684,9 +667,9 @@ class TFLiteSavedModelConverterV2(TFLiteConverterBaseV2):
|
|||||||
saved_model_dir: Directory of the SavedModel.
|
saved_model_dir: Directory of the SavedModel.
|
||||||
saved_model_tags: Set of tags identifying the MetaGraphDef within the
|
saved_model_tags: Set of tags identifying the MetaGraphDef within the
|
||||||
SavedModel to analyze. All tags in the tag set must be present. (default
|
SavedModel to analyze. All tags in the tag set must be present. (default
|
||||||
set(SERVING)).
|
{tf.saved_model.SERVING}).
|
||||||
saved_model_exported_names: Names to be exported (default: export all)
|
saved_model_exported_names: Names to be exported when the saved model
|
||||||
when the saved model import path is on.
|
import path is on.
|
||||||
trackable_obj: tf.AutoTrackable object associated with `funcs`. A
|
trackable_obj: tf.AutoTrackable object associated with `funcs`. A
|
||||||
reference to this object needs to be maintained so that Variables do not
|
reference to this object needs to be maintained so that Variables do not
|
||||||
get garbage collected since functions have a weak reference to
|
get garbage collected since functions have a weak reference to
|
||||||
@ -972,20 +955,21 @@ class TFLiteConverterV2(TFLiteFrozenGraphConverterV2):
|
|||||||
"""Converts a TensorFlow model into TensorFlow Lite model.
|
"""Converts a TensorFlow model into TensorFlow Lite model.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
allow_custom_ops: Boolean indicating whether to allow custom operations.
|
optimizations: Experimental flag, subject to change. Set of optimizations
|
||||||
When False, any unknown operation is an error. When True, custom ops are
|
to apply. e.g {tf.lite.Optimize.DEFAULT}. (default None, must be None or a
|
||||||
created for any op that is unknown. The developer needs to provide these
|
set of values of type `tf.lite.Optimize`)
|
||||||
to the TensorFlow Lite runtime with a custom resolver. (default False)
|
representative_dataset: A generator function used for integer quantization
|
||||||
optimizations: Experimental flag, subject to change. A list of optimizations
|
where each generated sample has the same order, type and shape as the
|
||||||
to apply when converting the model. E.g. `[Optimize.DEFAULT]`
|
inputs to the model. Usually, this is a small subset of a few hundred
|
||||||
representative_dataset: A representative dataset that can be used to
|
samples randomly chosen, in no particular order, from the training or
|
||||||
generate input and output samples for the model. The converter can use the
|
evaluation dataset. This is an optional attribute, but required for full
|
||||||
dataset to evaluate different optimizations. Note that this is an optional
|
integer quantization, i.e, if `tf.int8` is the only supported type in
|
||||||
attribute but it is necessary if INT8 is the only support builtin ops in
|
`target_spec.supported_types`. Refer to `tf.lite.RepresentativeDataset`.
|
||||||
target ops.
|
(default None)
|
||||||
target_spec: Experimental flag, subject to change. Specifications of target
|
target_spec: Experimental flag, subject to change. Specifications of target
|
||||||
device, including supported ops set, supported types and a set of user's
|
device, including supported ops set, supported types and a set of user's
|
||||||
defined TensorFlow operators required in the TensorFlow Lite runtime.
|
defined TensorFlow operators required in the TensorFlow Lite runtime.
|
||||||
|
Refer to `tf.lite.TargetSpec`.
|
||||||
inference_input_type: Data type of the input layer. Note that integer types
|
inference_input_type: Data type of the input layer. Note that integer types
|
||||||
(tf.int8 and tf.uint8) are currently only supported for post training
|
(tf.int8 and tf.uint8) are currently only supported for post training
|
||||||
integer quantization and quantization aware training. (default tf.float32,
|
integer quantization and quantization aware training. (default tf.float32,
|
||||||
@ -994,8 +978,13 @@ class TFLiteConverterV2(TFLiteFrozenGraphConverterV2):
|
|||||||
types (tf.int8 and tf.uint8) are currently only supported for post
|
types (tf.int8 and tf.uint8) are currently only supported for post
|
||||||
training integer quantization and quantization aware training. (default
|
training integer quantization and quantization aware training. (default
|
||||||
tf.float32, must be in {tf.float32, tf.int8, tf.uint8})
|
tf.float32, must be in {tf.float32, tf.int8, tf.uint8})
|
||||||
|
allow_custom_ops: Boolean indicating whether to allow custom operations.
|
||||||
|
When False, any unknown operation is an error. When True, custom ops are
|
||||||
|
created for any op that is unknown. The developer needs to provide these
|
||||||
|
to the TensorFlow Lite runtime with a custom resolver. (default False)
|
||||||
experimental_new_converter: Experimental flag, subject to change. Enables
|
experimental_new_converter: Experimental flag, subject to change. Enables
|
||||||
MLIR-based conversion instead of TOCO conversion. (default True)
|
MLIR-based conversion instead of TOCO conversion. (default True)
|
||||||
|
|
||||||
Example usage:
|
Example usage:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
@ -1063,7 +1052,8 @@ class TFLiteConverterV2(TFLiteFrozenGraphConverterV2):
|
|||||||
`signatures` attribute of the MetaGraphdef is used. (default
|
`signatures` attribute of the MetaGraphdef is used. (default
|
||||||
saved_model.signatures)
|
saved_model.signatures)
|
||||||
tags: Set of tags identifying the MetaGraphDef within the SavedModel to
|
tags: Set of tags identifying the MetaGraphDef within the SavedModel to
|
||||||
analyze. All tags in the tag set must be present. (default set(SERVING))
|
analyze. All tags in the tag set must be present. (default
|
||||||
|
{tf.saved_model.SERVING} or {'serve'})
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
TFLiteConverter object.
|
TFLiteConverter object.
|
||||||
@ -1209,9 +1199,13 @@ class TFLiteConverterBaseV1(TFLiteConverterBase):
|
|||||||
|
|
||||||
if (requires_quantized_input_stats and
|
if (requires_quantized_input_stats and
|
||||||
not converter_kwargs["quantized_input_stats"]):
|
not converter_kwargs["quantized_input_stats"]):
|
||||||
raise ValueError("The `quantized_input_stats` flag must be defined when "
|
raise ValueError(
|
||||||
"either `inference_type` flag or `inference_input_type` "
|
"The `quantized_input_stats` flag must be defined when either "
|
||||||
"flag is set to tf.uint8 or tf.int8.")
|
"`inference_type` flag or `inference_input_type` flag is set to "
|
||||||
|
"tf.int8 or tf.uint8. Currently, `inference_type={}` and "
|
||||||
|
"`inference_input_type={}`.".format(
|
||||||
|
_get_tf_type_name(converter_kwargs["inference_type"]),
|
||||||
|
_get_tf_type_name(converter_kwargs["inference_input_type"])))
|
||||||
|
|
||||||
def convert(self):
|
def convert(self):
|
||||||
"""Converts a TensorFlow GraphDef based on instance variables.
|
"""Converts a TensorFlow GraphDef based on instance variables.
|
||||||
@ -1424,9 +1418,9 @@ class TFLiteSavedModelConverter(TFLiteConverterBaseV1):
|
|||||||
saved_model_dir: Directory of the SavedModel.
|
saved_model_dir: Directory of the SavedModel.
|
||||||
saved_model_tags: Set of tags identifying the MetaGraphDef within the
|
saved_model_tags: Set of tags identifying the MetaGraphDef within the
|
||||||
SavedModel to analyze. All tags in the tag set must be present. (default
|
SavedModel to analyze. All tags in the tag set must be present. (default
|
||||||
set(SERVING)).
|
{tf.saved_model.SERVING}).
|
||||||
saved_model_exported_names: Names to be exported (default: export all)
|
saved_model_exported_names: Names to be exported when the saved model
|
||||||
when the saved model import path is on.
|
import path is on.
|
||||||
experimental_debug_info_func: An experimental function to retrieve the
|
experimental_debug_info_func: An experimental function to retrieve the
|
||||||
graph debug info for a set of nodes from the `graph_def`.
|
graph debug info for a set of nodes from the `graph_def`.
|
||||||
|
|
||||||
@ -1645,33 +1639,42 @@ class TFLiteConverter(TFLiteFrozenGraphConverter):
|
|||||||
model into either a TFLite FlatBuffer or graph visualization.
|
model into either a TFLite FlatBuffer or graph visualization.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
inference_type: Target data type of real-number arrays in the output file.
|
optimizations: Experimental flag, subject to change. Set of optimizations to
|
||||||
Must be `{tf.float32, tf.uint8}`. If `optimzations` are provided, this
|
apply. e.g {tf.lite.Optimize.DEFAULT}. (default None, must be None or a
|
||||||
parameter is ignored. (default tf.float32)
|
set of values of type `tf.lite.Optimize`)
|
||||||
inference_input_type: Target data type of real-number input arrays. Allows
|
representative_dataset: A generator function used for integer quantization
|
||||||
for a different type for input arrays. If an integer type is provided and
|
where each generated sample has the same order, type and shape as the
|
||||||
`optimizations` are not used, `quantized_input_stats` must be provided. If
|
inputs to the model. Usually, this is a small subset of a few hundred
|
||||||
`inference_type` is tf.uint8, signaling conversion to a fully quantized
|
samples randomly chosen, in no particular order, from the training or
|
||||||
model from a quantization-aware trained input model, then
|
evaluation dataset. This is an optional attribute, but required for full
|
||||||
`inference_input_type` defaults to tf.uint8. In all other cases,
|
integer quantization, i.e, if `tf.int8` is the only supported type in
|
||||||
`inference_input_type` defaults to tf.float32. Must be `{tf.float32,
|
`target_spec.supported_types`. Refer to `tf.lite.RepresentativeDataset`.
|
||||||
tf.uint8, tf.int8}`
|
(default None)
|
||||||
inference_output_type: Target data type of real-number output arrays. Allows
|
target_spec: Experimental flag, subject to change. Specifications of target
|
||||||
for a different type for output arrays. If `inference_type` is tf.uint8,
|
device, including supported ops set, supported types and a set of user's
|
||||||
signaling conversion to a fully quantized model from a quantization-aware
|
defined TensorFlow operators required in the TensorFlow Lite runtime.
|
||||||
trained output model, then `inference_output_type` defaults to tf.uint8.
|
Refer to `tf.lite.TargetSpec`.
|
||||||
In all other cases, `inference_output_type` must be tf.float32, an error
|
inference_type: Data type of numeric arrays, excluding the input layer.
|
||||||
will be thrown otherwise. Must be `{tf.float32, tf.uint8, tf.int8}`
|
(default tf.float32, must be in {tf.float32, tf.int8, tf.uint8})
|
||||||
output_format: Output file format. Currently must be `{TFLITE,
|
inference_input_type: Data type of the numeric arrays in the input layer. If
|
||||||
GRAPHVIZ_DOT}`. (default TFLITE)
|
`inference_input_type` is in {tf.int8, tf.uint8}, then
|
||||||
quantized_input_stats: Dict of strings representing input tensor names
|
`quantized_input_stats` must be provided. (default is the value assigned
|
||||||
mapped to tuple of floats representing the mean and standard deviation
|
to `inference_type`, must be in {tf.float32, tf.int8, tf.uint8})
|
||||||
of the training data (e.g., {"foo" : (0., 1.)}). Only need if
|
inference_output_type: Data type of the numeric arrays in the output layer.
|
||||||
`inference_input_type` is `QUANTIZED_UINT8`. real_input_value =
|
(default is the value assigned to `inference_type`, must be in
|
||||||
(quantized_input_value - mean_value) / std_dev_value. (default {})
|
{tf.float32, tf.int8, tf.uint8})
|
||||||
default_ranges_stats: Tuple of integers representing (min, max) range values
|
quantized_input_stats: Map of input tensor names to a tuple of floats
|
||||||
for all arrays without a specified range. Intended for experimenting with
|
representing the mean and standard deviation of the training data.
|
||||||
quantization via "dummy quantization". (default None)
|
(e.g., {"foo" : (0., 1.)}). Required if `inference_input_type` is tf.int8
|
||||||
|
or tf.uint8. (default None)
|
||||||
|
default_ranges_stats: Tuple of integers (min, max) representing range values
|
||||||
|
for all numeric arrays without a specified range. Intended for
|
||||||
|
experimenting with quantization via "dummy quantization". (default None)
|
||||||
|
allow_custom_ops: Boolean indicating whether to allow custom operations.
|
||||||
|
When False any unknown operation is an error. When True, custom ops are
|
||||||
|
created for any op that is unknown. The developer will need to provide
|
||||||
|
these to the TensorFlow Lite runtime with a custom resolver. (default
|
||||||
|
False)
|
||||||
drop_control_dependency: Boolean indicating whether to drop control
|
drop_control_dependency: Boolean indicating whether to drop control
|
||||||
dependencies silently. This is due to TFLite not supporting control
|
dependencies silently. This is due to TFLite not supporting control
|
||||||
dependencies. (default True)
|
dependencies. (default True)
|
||||||
@ -1683,37 +1686,25 @@ class TFLiteConverter(TFLiteFrozenGraphConverter):
|
|||||||
change_concat_input_ranges: Boolean to change behavior of min/max ranges for
|
change_concat_input_ranges: Boolean to change behavior of min/max ranges for
|
||||||
inputs and outputs of the concat operator for quantized models. Changes
|
inputs and outputs of the concat operator for quantized models. Changes
|
||||||
the ranges of concat operator overlap when true. (default False)
|
the ranges of concat operator overlap when true. (default False)
|
||||||
allow_custom_ops: Boolean indicating whether to allow custom operations.
|
output_format: Output file format. (default
|
||||||
When false any unknown operation is an error. When true, custom ops are
|
tf.compat.v1.lite.constants.TFLITE, must be in
|
||||||
created for any op that is unknown. The developer will need to provide
|
{tf.compat.v1.lite.constants.TFLITE,
|
||||||
these to the TensorFlow Lite runtime with a custom resolver. (default
|
tf.compat.v1.lite.constants.GRAPHVIZ_DOT})
|
||||||
False)
|
|
||||||
post_training_quantize: Deprecated. Please specify `[Optimize.DEFAULT]` for
|
|
||||||
`optimizations` instead. Boolean indicating whether to quantize the
|
|
||||||
weights of the converted float model. Model size will be reduced and
|
|
||||||
there will be latency improvements (at the cost of accuracy). (default
|
|
||||||
False)
|
|
||||||
dump_graphviz_dir: Full filepath of folder to dump the graphs at various
|
dump_graphviz_dir: Full filepath of folder to dump the graphs at various
|
||||||
stages of processing GraphViz .dot files. Preferred over
|
stages of processing GraphViz .dot files. Preferred over
|
||||||
--output_format=GRAPHVIZ_DOT in order to keep the requirements of the
|
`output_format=tf.compat.v1.lite.constants.GRAPHVIZ_DOT` in order to keep
|
||||||
output file. (default None)
|
the requirements of the output file. (default None)
|
||||||
dump_graphviz_video: Boolean indicating whether to dump the graph after
|
dump_graphviz_video: Boolean indicating whether to dump the GraphViz .dot
|
||||||
every graph transformation. (default False)
|
files after every graph transformation. Requires the `dump_graphviz_dir`
|
||||||
conversion_summary_dir: A string indicating the path to the generated
|
flag to be specified. (default False)
|
||||||
conversion logs.
|
conversion_summary_dir: Full path of the directory to store conversion logs.
|
||||||
target_ops: Deprecated. Please specify `target_spec.supported_ops` instead.
|
(default None)
|
||||||
Set of OpsSet options indicating which converter to use. (default
|
target_ops: Deprecated. Please use `target_spec.supported_ops` instead.
|
||||||
set([OpsSet.TFLITE_BUILTINS]))
|
post_training_quantize: Deprecated. Please use `optimizations` instead and
|
||||||
target_spec: Experimental flag, subject to change. Specifications of target
|
set it to `{tf.lite.Optimize.DEFAULT}`. (default False)
|
||||||
device, including supported ops set, supported types and a set of user's
|
|
||||||
defined TensorFlow operators required in the TensorFlow Lite runtime.
|
|
||||||
optimizations: Experimental flag, subject to change. A list of optimizations
|
|
||||||
to apply when converting the model. E.g. `[Optimize.DEFAULT]`
|
|
||||||
representative_dataset: A representative dataset that can be used to
|
|
||||||
generate input and output samples for the model. The converter can use the
|
|
||||||
dataset to evaluate different optimizations.
|
|
||||||
experimental_new_converter: Experimental flag, subject to change. Enables
|
experimental_new_converter: Experimental flag, subject to change. Enables
|
||||||
MLIR-based conversion instead of TOCO conversion. (default True)
|
MLIR-based conversion instead of TOCO conversion. (default True)
|
||||||
|
|
||||||
Example usage:
|
Example usage:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
@ -1911,9 +1902,10 @@ class TFLiteConverter(TFLiteFrozenGraphConverter):
|
|||||||
output_arrays: List of output tensors to freeze graph with. Uses output
|
output_arrays: List of output tensors to freeze graph with. Uses output
|
||||||
arrays from SignatureDef when none are provided. (default None)
|
arrays from SignatureDef when none are provided. (default None)
|
||||||
tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to
|
tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to
|
||||||
analyze. All tags in the tag set must be present. (default set("serve"))
|
analyze. All tags in the tag set must be present. (default
|
||||||
|
{tf.saved_model.SERVING})
|
||||||
signature_key: Key identifying SignatureDef containing inputs and outputs.
|
signature_key: Key identifying SignatureDef containing inputs and outputs.
|
||||||
(default DEFAULT_SERVING_SIGNATURE_DEF_KEY)
|
(default tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
TFLiteConverter class.
|
TFLiteConverter class.
|
||||||
|
@ -1239,18 +1239,22 @@ class FromSessionTest(TestModels, parameterized.TestCase):
|
|||||||
quantized_converter.inference_type = quantized_type
|
quantized_converter.inference_type = quantized_type
|
||||||
quantized_converter.convert()
|
quantized_converter.convert()
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
'The `quantized_input_stats` flag must be defined when '
|
'The `quantized_input_stats` flag must be defined when either '
|
||||||
'either `inference_type` flag or `inference_input_type` '
|
'`inference_type` flag or `inference_input_type` flag is set to '
|
||||||
'flag is set to tf.uint8 or tf.int8.', str(error.exception))
|
'tf.int8 or tf.uint8. Currently, `inference_type=tf.{}` and '
|
||||||
|
'`inference_input_type=None`.'.format(quantized_type.name),
|
||||||
|
str(error.exception))
|
||||||
|
|
||||||
with self.assertRaises(ValueError) as error:
|
with self.assertRaises(ValueError) as error:
|
||||||
quantized_converter.inference_type = dtypes.float32
|
quantized_converter.inference_type = dtypes.float32
|
||||||
quantized_converter.inference_input_type = quantized_type
|
quantized_converter.inference_input_type = quantized_type
|
||||||
quantized_converter.convert()
|
quantized_converter.convert()
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
'The `quantized_input_stats` flag must be defined when '
|
'The `quantized_input_stats` flag must be defined when either '
|
||||||
'either `inference_type` flag or `inference_input_type` '
|
'`inference_type` flag or `inference_input_type` flag is set to '
|
||||||
'flag is set to tf.uint8 or tf.int8.', str(error.exception))
|
'tf.int8 or tf.uint8. Currently, `inference_type=tf.float32` and '
|
||||||
|
'`inference_input_type=tf.{}`.'.format(quantized_type.name),
|
||||||
|
str(error.exception))
|
||||||
|
|
||||||
quantized_converter.inference_type = quantized_type
|
quantized_converter.inference_type = quantized_type
|
||||||
quantized_converter.inference_input_type = quantized_type
|
quantized_converter.inference_input_type = quantized_type
|
||||||
|
@ -127,9 +127,9 @@ def _convert_tflite_enum_type_to_tf_type(tflite_enum_type):
|
|||||||
return tf_type
|
return tf_type
|
||||||
|
|
||||||
|
|
||||||
def _get_tf_type_name(tf_type):
|
def get_tf_type_name(tf_type):
|
||||||
"""Converts tf.dtype (eg: tf.float32) to str (eg: "tf.float32")."""
|
"""Converts tf.dtype (eg: tf.float32) to str (eg: "tf.float32")."""
|
||||||
return "tf." + tf_type.name
|
return "tf." + tf_type.name if tf_type else None
|
||||||
|
|
||||||
|
|
||||||
def get_tensor_name(tensor):
|
def get_tensor_name(tensor):
|
||||||
@ -674,7 +674,7 @@ def _modify_model_input_type(model, inference_input_type=dtypes.float32):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Initial model input type must be tf.float32. Expected type for "
|
"Initial model input type must be tf.float32. Expected type for "
|
||||||
"tensor with name '{}' is tf.float32, instead type is {}".format(
|
"tensor with name '{}' is tf.float32, instead type is {}".format(
|
||||||
float_tensor.name, _get_tf_type_name(float_type)))
|
float_tensor.name, get_tf_type_name(float_type)))
|
||||||
# If found, validate that the operator output is quantized and compatible
|
# If found, validate that the operator output is quantized and compatible
|
||||||
# with the final model input type
|
# with the final model input type
|
||||||
quant_type = _convert_tflite_enum_type_to_tf_type(quant_tensor.type)
|
quant_type = _convert_tflite_enum_type_to_tf_type(quant_tensor.type)
|
||||||
@ -683,17 +683,17 @@ def _modify_model_input_type(model, inference_input_type=dtypes.float32):
|
|||||||
"Initial model input is not quantized. Expected type for "
|
"Initial model input is not quantized. Expected type for "
|
||||||
"tensor with name '{}' should be in {}, instead type is {}".format(
|
"tensor with name '{}' should be in {}, instead type is {}".format(
|
||||||
quant_tensor.name,
|
quant_tensor.name,
|
||||||
tuple(_get_tf_type_name(t) for t in
|
tuple(get_tf_type_name(t) for t in
|
||||||
_MAP_QUANT_TO_IO_TYPES.keys()),
|
_MAP_QUANT_TO_IO_TYPES.keys()),
|
||||||
_get_tf_type_name(quant_type)))
|
get_tf_type_name(quant_type)))
|
||||||
else:
|
else:
|
||||||
inference_io_types = _MAP_QUANT_TO_IO_TYPES[quant_type]
|
inference_io_types = _MAP_QUANT_TO_IO_TYPES[quant_type]
|
||||||
if inference_input_type not in inference_io_types:
|
if inference_input_type not in inference_io_types:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Unsupported `inference_input_type` value. Expected to be in "
|
"Unsupported `inference_input_type` value. Expected to be in "
|
||||||
"{}, instead got {}.".format(
|
"{}, instead got {}.".format(
|
||||||
tuple(_get_tf_type_name(t) for t in inference_io_types),
|
tuple(get_tf_type_name(t) for t in inference_io_types),
|
||||||
_get_tf_type_name(inference_input_type)))
|
get_tf_type_name(inference_input_type)))
|
||||||
input_quant_ops.append(op)
|
input_quant_ops.append(op)
|
||||||
|
|
||||||
if len(subgraph.inputs) != len(input_quant_ops):
|
if len(subgraph.inputs) != len(input_quant_ops):
|
||||||
@ -725,7 +725,7 @@ def _modify_model_input_type(model, inference_input_type=dtypes.float32):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Unsupported `inference_input_type` value {}.".format(
|
"Unsupported `inference_input_type` value {}.".format(
|
||||||
_get_tf_type_name(inference_input_type)))
|
get_tf_type_name(inference_input_type)))
|
||||||
|
|
||||||
|
|
||||||
def _modify_model_output_type(model, inference_output_type=dtypes.float32):
|
def _modify_model_output_type(model, inference_output_type=dtypes.float32):
|
||||||
@ -768,7 +768,7 @@ def _modify_model_output_type(model, inference_output_type=dtypes.float32):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Initial model output type must be tf.float32. Expected type for "
|
"Initial model output type must be tf.float32. Expected type for "
|
||||||
"tensor with name '{}' is tf.float32, instead type is {}".format(
|
"tensor with name '{}' is tf.float32, instead type is {}".format(
|
||||||
float_tensor.name, _get_tf_type_name(float_type)))
|
float_tensor.name, get_tf_type_name(float_type)))
|
||||||
# If found, validate that the operator input is quantized and compatible
|
# If found, validate that the operator input is quantized and compatible
|
||||||
# with the final model output type
|
# with the final model output type
|
||||||
quant_type = _convert_tflite_enum_type_to_tf_type(quant_tensor.type)
|
quant_type = _convert_tflite_enum_type_to_tf_type(quant_tensor.type)
|
||||||
@ -777,17 +777,17 @@ def _modify_model_output_type(model, inference_output_type=dtypes.float32):
|
|||||||
"Initial model output is not dequantized. Expected type for "
|
"Initial model output is not dequantized. Expected type for "
|
||||||
"tensor with name '{}' should be in {}, instead type is {}".format(
|
"tensor with name '{}' should be in {}, instead type is {}".format(
|
||||||
quant_tensor.name,
|
quant_tensor.name,
|
||||||
tuple(_get_tf_type_name(t) for t in
|
tuple(get_tf_type_name(t) for t in
|
||||||
_MAP_QUANT_TO_IO_TYPES.keys()),
|
_MAP_QUANT_TO_IO_TYPES.keys()),
|
||||||
_get_tf_type_name(quant_type)))
|
get_tf_type_name(quant_type)))
|
||||||
else:
|
else:
|
||||||
inference_io_types = _MAP_QUANT_TO_IO_TYPES[quant_type]
|
inference_io_types = _MAP_QUANT_TO_IO_TYPES[quant_type]
|
||||||
if inference_output_type not in inference_io_types:
|
if inference_output_type not in inference_io_types:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Unsupported `inference_output_type` value. Expected to be in "
|
"Unsupported `inference_output_type` value. Expected to be in "
|
||||||
"{}, instead got {}.".format(
|
"{}, instead got {}.".format(
|
||||||
tuple(_get_tf_type_name(t) for t in inference_io_types),
|
tuple(get_tf_type_name(t) for t in inference_io_types),
|
||||||
_get_tf_type_name(inference_output_type)))
|
get_tf_type_name(inference_output_type)))
|
||||||
output_dequant_ops.append(op)
|
output_dequant_ops.append(op)
|
||||||
|
|
||||||
if len(subgraph.outputs) != len(output_dequant_ops):
|
if len(subgraph.outputs) != len(output_dequant_ops):
|
||||||
@ -834,7 +834,7 @@ def _modify_model_output_type(model, inference_output_type=dtypes.float32):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Unsupported `inference_output_type` value {}.".format(
|
"Unsupported `inference_output_type` value {}.".format(
|
||||||
_get_tf_type_name(inference_output_type)))
|
get_tf_type_name(inference_output_type)))
|
||||||
|
|
||||||
|
|
||||||
def modify_model_io_type(
|
def modify_model_io_type(
|
||||||
|
@ -26,11 +26,30 @@ from tensorflow.lite.testing.zip_test_utils import register_make_test_function
|
|||||||
@register_make_test_function()
|
@register_make_test_function()
|
||||||
def make_cast_tests(options):
|
def make_cast_tests(options):
|
||||||
"""Generate examples for cast."""
|
"""Generate examples for cast."""
|
||||||
test_parameters = [{
|
if options.use_experimental_converter:
|
||||||
"input_dtype": [tf.int32],
|
test_parameters = [
|
||||||
"output_dtype": [tf.float32],
|
{
|
||||||
"input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]],
|
"input_dtype": [tf.float32],
|
||||||
}]
|
"output_dtype": [tf.int16],
|
||||||
|
"input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"input_dtype": [tf.int16],
|
||||||
|
"output_dtype": [tf.float32],
|
||||||
|
"input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"input_dtype": [tf.int32],
|
||||||
|
"output_dtype": [tf.float32],
|
||||||
|
"input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]],
|
||||||
|
}]
|
||||||
|
else:
|
||||||
|
test_parameters = [
|
||||||
|
{
|
||||||
|
"input_dtype": [tf.int32],
|
||||||
|
"output_dtype": [tf.float32],
|
||||||
|
"input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]],
|
||||||
|
}]
|
||||||
|
|
||||||
def build_graph(parameters):
|
def build_graph(parameters):
|
||||||
"""Build the cast testing graph."""
|
"""Build the cast testing graph."""
|
||||||
|
@ -763,6 +763,7 @@ py_library(
|
|||||||
"//tensorflow/python/util:_pywrap_tfprof",
|
"//tensorflow/python/util:_pywrap_tfprof",
|
||||||
"//tensorflow/python/util:_pywrap_transform_graph",
|
"//tensorflow/python/util:_pywrap_transform_graph",
|
||||||
"//tensorflow/python/util:_pywrap_util_port",
|
"//tensorflow/python/util:_pywrap_util_port",
|
||||||
|
"//tensorflow/python/platform:_pywrap_tf2",
|
||||||
":_pywrap_utils",
|
":_pywrap_utils",
|
||||||
":composite_tensor",
|
":composite_tensor",
|
||||||
":config",
|
":config",
|
||||||
@ -5266,7 +5267,10 @@ pywrap_tensorflow_macro(
|
|||||||
tf_additional_plugin_deps() +
|
tf_additional_plugin_deps() +
|
||||||
tf_additional_profiler_deps()) + if_xla_available([
|
tf_additional_profiler_deps()) + if_xla_available([
|
||||||
"//tensorflow/compiler/aot:tfcompile_lib",
|
"//tensorflow/compiler/aot:tfcompile_lib",
|
||||||
]) + if_static(extra_deps = ["//tensorflow/core/platform:tensor_float_32_utils"]),
|
]) + if_static(extra_deps = [
|
||||||
|
"//tensorflow/core/platform:tensor_float_32_utils",
|
||||||
|
"//tensorflow/core/platform:enable_tf2_utils",
|
||||||
|
]),
|
||||||
)
|
)
|
||||||
|
|
||||||
# ** Targets for Windows build (start) **
|
# ** Targets for Windows build (start) **
|
||||||
@ -6779,6 +6783,8 @@ py_test(
|
|||||||
":client_testlib",
|
":client_testlib",
|
||||||
":framework_combinations",
|
":framework_combinations",
|
||||||
":tf2",
|
":tf2",
|
||||||
|
"//tensorflow/python/compat:v2_compat",
|
||||||
|
"//tensorflow/python/data/kernel_tests:test_base",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -33,7 +33,7 @@ from tensorflow.python.util.tf_export import tf_export
|
|||||||
# This value changes every day with an automatic CL. It can be modified in code
|
# This value changes every day with an automatic CL. It can be modified in code
|
||||||
# via `forward_compatibility_horizon()` or with the environment variable
|
# via `forward_compatibility_horizon()` or with the environment variable
|
||||||
# TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date.
|
# TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date.
|
||||||
_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2021, 1, 2)
|
_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2021, 1, 5)
|
||||||
_FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS"
|
_FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS"
|
||||||
_FORWARD_COMPATIBILITY_DATE_NUMBER = None
|
_FORWARD_COMPATIBILITY_DATE_NUMBER = None
|
||||||
|
|
||||||
|
@ -21,6 +21,7 @@ from __future__ import print_function
|
|||||||
from tensorflow.python.compat import v2_compat
|
from tensorflow.python.compat import v2_compat
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.platform import _pywrap_tf2
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
@ -29,9 +30,13 @@ class DisableV2BehaviorTest(test.TestCase):
|
|||||||
def test_basic(self):
|
def test_basic(self):
|
||||||
t = constant_op.constant([1, 2, 3]) # creates a hidden context
|
t = constant_op.constant([1, 2, 3]) # creates a hidden context
|
||||||
self.assertTrue(isinstance(t, ops.EagerTensor))
|
self.assertTrue(isinstance(t, ops.EagerTensor))
|
||||||
|
t = _pywrap_tf2.is_enabled()
|
||||||
|
self.assertTrue(t)
|
||||||
v2_compat.disable_v2_behavior()
|
v2_compat.disable_v2_behavior()
|
||||||
t = constant_op.constant([1, 2, 3])
|
t = constant_op.constant([1, 2, 3])
|
||||||
self.assertFalse(isinstance(t, ops.EagerTensor))
|
self.assertFalse(isinstance(t, ops.EagerTensor))
|
||||||
|
t = _pywrap_tf2.is_enabled()
|
||||||
|
self.assertFalse(t)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -36,7 +36,10 @@ py_library(
|
|||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "trt_convert_py",
|
name = "trt_convert_py",
|
||||||
srcs = ["trt_convert.py"],
|
srcs = [
|
||||||
|
"trt_convert.py",
|
||||||
|
"utils.py",
|
||||||
|
],
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/compiler/tf2tensorrt:_pywrap_py_utils",
|
"//tensorflow/compiler/tf2tensorrt:_pywrap_py_utils",
|
||||||
|
@ -35,7 +35,9 @@ from tensorflow.compiler.tf2tensorrt._pywrap_py_utils import get_linked_tensorrt
|
|||||||
from tensorflow.compiler.tf2tensorrt._pywrap_py_utils import is_tensorrt_enabled
|
from tensorflow.compiler.tf2tensorrt._pywrap_py_utils import is_tensorrt_enabled
|
||||||
from tensorflow.core.framework import graph_pb2
|
from tensorflow.core.framework import graph_pb2
|
||||||
from tensorflow.core.protobuf import config_pb2
|
from tensorflow.core.protobuf import config_pb2
|
||||||
|
from tensorflow.core.protobuf import rewriter_config_pb2
|
||||||
from tensorflow.python.compiler.tensorrt import trt_convert
|
from tensorflow.python.compiler.tensorrt import trt_convert
|
||||||
|
from tensorflow.python.compiler.tensorrt import utils as trt_utils
|
||||||
from tensorflow.python.eager import def_function
|
from tensorflow.python.eager import def_function
|
||||||
from tensorflow.python.framework import graph_io
|
from tensorflow.python.framework import graph_io
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
@ -331,17 +333,21 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
|
|||||||
"""Get config proto based on specific settings."""
|
"""Get config proto based on specific settings."""
|
||||||
conversion_params = self.GetConversionParams(run_params)
|
conversion_params = self.GetConversionParams(run_params)
|
||||||
max_batch_size = self.GetMaxBatchSize(run_params)
|
max_batch_size = self.GetMaxBatchSize(run_params)
|
||||||
|
|
||||||
if graph_state == GraphState.INFERENCE and run_params.convert_online:
|
if graph_state == GraphState.INFERENCE and run_params.convert_online:
|
||||||
rewriter_cfg = trt_convert.get_tensorrt_rewriter_config(
|
rewriter_cfg = trt_convert.get_tensorrt_rewriter_config(
|
||||||
conversion_params,
|
conversion_params,
|
||||||
is_dynamic_op=run_params.dynamic_engine,
|
is_dynamic_op=run_params.dynamic_engine,
|
||||||
max_batch_size=max_batch_size)
|
max_batch_size=max_batch_size,
|
||||||
graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_cfg)
|
disable_non_trt_optimizers=self._disable_non_trt_optimizers)
|
||||||
else:
|
else:
|
||||||
graph_options = config_pb2.GraphOptions()
|
rewriter_cfg = rewriter_config_pb2.RewriterConfig()
|
||||||
|
if self._disable_non_trt_optimizers:
|
||||||
|
trt_utils.disable_non_trt_optimizers_in_rewriter_config(rewriter_cfg)
|
||||||
|
|
||||||
config = config_pb2.ConfigProto(
|
config = config_pb2.ConfigProto(
|
||||||
gpu_options=self._GetGPUOptions(), graph_options=graph_options)
|
gpu_options=self._GetGPUOptions(),
|
||||||
|
graph_options=config_pb2.GraphOptions(rewrite_options=rewriter_cfg))
|
||||||
return config
|
return config
|
||||||
|
|
||||||
def _GetFeedNames(self):
|
def _GetFeedNames(self):
|
||||||
|
@ -30,6 +30,7 @@ from tensorflow.core.protobuf import config_pb2
|
|||||||
from tensorflow.core.protobuf import meta_graph_pb2
|
from tensorflow.core.protobuf import meta_graph_pb2
|
||||||
from tensorflow.core.protobuf import rewriter_config_pb2
|
from tensorflow.core.protobuf import rewriter_config_pb2
|
||||||
from tensorflow.python.client import session
|
from tensorflow.python.client import session
|
||||||
|
from tensorflow.python.compiler.tensorrt import utils as trt_utils
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import wrap_function
|
from tensorflow.python.eager import wrap_function
|
||||||
from tensorflow.python.framework import convert_to_constants
|
from tensorflow.python.framework import convert_to_constants
|
||||||
@ -271,11 +272,14 @@ def _get_tensorrt_rewriter_config(conversion_params,
|
|||||||
# need to run constant folding again.
|
# need to run constant folding again.
|
||||||
rewriter_config_with_trt.optimizers.extend(
|
rewriter_config_with_trt.optimizers.extend(
|
||||||
["constfold", "layout", "constfold"])
|
["constfold", "layout", "constfold"])
|
||||||
|
|
||||||
rewriter_config_with_trt.meta_optimizer_iterations = (
|
rewriter_config_with_trt.meta_optimizer_iterations = (
|
||||||
rewriter_config_pb2.RewriterConfig.ONE)
|
rewriter_config_pb2.RewriterConfig.ONE)
|
||||||
optimizer = rewriter_config_with_trt.custom_optimizers.add()
|
optimizer = rewriter_config_with_trt.custom_optimizers.add()
|
||||||
# Add a constfold optimizer to cleanup the unused Const nodes.
|
|
||||||
rewriter_config_with_trt.custom_optimizers.add().name = "constfold"
|
if not disable_non_trt_optimizers:
|
||||||
|
# Add a constfold optimizer to cleanup the unused Const nodes.
|
||||||
|
rewriter_config_with_trt.custom_optimizers.add().name = "constfold"
|
||||||
|
|
||||||
optimizer.name = "TensorRTOptimizer"
|
optimizer.name = "TensorRTOptimizer"
|
||||||
optimizer.parameter_map[
|
optimizer.parameter_map[
|
||||||
@ -295,25 +299,11 @@ def _get_tensorrt_rewriter_config(conversion_params,
|
|||||||
optimizer.parameter_map["max_batch_size"].i = max_batch_size
|
optimizer.parameter_map["max_batch_size"].i = max_batch_size
|
||||||
optimizer.parameter_map["use_implicit_batch"].b = use_implicit_batch
|
optimizer.parameter_map["use_implicit_batch"].b = use_implicit_batch
|
||||||
|
|
||||||
# Disabling optimizers should happen after CopyFrom the template
|
# Disabling optimizers should happen after defining the TF-TRT grappler pass
|
||||||
# otherwise the template can overwrite the disablement.
|
# otherwise the template can overwrite the disablement.
|
||||||
if disable_non_trt_optimizers:
|
if disable_non_trt_optimizers:
|
||||||
off = rewriter_config_pb2.RewriterConfig.OFF
|
trt_utils.disable_non_trt_optimizers_in_rewriter_config(
|
||||||
rewriter_config_with_trt.layout_optimizer = off
|
rewriter_config_with_trt)
|
||||||
rewriter_config_with_trt.constant_folding = off
|
|
||||||
rewriter_config_with_trt.shape_optimization = off
|
|
||||||
rewriter_config_with_trt.remapping = off
|
|
||||||
rewriter_config_with_trt.arithmetic_optimization = off
|
|
||||||
rewriter_config_with_trt.dependency_optimization = off
|
|
||||||
rewriter_config_with_trt.loop_optimization = off
|
|
||||||
rewriter_config_with_trt.function_optimization = off
|
|
||||||
rewriter_config_with_trt.debug_stripper = off
|
|
||||||
rewriter_config_with_trt.disable_model_pruning = True
|
|
||||||
rewriter_config_with_trt.scoped_allocator_optimization = off
|
|
||||||
rewriter_config_with_trt.memory_optimization = (
|
|
||||||
rewriter_config_pb2.RewriterConfig.NO_MEM_OPT)
|
|
||||||
rewriter_config_with_trt.pin_to_host_optimization = off
|
|
||||||
rewriter_config_with_trt.auto_parallel.enable = False
|
|
||||||
|
|
||||||
return rewriter_config_with_trt
|
return rewriter_config_with_trt
|
||||||
|
|
||||||
@ -652,10 +642,19 @@ class TrtGraphConverter(object):
|
|||||||
return_elements=fetch_names,
|
return_elements=fetch_names,
|
||||||
name="")
|
name="")
|
||||||
|
|
||||||
|
calibrate_rewriter_cfg = rewriter_config_pb2.RewriterConfig()
|
||||||
|
if self._test_only_disable_non_trt_optimizers:
|
||||||
|
trt_utils.disable_non_trt_optimizers_in_rewriter_config(
|
||||||
|
calibrate_rewriter_cfg)
|
||||||
|
|
||||||
# Set allow_soft_placement=True to run the graph for calibration so that
|
# Set allow_soft_placement=True to run the graph for calibration so that
|
||||||
# OPs supported by TensorRT but don't have a GPU implementation are allowed
|
# OPs supported by TensorRT but don't have a GPU implementation are allowed
|
||||||
# to execute on CPU.
|
# to execute on CPU.
|
||||||
calibrate_config = config_pb2.ConfigProto(allow_soft_placement=True)
|
calibrate_config = config_pb2.ConfigProto(
|
||||||
|
allow_soft_placement=True,
|
||||||
|
graph_options=config_pb2.GraphOptions(
|
||||||
|
rewrite_options=calibrate_rewriter_cfg))
|
||||||
|
|
||||||
with session.Session(
|
with session.Session(
|
||||||
graph=self._calibration_graph,
|
graph=self._calibration_graph,
|
||||||
config=calibrate_config) as calibration_sess:
|
config=calibrate_config) as calibration_sess:
|
||||||
|
47
tensorflow/python/compiler/tensorrt/utils.py
Normal file
47
tensorflow/python/compiler/tensorrt/utils.py
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# =============================================================================
|
||||||
|
"""Exposes the Python wrapper conversion to trt_graph."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.core.protobuf import rewriter_config_pb2
|
||||||
|
|
||||||
|
|
||||||
|
def disable_non_trt_optimizers_in_rewriter_config(rewriter_config):
|
||||||
|
"""Modifies rewriter_config to disable all non-TRT optimizations."""
|
||||||
|
off = rewriter_config_pb2.RewriterConfig.OFF
|
||||||
|
|
||||||
|
rewriter_config.arithmetic_optimization = off
|
||||||
|
rewriter_config.auto_mixed_precision = off
|
||||||
|
rewriter_config.auto_parallel.enable = False
|
||||||
|
rewriter_config.constant_folding = off
|
||||||
|
rewriter_config.debug_stripper = off
|
||||||
|
rewriter_config.dependency_optimization = off
|
||||||
|
# This one needs to be ON to allow TF-TRT
|
||||||
|
rewriter_config.disable_meta_optimizer = False
|
||||||
|
rewriter_config.disable_model_pruning = True
|
||||||
|
rewriter_config.function_optimization = off
|
||||||
|
rewriter_config.implementation_selector = off
|
||||||
|
rewriter_config.layout_optimizer = off
|
||||||
|
rewriter_config.loop_optimization = off
|
||||||
|
rewriter_config.memory_optimization = (
|
||||||
|
rewriter_config_pb2.RewriterConfig.NO_MEM_OPT)
|
||||||
|
rewriter_config.min_graph_nodes = -1
|
||||||
|
rewriter_config.pin_to_host_optimization = off
|
||||||
|
rewriter_config.remapping = off
|
||||||
|
rewriter_config.scoped_allocator_optimization = off
|
||||||
|
rewriter_config.shape_optimization = off
|
@ -442,7 +442,8 @@ class DataServiceOpsTest(data_service_test_base.TestBase,
|
|||||||
results = {}
|
results = {}
|
||||||
for _ in range(elements_to_read):
|
for _ in range(elements_to_read):
|
||||||
val = next(it).numpy()
|
val = next(it).numpy()
|
||||||
if val not in results: results[val] = 0
|
if val not in results:
|
||||||
|
results[val] = 0
|
||||||
results[val] += 1
|
results[val] += 1
|
||||||
for i in range(num_elements):
|
for i in range(num_elements):
|
||||||
self.assertGreater(results[i], elements_to_read / num_elements / 2)
|
self.assertGreater(results[i], elements_to_read / num_elements / 2)
|
||||||
@ -527,6 +528,37 @@ class DataServiceOpsTest(data_service_test_base.TestBase,
|
|||||||
data_service_ops.distribute(
|
data_service_ops.distribute(
|
||||||
processing_mode="invalid", service="grpc://localhost:5000"))
|
processing_mode="invalid", service="grpc://localhost:5000"))
|
||||||
|
|
||||||
|
@combinations.generate(test_base.eager_only_combinations())
|
||||||
|
def testZipDifferentProcessingModesDatasets(self):
|
||||||
|
cluster = self.create_cluster(num_workers=1)
|
||||||
|
num_elements = 100
|
||||||
|
ds1 = dataset_ops.Dataset.range(num_elements)
|
||||||
|
ds1 = self.make_distributed_dataset(
|
||||||
|
ds1, cluster, processing_mode="distributed_epoch")
|
||||||
|
ds2 = dataset_ops.Dataset.range(num_elements)
|
||||||
|
ds2 = self.make_distributed_dataset(
|
||||||
|
ds2, cluster, processing_mode="parallel_epochs")
|
||||||
|
ds = dataset_ops.Dataset.zip((ds1, ds2))
|
||||||
|
self.assertDatasetProduces(
|
||||||
|
ds,
|
||||||
|
list(zip(range(num_elements), range(num_elements))),
|
||||||
|
assert_items_equal=True)
|
||||||
|
|
||||||
|
@combinations.generate(test_base.eager_only_combinations())
|
||||||
|
def testZipDifferentProcessingModesDatasetsSharedJobName(self):
|
||||||
|
cluster = self.create_cluster(num_workers=1)
|
||||||
|
num_elements = 100
|
||||||
|
ds1 = dataset_ops.Dataset.range(num_elements)
|
||||||
|
ds1 = self.make_distributed_dataset(
|
||||||
|
ds1, cluster, processing_mode="distributed_epoch", job_name="job_name")
|
||||||
|
ds2 = dataset_ops.Dataset.range(num_elements)
|
||||||
|
ds2 = self.make_distributed_dataset(
|
||||||
|
ds2, cluster, processing_mode="parallel_epochs", job_name="job_name")
|
||||||
|
ds = dataset_ops.Dataset.zip((ds1, ds2))
|
||||||
|
with self.assertRaisesRegex(errors.FailedPreconditionError,
|
||||||
|
"but there is already an existing job"):
|
||||||
|
self.getDatasetOutput(ds)
|
||||||
|
|
||||||
@combinations.generate(test_base.eager_only_combinations())
|
@combinations.generate(test_base.eager_only_combinations())
|
||||||
def testFromDatasetId(self):
|
def testFromDatasetId(self):
|
||||||
cluster = self.create_cluster(num_workers=1)
|
cluster = self.create_cluster(num_workers=1)
|
||||||
|
@ -88,7 +88,8 @@ class DispatchServer(object):
|
|||||||
|
|
||||||
>>> dispatcher = tf.data.experimental.service.DispatchServer()
|
>>> dispatcher = tf.data.experimental.service.DispatchServer()
|
||||||
>>> dispatcher_address = dispatcher.target.split("://")[1]
|
>>> dispatcher_address = dispatcher.target.split("://")[1]
|
||||||
>>> worker = tf.data.experimental.service.WorkerServer(WorkerConfig(
|
>>> worker = tf.data.experimental.service.WorkerServer(
|
||||||
|
... tf.data.experimental.service.WorkerConfig(
|
||||||
... dispatcher_address=dispatcher_address))
|
... dispatcher_address=dispatcher_address))
|
||||||
>>> dataset = tf.data.Dataset.range(10)
|
>>> dataset = tf.data.Dataset.range(10)
|
||||||
>>> dataset = dataset.apply(tf.data.experimental.service.distribute(
|
>>> dataset = dataset.apply(tf.data.experimental.service.distribute(
|
||||||
|
@ -169,7 +169,7 @@ class _WorkerContext(object):
|
|||||||
def _get_master_target(self):
|
def _get_master_target(self):
|
||||||
"""Return the master target for a task."""
|
"""Return the master target for a task."""
|
||||||
# If cluster_spec is None or empty, we use local master.
|
# If cluster_spec is None or empty, we use local master.
|
||||||
if not self._cluster_spec:
|
if not self._cluster_spec or self._task_type == _TaskType.EVALUATOR:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
# If task_type is None, then it is in-graph replicated training. In this
|
# If task_type is None, then it is in-graph replicated training. In this
|
||||||
@ -842,7 +842,8 @@ def run_distribute_coordinator(worker_fn,
|
|||||||
session_config, cluster_spec,
|
session_config, cluster_spec,
|
||||||
task_type, task_id)
|
task_type, task_id)
|
||||||
|
|
||||||
if not getattr(strategy.extended, "_std_server_started", False):
|
if (task_type != _TaskType.EVALUATOR and
|
||||||
|
not getattr(strategy.extended, "_std_server_started", False)):
|
||||||
# Right now, with eager mode, context is configured with a std server at
|
# Right now, with eager mode, context is configured with a std server at
|
||||||
# the very beginning while with graph mode the std server is started when
|
# the very beginning while with graph mode the std server is started when
|
||||||
# distribute coordinator is called. We should consolidate these two paths.
|
# distribute coordinator is called. We should consolidate these two paths.
|
||||||
|
@ -589,8 +589,7 @@ class DistributeCoordinatorTestStandaloneMode(DistributeCoordinatorTestBase):
|
|||||||
# and distributed_mode.
|
# and distributed_mode.
|
||||||
self.assertEqual(self._worker_context["None"][0], (_strip_protocol(
|
self.assertEqual(self._worker_context["None"][0], (_strip_protocol(
|
||||||
_bytes_to_str(self._workers[0].target)), 3, True, True))
|
_bytes_to_str(self._workers[0].target)), 3, True, True))
|
||||||
self.assertEqual(self._worker_context[EVALUATOR][0],
|
self.assertEqual(self._worker_context[EVALUATOR][0], ("", 3, True, False))
|
||||||
("fake_evaluator", 3, True, False))
|
|
||||||
|
|
||||||
|
|
||||||
class DistributeCoordinatorTestIndependentWorkerMode(
|
class DistributeCoordinatorTestIndependentWorkerMode(
|
||||||
@ -755,19 +754,15 @@ class DistributeCoordinatorTestIndependentWorkerMode(
|
|||||||
# and distributed_mode.
|
# and distributed_mode.
|
||||||
self.assertEqual(self._worker_context["None"][0],
|
self.assertEqual(self._worker_context["None"][0],
|
||||||
(_bytes_to_str(cluster_spec[WORKER][0]), 3, True, True))
|
(_bytes_to_str(cluster_spec[WORKER][0]), 3, True, True))
|
||||||
self.assertEqual(self._worker_context[EVALUATOR][0],
|
self.assertEqual(self._worker_context[EVALUATOR][0], ("", 3, True, False))
|
||||||
(cluster_spec[EVALUATOR][0], 3, True, False))
|
|
||||||
|
|
||||||
# Make sure each worker runs a std server.
|
# Make sure each worker runs a std server.
|
||||||
self.assertEqual(len(self._std_servers), 2)
|
self.assertEqual(len(self._std_servers), 1)
|
||||||
self.assertTrue(WORKER in self._std_servers)
|
self.assertTrue(WORKER in self._std_servers)
|
||||||
self.assertTrue(EVALUATOR in self._std_servers)
|
|
||||||
self.assertEqual(len(self._std_servers[WORKER]), 3)
|
self.assertEqual(len(self._std_servers[WORKER]), 3)
|
||||||
self.assertEqual(len(self._std_servers[EVALUATOR]), 1)
|
|
||||||
self.assertFalse(self._std_servers[WORKER][0].joined)
|
self.assertFalse(self._std_servers[WORKER][0].joined)
|
||||||
self.assertTrue(self._std_servers[WORKER][1].joined)
|
self.assertTrue(self._std_servers[WORKER][1].joined)
|
||||||
self.assertTrue(self._std_servers[WORKER][2].joined)
|
self.assertTrue(self._std_servers[WORKER][2].joined)
|
||||||
self.assertFalse(self._std_servers[EVALUATOR][0].joined)
|
|
||||||
|
|
||||||
def testRunStdServerInGoogleEnvironment(self):
|
def testRunStdServerInGoogleEnvironment(self):
|
||||||
cluster_spec = {"worker": ["fake_worker"], "ps": ["localhost:0"]}
|
cluster_spec = {"worker": ["fake_worker"], "ps": ["localhost:0"]}
|
||||||
|
@ -1007,6 +1007,7 @@ class GradientTape(object):
|
|||||||
Raises:
|
Raises:
|
||||||
RuntimeError: If called on a used, non-persistent tape.
|
RuntimeError: If called on a used, non-persistent tape.
|
||||||
RuntimeError: If called inside the context of the tape.
|
RuntimeError: If called inside the context of the tape.
|
||||||
|
TypeError: If the target is a None object.
|
||||||
ValueError: If the target is a variable or if unconnected gradients is
|
ValueError: If the target is a variable or if unconnected gradients is
|
||||||
called with an unknown value.
|
called with an unknown value.
|
||||||
"""
|
"""
|
||||||
@ -1028,6 +1029,11 @@ class GradientTape(object):
|
|||||||
"gradient in order to compute higher order "
|
"gradient in order to compute higher order "
|
||||||
"derivatives.", 1)
|
"derivatives.", 1)
|
||||||
|
|
||||||
|
if target is None:
|
||||||
|
raise TypeError("Target should be a list or nested structure"
|
||||||
|
" of Tensors or Variables to be differentiated,"
|
||||||
|
" but recieved %r" % (target))
|
||||||
|
|
||||||
num_ndarrays = 0
|
num_ndarrays = 0
|
||||||
flat_targets = []
|
flat_targets = []
|
||||||
for t in nest.flatten(target):
|
for t in nest.flatten(target):
|
||||||
|
@ -1000,38 +1000,6 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase):
|
|||||||
self.assertAllEqual(obj2.testDouble.experimental_get_tracing_count(), 3)
|
self.assertAllEqual(obj2.testDouble.experimental_get_tracing_count(), 3)
|
||||||
self.assertAllEqual(obj1.testDouble.experimental_get_tracing_count(), 2)
|
self.assertAllEqual(obj1.testDouble.experimental_get_tracing_count(), 2)
|
||||||
|
|
||||||
def test_experimental_get_tracing_count_function(self):
|
|
||||||
|
|
||||||
@def_function.function
|
|
||||||
def double(a):
|
|
||||||
return a + a
|
|
||||||
|
|
||||||
double(constant_op.constant(1))
|
|
||||||
double(constant_op.constant(2))
|
|
||||||
self.assertAllEqual(double.experimental_get_tracing_count(), 1)
|
|
||||||
double(constant_op.constant('a'))
|
|
||||||
self.assertAllEqual(double.experimental_get_tracing_count(), 2)
|
|
||||||
|
|
||||||
def test_experimental_get_tracing_count_method(self):
|
|
||||||
|
|
||||||
class TestClass():
|
|
||||||
|
|
||||||
@def_function.function
|
|
||||||
def testDouble(self, a):
|
|
||||||
return a + a
|
|
||||||
|
|
||||||
obj1 = TestClass()
|
|
||||||
obj1.testDouble(constant_op.constant(1))
|
|
||||||
obj1.testDouble(constant_op.constant(2))
|
|
||||||
obj1.testDouble(constant_op.constant(1.1))
|
|
||||||
self.assertAllEqual(obj1.testDouble.experimental_get_tracing_count(), 2)
|
|
||||||
obj2 = TestClass()
|
|
||||||
obj2.testDouble(constant_op.constant(1))
|
|
||||||
obj2.testDouble(constant_op.constant(1.1))
|
|
||||||
obj2.testDouble(constant_op.constant('a'))
|
|
||||||
self.assertAllEqual(obj2.testDouble.experimental_get_tracing_count(), 3)
|
|
||||||
self.assertAllEqual(obj1.testDouble.experimental_get_tracing_count(), 2)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
ops.enable_eager_execution()
|
ops.enable_eager_execution()
|
||||||
|
@ -2047,6 +2047,10 @@ class TensorFlowTestCase(googletest.TestCase):
|
|||||||
self._tempdir = None
|
self._tempdir = None
|
||||||
self._cached_session = None
|
self._cached_session = None
|
||||||
self._test_start_time = None
|
self._test_start_time = None
|
||||||
|
# This flag provides the ability to control whether the graph mode gets
|
||||||
|
# initialized for TF1 or not. Initializing for TF1, which is what was
|
||||||
|
# happening earlier, was preventing enablement of 'eager mode' in the test.
|
||||||
|
self._set_default_seed = True
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super(TensorFlowTestCase, self).setUp()
|
super(TensorFlowTestCase, self).setUp()
|
||||||
@ -2061,7 +2065,8 @@ class TensorFlowTestCase(googletest.TestCase):
|
|||||||
# cleared first.
|
# cleared first.
|
||||||
ops._default_graph_stack.reset() # pylint: disable=protected-access
|
ops._default_graph_stack.reset() # pylint: disable=protected-access
|
||||||
ops.reset_default_graph()
|
ops.reset_default_graph()
|
||||||
random_seed.set_random_seed(random_seed.DEFAULT_GRAPH_SEED)
|
if self._set_default_seed:
|
||||||
|
random_seed.set_random_seed(random_seed.DEFAULT_GRAPH_SEED)
|
||||||
# Reset summary writer in case another test used set_as_default() with their
|
# Reset summary writer in case another test used set_as_default() with their
|
||||||
# summary writer.
|
# summary writer.
|
||||||
summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access
|
summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access
|
||||||
|
@ -18,68 +18,73 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import os
|
|
||||||
|
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
|
|
||||||
from tensorflow.python import tf2
|
from tensorflow.python import tf2
|
||||||
|
from tensorflow.python.compat import v2_compat
|
||||||
|
from tensorflow.python.data.kernel_tests import test_base
|
||||||
from tensorflow.python.framework import combinations
|
from tensorflow.python.framework import combinations
|
||||||
|
from tensorflow.python.platform import _pywrap_tf2
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
def set_environ():
|
|
||||||
os.environ['TF2_BEHAVIOR'] = '1'
|
|
||||||
|
|
||||||
|
|
||||||
def unset_environ():
|
|
||||||
os.environ['TF2_BEHAVIOR'] = '0'
|
|
||||||
|
|
||||||
|
|
||||||
class EnablingTF2Behavior(test.TestCase, parameterized.TestCase):
|
class EnablingTF2Behavior(test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def __init__(self, methodName):
|
||||||
super(EnablingTF2Behavior, self).setUp()
|
super().__init__(methodName)
|
||||||
tf2._force_enable = None
|
self._set_default_seed = False
|
||||||
if 'TF2_BEHAVIOR' in os.environ:
|
|
||||||
del os.environ['TF2_BEHAVIOR']
|
|
||||||
|
|
||||||
actions = [tf2.enable, tf2.disable, set_environ, unset_environ]
|
@combinations.generate(test_base.v1_only_combinations())
|
||||||
|
def test_tf1_enable_tf2_behaviour(self):
|
||||||
|
self.assertFalse(tf2.enabled())
|
||||||
|
self.assertFalse(_pywrap_tf2.is_enabled())
|
||||||
|
|
||||||
@combinations.generate(
|
v2_compat.enable_v2_behavior()
|
||||||
combinations.combine(
|
self.assertTrue(tf2.enabled())
|
||||||
action_0=actions, action_1=actions,
|
self.assertTrue(_pywrap_tf2.is_enabled())
|
||||||
action_2=actions, action_3=actions))
|
|
||||||
def test_scenarios(self, action_0, action_1, action_2, action_3):
|
|
||||||
|
|
||||||
def state(action, enabled, disabled):
|
v2_compat.disable_v2_behavior()
|
||||||
"""Returns bool tuple (tf2_enabled, force_enabled, force_disabled)."""
|
self.assertFalse(tf2.enabled())
|
||||||
if action is tf2.enable:
|
self.assertFalse(_pywrap_tf2.is_enabled())
|
||||||
return True, True, False
|
|
||||||
elif action is tf2.disable:
|
|
||||||
return False, False, True
|
|
||||||
elif action is set_environ:
|
|
||||||
return not disabled, enabled, disabled
|
|
||||||
elif action is unset_environ:
|
|
||||||
return enabled, enabled, disabled
|
|
||||||
else:
|
|
||||||
raise ValueError('Unexpected action {}. {} are supported'.format(
|
|
||||||
action, EnablingTF2Behavior.actions))
|
|
||||||
|
|
||||||
action_0()
|
@combinations.generate(test_base.v1_only_combinations())
|
||||||
expected, enabled, disabled = state(action_0, False, False)
|
def test_tf1_disable_tf2_behaviour(self):
|
||||||
self.assertEqual(tf2.enabled(), expected)
|
self.assertFalse(tf2.enabled())
|
||||||
|
self.assertFalse(_pywrap_tf2.is_enabled())
|
||||||
|
|
||||||
action_1()
|
v2_compat.disable_v2_behavior()
|
||||||
expected, enabled, disabled = state(action_1, enabled, disabled)
|
self.assertFalse(tf2.enabled())
|
||||||
self.assertEqual(tf2.enabled(), expected)
|
self.assertFalse(_pywrap_tf2.is_enabled())
|
||||||
|
|
||||||
action_2()
|
v2_compat.enable_v2_behavior()
|
||||||
expected, enabled, disabled = state(action_2, enabled, disabled)
|
self.assertTrue(tf2.enabled())
|
||||||
self.assertEqual(tf2.enabled(), expected)
|
self.assertTrue(_pywrap_tf2.is_enabled())
|
||||||
|
|
||||||
action_3()
|
@combinations.generate(test_base.v2_only_combinations())
|
||||||
expected, enabled, disabled = state(action_3, enabled, disabled)
|
def test_tf2_enable_tf2_behaviour(self):
|
||||||
self.assertEqual(tf2.enabled(), expected)
|
self.assertTrue(tf2.enabled())
|
||||||
|
self.assertTrue(_pywrap_tf2.is_enabled())
|
||||||
|
|
||||||
|
v2_compat.enable_v2_behavior()
|
||||||
|
self.assertTrue(tf2.enabled())
|
||||||
|
self.assertTrue(_pywrap_tf2.is_enabled())
|
||||||
|
|
||||||
|
v2_compat.disable_v2_behavior()
|
||||||
|
self.assertFalse(tf2.enabled())
|
||||||
|
self.assertFalse(_pywrap_tf2.is_enabled())
|
||||||
|
|
||||||
|
@combinations.generate(test_base.v2_only_combinations())
|
||||||
|
def test_tf2_disable_tf2_behaviour(self):
|
||||||
|
self.assertTrue(tf2.enabled())
|
||||||
|
self.assertTrue(_pywrap_tf2.is_enabled())
|
||||||
|
|
||||||
|
v2_compat.disable_v2_behavior()
|
||||||
|
self.assertFalse(tf2.enabled())
|
||||||
|
self.assertFalse(_pywrap_tf2.is_enabled())
|
||||||
|
|
||||||
|
v2_compat.enable_v2_behavior()
|
||||||
|
self.assertTrue(tf2.enabled())
|
||||||
|
self.assertTrue(_pywrap_tf2.is_enabled())
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -41,6 +41,7 @@ def index_directory(directory,
|
|||||||
directory: The target directory (string).
|
directory: The target directory (string).
|
||||||
labels: Either "inferred"
|
labels: Either "inferred"
|
||||||
(labels are generated from the directory structure),
|
(labels are generated from the directory structure),
|
||||||
|
None (no labels),
|
||||||
or a list/tuple of integer labels of the same size as the number of
|
or a list/tuple of integer labels of the same size as the number of
|
||||||
valid files found in the directory. Labels should be sorted according
|
valid files found in the directory. Labels should be sorted according
|
||||||
to the alphanumeric order of the image file paths
|
to the alphanumeric order of the image file paths
|
||||||
@ -61,19 +62,24 @@ def index_directory(directory,
|
|||||||
labels: list of matching integer labels (same length as file_paths)
|
labels: list of matching integer labels (same length as file_paths)
|
||||||
class_names: names of the classes corresponding to these labels, in order.
|
class_names: names of the classes corresponding to these labels, in order.
|
||||||
"""
|
"""
|
||||||
inferred_class_names = []
|
if labels is None:
|
||||||
for subdir in sorted(os.listdir(directory)):
|
# in the no-label case, index from the parent directory down.
|
||||||
if os.path.isdir(os.path.join(directory, subdir)):
|
subdirs = ['']
|
||||||
inferred_class_names.append(subdir)
|
class_names = subdirs
|
||||||
if not class_names:
|
|
||||||
class_names = inferred_class_names
|
|
||||||
else:
|
else:
|
||||||
if set(class_names) != set(inferred_class_names):
|
subdirs = []
|
||||||
raise ValueError(
|
for subdir in sorted(os.listdir(directory)):
|
||||||
'The `class_names` passed did not match the '
|
if os.path.isdir(os.path.join(directory, subdir)):
|
||||||
'names of the subdirectories of the target directory. '
|
subdirs.append(subdir)
|
||||||
'Expected: %s, but received: %s' %
|
if not class_names:
|
||||||
(inferred_class_names, class_names))
|
class_names = subdirs
|
||||||
|
else:
|
||||||
|
if set(class_names) != set(subdirs):
|
||||||
|
raise ValueError(
|
||||||
|
'The `class_names` passed did not match the '
|
||||||
|
'names of the subdirectories of the target directory. '
|
||||||
|
'Expected: %s, but received: %s' %
|
||||||
|
(subdirs, class_names))
|
||||||
class_indices = dict(zip(class_names, range(len(class_names))))
|
class_indices = dict(zip(class_names, range(len(class_names))))
|
||||||
|
|
||||||
# Build an index of the files
|
# Build an index of the files
|
||||||
@ -81,7 +87,8 @@ def index_directory(directory,
|
|||||||
pool = multiprocessing.pool.ThreadPool()
|
pool = multiprocessing.pool.ThreadPool()
|
||||||
results = []
|
results = []
|
||||||
filenames = []
|
filenames = []
|
||||||
for dirpath in (os.path.join(directory, subdir) for subdir in class_names):
|
|
||||||
|
for dirpath in (os.path.join(directory, subdir) for subdir in subdirs):
|
||||||
results.append(
|
results.append(
|
||||||
pool.apply_async(index_subdirectory,
|
pool.apply_async(index_subdirectory,
|
||||||
(dirpath, class_indices, follow_links, formats)))
|
(dirpath, class_indices, follow_links, formats)))
|
||||||
@ -90,7 +97,7 @@ def index_directory(directory,
|
|||||||
partial_filenames, partial_labels = res.get()
|
partial_filenames, partial_labels = res.get()
|
||||||
labels_list.append(partial_labels)
|
labels_list.append(partial_labels)
|
||||||
filenames += partial_filenames
|
filenames += partial_filenames
|
||||||
if labels != 'inferred':
|
if labels not in ('inferred', None):
|
||||||
if len(labels) != len(filenames):
|
if len(labels) != len(filenames):
|
||||||
raise ValueError('Expected the lengths of `labels` to match the number '
|
raise ValueError('Expected the lengths of `labels` to match the number '
|
||||||
'of files in the target directory. len(labels) is %s '
|
'of files in the target directory. len(labels) is %s '
|
||||||
@ -103,8 +110,11 @@ def index_directory(directory,
|
|||||||
labels[i:i + len(partial_labels)] = partial_labels
|
labels[i:i + len(partial_labels)] = partial_labels
|
||||||
i += len(partial_labels)
|
i += len(partial_labels)
|
||||||
|
|
||||||
print('Found %d files belonging to %d classes.' %
|
if labels is None:
|
||||||
(len(filenames), len(class_names)))
|
print('Found %d files.' % (len(filenames),))
|
||||||
|
else:
|
||||||
|
print('Found %d files belonging to %d classes.' %
|
||||||
|
(len(filenames), len(class_names)))
|
||||||
pool.close()
|
pool.close()
|
||||||
pool.join()
|
pool.join()
|
||||||
file_paths = [os.path.join(directory, fname) for fname in filenames]
|
file_paths = [os.path.join(directory, fname) for fname in filenames]
|
||||||
|
@ -74,6 +74,7 @@ def image_dataset_from_directory(directory,
|
|||||||
Otherwise, the directory structure is ignored.
|
Otherwise, the directory structure is ignored.
|
||||||
labels: Either "inferred"
|
labels: Either "inferred"
|
||||||
(labels are generated from the directory structure),
|
(labels are generated from the directory structure),
|
||||||
|
None (no labels),
|
||||||
or a list/tuple of integer labels of the same size as the number of
|
or a list/tuple of integer labels of the same size as the number of
|
||||||
image files found in the directory. Labels should be sorted according
|
image files found in the directory. Labels should be sorted according
|
||||||
to the alphanumeric order of the image file paths
|
to the alphanumeric order of the image file paths
|
||||||
@ -139,7 +140,7 @@ def image_dataset_from_directory(directory,
|
|||||||
- if `color_mode` is `rgba`,
|
- if `color_mode` is `rgba`,
|
||||||
there are 4 channel in the image tensors.
|
there are 4 channel in the image tensors.
|
||||||
"""
|
"""
|
||||||
if labels != 'inferred':
|
if labels not in ('inferred', None):
|
||||||
if not isinstance(labels, (list, tuple)):
|
if not isinstance(labels, (list, tuple)):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'`labels` argument should be a list/tuple of integer labels, of '
|
'`labels` argument should be a list/tuple of integer labels, of '
|
||||||
@ -156,6 +157,9 @@ def image_dataset_from_directory(directory,
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
'`label_mode` argument must be one of "int", "categorical", "binary", '
|
'`label_mode` argument must be one of "int", "categorical", "binary", '
|
||||||
'or None. Received: %s' % (label_mode,))
|
'or None. Received: %s' % (label_mode,))
|
||||||
|
if labels is None or label_mode is None:
|
||||||
|
labels = None
|
||||||
|
label_mode = None
|
||||||
if color_mode == 'rgb':
|
if color_mode == 'rgb':
|
||||||
num_channels = 3
|
num_channels = 3
|
||||||
elif color_mode == 'rgba':
|
elif color_mode == 'rgba':
|
||||||
@ -188,6 +192,8 @@ def image_dataset_from_directory(directory,
|
|||||||
|
|
||||||
image_paths, labels = dataset_utils.get_training_or_validation_split(
|
image_paths, labels = dataset_utils.get_training_or_validation_split(
|
||||||
image_paths, labels, validation_split, subset)
|
image_paths, labels, validation_split, subset)
|
||||||
|
if not image_paths:
|
||||||
|
raise ValueError('No images found.')
|
||||||
|
|
||||||
dataset = paths_and_labels_to_dataset(
|
dataset = paths_and_labels_to_dataset(
|
||||||
image_paths=image_paths,
|
image_paths=image_paths,
|
||||||
|
@ -82,7 +82,7 @@ class ImageDatasetFromDirectoryTest(keras_parameterized.TestCase):
|
|||||||
# Save images to the paths
|
# Save images to the paths
|
||||||
i = 0
|
i = 0
|
||||||
for img in self._get_images(color_mode=color_mode, count=count):
|
for img in self._get_images(color_mode=color_mode, count=count):
|
||||||
path = paths[count % len(paths)]
|
path = paths[i % len(paths)]
|
||||||
if color_mode == 'rgb':
|
if color_mode == 'rgb':
|
||||||
ext = 'jpg'
|
ext = 'jpg'
|
||||||
else:
|
else:
|
||||||
@ -92,6 +92,32 @@ class ImageDatasetFromDirectoryTest(keras_parameterized.TestCase):
|
|||||||
i += 1
|
i += 1
|
||||||
return temp_dir
|
return temp_dir
|
||||||
|
|
||||||
|
def test_image_dataset_from_directory_standalone(self):
|
||||||
|
# Test retrieving images without labels from a directory and its subdirs.
|
||||||
|
if PIL is None:
|
||||||
|
return # Skip test if PIL is not available.
|
||||||
|
|
||||||
|
# Save a few extra images in the parent directory.
|
||||||
|
directory = self._prepare_directory(count=7, num_classes=2)
|
||||||
|
for i, img in enumerate(self._get_images(3)):
|
||||||
|
filename = 'image_%s.jpg' % (i,)
|
||||||
|
img.save(os.path.join(directory, filename))
|
||||||
|
|
||||||
|
dataset = image_dataset.image_dataset_from_directory(
|
||||||
|
directory, batch_size=5, image_size=(18, 18), labels=None)
|
||||||
|
batch = next(iter(dataset))
|
||||||
|
# We return plain images
|
||||||
|
self.assertEqual(batch.shape, (5, 18, 18, 3))
|
||||||
|
self.assertEqual(batch.dtype.name, 'float32')
|
||||||
|
# Count samples
|
||||||
|
batch_count = 0
|
||||||
|
sample_count = 0
|
||||||
|
for batch in dataset:
|
||||||
|
batch_count += 1
|
||||||
|
sample_count += batch.shape[0]
|
||||||
|
self.assertEqual(batch_count, 2)
|
||||||
|
self.assertEqual(sample_count, 10)
|
||||||
|
|
||||||
def test_image_dataset_from_directory_binary(self):
|
def test_image_dataset_from_directory_binary(self):
|
||||||
if PIL is None:
|
if PIL is None:
|
||||||
return # Skip test if PIL is not available.
|
return # Skip test if PIL is not available.
|
||||||
@ -253,6 +279,11 @@ class ImageDatasetFromDirectoryTest(keras_parameterized.TestCase):
|
|||||||
sample_count += batch.shape[0]
|
sample_count += batch.shape[0]
|
||||||
self.assertEqual(sample_count, 25)
|
self.assertEqual(sample_count, 25)
|
||||||
|
|
||||||
|
def test_image_dataset_from_directory_no_images(self):
|
||||||
|
directory = self._prepare_directory(num_classes=2, count=0)
|
||||||
|
with self.assertRaisesRegex(ValueError, 'No images found.'):
|
||||||
|
_ = image_dataset.image_dataset_from_directory(directory)
|
||||||
|
|
||||||
def test_image_dataset_from_directory_errors(self):
|
def test_image_dataset_from_directory_errors(self):
|
||||||
if PIL is None:
|
if PIL is None:
|
||||||
return # Skip test if PIL is not available.
|
return # Skip test if PIL is not available.
|
||||||
@ -261,7 +292,7 @@ class ImageDatasetFromDirectoryTest(keras_parameterized.TestCase):
|
|||||||
|
|
||||||
with self.assertRaisesRegex(ValueError, '`labels` argument should be'):
|
with self.assertRaisesRegex(ValueError, '`labels` argument should be'):
|
||||||
_ = image_dataset.image_dataset_from_directory(
|
_ = image_dataset.image_dataset_from_directory(
|
||||||
directory, labels=None)
|
directory, labels='other')
|
||||||
|
|
||||||
with self.assertRaisesRegex(ValueError, '`label_mode` argument must be'):
|
with self.assertRaisesRegex(ValueError, '`label_mode` argument must be'):
|
||||||
_ = image_dataset.image_dataset_from_directory(
|
_ = image_dataset.image_dataset_from_directory(
|
||||||
|
@ -66,6 +66,7 @@ def text_dataset_from_directory(directory,
|
|||||||
Otherwise, the directory structure is ignored.
|
Otherwise, the directory structure is ignored.
|
||||||
labels: Either "inferred"
|
labels: Either "inferred"
|
||||||
(labels are generated from the directory structure),
|
(labels are generated from the directory structure),
|
||||||
|
None (no labels),
|
||||||
or a list/tuple of integer labels of the same size as the number of
|
or a list/tuple of integer labels of the same size as the number of
|
||||||
text files found in the directory. Labels should be sorted according
|
text files found in the directory. Labels should be sorted according
|
||||||
to the alphanumeric order of the text file paths
|
to the alphanumeric order of the text file paths
|
||||||
@ -114,7 +115,7 @@ def text_dataset_from_directory(directory,
|
|||||||
of shape `(batch_size, num_classes)`, representing a one-hot
|
of shape `(batch_size, num_classes)`, representing a one-hot
|
||||||
encoding of the class index.
|
encoding of the class index.
|
||||||
"""
|
"""
|
||||||
if labels != 'inferred':
|
if labels not in ('inferred', None):
|
||||||
if not isinstance(labels, (list, tuple)):
|
if not isinstance(labels, (list, tuple)):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'`labels` argument should be a list/tuple of integer labels, of '
|
'`labels` argument should be a list/tuple of integer labels, of '
|
||||||
@ -131,6 +132,9 @@ def text_dataset_from_directory(directory,
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
'`label_mode` argument must be one of "int", "categorical", "binary", '
|
'`label_mode` argument must be one of "int", "categorical", "binary", '
|
||||||
'or None. Received: %s' % (label_mode,))
|
'or None. Received: %s' % (label_mode,))
|
||||||
|
if labels is None or label_mode is None:
|
||||||
|
labels = None
|
||||||
|
label_mode = None
|
||||||
dataset_utils.check_validation_split_arg(
|
dataset_utils.check_validation_split_arg(
|
||||||
validation_split, subset, shuffle, seed)
|
validation_split, subset, shuffle, seed)
|
||||||
|
|
||||||
@ -152,6 +156,8 @@ def text_dataset_from_directory(directory,
|
|||||||
|
|
||||||
file_paths, labels = dataset_utils.get_training_or_validation_split(
|
file_paths, labels = dataset_utils.get_training_or_validation_split(
|
||||||
file_paths, labels, validation_split, subset)
|
file_paths, labels, validation_split, subset)
|
||||||
|
if not file_paths:
|
||||||
|
raise ValueError('No text files found.')
|
||||||
|
|
||||||
dataset = paths_and_labels_to_dataset(
|
dataset = paths_and_labels_to_dataset(
|
||||||
file_paths=file_paths,
|
file_paths=file_paths,
|
||||||
|
@ -58,7 +58,7 @@ class TextDatasetFromDirectoryTest(keras_parameterized.TestCase):
|
|||||||
paths += class_paths
|
paths += class_paths
|
||||||
|
|
||||||
for i in range(count):
|
for i in range(count):
|
||||||
path = paths[count % len(paths)]
|
path = paths[i % len(paths)]
|
||||||
filename = os.path.join(path, 'text_%s.txt' % (i,))
|
filename = os.path.join(path, 'text_%s.txt' % (i,))
|
||||||
f = open(os.path.join(temp_dir, filename), 'w')
|
f = open(os.path.join(temp_dir, filename), 'w')
|
||||||
text = ''.join([random.choice(string.printable) for _ in range(length)])
|
text = ''.join([random.choice(string.printable) for _ in range(length)])
|
||||||
@ -66,6 +66,32 @@ class TextDatasetFromDirectoryTest(keras_parameterized.TestCase):
|
|||||||
f.close()
|
f.close()
|
||||||
return temp_dir
|
return temp_dir
|
||||||
|
|
||||||
|
def test_text_dataset_from_directory_standalone(self):
|
||||||
|
# Test retrieving txt files without labels from a directory and its subdirs.
|
||||||
|
# Save a few extra files in the parent directory.
|
||||||
|
directory = self._prepare_directory(count=7, num_classes=2)
|
||||||
|
for i in range(3):
|
||||||
|
filename = 'text_%s.txt' % (i,)
|
||||||
|
f = open(os.path.join(directory, filename), 'w')
|
||||||
|
text = ''.join([random.choice(string.printable) for _ in range(20)])
|
||||||
|
f.write(text)
|
||||||
|
f.close()
|
||||||
|
|
||||||
|
dataset = text_dataset.text_dataset_from_directory(
|
||||||
|
directory, batch_size=5, label_mode=None, max_length=10)
|
||||||
|
batch = next(iter(dataset))
|
||||||
|
# We just return the texts, no labels
|
||||||
|
self.assertEqual(batch.shape, (5,))
|
||||||
|
self.assertEqual(batch.dtype.name, 'string')
|
||||||
|
# Count samples
|
||||||
|
batch_count = 0
|
||||||
|
sample_count = 0
|
||||||
|
for batch in dataset:
|
||||||
|
batch_count += 1
|
||||||
|
sample_count += batch.shape[0]
|
||||||
|
self.assertEqual(batch_count, 2)
|
||||||
|
self.assertEqual(sample_count, 10)
|
||||||
|
|
||||||
def test_text_dataset_from_directory_binary(self):
|
def test_text_dataset_from_directory_binary(self):
|
||||||
directory = self._prepare_directory(num_classes=2)
|
directory = self._prepare_directory(num_classes=2)
|
||||||
dataset = text_dataset.text_dataset_from_directory(
|
dataset = text_dataset.text_dataset_from_directory(
|
||||||
@ -172,12 +198,17 @@ class TextDatasetFromDirectoryTest(keras_parameterized.TestCase):
|
|||||||
sample_count += batch.shape[0]
|
sample_count += batch.shape[0]
|
||||||
self.assertEqual(sample_count, 25)
|
self.assertEqual(sample_count, 25)
|
||||||
|
|
||||||
|
def test_text_dataset_from_directory_no_files(self):
|
||||||
|
directory = self._prepare_directory(num_classes=2, count=0)
|
||||||
|
with self.assertRaisesRegex(ValueError, 'No text files found.'):
|
||||||
|
_ = text_dataset.text_dataset_from_directory(directory)
|
||||||
|
|
||||||
def test_text_dataset_from_directory_errors(self):
|
def test_text_dataset_from_directory_errors(self):
|
||||||
directory = self._prepare_directory(num_classes=3, count=5)
|
directory = self._prepare_directory(num_classes=3, count=5)
|
||||||
|
|
||||||
with self.assertRaisesRegex(ValueError, '`labels` argument should be'):
|
with self.assertRaisesRegex(ValueError, '`labels` argument should be'):
|
||||||
_ = text_dataset.text_dataset_from_directory(
|
_ = text_dataset.text_dataset_from_directory(
|
||||||
directory, labels=None)
|
directory, labels='other')
|
||||||
|
|
||||||
with self.assertRaisesRegex(ValueError, '`label_mode` argument must be'):
|
with self.assertRaisesRegex(ValueError, '`label_mode` argument must be'):
|
||||||
_ = text_dataset.text_dataset_from_directory(
|
_ = text_dataset.text_dataset_from_directory(
|
||||||
|
@ -3246,10 +3246,12 @@ cuda_py_test(
|
|||||||
srcs = ["extract_volume_patches_grad_test.py"],
|
srcs = ["extract_volume_patches_grad_test.py"],
|
||||||
shard_count = 50,
|
shard_count = 50,
|
||||||
tags = [
|
tags = [
|
||||||
|
"no_gpu", # b/171837334
|
||||||
|
"no_oss", # Test times out on oss-nightly cpu builds
|
||||||
"no_pip",
|
"no_pip",
|
||||||
"nogpu", # http://b/171837334
|
"nogpu", # b/171837334
|
||||||
"nomac", # http://b/139946976
|
"nomac", # b/139946976
|
||||||
"notap", # http://b/31080670
|
"notap", # b/31080670
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
|
@ -1557,6 +1557,14 @@ class AssertTypeTest(test.TestCase):
|
|||||||
with self.assertRaisesRegexp(TypeError, "must be of type.*float32"):
|
with self.assertRaisesRegexp(TypeError, "must be of type.*float32"):
|
||||||
check_ops.assert_type(sparse_float16, dtypes.float32)
|
check_ops.assert_type(sparse_float16, dtypes.float32)
|
||||||
|
|
||||||
|
def test_raise_when_tf_type_is_not_dtype(self):
|
||||||
|
# Test case for GitHub issue:
|
||||||
|
# https://github.com/tensorflow/tensorflow/issues/45975
|
||||||
|
value = constant_op.constant(0.0)
|
||||||
|
with self.assertRaisesRegexp(TypeError,
|
||||||
|
"Cannot convert.*to a TensorFlow DType"):
|
||||||
|
check_ops.assert_type(value, (dtypes.float32,))
|
||||||
|
|
||||||
|
|
||||||
class AssertShapesTest(test.TestCase):
|
class AssertShapesTest(test.TestCase):
|
||||||
|
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user