Merge branch 'master' of github.com:tensorflow/tensorflow
* 'master' of github.com:tensorflow/tensorflow: (41 commits) [XLA:CPU] Fuse reduce-window Minor docstring change: delete "normal". [XLA][MLIR] Lower complex operations to std dialect. Internal BUILD file change. compat: Update forward compatibility horizon to 2020-05-08 Update GraphDef version to 395. Add a kernel generator tool. Add i386 to ObjC podspec Go: Update generated wrapper functions for TensorFlow ops. Enable Keras/RNN model via conversion routes that accept a keras model. Go: Update generated wrapper functions for TensorFlow ops. Add op sanity checks to the following TFLite ops: Internal change Fix enable_v2_dtype_behavior() doctest failure. CalculateOutputShape for concatenation of BHWDC tensors. Try https and see if that works. Make regularizers API more consistent. Integrate LLVM at https://github.com/llvm/llvm-project/commit/910871532101 Don't concatenate empty tensors [Executor] Optimize `PropagatorState::FindOrCreateChildFrame()`. ...
This commit is contained in:
commit
08bc507e4c
@ -247,7 +247,14 @@ class TFL_TFTypesWithSameBits<int i, int j, int num> :
|
||||
Or<[CPred<"getElementTypeOrSelf($_op.getOperand(" # j # ")).isa<mlir::TF::Quint" # num # "Type>()">,
|
||||
CPred<"getElementTypeOrSelf($_op.getOperand(" # j # ")).isUnsignedInteger(" # num # ")">]>]>;
|
||||
|
||||
class TFL_OperandIsNoneOrHasRankLessThanOrEqualTo<int n, int m> :
|
||||
class TFL_TFOperandTypesWithSameBits<int i, int j, int num> :
|
||||
And<[
|
||||
Or<[CPred<"getElementTypeOrSelf($_op.getOperand(" # i # ")).isa<mlir::TF::Quint" # num # "Type>()">,
|
||||
CPred<"getElementTypeOrSelf($_op.getOperand(" # i # ")).isUnsignedInteger(" # num # ")">]>,
|
||||
Or<[CPred<"getElementTypeOrSelf($_op.getOperand(" # j # ")).isa<mlir::TF::Quint" # num # "Type>()">,
|
||||
CPred<"getElementTypeOrSelf($_op.getOperand(" # j # ")).isUnsignedInteger(" # num # ")">]>]>;
|
||||
|
||||
class TFL_OperandIsNoneOrHasRankAtMost<int n, int m> :
|
||||
PredOpTrait<"operand " # n # " is at most " # m # "-D",
|
||||
Or<[
|
||||
CPred<"$_op.getOperand(" # n # ").getType().isa<NoneType>()">,
|
||||
@ -255,13 +262,13 @@ class TFL_OperandIsNoneOrHasRankLessThanOrEqualTo<int n, int m> :
|
||||
CPred<"$_op.getOperand(" # n #
|
||||
").getType().cast<ShapedType>().getRank() <= " # m>]>>;
|
||||
|
||||
class TFL_OperandHasRankLessThanOrEqualTo<int n, int m> :
|
||||
class TFL_OperandHasRankAtMost<int n, int m> :
|
||||
PredOpTrait<"operand " # n # " is at most " # m # "-D",
|
||||
Or<[TFL_OperandIsUnrankedPred<n>,
|
||||
CPred<"$_op.getOperand(" # n #
|
||||
").getType().cast<ShapedType>().getRank() <= " # m>]>>;
|
||||
|
||||
class TFL_OperandHasRankGreaterThanOrEqualTo<int n, int m> :
|
||||
class TFL_OperandHasRankAtLeast<int n, int m> :
|
||||
PredOpTrait<"operand " # n # " is at least " # m # "-D",
|
||||
Or<[TFL_OperandIsUnrankedPred<n>,
|
||||
CPred<"$_op.getOperand(" # n #
|
||||
@ -300,6 +307,18 @@ class TFL_TCresVTEtIsSameAsOp<int i, int j> : And<[
|
||||
"quant::QuantizedType::castToStorageType("
|
||||
"getElementTypeOrSelf($_op.getOperand(" # j # ")))">]>]>]>;
|
||||
|
||||
// This is a quantization-aware version of TCresVTEtIsSameAsOp
|
||||
class TFL_TCopVTEtAreSameAt<int i, int j> : Or<[
|
||||
TCopVTEtAreSameAt<[i, j]>,
|
||||
TFL_TFOperandTypesWithSameBits<i, j, 8>,
|
||||
And<[
|
||||
SubstLeaves<"$_self", "getElementTypeOrSelf($_op.getOperand(" # j # "))",
|
||||
quant_QuantizedType.predicate>,
|
||||
CPred<"quant::QuantizedType::castToStorageType("
|
||||
"getElementTypeOrSelf($_op.getOperand(" # i # "))) == "
|
||||
"quant::QuantizedType::castToStorageType("
|
||||
"getElementTypeOrSelf($_op.getOperand(" # j # ")))">]>]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TFL op common constraints.
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -963,7 +982,11 @@ def TFL_ScatterNdOp : TFL_Op<"scatter_nd", [
|
||||
|
||||
// Same type check of lhs and rhs is handled by the ResultsBroadcastableShape trait.
|
||||
def TFL_LessEqualOp : TFL_Op<"less_equal", [
|
||||
ResultsBroadcastableShape, NoSideEffect, NoQuantizableResult]> {
|
||||
ResultsBroadcastableShape,
|
||||
BinaryOpSameElementTypeConstraint,
|
||||
TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 4>,
|
||||
NoSideEffect,
|
||||
NoQuantizableResult]> {
|
||||
let summary = "Less_equal operator";
|
||||
|
||||
let description = [{
|
||||
@ -971,8 +994,8 @@ def TFL_LessEqualOp : TFL_Op<"less_equal", [
|
||||
}];
|
||||
|
||||
let arguments = (
|
||||
ins TFL_TensorOf<[F32, I32, I64, I8, QI8, QUI8, TFL_Uint8]>:$lhs,
|
||||
TFL_TensorOf<[F32, I32, I64, I8, QI8, QUI8, TFL_Uint8]>:$rhs);
|
||||
ins TFL_TensorOf<[F32, I32, I64, QI8, QUI8]>:$lhs,
|
||||
TFL_TensorOf<[F32, I32, I64, QI8, QUI8]>:$rhs);
|
||||
|
||||
let results = (outs TFL_BoolTensor:$output);
|
||||
|
||||
@ -985,9 +1008,12 @@ def TFL_LessEqualOp : TFL_Op<"less_equal", [
|
||||
let hasOptions = 0;
|
||||
}
|
||||
|
||||
def TFL_LocalResponseNormalizationOp : TFL_Op<"local_response_normalization",
|
||||
[NoSideEffect]> {
|
||||
let summary = "Local Response Normalization.";
|
||||
def TFL_LocalResponseNormalizationOp : TFL_Op<"local_response_normalization", [
|
||||
TFL_OperandHasRank<0, 4>,
|
||||
SameOperandsAndResultShape,
|
||||
SameOperandsAndResultType,
|
||||
NoSideEffect]> {
|
||||
let summary = "Local Response Normalization.";
|
||||
|
||||
let description = [{
|
||||
The 4-D `input` tensor is treated as a 3-D array of 1-D vectors (along the last
|
||||
@ -1004,7 +1030,7 @@ convolutional neural networks (NIPS 2012)](http://papers.nips.cc/paper/4824-imag
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TFL_TensorOf<[F32, QI8, QUI8]>:$input,
|
||||
TFL_FpTensor:$input,
|
||||
I32Attr:$radius,
|
||||
F32Attr:$bias,
|
||||
F32Attr:$alpha,
|
||||
@ -1012,7 +1038,7 @@ convolutional neural networks (NIPS 2012)](http://papers.nips.cc/paper/4824-imag
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TFL_TensorOf<[F32, QI8, QUI8]>:$output
|
||||
TFL_FpTensor:$output
|
||||
);
|
||||
|
||||
let hasOptions = 1;
|
||||
@ -1048,7 +1074,7 @@ def TFL_MatrixDiagOp : TFL_Op<"matrix_diag", [
|
||||
NoSideEffect,
|
||||
TFL_OperandHasAtleastRank<0, 1>,
|
||||
PredOpTrait<"operand and result must have the same element type",
|
||||
TCresVTEtIsSameAsOp<0, 0>>]> {
|
||||
TFL_TCresVTEtIsSameAsOp<0, 0>>]> {
|
||||
let summary = [{
|
||||
Returns a tensor with the provided diagonal and everything else padded with zeros.
|
||||
}];
|
||||
@ -1061,17 +1087,21 @@ def TFL_MatrixDiagOp : TFL_Op<"matrix_diag", [
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TFL_TensorOf<[F32, I8, I64, I32, TFL_Uint8]>:$diagonal
|
||||
TFL_TensorOf<[F32, I8, I16, I32, I64, TFL_Uint8, QUI8, QI8, TFL_Quint8]>:$diagonal
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TFL_TensorOf<[F32, I8, I64, I32, TFL_Uint8]>:$output
|
||||
TFL_TensorOf<[F32, I8, I16, I32, I64, TFL_Uint8, QUI8, QI8, TFL_Quint8]>:$output
|
||||
);
|
||||
|
||||
let hasOptions = 0;
|
||||
}
|
||||
|
||||
def TFL_MatrixSetDiagOp : TFL_Op<"matrix_set_diag", [NoSideEffect]> {
|
||||
def TFL_MatrixSetDiagOp : TFL_Op<"matrix_set_diag", [
|
||||
TFL_OperandHasAtleastRank<0, 2>,
|
||||
PredOpTrait<"input and result must have the same element type",
|
||||
TFL_TCresVTEtIsSameAsOp<0, 0>>,
|
||||
NoSideEffect]> {
|
||||
let summary = [{
|
||||
Returns a batched matrix tensor with new batched diagonal values.
|
||||
}];
|
||||
@ -1083,12 +1113,12 @@ innermost matrices. These will be overwritten by the values in `diagonal`.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[F32, I32, I64, I8, QI8, QI16, QUI8, TFL_Uint8, TFL_Quint8]>:$input,
|
||||
TensorOf<[F32, I32, I64, I8, QI8, QI16, QUI8, TFL_Uint8, TFL_Quint8]>:$diagonal
|
||||
TensorOf<[F32, I8, I16, I32, I64, UI8, QI8, QI16, QUI8, TFL_Quint8]>:$input,
|
||||
TensorOf<[F32, I8, I16, I32, I64, UI8, QI8, QI16, QUI8, TFL_Quint8]>:$diagonal
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[F32, I32, I64, I8, QI8, QI16, QUI8, TFL_Uint8, TFL_Quint8]>:$output
|
||||
TensorOf<[F32, I8, I16, I32, I64, UI8, QI8, QI16, QUI8, TFL_Quint8]>:$result
|
||||
);
|
||||
|
||||
let hasOptions = 0;
|
||||
@ -1206,7 +1236,12 @@ larger than 0.
|
||||
}
|
||||
|
||||
def TFL_NotEqualOp : TFL_Op<"not_equal", [
|
||||
ResultsBroadcastableShape, Commutative, NoSideEffect, NoQuantizableResult]> {
|
||||
TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 4>,
|
||||
BinaryOpSameElementTypeConstraint,
|
||||
ResultsBroadcastableShape,
|
||||
Commutative,
|
||||
NoSideEffect,
|
||||
NoQuantizableResult]> {
|
||||
let summary = "Not_equal operator";
|
||||
|
||||
let description = [{
|
||||
@ -1214,8 +1249,8 @@ def TFL_NotEqualOp : TFL_Op<"not_equal", [
|
||||
}];
|
||||
|
||||
let arguments = (
|
||||
ins AnyTensor:$lhs,
|
||||
AnyTensor:$rhs);
|
||||
ins TFL_TensorOf<[I1, F32, I32, I64, QUI8, QI8, TFL_Quint8, TFL_Str]>:$lhs,
|
||||
TFL_TensorOf<[I1, F32, I32, I64, QUI8, QI8, TFL_Quint8, TFL_Str]>:$rhs);
|
||||
|
||||
let results = (outs TFL_BoolTensor:$output);
|
||||
|
||||
@ -1284,7 +1319,7 @@ def TFL_EmbeddingLookupOp: TFL_Op<"embedding_lookup",
|
||||
PredOpTrait<"value and output must have same element type",
|
||||
TFL_TCresVTEtIsSameAsOp<0, 1>>,
|
||||
TFL_OperandHasRank<0, 1>,
|
||||
TFL_OperandHasRankGreaterThanOrEqualTo<1, 2>
|
||||
TFL_OperandHasRankAtLeast<1, 2>
|
||||
]> {
|
||||
let summary = "Embedding lookup operator";
|
||||
|
||||
@ -1502,7 +1537,11 @@ def TFL_FloorModOp : TFL_Op<"floor_mod", [
|
||||
}
|
||||
|
||||
def TFL_GreaterOp : TFL_Op<"greater", [
|
||||
ResultsBroadcastableShape, NoSideEffect, NoQuantizableResult]> {
|
||||
ResultsBroadcastableShape,
|
||||
BinaryOpSameElementTypeConstraint,
|
||||
TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 4>,
|
||||
NoSideEffect,
|
||||
NoQuantizableResult]> {
|
||||
let summary = "Greater operator";
|
||||
|
||||
let description = [{
|
||||
@ -1510,10 +1549,10 @@ def TFL_GreaterOp : TFL_Op<"greater", [
|
||||
}];
|
||||
|
||||
let arguments = (
|
||||
ins AnyTensor:$lhs,
|
||||
AnyTensor:$rhs);
|
||||
ins TFL_TensorOf<[F32, I32, I64, QUI8, QI8, TFL_Quint8]>:$lhs,
|
||||
TFL_TensorOf<[F32, I32, I64, QUI8, QI8, TFL_Quint8]>:$rhs);
|
||||
|
||||
let results = (outs AnyTensor:$output);
|
||||
let results = (outs TFL_BoolTensor:$output);
|
||||
|
||||
let builders = [TFL_ComparisonBinaryBuilder];
|
||||
|
||||
@ -1523,8 +1562,9 @@ def TFL_GreaterOp : TFL_Op<"greater", [
|
||||
}
|
||||
|
||||
def TFL_HardSwishOp: TFL_Op<"hard_swish", [NoSideEffect,
|
||||
SameOperandsAndResultShape,
|
||||
TFL_GpuTargetOp]> {
|
||||
SameOperandsAndResultShape,
|
||||
SameOperandsAndResultType,
|
||||
TFL_GpuTargetOp]> {
|
||||
let summary = "Hardswish activation function.";
|
||||
let description = [{
|
||||
Computes hard-swish activation function
|
||||
@ -1563,29 +1603,34 @@ def TFL_L2NormalizationOp : TFL_Op<"l2_normalization", [NoSideEffect,
|
||||
let customOption = "L2NormOptions";
|
||||
}
|
||||
|
||||
def TFL_LeakyReluOp: TFL_Op<"leaky_relu", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
def TFL_LeakyReluOp: TFL_Op<"leaky_relu", [
|
||||
SameOperandsAndResultShape,
|
||||
NoSideEffect,
|
||||
SameOperandsAndResultType]> {
|
||||
let summary = "Leaky Relu operator";
|
||||
|
||||
// TODO(jpienaar): Add type restriction. This op is only defined for
|
||||
// restricted (floating point) types.
|
||||
let description = [{
|
||||
Element-wise Leaky ReLU operator
|
||||
x -> x >= 0 ? x : (alpha * x)
|
||||
}];
|
||||
|
||||
let arguments = (
|
||||
ins AnyTensor:$input,
|
||||
ins TFL_TensorOf<[F32, QUI8, QI8, TFL_Quint8]>:$input,
|
||||
// Slope of the activation function at x < 0.
|
||||
F32Attr:$alpha
|
||||
);
|
||||
|
||||
let results = (outs AnyTensor:$output);
|
||||
let results = (outs TFL_TensorOf<[F32, QUI8, QI8, TFL_Quint8]>:$output);
|
||||
|
||||
let hasOptions = 0b1;
|
||||
}
|
||||
|
||||
def TFL_LessOp : TFL_Op<"less", [
|
||||
ResultsBroadcastableShape, NoSideEffect, NoQuantizableResult]> {
|
||||
ResultsBroadcastableShape,
|
||||
BinaryOpSameElementTypeConstraint,
|
||||
TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 4>,
|
||||
NoSideEffect,
|
||||
NoQuantizableResult]> {
|
||||
let summary = "Less operator";
|
||||
|
||||
let description = [{
|
||||
@ -1593,8 +1638,8 @@ def TFL_LessOp : TFL_Op<"less", [
|
||||
}];
|
||||
|
||||
let arguments = (
|
||||
ins AnyTensor:$lhs,
|
||||
AnyTensor:$rhs);
|
||||
ins TFL_TensorOf<[F32, I32, I64, QUI8, QI8, TFL_Quint8]>:$lhs,
|
||||
TFL_TensorOf<[F32, I32, I64, QUI8, QI8, TFL_Quint8]>:$rhs);
|
||||
|
||||
let results = (outs TFL_BoolTensor:$output);
|
||||
|
||||
@ -1655,6 +1700,8 @@ def TFL_LogicalOrOp : TFL_Op<"logical_or", [NoSideEffect]> {
|
||||
|
||||
def TFL_LogisticOp: TFL_Op<"logistic", [
|
||||
NoSideEffect,
|
||||
PredOpTrait<"x and y must have same element type",
|
||||
TFL_TCresVTEtIsSameAsOp<0, 0>>,
|
||||
SameOperandsAndResultShape,
|
||||
// zero_point = 0
|
||||
// scale = 1. / (max_value + 1)
|
||||
@ -1667,9 +1714,9 @@ def TFL_LogisticOp: TFL_Op<"logistic", [
|
||||
Computes element-wise Sigmoid of input
|
||||
}];
|
||||
|
||||
let arguments = (ins TFL_TensorOf<[F32, QI8, QUI8, QI16, QUI16]>:$x);
|
||||
let arguments = (ins TFL_TensorOf<[F32, QI8, QUI8, QI16, TFL_Quint8]>:$x);
|
||||
|
||||
let results = (outs TFL_TensorOf<[F32, QI8, QUI8, QI16, QUI16]>:$y);
|
||||
let results = (outs TFL_TensorOf<[F32, QI8, QUI8, QI16, TFL_Quint8]>:$y);
|
||||
}
|
||||
|
||||
def TFL_LogOp: TFL_Op<"log", [
|
||||
@ -1690,10 +1737,10 @@ def TFL_LogOp: TFL_Op<"log", [
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
// TODO(b/130643170): Adds some constraint for the input/output element types.
|
||||
def TFL_LogSoftmaxOp : TFL_Op<"log_softmax", [
|
||||
NoSideEffect,
|
||||
SameOperandsAndResultShape,
|
||||
SameOperandsAndResultType,
|
||||
// zero_point = max_value
|
||||
// scale = -log_softmax_output_min / (max_value + 1)
|
||||
FixedResultScale<Int8UniformQuantizedType<127, 625, -4>>,
|
||||
@ -1706,9 +1753,9 @@ def TFL_LogSoftmaxOp : TFL_Op<"log_softmax", [
|
||||
input - log(reduce_sum(exp(input), dim))
|
||||
}];
|
||||
|
||||
let arguments = (ins AnyTensor:$input);
|
||||
let arguments = (ins TFL_TensorOf<[F32, QUI8, QI8, TFL_Quint8]>:$input);
|
||||
|
||||
let results = (outs AnyTensor:$output);
|
||||
let results = (outs TFL_TensorOf<[F32, QUI8, QI8, TFL_Quint8]>:$output);
|
||||
|
||||
let hasOptions = 1;
|
||||
}
|
||||
@ -1727,6 +1774,9 @@ def MaxPoolOperandAndResultConstraints : PredOpTrait<"MaxPool2D operand and "
|
||||
TFL_TCresVTEtIsSameAsOp<0, 0>]>>;
|
||||
|
||||
def TFL_MaxPool2DOp : TFL_Op<"max_pool_2d", [
|
||||
TFL_OperandHasRank<0, 4>,
|
||||
PredOpTrait<"input and output must have same element type",
|
||||
TFL_TCresVTEtIsSameAsOp<0, 0>>,
|
||||
NoSideEffect,
|
||||
MaxPoolOperandAndResultConstraints,
|
||||
SameOperandsAndResultsScale,
|
||||
@ -1741,7 +1791,7 @@ def TFL_MaxPool2DOp : TFL_Op<"max_pool_2d", [
|
||||
}];
|
||||
|
||||
let arguments = (
|
||||
ins AnyTensor:$input,
|
||||
ins TFL_TensorOf<[F32, QUI8, QI8, QI16, TFL_Quint8]>:$input,
|
||||
TFL_PaddingAttr:$padding,
|
||||
I32Attr:$stride_w,
|
||||
I32Attr:$stride_h,
|
||||
@ -1750,7 +1800,7 @@ def TFL_MaxPool2DOp : TFL_Op<"max_pool_2d", [
|
||||
TFL_AFAttr:$fused_activation_function
|
||||
);
|
||||
|
||||
let results = (outs AnyTensor:$output);
|
||||
let results = (outs TFL_TensorOf<[F32, QUI8, QI8, QI16, TFL_Quint8]>:$output);
|
||||
|
||||
let hasOptions = 1;
|
||||
|
||||
@ -1782,7 +1832,11 @@ def TFL_MaximumOp : TFL_Op<"maximum", [
|
||||
let hasOptions = 0;
|
||||
}
|
||||
|
||||
def TFL_MeanOp : TFL_Op<"mean", [NoSideEffect, TFL_GpuTargetOp]> {
|
||||
def TFL_MeanOp : TFL_Op<"mean", [
|
||||
PredOpTrait<"input and output must have same element type",
|
||||
TFL_TCresVTEtIsSameAsOp<0, 0>>,
|
||||
NoSideEffect,
|
||||
TFL_GpuTargetOp]> {
|
||||
let summary = "Mean operator";
|
||||
|
||||
let description = [{
|
||||
@ -1794,13 +1848,13 @@ def TFL_MeanOp : TFL_Op<"mean", [NoSideEffect, TFL_GpuTargetOp]> {
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TFL_TensorOf<[F32, I8, I32, I64, QI8, QUI8, TFL_Uint8]>:$input,
|
||||
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Uint8]>:$input,
|
||||
TFL_TensorOf<[I32, I64]>:$axis,
|
||||
BoolAttr:$keep_dims
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TFL_TensorOf<[F32, I32, I64, I8, QI8, QUI8, TFL_Uint8]>:$output);
|
||||
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Uint8]>:$output);
|
||||
|
||||
let hasOptions = 1;
|
||||
let customOption = "ReducerOptions";
|
||||
@ -1821,14 +1875,14 @@ def TFL_OneHotOp : TFL_Op<"one_hot", [NoSideEffect]> {
|
||||
let arguments = (ins
|
||||
TFL_TensorOf<[I32, I64]>:$indices,
|
||||
TFL_I32Tensor:$depth,
|
||||
TFL_TensorOf<[F32, I32, I64, I1]>:$on_value,
|
||||
TFL_TensorOf<[F32, I32, I64, I1]>:$off_value,
|
||||
TFL_TensorOf<[F32, I32, I64, I1, I8, UI8]>:$on_value,
|
||||
TFL_TensorOf<[F32, I32, I64, I1, I8, UI8]>:$off_value,
|
||||
|
||||
I32Attr:$axis
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TFL_TensorOf<[F32, I32, I64, I1]>:$output
|
||||
TFL_TensorOf<[F32, I32, I64, I1, I8, UI8]>:$output
|
||||
);
|
||||
|
||||
let hasOptions = 1;
|
||||
@ -2032,7 +2086,11 @@ def TFL_NegOp: TFL_Op<"neg", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def TFL_PackOp : TFL_Op<"pack", [NoSideEffect, SameOperandsAndResultsScale]> {
|
||||
def TFL_PackOp : TFL_Op<"pack", [
|
||||
PredOpTrait<"values and output must have same element type",
|
||||
TFL_TCresVTEtIsSameAsOp<0, 0>>,
|
||||
NoSideEffect,
|
||||
SameOperandsAndResultsScale]> {
|
||||
let summary = "Packs a list of tensors along a dimension into one tensor";
|
||||
|
||||
let description = [{
|
||||
@ -2063,14 +2121,14 @@ def TFL_PackOp : TFL_Op<"pack", [NoSideEffect, SameOperandsAndResultsScale]> {
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TFL_VariadicTensorOf<[F32, I8, I16, I32, I64, QI8, QUI8, QI16]>:$values,
|
||||
TFL_VariadicTensorOf<[F32, I8, I16, I32, I64, UI8, QI8, QUI8, QI16, TFL_Quint8]>:$values,
|
||||
|
||||
I32Attr:$values_count,
|
||||
Confined<I32Attr, [IntPositive]>:$values_count,
|
||||
I32Attr:$axis
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TFL_TensorOf<[F32, I8, I16, I32, I64, QI8, QUI8, QI16]>:$output
|
||||
TFL_TensorOf<[F32, I8, I16, I32, I64, UI8, QI8, QUI8, QI16, TFL_Quint8]>:$output
|
||||
);
|
||||
|
||||
let verifier = [{ return Verify(*this); }];
|
||||
@ -2081,8 +2139,11 @@ def TFL_PackOp : TFL_Op<"pack", [NoSideEffect, SameOperandsAndResultsScale]> {
|
||||
}
|
||||
|
||||
def TFL_PadOp : TFL_Op<"pad", [
|
||||
PredOpTrait<"input and output must have same element type",
|
||||
TFL_TCresVTEtIsSameAsOp<0, 0>>,
|
||||
NoSideEffect,
|
||||
SameOperandsAndResultsScale,
|
||||
TFL_OperandHasRankAtMost<0, 4>,
|
||||
TFL_OperandHasRank<1, 2>,
|
||||
TFL_OperandRankEquals1DimOfOperand<0, 1>,
|
||||
TFL_GpuTargetOp]> {
|
||||
@ -2113,22 +2174,25 @@ def TFL_PadOp : TFL_Op<"pad", [
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins TFL_TensorOf<[F32, I8, I32, I64, QI8, QUI8]>:$input,
|
||||
let arguments = (ins TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$input,
|
||||
TFL_I32OrI64Tensor:$padding);
|
||||
|
||||
let results = (outs TFL_TensorOf<[F32, I8, I32, I64, QI8, QUI8]>:$output);
|
||||
let results = (outs TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$output);
|
||||
|
||||
let hasOptions = 1;
|
||||
}
|
||||
|
||||
def TFL_PadV2Op : TFL_Op<"padv2", [
|
||||
PredOpTrait<"input and output must have same element type",
|
||||
TFL_TCresVTEtIsSameAsOp<0, 0>>,
|
||||
NoSideEffect,
|
||||
SameOperandsAndResultsScale,
|
||||
TFL_OperandHasRankAtMost<0, 4>,
|
||||
TFL_OperandHasRank<1, 2>,
|
||||
TFL_OperandHasRank<2, 0>,
|
||||
TFL_OperandRankEquals1DimOfOperand<0, 1>,
|
||||
PredOpTrait<"input and constant value operands must have same element type",
|
||||
TCopVTEtAreSameAt<[0, 2]>>]> {
|
||||
TFL_TCopVTEtAreSameAt<0, 2>>]> {
|
||||
let summary = "Padding operator v2";
|
||||
|
||||
let description = [{
|
||||
@ -2159,11 +2223,11 @@ def TFL_PadV2Op : TFL_Op<"padv2", [
|
||||
}];
|
||||
|
||||
let arguments = (
|
||||
ins TFL_TensorOf<[F32, I8, I32, I64, QI8, QUI8]>:$input,
|
||||
ins TFL_TensorOf<[F32, I32, I64, UI8, QI8, QUI8, TFL_Quint8]>:$input,
|
||||
TFL_I32OrI64Tensor:$padding,
|
||||
TFL_TensorOf<[F32, I8, I32, I64]>:$constant_values);
|
||||
TFL_TensorOf<[F32, I32, I64, UI8, QI8, QUI8, TFL_Quint8]>:$constant_values);
|
||||
|
||||
let results = (outs TFL_TensorOf<[F32, I8, I32, I64, QI8, QUI8]>:$output);
|
||||
let results = (outs TFL_TensorOf<[F32, I32, I64, UI8, QI8, QUI8, TFL_Quint8]>:$output);
|
||||
|
||||
let hasOptions = 1;
|
||||
}
|
||||
@ -2191,9 +2255,22 @@ def TFL_PowOp : TFL_Op<"pow", [ResultsBroadcastableShape,
|
||||
let builders = [TFL_BroadcastableBinaryBuilder];
|
||||
}
|
||||
|
||||
def TFL_PReluOp : TFL_Op<"prelu", [NoSideEffect,
|
||||
TFL_GpuTargetOp,
|
||||
SameOperandsAndResultsScale]> {
|
||||
def TFL_PReluOp : TFL_Op<"prelu", [
|
||||
NoSideEffect,
|
||||
ResultsBroadcastableShape,
|
||||
TFL_GpuTargetOp,
|
||||
TFL_OperandHasRankAtMost<0, 4>,
|
||||
TFL_OperandHasRankAtMost<1, 4>,
|
||||
BinaryOpSameElementTypeConstraint,
|
||||
PredOpTrait<"input and output must have the same element type",
|
||||
TFL_TCresVTEtIsSameAsOp<0, 0>>,
|
||||
PredOpTrait<"'alpha' should have one less rank than 'input'.",
|
||||
Or<[TFL_OperandIsUnrankedPred<0>,
|
||||
TFL_OperandIsUnrankedPred<1>,
|
||||
CPred<"$_op.getOperand(0).getType().cast<ShapedType>().getRank() == "
|
||||
"$_op.getOperand(1).getType().cast<ShapedType>().getRank() "
|
||||
"+ 1">]>>,
|
||||
SameOperandsAndResultsScale]> {
|
||||
let summary = "Parameterized Relu operator";
|
||||
|
||||
let description = [{
|
||||
@ -2206,11 +2283,11 @@ def TFL_PReluOp : TFL_Op<"prelu", [NoSideEffect,
|
||||
}];
|
||||
|
||||
let arguments = (
|
||||
ins TFL_TensorOf<[F32, QUI8]>:$input,
|
||||
TFL_TensorOf<[F32, QUI8]>:$alpha
|
||||
ins TFL_TensorOf<[F32, QI8, QUI8, TFL_Quint8]>:$input,
|
||||
TFL_TensorOf<[F32, QI8, QUI8, TFL_Quint8]>:$alpha
|
||||
);
|
||||
|
||||
let results = (outs TFL_TensorOf<[F32, QUI8]>:$output);
|
||||
let results = (outs TFL_TensorOf<[F32, QI8, QUI8, TFL_Quint8]>:$output);
|
||||
|
||||
let verifier = [{ return Verify(*this); }];
|
||||
}
|
||||
@ -2887,7 +2964,7 @@ def TFL_DepthToSpaceOp: TFL_Op<"depth_to_space", [
|
||||
SameOperandsAndResultsScale,
|
||||
PredOpTrait<"input and output must have same element type",
|
||||
TFL_TCresVTEtIsSameAsOp<0, 0>>,
|
||||
TFL_OperandHasRankLessThanOrEqualTo<0, 4>
|
||||
TFL_OperandHasRankAtMost<0, 4>
|
||||
]> {
|
||||
let summary = "DepthToSpace operator";
|
||||
|
||||
@ -3224,7 +3301,7 @@ def TFL_QConstOp : Op<TFL_Dialect, "pseudo_qconst", [
|
||||
ElementsAttr:$value
|
||||
);
|
||||
|
||||
let results = (outs AnyTensor:$output);
|
||||
let results = (outs TFL_TensorOf<[QUI8, QI8, QI16, QUI16, TFL_Quint8]>:$output);
|
||||
|
||||
let builders = [OpBuilder<
|
||||
"OpBuilder &, OperationState &state, TypeAttr qtype, Attribute value",
|
||||
@ -3849,7 +3926,7 @@ def TFL_NumericVerifyOp : Op<TFL_Dialect, "NumericVerify", [
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TFL_TensorOf<[QI8, QUI8, QI16, QUI16]>:$input,
|
||||
TFL_TensorOf<[QI8, QUI8, QI16, F16, TFL_Quint8]>:$input,
|
||||
TFL_TensorOf<[F32]>:$ref,
|
||||
|
||||
// Attributes
|
||||
|
@ -573,7 +573,7 @@ func @testLogistic(tensor<1x2x3x4x5xf32>) -> tensor<1x2x3x4x5xf32> {
|
||||
// test invalid Logistic input
|
||||
func @testLogisticWithWrongInputType(tensor<?xi32>) -> tensor<?xi32> {
|
||||
^bb0(%arg0: tensor<?xi32>):
|
||||
// expected-error @+1 {{tfl.logistic' op operand #0 must be tensor of 32-bit float or QI8 type or QUI8 type or QI16 type or QUI16 type values}}
|
||||
// expected-error @+1 {{'tfl.logistic' op operand #0 must be tensor of 32-bit float or QI8 type or QUI8 type or QI16 type or TFLite quint8 type values, but got 'tensor<?xi32>'}}
|
||||
%0 = "tfl.logistic"(%arg0): (tensor<?xi32>) -> tensor<?xi32>
|
||||
return %0#0 : tensor<?xi32>
|
||||
}
|
||||
@ -1252,10 +1252,10 @@ func @testOneHot(%arg0: tensor<3xi32>, %arg1: tensor<i32>, %arg2: tensor<f32>, %
|
||||
|
||||
// -----
|
||||
|
||||
func @testOneHotWithInvalidOutputType(%arg0: tensor<3xi32>, %arg1: tensor<i32>, %arg2: tensor<f32>, %arg3: tensor<f32>) -> tensor<*xi8> {
|
||||
// expected-error @+1 {{'tfl.one_hot' op result #0 must be tensor of 32-bit float or 32-bit signless integer or 64-bit signless integer or 1-bit signless integer values}}
|
||||
%0 = "tfl.one_hot"(%arg0, %arg1, %arg2, %arg3) {axis = -1 : i32} : (tensor<3xi32>, tensor<i32>, tensor<f32>, tensor<f32>) -> tensor<*xi8>
|
||||
return %0 : tensor<*xi8>
|
||||
func @testOneHotWithInvalidOutputType(%arg0: tensor<3xi32>, %arg1: tensor<i32>, %arg2: tensor<f32>, %arg3: tensor<f32>) -> tensor<*xi16> {
|
||||
// expected-error @+1 {{'tfl.one_hot' op result #0 must be tensor of 32-bit float or 32-bit signless integer or 64-bit signless integer or 1-bit signless integer or 8-bit signless integer or 8-bit unsigned integer values, but got 'tensor<*xi16>'}}
|
||||
%0 = "tfl.one_hot"(%arg0, %arg1, %arg2, %arg3) {axis = -1 : i32} : (tensor<3xi32>, tensor<i32>, tensor<f32>, tensor<f32>) -> tensor<*xi16>
|
||||
return %0 : tensor<*xi16>
|
||||
}
|
||||
|
||||
// -----
|
||||
@ -1489,7 +1489,8 @@ func @testEmbeddingLookupValueAndResultElementTypeTraitFailed(%arg0 : tensor<?xi
|
||||
|
||||
// -----
|
||||
|
||||
func @testQuantizedLocalResponseNormalization(%arg0 : tensor<1x56x56x192x!quant.uniform<u8:f32, 0.02>>) -> tensor<1x56x56x192x!quant.uniform<u8:f32, 0.02>> {
|
||||
func @testWrongQuantizedLocalResponseNormalization(%arg0 : tensor<1x56x56x192x!quant.uniform<u8:f32, 0.02>>) -> tensor<1x56x56x192x!quant.uniform<u8:f32, 0.02>> {
|
||||
// expected-error @+1 {{'tfl.local_response_normalization' op operand #0 must be tensor of 32-bit float values, but got 'tensor<1x56x56x192x!quant.uniform<u8:f32, 2.000000e-02>>'}}
|
||||
%0 = "tfl.local_response_normalization"(%arg0) {alpha = 9.99999974E-5 : f32, beta = 5.000000e-01 : f32, bias = 2.000000e+00 : f32, radius = 5 : i32} : (tensor<1x56x56x192x!quant.uniform<u8:f32, 0.02>>) -> tensor<1x56x56x192x!quant.uniform<u8:f32, 0.02>>
|
||||
return %0 : tensor<1x56x56x192x!quant.uniform<u8:f32, 0.02>>
|
||||
}
|
||||
@ -1523,32 +1524,32 @@ func @testDepthToSpaceInvalidOutputType(%arg0: tensor<1x1x1x4xf32>) -> tensor<1x
|
||||
|
||||
// -----
|
||||
|
||||
func @testPReluWrongOutputRank(%arg0: tensor<10x10x10x10xf32>, %arg1: tensor<1x1x10xf32>) -> tensor<10x10x10xf32> {
|
||||
// expected-error @+1 {{'input' and 'output' should have the same rank}}
|
||||
%0 = "tfl.prelu"(%arg0, %arg1) : (tensor<10x10x10x10xf32>, tensor<1x1x10xf32>) -> tensor<10x10x10xf32>
|
||||
return %0 : tensor<10x10x10xf32>
|
||||
func @testPReluWrongOutputRank(%arg0: tensor<10x10x10x10xf32>, %arg1: tensor<10x10x10x10xf32>) -> tensor<10x10xf32> {
|
||||
// expected-error @+1 {{'tfl.prelu' op result type '10x10' not broadcast compatible with broadcasted operands's shapes '10x10x10x10'}}
|
||||
%0 = "tfl.prelu"(%arg0, %arg1) : (tensor<10x10x10x10xf32>, tensor<10x10x10x10xf32>) -> tensor<10x10xf32>
|
||||
return %0 : tensor<10x10xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testPReluWrongOutputShape(%arg0: tensor<1x2x3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<1x2x3x5xf32> {
|
||||
// expected-error @+1 {{'input' and 'output' should have the same shape}}
|
||||
// expected-error @+1 {{'tfl.prelu' op result type '1x2x3x5' not broadcast compatible with broadcasted operands's shapes '1x2x3x4'}}
|
||||
%0 = "tfl.prelu"(%arg0, %arg1) : (tensor<1x2x3x4xf32>, tensor<2x3x4xf32>) -> tensor<1x2x3x5xf32>
|
||||
return %0 : tensor<1x2x3x5xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testPReluWrongAlphaRank(%arg0: tensor<7x3x2x14xf32>, %arg1: tensor<2x7x3x2x14xf32>) -> tensor<7x3x2x14xf32> {
|
||||
func @testPReluWrongAlphaRank(%arg0: tensor<7x3x2x14xf32>, %arg1: tensor<7x3x2x14xf32>) -> tensor<7x3x2x14xf32> {
|
||||
// expected-error @+1 {{'alpha' should have one less rank than 'input'.}}
|
||||
%0 = "tfl.prelu"(%arg0, %arg1) : (tensor<7x3x2x14xf32>, tensor<2x7x3x2x14xf32>) -> tensor<7x3x2x14xf32>
|
||||
%0 = "tfl.prelu"(%arg0, %arg1) : (tensor<7x3x2x14xf32>, tensor<7x3x2x14xf32>) -> tensor<7x3x2x14xf32>
|
||||
return %0 : tensor<7x3x2x14xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testPReluInvalidBroadcast(%arg0: tensor<15x14x2x14xf32>, %arg1: tensor<1x1x3xf32>) -> tensor<15x14x2x14xf32> {
|
||||
// expected-error @+1 {{'alpha' is not broadcastable at dimension 2.}}
|
||||
// expected-error @+1 {{'tfl.prelu' op operands don't have broadcast-compatible shapes}}
|
||||
%0 = "tfl.prelu"(%arg0, %arg1) : (tensor<15x14x2x14xf32>, tensor<1x1x3xf32>) -> tensor<15x14x2x14xf32>
|
||||
return %0 : tensor<15x14x2x14xf32>
|
||||
}
|
||||
|
@ -12,6 +12,22 @@ cc_library(
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/compiler/mlir/tensorflow:convert_graphdef",
|
||||
"//tensorflow/compiler/mlir/tensorflow:error_util",
|
||||
# (yongtang) The graph_optimization_pass_registration needs to be part
|
||||
# of a shared object that will be loaded whenever `import tensorflow`
|
||||
# is run. The natural place is libtensorflow_framework.so.
|
||||
# While adding graph_optimization_pass_registration to
|
||||
# libtensorflow_framework.so is possible with some modification in
|
||||
# dependency, many tests will fail due to multiple copies of LLVM.
|
||||
# See https://github.com/tensorflow/tensorflow/pull/39231 for details.
|
||||
# Alternatively, we place graph_optimization_pass_registration here
|
||||
# because:
|
||||
# - tensorflow/python/_pywrap_mlir.so already depends on LLVM anyway
|
||||
# - tensorflow/python/_pywrap_mlir.so always loaded as part of python
|
||||
# binding
|
||||
# TODO: It might be still preferrable to place graph_optimization_pass
|
||||
# as part of the libtensorflow_framework.so, as it is the central
|
||||
# place for core related components.
|
||||
"//tensorflow/compiler/mlir/tensorflow:graph_optimization_pass_registration",
|
||||
"//tensorflow/compiler/mlir/tensorflow:import_utils",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
|
@ -1217,7 +1217,7 @@ that are not a number (NaN) or infinity (Inf). Otherwise, passes `tensor` as-is.
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_ClipByValueOp : TF_Op<"ClipByValue", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
def TF_ClipByValueOp : TF_Op<"ClipByValue", [NoSideEffect]> {
|
||||
let summary = "Clips tensor values to a specified min and max.";
|
||||
|
||||
let description = [{
|
||||
@ -1682,6 +1682,27 @@ Given an input tensor, this function computes hyperbolic cosine of every
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_CrossOp : TF_Op<"Cross", [NoSideEffect]> {
|
||||
let summary = "Compute the pairwise cross product.";
|
||||
|
||||
let description = [{
|
||||
`a` and `b` must be the same shape; they can either be simple 3-element vectors,
|
||||
or any shape where the innermost dimension is 3. In the latter case, each pair
|
||||
of corresponding 3-element vectors is cross-multiplied independently.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_IntOrFpTensor:$a,
|
||||
TF_IntOrFpTensor:$b
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_IntOrFpTensor:$product
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_CrossReplicaSumOp : TF_Op<"CrossReplicaSum", [AllTypesMatch<["input", "output"]>, NoSideEffect]> {
|
||||
let summary = "An Op to sum inputs across replicated TPU instances.";
|
||||
|
||||
|
@ -447,4 +447,27 @@ std::string GetDeviceAliasForLogicalCore(int core_index) {
|
||||
return llvm::formatv("{0}_{1}", kTPUReplicatedCore, core_index).str();
|
||||
}
|
||||
|
||||
StatusOr<std::string> GetCPUHostForTPUDevice(llvm::StringRef tpu_device) {
|
||||
Device device;
|
||||
if (!DeviceNameUtils::ParseFullName(tpu_device.str(), &device))
|
||||
return errors::InvalidArgument("'", tpu_device.str(),
|
||||
"' is not a valid device");
|
||||
|
||||
device.type = DEVICE_CPU;
|
||||
device.id = 0;
|
||||
return DeviceNameUtils::ParsedNameToString(device);
|
||||
}
|
||||
|
||||
StatusOr<llvm::SmallVector<std::string, 8>> GetCPUHostsForTPUDevices(
|
||||
llvm::ArrayRef<std::string> tpu_devices) {
|
||||
llvm::SmallVector<std::string, 8> cpu_devices;
|
||||
cpu_devices.reserve(tpu_devices.size());
|
||||
for (const auto& tpu_device : tpu_devices) {
|
||||
TF_ASSIGN_OR_RETURN(cpu_devices.emplace_back(),
|
||||
GetCPUHostForTPUDevice(tpu_device));
|
||||
}
|
||||
|
||||
return cpu_devices;
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -216,6 +216,17 @@ StatusOr<TPUDeviceAssignment> GetTPUCompilationAndExecutionDevices(
|
||||
// logical core.
|
||||
std::string GetDeviceAliasForLogicalCore(int core_index);
|
||||
|
||||
// Finds associated CPU host device for given TPU device. This assumes a
|
||||
// matching CPU host device exists based on TPU device name. An error will be
|
||||
// returned if the TPU device name is invalid.
|
||||
StatusOr<std::string> GetCPUHostForTPUDevice(llvm::StringRef tpu_device);
|
||||
|
||||
// Finds associated CPU host devices for given TPU devices. This assumes a
|
||||
// matching CPU host device exist based on each TPU device name. An error will
|
||||
// be returned if a TPU device name is invalid.
|
||||
StatusOr<llvm::SmallVector<std::string, 8>> GetCPUHostsForTPUDevices(
|
||||
llvm::ArrayRef<std::string> tpu_devices);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_TPU_REWRITE_DEVICE_UTIL_H_
|
||||
|
@ -552,5 +552,44 @@ TEST(TPURewriteDeviceUtilTest, ValidGeneralDeviceAssignmentMesh1x2x1x3) {
|
||||
EXPECT_EQ(computation_device_2.replica_device_ids(1), 3);
|
||||
}
|
||||
|
||||
struct ParameterizedCPUHostForTPUDeviceTest
|
||||
: ::testing::TestWithParam<std::tuple<std::string, std::string>> {};
|
||||
|
||||
TEST_P(ParameterizedCPUHostForTPUDeviceTest, CPUHostForTPUDevice) {
|
||||
auto status_or_device = GetCPUHostForTPUDevice(std::get<0>(GetParam()));
|
||||
TF_ASSERT_OK(status_or_device.status());
|
||||
EXPECT_EQ(status_or_device.ValueOrDie(), std::get<1>(GetParam()));
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
CPUHostForTPUDevice, ParameterizedCPUHostForTPUDeviceTest,
|
||||
::testing::Values(
|
||||
std::make_tuple("/job:worker/replica:0/task:0/device:TPU:0",
|
||||
"/job:worker/replica:0/task:0/device:CPU:0"),
|
||||
std::make_tuple("/job:worker/replica:0/task:1/device:TPU:1",
|
||||
"/job:worker/replica:0/task:1/device:CPU:0")));
|
||||
|
||||
TEST(TPURewriteDeviceUtilTest, CPUHostForTPUDeviceInvalidDevice) {
|
||||
auto status_or_device = GetCPUHostForTPUDevice("bad_device");
|
||||
ASSERT_FALSE(status_or_device.ok());
|
||||
}
|
||||
|
||||
TEST(TPURewriteDeviceUtilTest, CPUHostsForTPUDevices) {
|
||||
auto status_or_devices =
|
||||
GetCPUHostsForTPUDevices({"/job:worker/replica:0/task:0/device:TPU:0",
|
||||
"/job:worker/replica:0/task:1/device:TPU:1"});
|
||||
TF_ASSERT_OK(status_or_devices.status());
|
||||
const auto& devices = status_or_devices.ValueOrDie();
|
||||
ASSERT_EQ(devices.size(), 2);
|
||||
EXPECT_EQ(devices[0], "/job:worker/replica:0/task:0/device:CPU:0");
|
||||
EXPECT_EQ(devices[1], "/job:worker/replica:0/task:1/device:CPU:0");
|
||||
}
|
||||
|
||||
TEST(TPURewriteDeviceUtilTest, CPUHostsForTPUDevicesInvalidDevice) {
|
||||
auto status_or_devices = GetCPUHostsForTPUDevices(
|
||||
{"/job:worker/replica:0/task:0/device:TPU:0", "bad_device"});
|
||||
ASSERT_FALSE(status_or_devices.ok());
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
} // namespace tensorflow
|
||||
|
49
tensorflow/compiler/mlir/tools/kernel_gen/BUILD
Normal file
49
tensorflow/compiler/mlir/tools/kernel_gen/BUILD
Normal file
@ -0,0 +1,49 @@
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_binary")
|
||||
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
cc_library(
|
||||
name = "cubin_creator",
|
||||
srcs = ["cubin_creator.cc"],
|
||||
hdrs = ["cubin_creator.h"],
|
||||
copts = if_cuda(["-DGOOGLE_CUDA=1"]),
|
||||
deps = [
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:AllPassesAndDialects",
|
||||
"@llvm-project//mlir:GPUDialect",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:LLVMDialect",
|
||||
"@llvm-project//mlir:Parser",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:TargetNVVMIR",
|
||||
"@llvm-project//mlir:Transforms",
|
||||
"//tensorflow/compiler/mlir/xla:hlo",
|
||||
"//tensorflow/compiler/mlir/xla:lhlo",
|
||||
"//tensorflow/compiler/mlir/xla:xla_legalize_tf",
|
||||
"//tensorflow/compiler/mlir/xla:xla_materialize_broadcasts", # buildcleaner: keep
|
||||
"//tensorflow/compiler/mlir/xla:xla_unfuse_batch_norm", # buildcleaner: keep
|
||||
"//tensorflow/compiler/xla:debug_options_flags",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla/service/gpu:stream_executor_util",
|
||||
"//tensorflow/compiler/xla/service/gpu:target_constants",
|
||||
"//tensorflow/compiler/xla/service/gpu/llvm_gpu_backend",
|
||||
"//tensorflow/compiler/xla/service/mlir_gpu:kernel_lowering",
|
||||
"//tensorflow/core:cuda_libdevice_path",
|
||||
"//tensorflow/core:lib",
|
||||
] + if_cuda(["//tensorflow/stream_executor/gpu:asm_compiler"]),
|
||||
)
|
||||
|
||||
tf_cc_binary(
|
||||
name = "tf_to_cubin",
|
||||
srcs = ["tf_to_cubin.cc"],
|
||||
deps = [
|
||||
":cubin_creator",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
96
tensorflow/compiler/mlir/tools/kernel_gen/build_defs.bzl
Normal file
96
tensorflow/compiler/mlir/tools/kernel_gen/build_defs.bzl
Normal file
@ -0,0 +1,96 @@
|
||||
load("//third_party/gpus/cuda:build_defs.bzl", "cuda_gpu_select_list")
|
||||
|
||||
def _lookup_file(filegroup, path):
|
||||
"""Extracts file at (relative) path in filegroup."""
|
||||
for file in filegroup.files.to_list():
|
||||
if file.path.endswith(path):
|
||||
return file
|
||||
return None
|
||||
|
||||
def _gen_kernel_image_hdr_impl(ctx):
|
||||
if not ctx.attr.gpu_archs:
|
||||
fail("No GPU architecture specified, use --config=cuda or similar.")
|
||||
|
||||
name = ctx.attr.name
|
||||
tile_sizes = ctx.attr.tile_size.replace("x", ",")
|
||||
same_shape = []
|
||||
if ctx.attr.same_shape:
|
||||
same_shape.append("--same_shape=%s" % ctx.attr.same_shape)
|
||||
|
||||
cubins = []
|
||||
images = []
|
||||
for arch in ctx.attr.gpu_archs:
|
||||
filename = "%s.%s.cubin" % (name, arch)
|
||||
cubin = ctx.actions.declare_file(filename)
|
||||
ctx.actions.run(
|
||||
outputs = [cubin],
|
||||
executable = ctx.executable._tool,
|
||||
arguments = same_shape + [
|
||||
"--tile_sizes=%s" % tile_sizes,
|
||||
"--arch=%s" % arch.split("_")[1],
|
||||
"--output=%s" % cubin.path,
|
||||
ctx.attr.op,
|
||||
],
|
||||
mnemonic = "compile",
|
||||
)
|
||||
cubins.append(cubin)
|
||||
images.append("--image=profile=%s,file=%s" % (arch, cubin.path))
|
||||
|
||||
# Generate fatbin file from all cubins.
|
||||
fatbin = ctx.actions.declare_file("%s.fatbin" % name)
|
||||
ctx.actions.run(
|
||||
outputs = [fatbin],
|
||||
inputs = cubins,
|
||||
executable = _lookup_file(ctx.attr._cuda_root, "bin/fatbinary"),
|
||||
arguments = [
|
||||
"--64",
|
||||
"--cmdline=--compile-only",
|
||||
"--link",
|
||||
"--compress-all",
|
||||
"--create=%s" % fatbin.path,
|
||||
] + images,
|
||||
mnemonic = "fatbinary",
|
||||
)
|
||||
|
||||
bin2c = _lookup_file(ctx.attr._cuda_root, "bin/bin2c")
|
||||
ctx.actions.run_shell(
|
||||
outputs = [ctx.outputs.out],
|
||||
inputs = [fatbin],
|
||||
tools = [bin2c],
|
||||
command = "%s --static --const --type=int --name=%s %s 1> %s" %
|
||||
(bin2c.path, ctx.attr.symbol, fatbin.path, ctx.outputs.out.path),
|
||||
mnemonic = "bin2c",
|
||||
)
|
||||
|
||||
_gen_kernel_image_hdr = rule(
|
||||
implementation = _gen_kernel_image_hdr_impl,
|
||||
output_to_genfiles = True,
|
||||
attrs = {
|
||||
"op": attr.string(mandatory = True),
|
||||
"tile_size": attr.string(mandatory = True),
|
||||
"same_shape": attr.string(),
|
||||
"out": attr.output(mandatory = True),
|
||||
"symbol": attr.string(mandatory = True),
|
||||
"gpu_archs": attr.string_list(mandatory = True),
|
||||
"_cuda_root": attr.label(
|
||||
default = Label("//third_party/gpus/cuda:cuda_root"),
|
||||
),
|
||||
"_tool": attr.label(
|
||||
executable = True,
|
||||
default = Label("//tensorflow/compiler/mlir/tools/kernel_gen:tf_to_cubin"),
|
||||
cfg = "host",
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
def gen_kernel_image_hdr(name, op, tile_size, same_shape = None):
|
||||
"""Generates a C header with fatbin data from a Tensorflow op."""
|
||||
_gen_kernel_image_hdr(
|
||||
name = name,
|
||||
op = op,
|
||||
tile_size = tile_size,
|
||||
same_shape = same_shape,
|
||||
out = "include/tfrt/gpu/ops/tf/%s.h" % name,
|
||||
symbol = "k%s" % name.replace("_", " ").title().replace(" ", ""),
|
||||
gpu_archs = cuda_gpu_select_list("sm_{}"),
|
||||
)
|
264
tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.cc
Normal file
264
tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.cc
Normal file
@ -0,0 +1,264 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
//===- cubin_creator.cc -----------------------------------------*- C++ -*-===//
|
||||
//
|
||||
// This file implements the function to compile a TF kernel function to a cubin.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
#include "tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.h"
|
||||
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/escaping.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/None.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "mlir/Dialect/GPU/GPUDialect.h" // from @llvm-project
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/IR/Value.h" // from @llvm-project
|
||||
#include "mlir/Parser.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Pass/PassManager.h" // from @llvm-project
|
||||
#include "mlir/Target/NVVMIR.h" // from @llvm-project
|
||||
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
|
||||
#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/xla/transforms/rewriters.h"
|
||||
#include "tensorflow/compiler/xla/debug_options_flags.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/target_constants.h"
|
||||
#include "tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.h"
|
||||
#include "tensorflow/core/platform/cuda_libdevice_path.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/path.h"
|
||||
#if GOOGLE_CUDA
|
||||
#include "tensorflow/stream_executor/gpu/asm_compiler.h"
|
||||
#endif
|
||||
|
||||
namespace {
|
||||
using tensorflow::Status;
|
||||
using xla::InternalError;
|
||||
using xla::StatusOr;
|
||||
|
||||
StatusOr<std::string> GetLibdeviceDir(
|
||||
const xla::HloModuleConfig& hlo_module_config) {
|
||||
for (const string& cuda_root : tensorflow::CandidateCudaRoots(
|
||||
hlo_module_config.debug_options().xla_gpu_cuda_data_dir())) {
|
||||
string libdevice_dir =
|
||||
tensorflow::io::JoinPath(cuda_root, "nvvm", "libdevice");
|
||||
VLOG(2) << "Looking for libdevice at " << libdevice_dir;
|
||||
if (tensorflow::Env::Default()->IsDirectory(libdevice_dir).ok()) {
|
||||
VLOG(2) << "Found libdevice dir " << libdevice_dir;
|
||||
return libdevice_dir;
|
||||
}
|
||||
}
|
||||
return InternalError(
|
||||
"Can't find libdevice directory ${CUDA_DIR}/nvvm/libdevice");
|
||||
}
|
||||
|
||||
struct MaterializeBroadcastsPass
|
||||
: public mlir::PassWrapper<MaterializeBroadcastsPass, mlir::FunctionPass> {
|
||||
void runOnFunction() override {
|
||||
mlir::ConversionTarget conversionTarget(getContext());
|
||||
mlir::OwningRewritePatternList conversionPatterns;
|
||||
|
||||
// Consider the xla_hlo dialect legal for tests.
|
||||
conversionTarget.addLegalDialect<mlir::xla_hlo::XlaHloDialect>();
|
||||
// The conversion uses helpers from the Standard dialect.
|
||||
conversionTarget.addLegalDialect<mlir::StandardOpsDialect>();
|
||||
|
||||
mlir::xla_hlo::SetupMaterializeBroadcastsLegality(&getContext(),
|
||||
&conversionTarget);
|
||||
mlir::xla_hlo::PopulateMaterializeBroadcastsPatterns(&getContext(),
|
||||
&conversionPatterns);
|
||||
|
||||
if (failed(applyPartialConversion(getFunction(), conversionTarget,
|
||||
conversionPatterns))) {
|
||||
return signalPassFailure();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct UnfuseBatchNormPass
|
||||
: public mlir::PassWrapper<UnfuseBatchNormPass, mlir::FunctionPass> {
|
||||
void runOnFunction() override {
|
||||
mlir::OwningRewritePatternList patterns;
|
||||
mlir::xla_hlo::PopulateUnfuseBatchNormPatterns(&getContext(), &patterns);
|
||||
mlir::applyPatternsAndFoldGreedily(getOperation(), patterns);
|
||||
}
|
||||
};
|
||||
|
||||
Status LowerTfOpToLhloWithDynamicShapes(mlir::ModuleOp module) {
|
||||
mlir::PassManager pm(module.getContext());
|
||||
auto enable_if_vlog_is_on = [](mlir::Pass* pass, mlir::Operation* op) {
|
||||
return VLOG_IS_ON(1);
|
||||
};
|
||||
pm.enableIRPrinting(/*shouldPrintBeforePass=*/{},
|
||||
/*shouldPrintAfterPass=*/enable_if_vlog_is_on,
|
||||
/*printModuleScope=*/false,
|
||||
/*printAfterOnlyOnChange=*/false, llvm::dbgs());
|
||||
pm.addNestedPass<mlir::FuncOp>(mlir::xla_hlo::createLegalizeTFPass(false));
|
||||
pm.addNestedPass<mlir::FuncOp>(
|
||||
absl::make_unique<MaterializeBroadcastsPass>());
|
||||
pm.addNestedPass<mlir::FuncOp>(absl::make_unique<UnfuseBatchNormPass>());
|
||||
pm.addPass(mlir::xla_hlo::createLegalizeToLhloPass());
|
||||
pm.addNestedPass<mlir::FuncOp>(mlir::xla_lhlo::createLhloCopyRemovalPass());
|
||||
|
||||
if (failed(pm.run(module))) {
|
||||
return InternalError("Lowering TF to LHLO failed.");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
struct PropagateStaticKnowledge
|
||||
: public mlir::PassWrapper<PropagateStaticKnowledge,
|
||||
mlir::OperationPass<mlir::LLVM::LLVMFuncOp>> {
|
||||
explicit PropagateStaticKnowledge(mlir::FunctionType type,
|
||||
llvm::ArrayRef<unsigned> same_shape_)
|
||||
: func_type(type), same_shape(same_shape_) {}
|
||||
|
||||
void runOnOperation() override {
|
||||
// We know due to tensorflow ABI that the offset is always 0 and that the
|
||||
// innermost stride is always 1. To make this visible to the compiler,
|
||||
// we insert constants into the code and replace usages accordingly.
|
||||
// We do not change the signature so that we keep a somewhat stable ABI
|
||||
// that is easy to undertand by tools.
|
||||
mlir::LLVM::LLVMFuncOp func = getOperation();
|
||||
mlir::OpBuilder b(func.getBody());
|
||||
auto index_type = func.getArgument(3).getType();
|
||||
mlir::Value one = b.create<mlir::LLVM::ConstantOp>(
|
||||
func.getLoc(), index_type, b.getIntegerAttr(b.getIndexType(), 1));
|
||||
mlir::Value zero = b.create<mlir::LLVM::ConstantOp>(
|
||||
func.getLoc(), index_type, b.getIntegerAttr(b.getIndexType(), 0));
|
||||
unsigned arg_pos = 0;
|
||||
std::vector<unsigned> positions;
|
||||
for (mlir::Type arg_type : func_type.getInputs()) {
|
||||
positions.push_back(arg_pos);
|
||||
func.getArgument(arg_pos + 2).replaceAllUsesWith(zero);
|
||||
arg_pos += 3 + arg_type.cast<mlir::ShapedType>().getRank() * 2;
|
||||
func.getArgument(arg_pos - 1).replaceAllUsesWith(one);
|
||||
}
|
||||
|
||||
// If we have knowledge that some arguments have the same shape, we
|
||||
// can use that here. Simply replace usages of the shape parameters within
|
||||
// the function body to a single shape parameter.
|
||||
if (!same_shape.empty()) {
|
||||
int first = same_shape.front();
|
||||
int first_offset = positions.at(first);
|
||||
mlir::ShapedType first_type =
|
||||
func_type.getInput(first).cast<mlir::ShapedType>();
|
||||
unsigned rank = first_type.getRank();
|
||||
for (int same : same_shape.drop_front(1)) {
|
||||
unsigned same_offset = positions.at(same);
|
||||
auto same_type = func_type.getInput(same).cast<mlir::ShapedType>();
|
||||
if (same_type.getRank() != rank) {
|
||||
func.emitOpError() << "same shape constraints on arguments with "
|
||||
"non-matching shapes: #"
|
||||
<< first << " and #" << same;
|
||||
signalPassFailure();
|
||||
}
|
||||
|
||||
for (int i = 0; i < 2 * rank; ++i) {
|
||||
// Replace uses for second arg data with first arg.
|
||||
auto same_arg = func.getArgument(same_offset + 3 + i);
|
||||
auto first_arg = func.getArgument(first_offset + 3 + i);
|
||||
same_arg.replaceAllUsesWith(first_arg);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
mlir::FunctionType func_type;
|
||||
llvm::ArrayRef<unsigned> same_shape;
|
||||
};
|
||||
|
||||
Status PropagateStaticShapeKnowledgeToKernel(
|
||||
mlir::ModuleOp module, llvm::ArrayRef<unsigned> same_shape) {
|
||||
// Grab the original signature from the single function.
|
||||
auto func = *module.getBody()->op_begin<mlir::FuncOp>();
|
||||
|
||||
mlir::PassManager pm(module.getContext());
|
||||
auto enable_if_vlog_is_on = [](mlir::Pass*, mlir::Operation*) {
|
||||
return VLOG_IS_ON(1);
|
||||
};
|
||||
pm.enableIRPrinting(/*shouldPrintBeforePass=*/{},
|
||||
/*shouldPrintAfterPass=*/enable_if_vlog_is_on,
|
||||
/*printModuleScope=*/false,
|
||||
/*printAfterOnlyOnChange=*/false, llvm::dbgs());
|
||||
auto& kernel_pm = pm.nest<::mlir::gpu::GPUModuleOp>();
|
||||
kernel_pm.addNestedPass<mlir::LLVM::LLVMFuncOp>(
|
||||
absl::make_unique<PropagateStaticKnowledge>(func.getType(), same_shape));
|
||||
|
||||
if (failed(pm.run(module))) {
|
||||
return InternalError("Static knowledge propagation failed.");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
StatusOr<std::vector<uint8>> tensorflow::kernel_gen::GenerateCubinForTfCode(
|
||||
llvm::StringRef tf_code, std::pair<int, int> compute_capability,
|
||||
llvm::ArrayRef<unsigned> tile_sizes, llvm::ArrayRef<unsigned> same_shape,
|
||||
llvm::ArrayRef<unsigned> unroll_factors) {
|
||||
mlir::MLIRContext context;
|
||||
context.allowUnregisteredDialects(); // TODO(b/152572127)
|
||||
mlir::OwningModuleRef module = mlir::parseSourceString(tf_code, &context);
|
||||
|
||||
TF_RETURN_IF_ERROR(LowerTfOpToLhloWithDynamicShapes(module.get()));
|
||||
TF_RETURN_IF_ERROR(
|
||||
xla::mlir_gpu::LowerLHLOToGPU(module.get(), tile_sizes, unroll_factors,
|
||||
/*collapseParallelLoops=*/false));
|
||||
TF_RETURN_IF_ERROR(xla::mlir_gpu::LowerKernelBodiesToNVVM(module.get()));
|
||||
TF_RETURN_IF_ERROR(
|
||||
PropagateStaticShapeKnowledgeToKernel(module.get(), same_shape));
|
||||
|
||||
mlir::OwningModuleRef kernel_module =
|
||||
xla::mlir_gpu::ExtractKernelModule(*module).ValueOrDie();
|
||||
auto llvmModule = mlir::translateModuleToNVVMIR(*kernel_module);
|
||||
if (!llvmModule) {
|
||||
return InternalError("Could not translate MLIR module to NVVM");
|
||||
}
|
||||
|
||||
llvmModule->setModuleIdentifier("acme");
|
||||
llvmModule->setDataLayout(xla::gpu::nvptx::kDataLayout);
|
||||
|
||||
xla::HloModuleConfig config;
|
||||
config.set_debug_options(xla::GetDebugOptionsFromFlags());
|
||||
|
||||
TF_ASSIGN_OR_RETURN(std::string libdevice_dir, GetLibdeviceDir(config));
|
||||
TF_ASSIGN_OR_RETURN(std::string ptx, xla::gpu::nvptx::CompileToPtx(
|
||||
llvmModule.get(), compute_capability,
|
||||
config, libdevice_dir));
|
||||
VLOG(1) << ptx;
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
return tensorflow::se::CompileGpuAsm(
|
||||
std::get<0>(compute_capability), std::get<1>(compute_capability),
|
||||
ptx.c_str(), xla::gpu::PtxOptsFromConfig(config));
|
||||
#else
|
||||
return InternalError(
|
||||
"GOOGLE_CUDA not defined. Did you specify --config=cuda ?");
|
||||
#endif
|
||||
}
|
41
tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.h
Normal file
41
tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.h
Normal file
@ -0,0 +1,41 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
//===- cubin_creator.h ------------------------------------------*- C++ -*-===//
|
||||
//
|
||||
// This file declares the function to compile a TF kernel function to a cubin.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_CUBIN_CREATOR_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_CUBIN_CREATOR_H_
|
||||
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace kernel_gen {
|
||||
xla::StatusOr<std::vector<uint8>> GenerateCubinForTfCode(
|
||||
llvm::StringRef tf_code, std::pair<int, int> compute_capability = {7, 5},
|
||||
llvm::ArrayRef<unsigned> tile_sizes = {16, 64},
|
||||
llvm::ArrayRef<unsigned> same_shape = {},
|
||||
llvm::ArrayRef<unsigned> unroll_factors = {});
|
||||
} // namespace kernel_gen
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_CUBIN_CREATOR_H_
|
118
tensorflow/compiler/mlir/tools/kernel_gen/tf_to_cubin.cc
Normal file
118
tensorflow/compiler/mlir/tools/kernel_gen/tf_to_cubin.cc
Normal file
@ -0,0 +1,118 @@
|
||||
// Copyright 2020 The TensorFlow Runtime Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//===- tf_to_cubin.cc -------------------------------------------*- C++ -*-===//
|
||||
//
|
||||
// This file implements the entry point to compile a tf op to a cubin file.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/numbers.h"
|
||||
#include "absl/strings/str_split.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/init_main.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/util/command_line_flags.h"
|
||||
|
||||
namespace {
|
||||
bool ParseStringList(std::string string_list, std::vector<uint32>* result) {
|
||||
result->clear();
|
||||
uint32 item;
|
||||
auto items = absl::StrSplit(string_list, ',');
|
||||
for (const auto& item_str : items) {
|
||||
if (!absl::SimpleAtoi(item_str, &item)) {
|
||||
LOG(ERROR) << "Expected token " << item_str << " to be an integer";
|
||||
return false;
|
||||
}
|
||||
result->push_back(item);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
std::string output_file = "foo.bin";
|
||||
int32 architecture = 50;
|
||||
std::vector<uint32> tile_sizes;
|
||||
std::vector<uint32> unroll_factors;
|
||||
std::vector<uint32> same_shape;
|
||||
|
||||
auto parse_tile_sizes = [&tile_sizes](std::string tile_sizes_str) {
|
||||
if (!ParseStringList(tile_sizes_str, &tile_sizes)) {
|
||||
return false;
|
||||
}
|
||||
// Initialize with the default.
|
||||
if (tile_sizes.empty()) {
|
||||
tile_sizes.push_back(16);
|
||||
tile_sizes.push_back(64);
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
||||
auto parse_unroll_factors =
|
||||
[&unroll_factors](std::string unroll_factors_str) {
|
||||
return ParseStringList(unroll_factors_str, &unroll_factors);
|
||||
};
|
||||
|
||||
auto parse_same_shape = [&same_shape](std::string same_shape_str) {
|
||||
return ParseStringList(same_shape_str, &same_shape);
|
||||
};
|
||||
|
||||
std::vector<tensorflow::Flag> flag_list = {
|
||||
tensorflow::Flag("output", &output_file, "output file"),
|
||||
tensorflow::Flag("arch", &architecture,
|
||||
"target architecture (e.g. 50 for sm_50)"),
|
||||
tensorflow::Flag("tile_sizes", parse_tile_sizes, "16,64",
|
||||
"tile sizes to use"),
|
||||
tensorflow::Flag("unroll_factors", parse_unroll_factors, "",
|
||||
"factors to unroll by, separated by commas"),
|
||||
tensorflow::Flag("same_shape", parse_same_shape, "",
|
||||
"arguments with same shape, separated by commas"),
|
||||
};
|
||||
bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list);
|
||||
tensorflow::port::InitMain("usage", &argc, &argv);
|
||||
if (!parse_ok) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
std::pair<int32, int32> compute_capability(architecture / 10,
|
||||
architecture % 10);
|
||||
|
||||
auto cubin = tensorflow::kernel_gen::GenerateCubinForTfCode(
|
||||
argv[1], compute_capability, tile_sizes, same_shape, unroll_factors);
|
||||
|
||||
if (!cubin.ok()) {
|
||||
LOG(ERROR) << cubin.status();
|
||||
return 1;
|
||||
}
|
||||
|
||||
std::vector<uint8> cubin_data = cubin.ConsumeValueOrDie();
|
||||
|
||||
auto status = tensorflow::WriteStringToFile(
|
||||
tensorflow::Env::Default(), output_file,
|
||||
absl::string_view{reinterpret_cast<char*>(cubin_data.data()),
|
||||
cubin_data.size()});
|
||||
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << status;
|
||||
return 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
@ -23,7 +23,6 @@ package_group(
|
||||
"//tensorflow/compiler/xla/...",
|
||||
"//third_party/iree/...",
|
||||
"//third_party/mlir_edge/...",
|
||||
"//third_party/tf_runtime/tools/tf_kernel_gen/...",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -95,6 +95,7 @@ def HLO_CreateTokenOp : HLO_Op<"create_token", [NoSideEffect]> {
|
||||
// XLA unary elementwise op definitions.
|
||||
//===----------------------------------------------------------------------===//
|
||||
// See https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions
|
||||
|
||||
class HLO_UnaryElementwiseOp<string mnemonic, list<OpTrait> traits,
|
||||
Type TensorType>: HLO_Op<mnemonic,
|
||||
!listconcat(traits, [InferShapedTypeOpInterface])> {
|
||||
@ -161,6 +162,16 @@ def HLO_Expm1Op: HLO_UnaryElementwiseOp<"exponential_minus_one",
|
||||
def HLO_FloorOp: HLO_UnaryElementwiseOp<"floor",
|
||||
[NoSideEffect, SameOperandsAndResultType], HLO_FpTensor>, BASE_HLO_FloorOp;
|
||||
|
||||
def HLO_ImagOp: HLO_Op<
|
||||
"imag", [NoSideEffect, SameOperandsAndResultShape]>, BASE_HLO_ImagOp {
|
||||
let builders = [OpBuilder<
|
||||
"OpBuilder &, OperationState &tblgen_state, Value val">];
|
||||
|
||||
let arguments = (ins HLO_ComplexTensor);
|
||||
let results = (outs HLO_FpTensor);
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def HLO_IsFiniteOp: HLO_UnaryElementwiseOp<"is_finite",
|
||||
[NoSideEffect, SameOperandsAndResultShape], HLO_Tensor>,
|
||||
BASE_HLO_IsFiniteOp {
|
||||
@ -188,6 +199,16 @@ def HLO_PopulationCountOp: HLO_UnaryElementwiseOp<"popcnt",
|
||||
[NoSideEffect, SameOperandsAndResultType], HLO_IntTensor>,
|
||||
BASE_HLO_PopulationCountOp;
|
||||
|
||||
def HLO_RealOp: HLO_Op<
|
||||
"real", [NoSideEffect, SameOperandsAndResultShape]>, BASE_HLO_RealOp {
|
||||
let builders = [OpBuilder<
|
||||
"OpBuilder &, OperationState &tblgen_state, Value val">];
|
||||
|
||||
let arguments = (ins HLO_ComplexTensor);
|
||||
let results = (outs HLO_FpTensor);
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def HLO_RoundOp: HLO_UnaryElementwiseOp<"round_nearest_afz",
|
||||
[NoSideEffect, SameOperandsAndResultType], HLO_FpTensor>, BASE_HLO_RoundOp;
|
||||
|
||||
@ -212,47 +233,11 @@ def HLO_TanhOp: HLO_UnaryElementwiseOp<"tanh",
|
||||
[ResultsAreFloatLike, NoSideEffect, SameOperandsAndResultType],
|
||||
HLO_FpOrComplexTensor>, BASE_HLO_TanhOp;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// XLA complex unary elementwise op definitions.
|
||||
//===----------------------------------------------------------------------===//
|
||||
// See https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions
|
||||
|
||||
def HLO_ComplexOp: HLO_Op<"complex",
|
||||
[NoSideEffect, SameOperandsElementType, SameOperandsAndResultShape]>,
|
||||
BASE_HLO_ComplexOp {
|
||||
let builders = [OpBuilder<
|
||||
"OpBuilder &, OperationState &tblgen_state, Value lhs, Value rhs">];
|
||||
|
||||
let arguments = (ins HLO_FpTensor:$lhs, HLO_FpTensor:$rhs);
|
||||
let results = (outs HLO_ComplexTensor);
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def HLO_ImagOp: HLO_Op<
|
||||
"imag", [NoSideEffect, SameOperandsAndResultShape]>, BASE_HLO_ImagOp {
|
||||
let builders = [OpBuilder<
|
||||
"OpBuilder &, OperationState &tblgen_state, Value val">];
|
||||
|
||||
let arguments = (ins HLO_ComplexTensor);
|
||||
let results = (outs HLO_FpTensor);
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def HLO_RealOp: HLO_Op<
|
||||
"real", [NoSideEffect, SameOperandsAndResultShape]>, BASE_HLO_RealOp {
|
||||
let builders = [OpBuilder<
|
||||
"OpBuilder &, OperationState &tblgen_state, Value val">];
|
||||
|
||||
let arguments = (ins HLO_ComplexTensor);
|
||||
let results = (outs HLO_FpTensor);
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// XLA binary elementwise op definitions.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// See https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations
|
||||
|
||||
class HLO_BinaryElementwiseOp<string mnemonic, list<OpTrait> traits> :
|
||||
HLO_Op<mnemonic, !listconcat(traits, [InferShapedTypeOpInterface])> {
|
||||
let arguments = (ins
|
||||
@ -293,6 +278,17 @@ def HLO_AddOp : HLO_BinaryElementwiseOp<"add",
|
||||
def HLO_Atan2Op : HLO_BinaryElementwiseOp<"atan2",
|
||||
[NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_Atan2Op;
|
||||
|
||||
def HLO_ComplexOp: HLO_Op<"complex",
|
||||
[NoSideEffect, SameOperandsElementType, SameOperandsAndResultShape]>,
|
||||
BASE_HLO_ComplexOp {
|
||||
let builders = [OpBuilder<
|
||||
"OpBuilder &, OperationState &tblgen_state, Value lhs, Value rhs">];
|
||||
|
||||
let arguments = (ins HLO_FpTensor:$lhs, HLO_FpTensor:$rhs);
|
||||
let results = (outs HLO_ComplexTensor);
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def HLO_DivOp : HLO_BinaryElementwiseOp<"divide",
|
||||
[NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_DivOp {
|
||||
}
|
||||
|
@ -150,15 +150,6 @@ class BASE_HLO_ClzOp {
|
||||
}];
|
||||
}
|
||||
|
||||
class BASE_HLO_ComplexOp {
|
||||
string summary = "Complex operator";
|
||||
|
||||
string description = [{
|
||||
Performs element-wise conversion of a pair of real and imaginary values to
|
||||
a complex value.
|
||||
}];
|
||||
}
|
||||
|
||||
class BASE_HLO_ConvertOp {
|
||||
string summary = "Convert operator";
|
||||
|
||||
@ -400,6 +391,15 @@ class BASE_HLO_AddOp {
|
||||
}];
|
||||
}
|
||||
|
||||
class BASE_HLO_ComplexOp {
|
||||
string summary = "Complex operator";
|
||||
|
||||
string description = [{
|
||||
Performs element-wise conversion of a pair of real and imaginary values to
|
||||
a complex value.
|
||||
}];
|
||||
}
|
||||
|
||||
class BASE_HLO_DivOp {
|
||||
string summary = "Division operator";
|
||||
|
||||
|
@ -92,10 +92,20 @@ def LHLO_CosOp: LHLO_UnaryElementwiseOp<"cosine">, BASE_HLO_CosOp;
|
||||
|
||||
def LHLO_ExpOp: LHLO_UnaryElementwiseOp<"exponential">, BASE_HLO_ExpOp;
|
||||
|
||||
def LHLO_ImagOp: LHLO_Op<"imag", [SameOperandsShape]>, BASE_HLO_ImagOp {
|
||||
let arguments = (ins Arg<LHLO_Buffer, "", [MemRead]>:$input,
|
||||
Arg<LHLO_Buffer, "", [MemWrite]>:$output);
|
||||
}
|
||||
|
||||
def LHLO_LogOp: LHLO_UnaryElementwiseOp<"log">, BASE_HLO_LogOp;
|
||||
|
||||
def LHLO_NegOp: LHLO_UnaryElementwiseOp<"negate">, BASE_HLO_NegOp;
|
||||
|
||||
def LHLO_RealOp: LHLO_Op<"real", [SameOperandsShape]>, BASE_HLO_RealOp {
|
||||
let arguments = (ins Arg<LHLO_Buffer, "", [MemRead]>:$input,
|
||||
Arg<LHLO_Buffer, "", [MemWrite]>:$output);
|
||||
}
|
||||
|
||||
def LHLO_RsqrtOp: LHLO_UnaryElementwiseOp<"rsqrt">, BASE_HLO_RsqrtOp;
|
||||
|
||||
def LHLO_SqrtOp: LHLO_UnaryElementwiseOp<"sqrt">, BASE_HLO_SqrtOp;
|
||||
@ -106,27 +116,6 @@ def LHLO_SinOp: LHLO_UnaryElementwiseOp<"sine">, BASE_HLO_SinOp;
|
||||
|
||||
def LHLO_TanhOp: LHLO_UnaryElementwiseOp<"tanh">, BASE_HLO_TanhOp;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// XLA complex unary elementwise op definitions.
|
||||
//===----------------------------------------------------------------------===//
|
||||
// See https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions
|
||||
|
||||
def LHLO_ComplexOp: LHLO_Op<"complex", [SameOperandsShape]>, BASE_HLO_ComplexOp {
|
||||
let arguments = (ins Arg<LHLO_Buffer, "", [MemRead]>:$lhs,
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
|
||||
Arg<LHLO_Buffer, "", [MemWrite]>:$output);
|
||||
}
|
||||
|
||||
def LHLO_ImagOp: LHLO_Op<"imag", [SameOperandsShape]>, BASE_HLO_ImagOp {
|
||||
let arguments = (ins Arg<LHLO_Buffer, "", [MemRead]>:$input,
|
||||
Arg<LHLO_Buffer, "", [MemWrite]>:$output);
|
||||
}
|
||||
|
||||
def LHLO_RealOp: LHLO_Op<"real", [SameOperandsShape]>, BASE_HLO_RealOp {
|
||||
let arguments = (ins Arg<LHLO_Buffer, "", [MemRead]>:$input,
|
||||
Arg<LHLO_Buffer, "", [MemWrite]>:$output);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// XLA binary elementwise op definitions.
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -144,6 +133,12 @@ class LHLO_BinaryElementwiseOp<string mnemonic, list<OpTrait> traits> :
|
||||
|
||||
def LHLO_AddOp : LHLO_BinaryElementwiseOp<"add", []>, BASE_HLO_AddOp;
|
||||
|
||||
def LHLO_ComplexOp: LHLO_Op<"complex", [SameOperandsShape]>, BASE_HLO_ComplexOp {
|
||||
let arguments = (ins Arg<LHLO_Buffer, "", [MemRead]>:$lhs,
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
|
||||
Arg<LHLO_Buffer, "", [MemWrite]>:$output);
|
||||
}
|
||||
|
||||
def LHLO_DivOp : LHLO_BinaryElementwiseOp<"divide", []>, BASE_HLO_DivOp;
|
||||
|
||||
def LHLO_MaxOp : LHLO_BinaryElementwiseOp<"maximum", []>, BASE_HLO_MaxOp;
|
||||
|
@ -2205,13 +2205,6 @@ func @sin_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @round
|
||||
func @round(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
// CHECK: "xla_hlo.round_nearest_afz"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
%0 = "tf.Round"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @rsqrt
|
||||
func @rsqrt(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
// CHECK: "xla_hlo.rsqrt"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
|
@ -523,6 +523,48 @@ func @tanh(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = tanh %[[OPERAND_IN]] : f32
|
||||
// CHECK-NEXT: linalg.yield %[[RESULT]] : f32
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @complex
|
||||
func @complex(%real: memref<2x2xf32>,
|
||||
%imag: memref<2x2xf32>,
|
||||
%cplx: memref<2x2xcomplex<f32>>) {
|
||||
"xla_lhlo.complex"(%real, %imag, %cplx)
|
||||
: (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xcomplex<f32>>) -> ()
|
||||
return
|
||||
}
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-NEXT: ^bb0(%[[RE:.*]]: f32, %[[IM:.*]]: f32, %[[CP:.*]]: complex<f32>):
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = create_complex %[[RE]], %[[IM]] : complex<f32>
|
||||
// CHECK-NEXT: linalg.yield %[[RESULT]] : complex<f32>
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @real
|
||||
func @real(%cplx: memref<2x2xcomplex<f32>>,
|
||||
%real: memref<2x2xf32>) {
|
||||
"xla_lhlo.real"(%cplx, %real)
|
||||
: (memref<2x2xcomplex<f32>>, memref<2x2xf32>) -> ()
|
||||
return
|
||||
}
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-NEXT: ^bb0(%[[CPLX_IN:.*]]: complex<f32>, %[[REAL_OUT:.*]]: f32):
|
||||
// CHECK-NEXT: %[[REAL:.*]] = re %[[CPLX_IN:.*]] : complex<f32>
|
||||
// CHECK-NEXT: linalg.yield %[[REAL]] : f32
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @imag
|
||||
func @imag(%cplx: memref<2x2xcomplex<f32>>,
|
||||
%imag: memref<2x2xf32>) {
|
||||
"xla_lhlo.imag"(%cplx, %imag)
|
||||
: (memref<2x2xcomplex<f32>>, memref<2x2xf32>) -> ()
|
||||
return
|
||||
}
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-NEXT: ^bb0(%[[CPLX_IN:.*]]: complex<f32>, %[[IMAG_OUT:.*]]: f32):
|
||||
// CHECK-NEXT: %[[IMAG:.*]] = im %[[CPLX_IN:.*]] : complex<f32>
|
||||
// CHECK-NEXT: linalg.yield %[[IMAG]] : f32
|
||||
|
||||
// -----
|
||||
|
||||
|
@ -543,7 +543,6 @@ foreach Mapping = [
|
||||
[TF_LogicalNotOp, HLO_NotOp],
|
||||
[TF_NegOp, HLO_NegOp],
|
||||
[TF_RealOp, HLO_RealOp],
|
||||
[TF_RoundOp, HLO_RoundOp],
|
||||
[TF_RsqrtOp, HLO_RsqrtOp],
|
||||
[TF_SinOp, HLO_SinOp],
|
||||
[TF_SqrtOp, HLO_SqrtOp],
|
||||
|
@ -83,31 +83,50 @@ static bool IsOpWhitelisted(Operation* op) {
|
||||
// clang-format off
|
||||
static llvm::SmallDenseSet<mlir::TypeID, 512> ops = {
|
||||
TypeID::get<TF::AbsOp>(),
|
||||
TypeID::get<TF::AcoshOp>(),
|
||||
TypeID::get<TF::AcosOp>(),
|
||||
TypeID::get<TF::AddNOp>(),
|
||||
TypeID::get<TF::AddV2Op>(),
|
||||
TypeID::get<TF::ApproximateEqualOp>(),
|
||||
TypeID::get<TF::AsinhOp>(),
|
||||
TypeID::get<TF::AsinOp>(),
|
||||
TypeID::get<TF::Atan2Op>(),
|
||||
TypeID::get<TF::AtanhOp>(),
|
||||
TypeID::get<TF::AtanOp>(),
|
||||
TypeID::get<TF::BatchMatMulV2Op>(),
|
||||
TypeID::get<TF::BiasAddOp>(),
|
||||
TypeID::get<TF::BiasAddGradOp>(),
|
||||
TypeID::get<TF::BiasAddOp>(),
|
||||
TypeID::get<TF::BitwiseAndOp>(),
|
||||
TypeID::get<TF::BitwiseOrOp>(),
|
||||
TypeID::get<TF::BitwiseXorOp>(),
|
||||
TypeID::get<TF::CastOp>(),
|
||||
TypeID::get<TF::ClipByValueOp>(),
|
||||
TypeID::get<TF::ComplexAbsOp>(),
|
||||
TypeID::get<TF::CoshOp>(),
|
||||
TypeID::get<TF::CrossOp>(),
|
||||
TypeID::get<TF::DataFormatDimMapOp>(),
|
||||
TypeID::get<TF::DataFormatVecPermuteOp>(),
|
||||
TypeID::get<TF::DigammaOp>(),
|
||||
TypeID::get<TF::DivNoNanOp>(),
|
||||
TypeID::get<TF::EluGradOp>(),
|
||||
TypeID::get<TF::EluOp>(),
|
||||
TypeID::get<TF::EqualOp>(),
|
||||
TypeID::get<TF::ErfcOp>(),
|
||||
TypeID::get<TF::ErfOp>(),
|
||||
TypeID::get<TF::Expm1Op>(),
|
||||
TypeID::get<TF::FloorDivOp>(),
|
||||
TypeID::get<TF::FloorModOp>(),
|
||||
TypeID::get<TF::GreaterOp>(),
|
||||
TypeID::get<TF::GreaterEqualOp>(),
|
||||
TypeID::get<TF::GatherNdOp>(),
|
||||
TypeID::get<TF::InvOp>(),
|
||||
TypeID::get<TF::GreaterEqualOp>(),
|
||||
TypeID::get<TF::GreaterOp>(),
|
||||
TypeID::get<TF::InvertOp>(),
|
||||
TypeID::get<TF::InvOp>(),
|
||||
TypeID::get<TF::LeakyReluGradOp>(),
|
||||
TypeID::get<TF::LeakyReluOp>(),
|
||||
TypeID::get<TF::LeftShiftOp>(),
|
||||
TypeID::get<TF::LessOp>(),
|
||||
TypeID::get<TF::LessEqualOp>(),
|
||||
TypeID::get<TF::LessOp>(),
|
||||
TypeID::get<TF::LgammaOp>(),
|
||||
TypeID::get<TF::LogicalAndOp>(),
|
||||
TypeID::get<TF::LogicalNotOp>(),
|
||||
TypeID::get<TF::LogicalOrOp>(),
|
||||
@ -119,18 +138,34 @@ static bool IsOpWhitelisted(Operation* op) {
|
||||
TypeID::get<TF::PlaceholderWithDefaultOp>(),
|
||||
TypeID::get<TF::PowOp>(),
|
||||
TypeID::get<TF::RealDivOp>(),
|
||||
TypeID::get<TF::ReciprocalGradOp>(),
|
||||
TypeID::get<TF::Relu6GradOp>(),
|
||||
TypeID::get<TF::RightShiftOp>(),
|
||||
TypeID::get<TF::SinOp>(),
|
||||
TypeID::get<TF::RintOp>(),
|
||||
TypeID::get<TF::RoundOp>(),
|
||||
TypeID::get<TF::SelectV2Op>(),
|
||||
TypeID::get<TF::SubOp>(),
|
||||
TypeID::get<TF::SeluGradOp>(),
|
||||
TypeID::get<TF::SeluOp>(),
|
||||
TypeID::get<TF::SigmoidGradOp>(),
|
||||
TypeID::get<TF::SinhOp>(),
|
||||
TypeID::get<TF::SinOp>(),
|
||||
TypeID::get<TF::SoftplusGradOp>(),
|
||||
TypeID::get<TF::SoftsignGradOp>(),
|
||||
TypeID::get<TF::SoftsignOp>(),
|
||||
TypeID::get<TF::SqrtGradOp>(),
|
||||
TypeID::get<TF::SquareOp>(),
|
||||
TypeID::get<TF::SubOp>(),
|
||||
TypeID::get<TF::TanOp>(),
|
||||
TypeID::get<TF::TransposeOp>(),
|
||||
TypeID::get<TF::TruncateDivOp>(),
|
||||
TypeID::get<TF::TruncateModOp>(),
|
||||
TypeID::get<TF::TruncatedNormalOp>(),
|
||||
TypeID::get<TF::TruncateModOp>(),
|
||||
TypeID::get<TF::UnpackOp>(),
|
||||
TypeID::get<TF::XdivyOp>(),
|
||||
TypeID::get<TF::XlaDotOp>(),
|
||||
TypeID::get<TF::XlaPadOp>()
|
||||
TypeID::get<TF::XlaPadOp>(),
|
||||
TypeID::get<TF::Xlog1pyOp>(),
|
||||
TypeID::get<TF::XlogyOp>()
|
||||
};
|
||||
// clang-format on
|
||||
|
||||
|
@ -227,6 +227,28 @@ inline Value MapLhloOpToStdScalarOp<xla_lhlo::CeilOp>(
|
||||
loc, result_types, args, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<xla_lhlo::ComplexOp>(
|
||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapLhloOpToStdScalarOpImpl<CreateComplexOp>{}(loc, result_types, args,
|
||||
b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<xla_lhlo::RealOp>(
|
||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapLhloOpToStdScalarOpImpl<ReOp>{}(loc, result_types, args, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<xla_lhlo::ImagOp>(
|
||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapLhloOpToStdScalarOpImpl<ImOp>{}(loc, result_types, args, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<xla_lhlo::ConvertOp>(
|
||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
|
@ -84,7 +84,8 @@ class PointwiseToLinalgConverter : public OpConversionPattern<OpTy> {
|
||||
emitError(loc, "lhlo to linalg conversion expects ranked args");
|
||||
return failure();
|
||||
}
|
||||
if (!argType.getElementType().isSignlessIntOrFloat()) {
|
||||
auto elemTy = argType.getElementType();
|
||||
if (!elemTy.isSignlessIntOrFloat() && !elemTy.template isa<ComplexType>()) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
@ -618,17 +619,20 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
|
||||
PointwiseToLinalgConverter<xla_lhlo::AndOp>,
|
||||
PointwiseToLinalgConverter<xla_lhlo::CeilOp>,
|
||||
PointwiseToLinalgConverter<xla_lhlo::CompareOp>,
|
||||
PointwiseToLinalgConverter<xla_lhlo::ComplexOp>,
|
||||
PointwiseToLinalgConverter<xla_lhlo::ConvertOp>,
|
||||
// TODO(ataei): Remove this pattern, CopyOp is folded away.
|
||||
PointwiseToLinalgConverter<xla_lhlo::CopyOp>,
|
||||
PointwiseToLinalgConverter<xla_lhlo::CosOp>,
|
||||
PointwiseToLinalgConverter<xla_lhlo::DivOp>,
|
||||
PointwiseToLinalgConverter<xla_lhlo::ExpOp>,
|
||||
PointwiseToLinalgConverter<xla_lhlo::ImagOp>,
|
||||
PointwiseToLinalgConverter<xla_lhlo::LogOp>,
|
||||
PointwiseToLinalgConverter<xla_lhlo::MaxOp>,
|
||||
PointwiseToLinalgConverter<xla_lhlo::MinOp>,
|
||||
PointwiseToLinalgConverter<xla_lhlo::MulOp>,
|
||||
PointwiseToLinalgConverter<xla_lhlo::NegOp>,
|
||||
PointwiseToLinalgConverter<xla_lhlo::RealOp>,
|
||||
PointwiseToLinalgConverter<xla_lhlo::RemOp>,
|
||||
PointwiseToLinalgConverter<xla_lhlo::RsqrtOp>,
|
||||
PointwiseToLinalgConverter<xla_lhlo::SelectOp>,
|
||||
@ -716,16 +720,19 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
|
||||
PointwiseToLinalgConverter<xla_hlo::AndOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::CeilOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::CompareOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::ComplexOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::ConvertOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::CopyOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::CosOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::DivOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::ExpOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::ImagOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::LogOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::MaxOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::MinOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::MulOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::NegOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::RealOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::RemOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::RsqrtOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::SelectOp, false>,
|
||||
|
@ -470,6 +470,7 @@ tf_xla_py_test(
|
||||
name = "concat_ops_test",
|
||||
size = "medium",
|
||||
srcs = ["concat_ops_test.py"],
|
||||
enable_mlir_bridge = True,
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"many_xla_args",
|
||||
@ -1342,6 +1343,7 @@ tf_xla_py_test(
|
||||
name = "ternary_ops_test",
|
||||
size = "medium",
|
||||
srcs = ["ternary_ops_test.py"],
|
||||
enable_mlir_bridge = True,
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
|
@ -73,8 +73,6 @@ class BinaryOpsTest(xla_test.XLATestCase):
|
||||
self.assertAllCloseAccordingToType(
|
||||
result[i], expected[i], rtol=rtol, atol=atol)
|
||||
|
||||
@test_util.disable_mlir_bridge(
|
||||
"F16 type is not supported in CreateDenseElementsAttrFromLiteral")
|
||||
def testFloatOps(self):
|
||||
for dtype in self.float_types:
|
||||
if dtype == dtypes.bfloat16.as_numpy_dtype:
|
||||
@ -1513,7 +1511,6 @@ class BinaryOpsTest(xla_test.XLATestCase):
|
||||
np.array([1, 0], dtype=np.int32),
|
||||
expected=np.array([[1 + 1j, 3 + 3j], [2 - 2j, 4 - 4j]], dtype=dtype))
|
||||
|
||||
@test_util.disable_mlir_bridge("Enable tf.Cross Compilation")
|
||||
def testCross(self):
|
||||
for dtype in self.float_types:
|
||||
self._testBinary(
|
||||
|
@ -23,6 +23,7 @@ import numpy as np
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_array_ops
|
||||
from tensorflow.python.ops import gradients_impl
|
||||
@ -293,6 +294,7 @@ class ConcatTest(xla_test.XLATestCase):
|
||||
|
||||
# The purpose of this is to ensure that XLA on GPU will not run out of memory
|
||||
# with too many arguments.
|
||||
@test_util.disable_mlir_bridge("TODO(b/153895138): Debug.")
|
||||
def testConcatLargeNumberOfTensors(self):
|
||||
if "CPU" in self.device:
|
||||
self.skipTest("This test can time out on CPU, so we will just allow "
|
||||
|
@ -24,6 +24,7 @@ import scipy.special as sps
|
||||
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_math_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
@ -47,6 +48,7 @@ class TernaryOpsTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
{'start': 1, 'end': 2, 'num': 1},
|
||||
{'start': 1, 'end': 4, 'num': 3},
|
||||
{'start': 0, 'end': 41, 'num': 42})
|
||||
@test_util.disable_mlir_bridge('Requires dynamic shape handling')
|
||||
def testLinspace(self, start, end, num):
|
||||
expected = np.linspace(start, end, num, dtype=np.float32)
|
||||
result = self._testTernary(
|
||||
@ -74,6 +76,7 @@ class TernaryOpsTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
np.int32(2),
|
||||
expected=np.array([1, 3, 5], dtype=np.int32))
|
||||
|
||||
@test_util.disable_mlir_bridge('TODO(b/155949336)')
|
||||
def testSelect(self):
|
||||
for dtype in self.numeric_types:
|
||||
self._testTernary(
|
||||
@ -179,6 +182,7 @@ class TernaryOpsTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
np.array([8, 9], dtype=dtype),
|
||||
expected=np.array([[7, 9], [8, 7], [8, 9]], dtype=dtype))
|
||||
|
||||
@test_util.disable_mlir_bridge('TODO(b/155097273)')
|
||||
def testSlice(self):
|
||||
for dtype in self.numeric_types:
|
||||
self._testTernary(
|
||||
@ -211,6 +215,7 @@ class TernaryOpsTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
upper,
|
||||
expected=np.minimum(np.maximum(x, lower), upper))
|
||||
|
||||
@test_util.disable_mlir_bridge('Enable tf.Betainc Compilation')
|
||||
def testBetaincSanity(self):
|
||||
# This operation is only supported for float32 and float64.
|
||||
for dtype in self.numeric_types & {np.float32, np.float64}:
|
||||
@ -248,6 +253,7 @@ class TernaryOpsTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
'atol': 2e-4
|
||||
},
|
||||
)
|
||||
@test_util.disable_mlir_bridge('Enable tf.Betainc Compilation')
|
||||
def testBetainc(self, sigma, rtol, atol):
|
||||
# This operation is only supported for float32 and float64.
|
||||
for dtype in self.numeric_types & {np.float32, np.float64}:
|
||||
|
@ -186,8 +186,6 @@ class UnaryOpsTest(xla_test.XLATestCase):
|
||||
self._assertOpOutputMatchesExpected(
|
||||
math_ops.cos, x, expected=np.cos(x), rtol=tol, atol=1e-5)
|
||||
|
||||
@test_util.disable_mlir_bridge(
|
||||
"TODO(b/153812660): Handle tf.Softmax compilation")
|
||||
def testFloatOps(self):
|
||||
for dtype in self.float_types:
|
||||
x = np.arange(-0.90, 0.90, 0.25)
|
||||
@ -514,6 +512,11 @@ class UnaryOpsTest(xla_test.XLATestCase):
|
||||
],
|
||||
dtype=dtype))
|
||||
|
||||
@test_util.disable_mlir_bridge(
|
||||
"TODO(b/153812660): Handle tf.QuantizeAndDequantize compilation")
|
||||
def testQuantizeAndDequantize(self):
|
||||
for dtype in self.float_types:
|
||||
|
||||
def quantize_and_dequantize_v2(x):
|
||||
return array_ops.quantize_and_dequantize_v2(
|
||||
x, -127, 127, signed_input=True, num_bits=8)
|
||||
|
@ -72,7 +72,6 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
np.array([7, 11], dtype=dtype)),
|
||||
expected=np.array([[8, 13], [10, 15]], dtype=dtype))
|
||||
|
||||
@test_util.disable_mlir_bridge('Not supported yet')
|
||||
def testBroadcast(self):
|
||||
for dtype in self.numeric_types:
|
||||
v = np.arange(4, dtype=np.int32).astype(dtype).reshape([2, 2])
|
||||
|
@ -17,7 +17,6 @@ package_group(
|
||||
"//tensorflow/compiler/...",
|
||||
"//tensorflow/python/tpu/...",
|
||||
"//third_party/py/jax/...",
|
||||
"//third_party/tf_runtime/tools/tf_kernel_gen/...",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -36,6 +36,7 @@ bool CanBeLoopFused(const HloInstruction& hlo) {
|
||||
hlo.opcode() == HloOpcode::kGather ||
|
||||
hlo.opcode() == HloOpcode::kIota || hlo.opcode() == HloOpcode::kPad ||
|
||||
hlo.opcode() == HloOpcode::kReduce ||
|
||||
hlo.opcode() == HloOpcode::kReduceWindow ||
|
||||
hlo.opcode() == HloOpcode::kReshape ||
|
||||
hlo.opcode() == HloOpcode::kReverse ||
|
||||
hlo.opcode() == HloOpcode::kSlice ||
|
||||
|
@ -945,6 +945,34 @@ ENTRY main {
|
||||
EXPECT_TRUE(fused_something);
|
||||
EXPECT_THAT(module->entry_computation()->root_instruction(), op::Fusion());
|
||||
}
|
||||
|
||||
TEST_F(InstructionFusionTest, FuseReduceWindow) {
|
||||
absl::string_view module_string = R"(
|
||||
HloModule module
|
||||
|
||||
add {
|
||||
lhs = f32[] parameter(0)
|
||||
rhs = f32[] parameter(1)
|
||||
ROOT add = f32[] add(lhs, rhs)
|
||||
}
|
||||
|
||||
ENTRY main {
|
||||
a = f32[50,60]{1,0} parameter(0)
|
||||
b = f32[50,60]{1,0} parameter(1)
|
||||
c = f32[50,60]{1,0} multiply(a, b)
|
||||
init = f32[] constant(0)
|
||||
ROOT r = f32[50,60] reduce-window(c, init), window={size=2x3 pad=0_1x1_1},
|
||||
to_apply=add
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(module_string));
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool fused_something,
|
||||
CpuInstructionFusion().Run(module.get()));
|
||||
EXPECT_TRUE(fused_something);
|
||||
EXPECT_THAT(module->entry_computation()->root_instruction(), op::Fusion());
|
||||
}
|
||||
} // namespace
|
||||
} // namespace cpu
|
||||
} // namespace xla
|
||||
|
@ -125,7 +125,8 @@ llvm_ir::ElementGenerator CpuElementalIrEmitter::MakeElementGenerator(
|
||||
return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
|
||||
return ir_emitter_->EmitElementalReduceWindow(
|
||||
Cast<HloReduceWindowInstruction>(hlo),
|
||||
operand_to_generator.at(hlo->operand(0)), index);
|
||||
operand_to_generator.at(hlo->operand(0)),
|
||||
operand_to_generator.at(hlo->operand(1)), index);
|
||||
};
|
||||
case HloOpcode::kConvolution:
|
||||
return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
|
||||
|
@ -705,6 +705,7 @@ llvm::Value* IrEmitter::EmitElementalMap(
|
||||
StatusOr<llvm::Value*> IrEmitter::EmitElementalReduceWindow(
|
||||
const HloReduceWindowInstruction* reduce_window,
|
||||
const llvm_ir::ElementGenerator& input_generator,
|
||||
const llvm_ir::ElementGenerator& initial_value_generator,
|
||||
const llvm_ir::IrArray::Index& index) {
|
||||
const HloInstruction* operand = reduce_window->operand(0);
|
||||
const Window& window = reduce_window->window();
|
||||
@ -716,8 +717,10 @@ StatusOr<llvm::Value*> IrEmitter::EmitElementalReduceWindow(
|
||||
llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_),
|
||||
"reduce_window_accumulator_address", &b_,
|
||||
MinimumAlignmentForPrimitiveType(operand_element_type));
|
||||
Store(Load(GetEmittedValueFor(reduce_window->operand(1))),
|
||||
accumulator_address);
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
llvm::Value* const initial_value,
|
||||
initial_value_generator(llvm_ir::IrArray::Index(index.GetType())));
|
||||
Store(initial_value, accumulator_address);
|
||||
|
||||
llvm_ir::ForLoopNest loops(IrName(reduce_window, "inner"), &b_);
|
||||
std::vector<int64> window_size;
|
||||
|
@ -122,6 +122,7 @@ class IrEmitter : public DfsHloVisitorWithDefault,
|
||||
StatusOr<llvm::Value*> EmitElementalReduceWindow(
|
||||
const HloReduceWindowInstruction* reduce_window,
|
||||
const llvm_ir::ElementGenerator& input_generator,
|
||||
const llvm_ir::ElementGenerator& initial_value_generator,
|
||||
const llvm_ir::IrArray::Index& index);
|
||||
// Emit code to emit the element at `index` for a convolution instruction.
|
||||
StatusOr<llvm::Value*> EmitElementalConvolution(
|
||||
|
@ -53,12 +53,18 @@ ENTRY main {
|
||||
|
||||
MatchOptimizedHlo(hlo_text,
|
||||
R"(
|
||||
; CHECK-LABEL: %fused_computation (param_0.1: f32[32]) -> f32[] {
|
||||
; CHECK-NEXT: %param_0.1 = f32[32]{0} parameter(0)
|
||||
; CHECK-NEXT: %zero.1 = f32[] constant(0)
|
||||
; CHECK-NEXT: %reduce-window.2 = f32[1]{0} reduce-window(%param_0.1, %zero.1), window={size=32 stride=32}, to_apply=%add
|
||||
; CHECK-NEXT: ROOT %reshape.1 = f32[] reshape(%reduce-window.2)
|
||||
; CHECK-NEXT: }
|
||||
|
||||
; CHECK-LABEL: ENTRY %main (input: f32[1000]) -> f32[] {
|
||||
; CHECK-NEXT: %input = f32[1000]{0} parameter(0)
|
||||
; CHECK-NEXT: %zero = f32[] constant(0)
|
||||
; CHECK-NEXT: %reduce-window = f32[32]{0} reduce-window(%input, %zero)
|
||||
; CHECK-NEXT: %reduce-window.1 = f32[1]{0} reduce-window(%reduce-window, %zero), window={size=32 stride=32}, to_apply=%add
|
||||
; CHECK-NEXT: ROOT %bitcast = f32[] bitcast(%reduce-window.1)
|
||||
; CHECK-NEXT: ROOT %fusion = f32[] fusion(%reduce-window), kind=kLoop, calls=%fused_computation
|
||||
)");
|
||||
}
|
||||
|
||||
|
@ -300,7 +300,7 @@ WhileLoopInvariantCodeMotion::TryHoistingInvariantInstructionsFromWhileBody(
|
||||
}
|
||||
|
||||
StatusOr<bool> WhileLoopInvariantCodeMotion::Run(HloModule* module) {
|
||||
VLOG(2) << "HLO module before WhileLoopConstantSinking:";
|
||||
VLOG(2) << "HLO module before WhileLoopInvariantCodeMotion:";
|
||||
XLA_VLOG_LINES(2, module->ToString());
|
||||
|
||||
bool changed = false;
|
||||
@ -332,10 +332,10 @@ StatusOr<bool> WhileLoopInvariantCodeMotion::Run(HloModule* module) {
|
||||
}
|
||||
|
||||
if (changed) {
|
||||
VLOG(2) << "HLO module after WhileLoopConstantSinking:";
|
||||
VLOG(2) << "HLO module after WhileLoopInvariantCodeMotion:";
|
||||
XLA_VLOG_LINES(2, module->ToString());
|
||||
} else {
|
||||
VLOG(2) << "HLO module unchanged after WhileLoopConstantSinking";
|
||||
VLOG(2) << "HLO module unchanged after WhileLoopInvariantCodeMotion";
|
||||
}
|
||||
|
||||
return changed;
|
||||
|
@ -1212,6 +1212,7 @@ cc_library(
|
||||
":propagator_debug_utils",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/platform:hash",
|
||||
"//tensorflow/core/profiler/lib:traceme",
|
||||
],
|
||||
)
|
||||
|
@ -549,7 +549,8 @@ BENCHMARK(BM_FeedInputFetchOutput);
|
||||
//
|
||||
// ...using the functional `WhileOp` (if `lower` is false) or the
|
||||
// `Switch`/`Merge`-style of control flow (if `lower` is true).
|
||||
static void BM_WhileLoopHelper(int iters, int loop_iters, bool lower) {
|
||||
static void BM_WhileLoopHelper(int iters, int loop_iters, int loop_vars,
|
||||
bool lower) {
|
||||
testing::StopTiming();
|
||||
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||
|
||||
@ -558,20 +559,44 @@ static void BM_WhileLoopHelper(int iters, int loop_iters, bool lower) {
|
||||
|
||||
// Define the loop body as a function: `x = x + 1`.
|
||||
const Tensor one_t = test::AsScalar<int32>(1);
|
||||
|
||||
std::vector<string> args;
|
||||
args.reserve(loop_vars);
|
||||
args.push_back("x: int32");
|
||||
for (int i = 1; i < loop_vars; ++i) {
|
||||
args.push_back(strings::StrCat("x", i, ": int32"));
|
||||
}
|
||||
|
||||
std::vector<string> body_rets;
|
||||
body_rets.reserve(loop_vars);
|
||||
body_rets.push_back("y: int32");
|
||||
for (int i = 1; i < loop_vars; ++i) {
|
||||
body_rets.push_back(strings::StrCat("y", i, ": int32"));
|
||||
}
|
||||
|
||||
std::vector<FunctionDefHelper::Node> body_nodes;
|
||||
body_nodes.reserve(1 + loop_vars);
|
||||
body_nodes.push_back(
|
||||
{{"one"}, "Const", {}, {{"value", one_t}, {"dtype", DT_INT32}}});
|
||||
body_nodes.push_back({{"y"}, "Add", {"x", "one"}, {{"T", DT_INT32}}});
|
||||
for (int i = 1; i < loop_vars; ++i) {
|
||||
body_nodes.push_back({{strings::StrCat("y", i)},
|
||||
"Identity",
|
||||
{strings::StrCat("x", i)},
|
||||
{{"T", DT_INT32}}});
|
||||
}
|
||||
|
||||
*f_lib_proto.add_function() = FunctionDefHelper::Define(
|
||||
// Name
|
||||
"XPlusOne",
|
||||
// Args
|
||||
{"x: int32"},
|
||||
args,
|
||||
// Return values
|
||||
{"y: int32"},
|
||||
body_rets,
|
||||
// Attr def
|
||||
{},
|
||||
// Nodes
|
||||
{
|
||||
{{"one"}, "Const", {}, {{"value", one_t}, {"dtype", DT_INT32}}},
|
||||
{{"y"}, "Add", {"x", "one"}, {{"T", DT_INT32}}},
|
||||
});
|
||||
body_nodes);
|
||||
|
||||
// Define the loop condition as a function: `x < loop_iters`.
|
||||
const Tensor loop_iters_t = test::AsScalar<int32>(loop_iters);
|
||||
@ -579,7 +604,7 @@ static void BM_WhileLoopHelper(int iters, int loop_iters, bool lower) {
|
||||
// Name
|
||||
"LessThanOrEqualToN",
|
||||
// Args
|
||||
{"x: int32"},
|
||||
args,
|
||||
// Return values
|
||||
{"z: bool"},
|
||||
// Attr def
|
||||
@ -594,7 +619,12 @@ static void BM_WhileLoopHelper(int iters, int loop_iters, bool lower) {
|
||||
TF_ASSERT_OK(root.graph()->AddFunctionLibrary(f_lib_proto));
|
||||
auto a = ops::Const(root.WithOpName("A"), 0, {});
|
||||
Node* while_node;
|
||||
std::vector<NodeBuilder::NodeOut> inputs({NodeBuilder::NodeOut(a.node())});
|
||||
std::vector<NodeBuilder::NodeOut> inputs;
|
||||
std::vector<DataType> input_types(loop_vars, DT_INT32);
|
||||
inputs.reserve(loop_vars);
|
||||
for (int i = 0; i < loop_vars; ++i) {
|
||||
inputs.push_back(NodeBuilder::NodeOut(a.node()));
|
||||
}
|
||||
AttrValue int32_attr;
|
||||
int32_attr.set_type(DT_INT32);
|
||||
AttrValue cond_func;
|
||||
@ -604,7 +634,7 @@ static void BM_WhileLoopHelper(int iters, int loop_iters, bool lower) {
|
||||
TF_ASSERT_OK(
|
||||
NodeBuilder("while", "While", &root.graph()->flib_def())
|
||||
.Input(inputs)
|
||||
.Attr("T", {DT_INT32})
|
||||
.Attr("T", input_types)
|
||||
.Attr("cond", cond_func)
|
||||
.Attr("body", body_func)
|
||||
.Attr("parallel_iterations", 100)
|
||||
@ -635,21 +665,33 @@ static void BM_WhileLoopHelper(int iters, int loop_iters, bool lower) {
|
||||
test::Benchmark("cpu", graph.release()).Run(iters);
|
||||
}
|
||||
|
||||
static void BM_LoweredWhileLoop(int iters, int loop_iters) {
|
||||
BM_WhileLoopHelper(iters, loop_iters, /* lower= */ true);
|
||||
static void BM_LoweredWhileLoop(int iters, int loop_iters, int loop_vars) {
|
||||
BM_WhileLoopHelper(iters, loop_iters, loop_vars, /* lower= */ true);
|
||||
}
|
||||
BENCHMARK(BM_LoweredWhileLoop)->Arg(0);
|
||||
BENCHMARK(BM_LoweredWhileLoop)->Arg(1);
|
||||
BENCHMARK(BM_LoweredWhileLoop)->Arg(10);
|
||||
BENCHMARK(BM_LoweredWhileLoop)->Arg(100);
|
||||
BENCHMARK(BM_LoweredWhileLoop)->Arg(1000);
|
||||
BENCHMARK(BM_LoweredWhileLoop)
|
||||
->ArgPair(0, 1)
|
||||
->ArgPair(1, 1)
|
||||
->ArgPair(10, 1)
|
||||
->ArgPair(100, 1)
|
||||
->ArgPair(1000, 1)
|
||||
->ArgPair(0, 100)
|
||||
->ArgPair(1, 100)
|
||||
->ArgPair(10, 100)
|
||||
->ArgPair(100, 100)
|
||||
->ArgPair(1000, 100);
|
||||
|
||||
static void BM_FunctionalWhileLoop(int iters, int loop_iters) {
|
||||
BM_WhileLoopHelper(iters, loop_iters, /* lower= */ false);
|
||||
static void BM_FunctionalWhileLoop(int iters, int loop_iters, int loop_vars) {
|
||||
BM_WhileLoopHelper(iters, loop_iters, loop_vars, /* lower= */ false);
|
||||
}
|
||||
BENCHMARK(BM_FunctionalWhileLoop)->Arg(0);
|
||||
BENCHMARK(BM_FunctionalWhileLoop)->Arg(1);
|
||||
BENCHMARK(BM_FunctionalWhileLoop)->Arg(10);
|
||||
BENCHMARK(BM_FunctionalWhileLoop)->Arg(100);
|
||||
BENCHMARK(BM_FunctionalWhileLoop)->Arg(1000);
|
||||
BENCHMARK(BM_FunctionalWhileLoop)
|
||||
->ArgPair(0, 1)
|
||||
->ArgPair(1, 1)
|
||||
->ArgPair(10, 1)
|
||||
->ArgPair(100, 1)
|
||||
->ArgPair(1000, 1)
|
||||
->ArgPair(0, 100)
|
||||
->ArgPair(1, 100)
|
||||
->ArgPair(10, 100)
|
||||
->ArgPair(100, 100)
|
||||
->ArgPair(1000, 100);
|
||||
} // namespace tensorflow
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/graph/edgeset.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/graph_node_util.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -39,9 +40,6 @@ ImmutableExecutorState::~ImmutableExecutorState() {
|
||||
params_.delete_kernel(item->kernel);
|
||||
}
|
||||
}
|
||||
for (auto fiter : frame_info_) {
|
||||
delete fiter.second;
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
@ -71,11 +69,16 @@ void GetMaxPendingCounts(const Node* n, size_t* max_pending,
|
||||
|
||||
ImmutableExecutorState::FrameInfo* ImmutableExecutorState::EnsureFrameInfo(
|
||||
const string& fname) {
|
||||
auto slot = &frame_info_[fname];
|
||||
if (*slot == nullptr) {
|
||||
*slot = new FrameInfo;
|
||||
auto iter = frame_info_.find(fname);
|
||||
if (iter != frame_info_.end()) {
|
||||
return iter->second.get();
|
||||
} else {
|
||||
auto frame_info = absl::make_unique<FrameInfo>(fname);
|
||||
absl::string_view fname_view = frame_info->name;
|
||||
auto emplace_result =
|
||||
frame_info_.emplace(fname_view, std::move(frame_info));
|
||||
return emplace_result.first->second.get();
|
||||
}
|
||||
return *slot;
|
||||
}
|
||||
|
||||
Status ImmutableExecutorState::Initialize(const Graph& graph) {
|
||||
@ -89,7 +92,7 @@ Status ImmutableExecutorState::Initialize(const Graph& graph) {
|
||||
EnsureFrameInfo(it)->nodes =
|
||||
absl::make_unique<std::vector<const NodeItem*>>();
|
||||
}
|
||||
root_frame_info_ = frame_info_[""];
|
||||
root_frame_info_ = frame_info_[""].get();
|
||||
|
||||
pending_ids_.resize(gview_.num_nodes());
|
||||
|
||||
@ -157,6 +160,28 @@ Status ImmutableExecutorState::Initialize(const Graph& graph) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetNodeAttr(n->attrs(), "is_constant", &is_constant_enter));
|
||||
item->is_constant_enter = is_constant_enter;
|
||||
|
||||
string frame_name;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "frame_name", &frame_name));
|
||||
FrameInfo* frame_info = frame_info_[frame_name].get();
|
||||
|
||||
int parallel_iterations;
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetNodeAttr(n->attrs(), "parallel_iterations", ¶llel_iterations));
|
||||
|
||||
if (frame_info->parallel_iterations == -1) {
|
||||
frame_info->parallel_iterations = parallel_iterations;
|
||||
} else if (frame_info->parallel_iterations != parallel_iterations) {
|
||||
LOG(WARNING) << "Loop frame \"" << frame_name
|
||||
<< "\" had two different values for parallel_iterations: "
|
||||
<< frame_info->parallel_iterations << " vs. "
|
||||
<< parallel_iterations << ".";
|
||||
}
|
||||
|
||||
if (enter_frame_info_.size() <= id) {
|
||||
enter_frame_info_.resize(id + 1);
|
||||
}
|
||||
enter_frame_info_[id] = frame_info;
|
||||
} else {
|
||||
item->is_constant_enter = false;
|
||||
}
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "tensorflow/core/common_runtime/graph_view.h"
|
||||
#include "tensorflow/core/common_runtime/local_executor_params.h"
|
||||
#include "tensorflow/core/common_runtime/pending_counts.h"
|
||||
@ -41,11 +42,16 @@ class Graph;
|
||||
class ImmutableExecutorState {
|
||||
public:
|
||||
struct FrameInfo {
|
||||
FrameInfo()
|
||||
: input_count(0),
|
||||
explicit FrameInfo(string name)
|
||||
: name(std::move(name)),
|
||||
input_count(0),
|
||||
total_inputs(0),
|
||||
pending_counts(nullptr),
|
||||
nodes(nullptr) {}
|
||||
nodes(nullptr),
|
||||
parallel_iterations(-1) {}
|
||||
|
||||
// The name of the frame.
|
||||
string name;
|
||||
|
||||
// The total number of inputs to a frame.
|
||||
int input_count;
|
||||
@ -63,6 +69,9 @@ class ImmutableExecutorState {
|
||||
|
||||
// The nodes in a frame. Used only for debugging.
|
||||
std::unique_ptr<std::vector<const NodeItem*>> nodes;
|
||||
|
||||
// The number of iterations of this frame that can execute concurrently.
|
||||
int32 parallel_iterations;
|
||||
};
|
||||
|
||||
explicit ImmutableExecutorState(const LocalExecutorParams& p)
|
||||
@ -83,17 +92,13 @@ class ImmutableExecutorState {
|
||||
}
|
||||
const std::vector<const NodeItem*>& root_nodes() const { return root_nodes_; }
|
||||
|
||||
const FrameInfo* get_frame_info(const string& frame_name) const {
|
||||
auto it_frame_info = frame_info_.find(frame_name);
|
||||
if (it_frame_info == frame_info_.end()) {
|
||||
return nullptr;
|
||||
} else {
|
||||
return it_frame_info->second;
|
||||
}
|
||||
}
|
||||
|
||||
const FrameInfo& get_root_frame_info() const { return *root_frame_info_; }
|
||||
|
||||
const FrameInfo& get_enter_frame_info(const NodeItem& node_item) const {
|
||||
DCHECK(node_item.is_enter);
|
||||
return *enter_frame_info_[node_item.node_id];
|
||||
}
|
||||
|
||||
bool requires_control_flow_support() const { return requires_control_flow_; }
|
||||
|
||||
// Copies the pending counts for nodes in this graph to the given array.
|
||||
@ -135,9 +140,14 @@ class ImmutableExecutorState {
|
||||
// Mapping from frame name to static information about the frame.
|
||||
// TODO(yuanbyu): We could cache it along with the graph so to avoid
|
||||
// the overhead of constructing it for each executor instance.
|
||||
gtl::FlatMap<string, FrameInfo*> frame_info_;
|
||||
absl::flat_hash_map<absl::string_view, std::unique_ptr<FrameInfo>>
|
||||
frame_info_;
|
||||
const FrameInfo* root_frame_info_; // Not owned.
|
||||
|
||||
// If the graph contains any "Enter" or "RefEnter" nodes, this vector maps
|
||||
// dense node IDs to the corresponding FrameInfo.
|
||||
std::vector<FrameInfo*> enter_frame_info_;
|
||||
|
||||
// If `requires_control_flow_` is false, this points to an array of initial
|
||||
// pending counts for the nodes in the graph, indexed by node ID.
|
||||
std::unique_ptr<std::atomic<int32>[]> atomic_pending_counts_;
|
||||
|
@ -16,9 +16,11 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/propagator_state.h"
|
||||
|
||||
#include "tensorflow/core/common_runtime/graph_view.h"
|
||||
#include "tensorflow/core/common_runtime/immutable_executor_state.h"
|
||||
#include "tensorflow/core/common_runtime/propagator_debug_utils.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/lib/hash/hash.h"
|
||||
#include "tensorflow/core/platform/hash.h"
|
||||
#include "tensorflow/core/profiler/lib/traceme.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -33,14 +35,14 @@ PropagatorState::PropagatorState(const ImmutableExecutorState& immutable_state,
|
||||
// We assume root_frame_->frame_name.empty().
|
||||
root_frame_ = new FrameState(immutable_state_, 1);
|
||||
root_frame_->frame_id = 0; // must be 0
|
||||
root_frame_->InitializeFrameInfo(root_frame_->frame_name);
|
||||
root_frame_->InitializeFrameInfo(immutable_state_.get_root_frame_info());
|
||||
|
||||
// Initialize iteration 0.
|
||||
root_frame_->SetIteration(
|
||||
0, new PropagatorState::IterationState(0, root_frame_->pending_counts,
|
||||
root_frame_->total_input_tensors));
|
||||
|
||||
outstanding_frames_.insert({root_frame_->frame_name, root_frame_});
|
||||
outstanding_frames_.emplace(root_frame_->frame_id, root_frame_);
|
||||
}
|
||||
|
||||
PropagatorState::~PropagatorState() {
|
||||
@ -224,16 +226,16 @@ void PropagatorState::FindOrCreateChildFrame(FrameState* frame,
|
||||
const NodeItem& node_item,
|
||||
FrameState** child) {
|
||||
// Get the child frame name.
|
||||
AttrSlice attrs(node_item.kernel->def());
|
||||
const string& enter_name = GetNodeAttrString(attrs, "frame_name");
|
||||
DCHECK(!enter_name.empty()) << "Could not find \"frame_name\" attr in node "
|
||||
<< node_item.kernel->name();
|
||||
const string child_name = strings::StrCat(
|
||||
frame->frame_name, ";", iter_state->iter_num, ";", enter_name);
|
||||
const ImmutableExecutorState::FrameInfo& frame_info =
|
||||
immutable_state_.get_enter_frame_info(node_item);
|
||||
|
||||
const uint64 child_id = Hash64Combine(
|
||||
frame->frame_id,
|
||||
Hash64Combine(iter_state->iter_num, Hash64(frame_info.name)));
|
||||
|
||||
{
|
||||
mutex_lock executor_lock(mu_);
|
||||
auto it = outstanding_frames_.find(child_name);
|
||||
tf_shared_lock executor_lock(mu_);
|
||||
auto it = outstanding_frames_.find(child_id);
|
||||
if (it != outstanding_frames_.end()) {
|
||||
*child = it->second;
|
||||
return;
|
||||
@ -242,20 +244,18 @@ void PropagatorState::FindOrCreateChildFrame(FrameState* frame,
|
||||
|
||||
// Need to create a new frame instance.
|
||||
// Note that this new frame instance is created without any locks.
|
||||
if (vlog_) VLOG(2) << "Create frame: " << child_name;
|
||||
if (vlog_) {
|
||||
const string child_name = strings::StrCat(
|
||||
frame->frame_name, ";", iter_state->iter_num, ";", frame_info.name);
|
||||
VLOG(2) << "Create frame: " << child_name << " id: " << child_id;
|
||||
}
|
||||
|
||||
int parallel_iters;
|
||||
bool found_parallel_iters =
|
||||
TryGetNodeAttr(attrs, "parallel_iterations", ¶llel_iters);
|
||||
DCHECK(found_parallel_iters)
|
||||
<< "Could not find \"parallel_iterations\" attr in node "
|
||||
<< node_item.kernel->name();
|
||||
FrameState* temp = new FrameState(immutable_state_, parallel_iters);
|
||||
temp->frame_name = child_name;
|
||||
temp->frame_id = Hash64(child_name);
|
||||
FrameState* temp =
|
||||
new FrameState(immutable_state_, frame_info.parallel_iterations);
|
||||
temp->frame_id = child_id;
|
||||
temp->parent_frame = frame;
|
||||
temp->parent_iter = iter_state;
|
||||
temp->InitializeFrameInfo(enter_name);
|
||||
temp->InitializeFrameInfo(frame_info);
|
||||
|
||||
// Initialize iteration 0.
|
||||
{
|
||||
@ -266,13 +266,13 @@ void PropagatorState::FindOrCreateChildFrame(FrameState* frame,
|
||||
|
||||
{
|
||||
mutex_lock executor_lock(mu_);
|
||||
auto it = outstanding_frames_.find(child_name);
|
||||
auto it = outstanding_frames_.find(child_id);
|
||||
if (it != outstanding_frames_.end()) {
|
||||
*child = it->second;
|
||||
} else {
|
||||
mutex_lock frame_lock(frame->mu);
|
||||
iter_state->outstanding_frame_count++;
|
||||
outstanding_frames_[child_name] = temp;
|
||||
outstanding_frames_[child_id] = temp;
|
||||
*child = temp;
|
||||
temp = nullptr;
|
||||
}
|
||||
@ -349,11 +349,10 @@ void PropagatorState::DeleteFrame(FrameState* frame, TaggedNodeSeq* ready) {
|
||||
}
|
||||
|
||||
// Delete the frame.
|
||||
const string& frame_name = frame->frame_name;
|
||||
if (vlog_) VLOG(2) << "Delete frame " << frame_name;
|
||||
if (vlog_) VLOG(2) << "Delete frame " << frame->frame_id;
|
||||
{
|
||||
mutex_lock executor_lock(mu_);
|
||||
outstanding_frames_.erase(frame_name);
|
||||
outstanding_frames_.erase(frame->frame_id);
|
||||
}
|
||||
delete frame;
|
||||
}
|
||||
@ -655,14 +654,11 @@ bool PropagatorState::FrameState::CleanupIterations(IterationState* iter_state,
|
||||
}
|
||||
|
||||
void PropagatorState::FrameState::InitializeFrameInfo(
|
||||
const string& enter_name) {
|
||||
const ImmutableExecutorState::FrameInfo* finfo =
|
||||
immutable_state.get_frame_info(enter_name);
|
||||
DCHECK_NE(finfo, nullptr);
|
||||
pending_counts = finfo->pending_counts.get();
|
||||
total_input_tensors = finfo->total_inputs;
|
||||
num_pending_inputs = finfo->input_count;
|
||||
nodes = finfo->nodes.get();
|
||||
const ImmutableExecutorState::FrameInfo& finfo) {
|
||||
pending_counts = finfo.pending_counts.get();
|
||||
total_input_tensors = finfo.total_inputs;
|
||||
num_pending_inputs = finfo.input_count;
|
||||
nodes = finfo.nodes.get();
|
||||
}
|
||||
|
||||
void PropagatorState::FrameState::SetIteration(int64 iter,
|
||||
|
@ -279,7 +279,7 @@ class PropagatorState {
|
||||
// during structured traversal: parent_frame->mu < mu.
|
||||
mutex mu;
|
||||
|
||||
void InitializeFrameInfo(const string& enter_name);
|
||||
void InitializeFrameInfo(const ImmutableExecutorState::FrameInfo& finfo);
|
||||
|
||||
inline IterationState* GetIteration(int64 iter)
|
||||
TF_EXCLUSIVE_LOCKS_REQUIRED(mu) {
|
||||
@ -447,12 +447,13 @@ class PropagatorState {
|
||||
// The root frame in which the execution of this step is started.
|
||||
FrameState* root_frame_;
|
||||
|
||||
// Mapping from frame name to outstanding frames. A new frame is created
|
||||
// Mapping from frame ID to outstanding frames. A new frame is created
|
||||
// at some iteration of an active frame. So the unique key for the new
|
||||
// child frame is composed of the name of the parent frame, the iteration
|
||||
// child frame is a hash composed of the ID of the parent frame, the iteration
|
||||
// number at which the parent frame is creating the new frame, and the
|
||||
// name of the new frame from nodedef.
|
||||
gtl::FlatMap<string, FrameState*> outstanding_frames_ TF_GUARDED_BY(mu_);
|
||||
absl::flat_hash_map<uint64, FrameState*> outstanding_frames_
|
||||
TF_GUARDED_BY(mu_);
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(PropagatorState);
|
||||
};
|
||||
|
@ -259,7 +259,7 @@ class GpuSparse {
|
||||
// http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-coo2csr.
|
||||
Status Coo2csr(const int* cooRowInd, int nnz, int m, int* csrRowPtr) const;
|
||||
|
||||
#if CUDA_VERSION < 10020
|
||||
#if (GOOGLE_CUDA && (CUDA_VERSION < 10020)) || TENSORFLOW_USE_ROCM
|
||||
// Sparse-dense matrix multiplication C = alpha * op(A) * op(B) + beta * C,
|
||||
// where A is a sparse matrix in CSR format, B and C are dense tall
|
||||
// matrices. This routine allows transposition of matrix B, which
|
||||
@ -311,7 +311,7 @@ class GpuSparse {
|
||||
// http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrmv_mergepath
|
||||
//
|
||||
// **NOTE** This is an in-place operation for data in y.
|
||||
#if CUDA_VERSION < 10020
|
||||
#if (GOOGLE_CUDA && (CUDA_VERSION < 10020)) || TENSORFLOW_USE_ROCM
|
||||
template <typename Scalar>
|
||||
Status Csrmv(gpusparseOperation_t transA, int m, int n, int nnz,
|
||||
const Scalar* alpha_host, const gpusparseMatDescr_t descrA,
|
||||
@ -366,7 +366,7 @@ class GpuSparse {
|
||||
Scalar* csrSortedValC, int* csrSortedRowPtrC,
|
||||
int* csrSortedColIndC, void* workspace);
|
||||
|
||||
#if CUDA_VERSION >= 10000
|
||||
#if GOOGLE_CUDA && (CUDA_VERSION >= 10000)
|
||||
// Computes sparse-sparse matrix multiplication of matrices
|
||||
// stored in CSR format. This is part zero: calculate required workspace
|
||||
// size.
|
||||
@ -383,7 +383,7 @@ class GpuSparse {
|
||||
// output. csrSortedRowPtrC must be preallocated on device with
|
||||
// m + 1 entries. See:
|
||||
// http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrgemm.
|
||||
#if CUDA_VERSION < 10000
|
||||
#if (GOOGLE_CUDA && (CUDA_VERSION < 10000)) || TENSORFLOW_USE_ROCM
|
||||
Status CsrgemmNnz(gpusparseOperation_t transA, gpusparseOperation_t transB,
|
||||
int m, int k, int n, const gpusparseMatDescr_t descrA,
|
||||
int nnzA, const int* csrSortedRowPtrA,
|
||||
@ -408,7 +408,7 @@ class GpuSparse {
|
||||
// addition. csrValC and csrColIndC must be allocated on the device
|
||||
// with nnzTotalDevHostPtr entries (as calculated by CsrgemmNnz). See:
|
||||
// http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrgemm.
|
||||
#if CUDA_VERSION < 10000
|
||||
#if (GOOGLE_CUDA && (CUDA_VERSION < 10000)) || TENSORFLOW_USE_ROCM
|
||||
template <typename Scalar>
|
||||
Status Csrgemm(gpusparseOperation_t transA, gpusparseOperation_t transB,
|
||||
int m, int k, int n, const gpusparseMatDescr_t descrA,
|
||||
|
@ -435,8 +435,10 @@ class TensorListConcat : public OpKernel {
|
||||
for (int i = 0; i < tensor_list->tensors().size(); i++) {
|
||||
const Tensor& element_tensor = tensor_list->tensors()[i];
|
||||
if (element_tensor.dtype() != DT_INVALID) {
|
||||
inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
|
||||
element_tensor.shaped<T, 2>({1, element_tensor.NumElements()})));
|
||||
if (element_tensor.NumElements() > 0) {
|
||||
inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
|
||||
element_tensor.shaped<T, 2>({1, element_tensor.NumElements()})));
|
||||
}
|
||||
} else {
|
||||
AllocatorAttributes attr;
|
||||
if (element_dtype_ == DT_VARIANT) {
|
||||
|
@ -728,12 +728,14 @@ namespace {
|
||||
template <typename T>
|
||||
struct GPUDataType;
|
||||
|
||||
// GPUDataType templates are currently not instantiated in the ROCm flow
|
||||
// So leaving out the #elif TENSORFLOW_USE_ROCM blocks for now
|
||||
// hipblas library is not (yet) being pulled in via rocm_configure.bzl
|
||||
// so cannot reference tyeps from hipblas headers here
|
||||
template <>
|
||||
struct GPUDataType<Eigen::half> {
|
||||
#if GOOGLE_CUDA
|
||||
static constexpr cudaDataType_t type = CUDA_R_16F;
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
static constexpr hipblasDataType_t type = HIPBLAS_R_16F;
|
||||
#endif
|
||||
};
|
||||
|
||||
@ -741,8 +743,6 @@ template <>
|
||||
struct GPUDataType<float> {
|
||||
#if GOOGLE_CUDA
|
||||
static constexpr cudaDataType_t type = CUDA_R_32F;
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
static constexpr hipblasDataType_t type = HIPBLAS_R_32F;
|
||||
#endif
|
||||
};
|
||||
|
||||
@ -750,8 +750,6 @@ template <>
|
||||
struct GPUDataType<std::complex<float>> {
|
||||
#if GOOGLE_CUDA
|
||||
static constexpr cudaDataType_t type = CUDA_C_32F;
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
static constexpr hipblasDataType_t type = HIPBLAS_C_32F;
|
||||
#endif
|
||||
};
|
||||
|
||||
@ -759,8 +757,6 @@ template <>
|
||||
struct GPUDataType<double> {
|
||||
#if GOOGLE_CUDA
|
||||
static constexpr cudaDataType_t type = CUDA_R_64F;
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
static constexpr hipblasDataType_t type = HIPBLAS_R_64F;
|
||||
#endif
|
||||
};
|
||||
|
||||
@ -768,8 +764,6 @@ template <>
|
||||
struct GPUDataType<std::complex<double>> {
|
||||
#if GOOGLE_CUDA
|
||||
static constexpr cudaDataType_t type = CUDA_C_64F;
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
static constexpr hipblasDataType_t type = HIPBLAS_C_64F;
|
||||
#endif
|
||||
};
|
||||
|
||||
@ -957,7 +951,7 @@ class CSRSparseMatrixMatVec<GPUDevice, T> {
|
||||
const int n = a.dense_shape_host(1);
|
||||
const int nnz = a.values.size();
|
||||
DCHECK_EQ(nnz, a.col_ind.size());
|
||||
#if CUDA_VERSION >= 10020
|
||||
#if GOOGLE_CUDA && (CUDA_VERSION >= 10020)
|
||||
TF_RETURN_IF_ERROR(cuda_sparse.Csrmv(transA_, m, n, nnz, &alpha,
|
||||
a.values.data(), a.row_ptr.data(),
|
||||
a.col_ind.data(), x, &beta, y));
|
||||
|
@ -417,7 +417,7 @@ class CSRSparseMatMulGPUOp : public OpKernel {
|
||||
}
|
||||
auto b_input_dense_shape = b_input_matrix->dense_shape().vec<int64>();
|
||||
|
||||
#if CUDA_VERSION >= 10000
|
||||
#if GOOGLE_CUDA && (CUDA_VERSION >= 10000)
|
||||
size_t maxWorkspaceSize = 0;
|
||||
for (int i = 0; i < batch_size; ++i) {
|
||||
// Calculate maximum workspace size over batch.
|
||||
@ -558,7 +558,7 @@ struct CSRSparseSparseMatrixMatMul<GPUDevice, T>
|
||||
initialized_(false),
|
||||
transpose_a_(transpose_a),
|
||||
adjoint_a_(adjoint_a),
|
||||
#if CUDA_VERSION < 10000
|
||||
#if (GOOGLE_CUDA && (CUDA_VERSION < 10000)) || TENSORFLOW_USE_ROCM
|
||||
transpose_b_(transpose_b) {
|
||||
#else
|
||||
transpose_b_(transpose_b),
|
||||
@ -573,7 +573,7 @@ struct CSRSparseSparseMatrixMatMul<GPUDevice, T>
|
||||
: GPUSPARSE(OPERATION_NON_TRANSPOSE);
|
||||
}
|
||||
|
||||
#if CUDA_VERSION >= 10000
|
||||
#if GOOGLE_CUDA && (CUDA_VERSION >= 10000)
|
||||
~CSRSparseSparseMatrixMatMul() {
|
||||
if (initialized_) {
|
||||
cusparseDestroyCsrgemm2Info(info_);
|
||||
@ -591,7 +591,7 @@ struct CSRSparseSparseMatrixMatMul<GPUDevice, T>
|
||||
TF_RETURN_IF_ERROR(descrA_.Initialize());
|
||||
TF_RETURN_IF_ERROR(descrB_.Initialize());
|
||||
TF_RETURN_IF_ERROR(descrC_.Initialize());
|
||||
#if CUDA_VERSION >= 10000
|
||||
#if GOOGLE_CUDA && (CUDA_VERSION >= 10000)
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateCsrgemm2Info(&info_));
|
||||
#endif
|
||||
initialized_ = true;
|
||||
@ -600,6 +600,7 @@ struct CSRSparseSparseMatrixMatMul<GPUDevice, T>
|
||||
|
||||
Status GetWorkspaceSize(const ConstCSRComponent<T>& a,
|
||||
const ConstCSRComponent<T>& b, size_t* bufferSize) {
|
||||
#if GOOGLE_CUDA && (CUDA_VERSION >= 10000)
|
||||
DCHECK(initialized_);
|
||||
const int m =
|
||||
a.dense_shape_host(a.dense_shape_host.size() - (transpose_a_ ? 1 : 2));
|
||||
@ -621,6 +622,7 @@ struct CSRSparseSparseMatrixMatMul<GPUDevice, T>
|
||||
m, n, k, descrA_.descr(), nnzA, a.row_ptr.data(), a.col_ind.data(),
|
||||
descrB_.descr(), nnzB, b.row_ptr.data(), b.col_ind.data(), info_,
|
||||
bufferSize));
|
||||
#endif
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
@ -650,7 +652,7 @@ struct CSRSparseSparseMatrixMatMul<GPUDevice, T>
|
||||
|
||||
*output_nnz = -1;
|
||||
|
||||
#if CUDA_VERSION < 10000
|
||||
#if (GOOGLE_CUDA && (CUDA_VERSION < 10000)) || TENSORFLOW_USE_ROCM
|
||||
TF_RETURN_IF_ERROR(cuda_sparse_.CsrgemmNnz(
|
||||
transA_, transB_, m, n, k, descrA_.descr(), nnzA, a.row_ptr.data(),
|
||||
a.col_ind.data(), descrB_.descr(), nnzB, b.row_ptr.data(),
|
||||
@ -693,7 +695,7 @@ struct CSRSparseSparseMatrixMatMul<GPUDevice, T>
|
||||
b.dense_shape_host(b.dense_shape_host.size() - (transpose_b_ ? 2 : 1));
|
||||
DCHECK_EQ(n, c->dense_shape_host(c->dense_shape_host.size() - 1));
|
||||
|
||||
#if CUDA_VERSION < 10000
|
||||
#if (GOOGLE_CUDA && (CUDA_VERSION < 10000)) || TENSORFLOW_USE_ROCM
|
||||
TF_RETURN_IF_ERROR(cuda_sparse_.Csrgemm(
|
||||
transA_, transB_, m, k, n, descrA_.descr(), nnzA, a.values.data(),
|
||||
a.row_ptr.data(), a.col_ind.data(), descrB_.descr(), nnzB,
|
||||
@ -732,7 +734,7 @@ struct CSRSparseSparseMatrixMatMul<GPUDevice, T>
|
||||
GpuSparseMatrixDescriptor descrC_;
|
||||
gpusparseOperation_t transA_;
|
||||
gpusparseOperation_t transB_;
|
||||
#if CUDA_VERSION >= 10000
|
||||
#if GOOGLE_CUDA && (CUDA_VERSION >= 10000)
|
||||
csrgemm2Info_t info_;
|
||||
#endif
|
||||
};
|
||||
|
@ -108,7 +108,7 @@ limitations under the License.
|
||||
|
||||
#define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0
|
||||
#define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0
|
||||
#define TF_GRAPH_DEF_VERSION 394 // Updated: 2020/5/7
|
||||
#define TF_GRAPH_DEF_VERSION 395 // Updated: 2020/5/8
|
||||
|
||||
// Checkpoint compatibility versions (the versions field in SavedSliceMeta).
|
||||
//
|
||||
|
@ -1950,8 +1950,8 @@ func GatherV2BatchDims(value int64) GatherV2Attr {
|
||||
// Gather slices from `params` axis `axis` according to `indices`.
|
||||
//
|
||||
// `indices` must be an integer tensor of any dimension (usually 0-D or 1-D).
|
||||
// Produces an output tensor with shape `params.shape[:axis] + indices.shape +
|
||||
// params.shape[axis + 1:]` where:
|
||||
// Produces an output tensor with shape `params.shape[:axis] +
|
||||
// indices.shape[batch_dims:] + params.shape[axis + 1:]` where:
|
||||
//
|
||||
// ```python
|
||||
// # Scalar indices (output is rank(params) - 1).
|
||||
|
@ -702,6 +702,7 @@ def gen_model_coverage_test(src, model_name, data, failure_type, tags, size = "m
|
||||
"//tensorflow/lite/python:lite",
|
||||
"//tensorflow/python:client_testlib",
|
||||
] + flex_dep(target_op_sets),
|
||||
timeout = "long",
|
||||
)
|
||||
|
||||
def if_tflite_experimental_runtime(if_eager, if_non_eager, if_none = []):
|
||||
|
@ -2421,6 +2421,40 @@ class TransformLandmarksOperationParser : public TFLiteOperationParser {
|
||||
private:
|
||||
};
|
||||
|
||||
class TransformLandmarksV2OperationParser : public TFLiteOperationParser {
|
||||
public:
|
||||
absl::Status IsSupported(const TfLiteContext* context,
|
||||
const TfLiteNode* tflite_node,
|
||||
const TfLiteRegistration* registration) final {
|
||||
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
|
||||
/*runtime_inputs=*/2, /*outputs=*/1));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Parse(const TfLiteNode* tflite_node,
|
||||
const TfLiteRegistration* registration,
|
||||
GraphFloat32* graph, ObjectReader* reader) final {
|
||||
Node* node = graph->NewNode();
|
||||
RETURN_IF_ERROR(reader->AddInput(node, 0)); // data
|
||||
RETURN_IF_ERROR(reader->AddInput(node, 1)); // bbox
|
||||
RETURN_IF_ERROR(reader->AddOutputs(node));
|
||||
std::string op_name = "transform_landmarks_v2";
|
||||
node->operation.type = op_name;
|
||||
BHWC output_shape;
|
||||
RETURN_IF_ERROR(
|
||||
ParseCustomAttributes(op_name, tflite_node->custom_initial_data,
|
||||
tflite_node->custom_initial_data_size,
|
||||
&(node->operation.attributes), &output_shape));
|
||||
|
||||
auto output_value = graph->FindOutputs(node->id)[0];
|
||||
|
||||
output_value->tensor.shape = graph->FindInputs(node->id)[0]->tensor.shape;
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
private:
|
||||
};
|
||||
|
||||
class Landmarks2TransformMatrixOperationParser : public TFLiteOperationParser {
|
||||
public:
|
||||
absl::Status IsSupported(const TfLiteContext* context,
|
||||
@ -2672,6 +2706,9 @@ std::unique_ptr<TFLiteOperationParser> NewOperationParser(
|
||||
if (custom_name == "TransformLandmarks") {
|
||||
return std::make_unique<TransformLandmarksOperationParser>();
|
||||
}
|
||||
if (custom_name == "TransformLandmarksV2") {
|
||||
return std::make_unique<TransformLandmarksV2OperationParser>();
|
||||
}
|
||||
if (custom_name == "Landmarks2TransformMatrix") {
|
||||
return std::make_unique<Landmarks2TransformMatrixOperationParser>();
|
||||
}
|
||||
|
@ -562,6 +562,62 @@ absl::Status CalculateOutputShape(const std::vector<BHWC>& input,
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status CalculateOutputShape(const std::vector<BHWDC>& input,
|
||||
const ConcatAttributes& attr,
|
||||
BHWDC* output_shape) {
|
||||
BHWDC new_shape = input[0];
|
||||
switch (attr.axis) {
|
||||
case Axis::CHANNELS:
|
||||
for (int i = 1; i < input.size(); ++i) {
|
||||
if (input[i].h != new_shape.h || input[i].w != new_shape.w ||
|
||||
input[i].d != new_shape.d) {
|
||||
return absl::InvalidArgumentError(
|
||||
"Height, Width and Depth must be the same when concatenating "
|
||||
"by channels axis");
|
||||
}
|
||||
new_shape.c += input[i].c;
|
||||
}
|
||||
break;
|
||||
case Axis::HEIGHT:
|
||||
for (int i = 1; i < input.size(); ++i) {
|
||||
if (input[i].w != new_shape.w || input[i].c != new_shape.c ||
|
||||
input[i].d != new_shape.d) {
|
||||
return absl::InvalidArgumentError(
|
||||
"Width, Depth and Channels must be the same when concatenating "
|
||||
"by height axis");
|
||||
}
|
||||
new_shape.h += input[i].h;
|
||||
}
|
||||
break;
|
||||
case Axis::WIDTH:
|
||||
for (int i = 1; i < input.size(); ++i) {
|
||||
if (input[i].h != new_shape.h || input[i].c != new_shape.c ||
|
||||
input[i].d != new_shape.d) {
|
||||
return absl::InvalidArgumentError(
|
||||
"Height, Depth and Channels must be the same when concatenating "
|
||||
"by width axis");
|
||||
}
|
||||
new_shape.w += input[i].w;
|
||||
}
|
||||
break;
|
||||
case Axis::DEPTH:
|
||||
for (int i = 1; i < input.size(); ++i) {
|
||||
if (input[i].w != new_shape.w || input[i].h != new_shape.h ||
|
||||
input[i].c != new_shape.c) {
|
||||
return absl::InvalidArgumentError(
|
||||
"Width, Height and Channels must be the same when concatenating "
|
||||
"by depth axis");
|
||||
}
|
||||
new_shape.d += input[i].d;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
return absl::InvalidArgumentError("Invalid axis");
|
||||
}
|
||||
*output_shape = new_shape;
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
Padding2D CalculateSamePadding(const BHWC& input,
|
||||
const Convolution2DAttributes& attr) {
|
||||
return MakeSamePadding(input, attr);
|
||||
|
@ -206,6 +206,12 @@ absl::Status CalculateOutputShape(const std::vector<BHWC>& input,
|
||||
const ConcatAttributes& attr,
|
||||
BHWC* output_shape);
|
||||
|
||||
// @return shape of a tensor after Concat operation is applied to the given
|
||||
// input.
|
||||
absl::Status CalculateOutputShape(const std::vector<BHWDC>& input,
|
||||
const ConcatAttributes& attr,
|
||||
BHWDC* output_shape);
|
||||
|
||||
// @return padding for pooling operation to make sure output keep the same shape
|
||||
// as the given input.
|
||||
Padding2D CalculateSamePadding(const BHWC& input,
|
||||
|
@ -73,6 +73,7 @@ cc_library(
|
||||
cc_test(
|
||||
name = "add_test",
|
||||
srcs = ["add_test.cc"],
|
||||
linkstatic = True,
|
||||
tags = tf_gpu_tests_tags() + [
|
||||
"notap",
|
||||
"tflite_not_portable_ios",
|
||||
@ -102,6 +103,7 @@ cc_library(
|
||||
cc_test(
|
||||
name = "concat_test",
|
||||
srcs = ["concat_test.cc"],
|
||||
linkstatic = True,
|
||||
tags = tf_gpu_tests_tags() + [
|
||||
"notap",
|
||||
"tflite_not_portable_ios",
|
||||
@ -136,6 +138,7 @@ cc_library(
|
||||
cc_test(
|
||||
name = "conv_test",
|
||||
srcs = ["conv_test.cc"],
|
||||
linkstatic = True,
|
||||
tags = tf_gpu_tests_tags() + [
|
||||
"notap",
|
||||
"tflite_not_portable_ios",
|
||||
@ -176,6 +179,7 @@ cc_library(
|
||||
cc_test(
|
||||
name = "depthwise_conv_test",
|
||||
srcs = ["depthwise_conv_test.cc"],
|
||||
linkstatic = True,
|
||||
tags = tf_gpu_tests_tags() + [
|
||||
"notap",
|
||||
"tflite_not_portable_ios",
|
||||
@ -205,6 +209,7 @@ cc_library(
|
||||
cc_test(
|
||||
name = "elementwise_test",
|
||||
srcs = ["elementwise_test.cc"],
|
||||
linkstatic = True,
|
||||
tags = tf_gpu_tests_tags() + [
|
||||
"notap",
|
||||
"tflite_not_portable_ios",
|
||||
@ -235,6 +240,7 @@ cc_library(
|
||||
cc_test(
|
||||
name = "fully_connected_test",
|
||||
srcs = ["fully_connected_test.cc"],
|
||||
linkstatic = True,
|
||||
tags = tf_gpu_tests_tags() + [
|
||||
"notap",
|
||||
"tflite_not_portable_ios",
|
||||
@ -263,6 +269,7 @@ cc_library(
|
||||
cc_test(
|
||||
name = "lstm_test",
|
||||
srcs = ["lstm_test.cc"],
|
||||
linkstatic = True,
|
||||
tags = tf_gpu_tests_tags() + [
|
||||
"notap",
|
||||
"tflite_not_portable_ios",
|
||||
@ -292,6 +299,7 @@ cc_library(
|
||||
cc_test(
|
||||
name = "max_unpooling_test",
|
||||
srcs = ["max_unpooling_test.cc"],
|
||||
linkstatic = True,
|
||||
tags = tf_gpu_tests_tags() + [
|
||||
"notap",
|
||||
"tflite_not_portable_ios",
|
||||
@ -322,6 +330,7 @@ cc_library(
|
||||
cc_test(
|
||||
name = "mean_test",
|
||||
srcs = ["mean_test.cc"],
|
||||
linkstatic = True,
|
||||
tags = [
|
||||
"notap",
|
||||
"tflite_not_portable_ios",
|
||||
@ -351,6 +360,7 @@ cc_library(
|
||||
cc_test(
|
||||
name = "mul_test",
|
||||
srcs = ["mul_test.cc"],
|
||||
linkstatic = True,
|
||||
tags = tf_gpu_tests_tags() + [
|
||||
"notap",
|
||||
"tflite_not_portable_ios",
|
||||
@ -380,6 +390,7 @@ cc_library(
|
||||
cc_test(
|
||||
name = "pad_test",
|
||||
srcs = ["pad_test.cc"],
|
||||
linkstatic = True,
|
||||
tags = tf_gpu_tests_tags() + [
|
||||
"notap",
|
||||
"tflite_not_portable_ios",
|
||||
@ -409,6 +420,7 @@ cc_library(
|
||||
cc_test(
|
||||
name = "pooling_test",
|
||||
srcs = ["pooling_test.cc"],
|
||||
linkstatic = True,
|
||||
tags = tf_gpu_tests_tags() + [
|
||||
"notap",
|
||||
"tflite_not_portable_ios",
|
||||
@ -440,6 +452,7 @@ cc_library(
|
||||
cc_test(
|
||||
name = "prelu_test",
|
||||
srcs = ["prelu_test.cc"],
|
||||
linkstatic = True,
|
||||
tags = tf_gpu_tests_tags() + [
|
||||
"notap",
|
||||
"tflite_not_portable_ios",
|
||||
@ -471,6 +484,7 @@ cc_library(
|
||||
cc_test(
|
||||
name = "quantize_and_dequantize_test",
|
||||
srcs = ["quantize_and_dequantize_test.cc"],
|
||||
linkstatic = True,
|
||||
tags = tf_gpu_tests_tags() + [
|
||||
"notap",
|
||||
"tflite_not_portable_ios",
|
||||
@ -501,6 +515,7 @@ cc_library(
|
||||
cc_test(
|
||||
name = "relu_test",
|
||||
srcs = ["relu_test.cc"],
|
||||
linkstatic = True,
|
||||
tags = tf_gpu_tests_tags() + [
|
||||
"notap",
|
||||
"tflite_not_portable_ios",
|
||||
@ -529,6 +544,7 @@ cc_library(
|
||||
cc_test(
|
||||
name = "reshape_test",
|
||||
srcs = ["reshape_test.cc"],
|
||||
linkstatic = True,
|
||||
tags = tf_gpu_tests_tags() + [
|
||||
"notap",
|
||||
"tflite_not_portable_ios",
|
||||
@ -558,6 +574,7 @@ cc_library(
|
||||
cc_test(
|
||||
name = "slice_test",
|
||||
srcs = ["slice_test.cc"],
|
||||
linkstatic = True,
|
||||
tags = tf_gpu_tests_tags() + [
|
||||
"notap",
|
||||
"tflite_not_portable_ios",
|
||||
@ -589,6 +606,7 @@ cc_library(
|
||||
cc_test(
|
||||
name = "softmax_test",
|
||||
srcs = ["softmax_test.cc"],
|
||||
linkstatic = True,
|
||||
tags = tf_gpu_tests_tags() + [
|
||||
"notap",
|
||||
"tflite_not_portable_ios",
|
||||
@ -618,6 +636,7 @@ cc_library(
|
||||
cc_test(
|
||||
name = "space_to_depth_test",
|
||||
srcs = ["space_to_depth_test.cc"],
|
||||
linkstatic = True,
|
||||
tags = tf_gpu_tests_tags() + [
|
||||
"notap",
|
||||
"tflite_not_portable_ios",
|
||||
@ -679,6 +698,7 @@ cc_library(
|
||||
cc_test(
|
||||
name = "transpose_conv_test",
|
||||
srcs = ["transpose_conv_test.cc"],
|
||||
linkstatic = True,
|
||||
tags = tf_gpu_tests_tags() + [
|
||||
"notap",
|
||||
"tflite_not_portable_ios",
|
||||
@ -708,6 +728,7 @@ cc_library(
|
||||
cc_test(
|
||||
name = "resize_test",
|
||||
srcs = ["resize_test.cc"],
|
||||
linkstatic = True,
|
||||
tags = tf_gpu_tests_tags() + [
|
||||
"notap",
|
||||
"tflite_not_portable_ios",
|
||||
|
@ -33,7 +33,7 @@ Pod::Spec.new do |s|
|
||||
'HEADER_SEARCH_PATHS' =>
|
||||
'"${PODS_TARGET_SRCROOT}" ' +
|
||||
'"${PODS_TARGET_SRCROOT}/' + objc_dir + 'apis"',
|
||||
'VALID_ARCHS' => 'x86_64 armv7 arm64',
|
||||
'VALID_ARCHS' => 'i386 x86_64 armv7 arm64',
|
||||
}
|
||||
|
||||
s.test_spec 'Tests' do |ts|
|
||||
|
@ -33,7 +33,7 @@ Pod::Spec.new do |s|
|
||||
'HEADER_SEARCH_PATHS' =>
|
||||
'"${PODS_TARGET_SRCROOT}" ' +
|
||||
'"${PODS_TARGET_SRCROOT}/' + objc_dir + 'apis"',
|
||||
'VALID_ARCHS' => 'x86_64 armv7 arm64',
|
||||
'VALID_ARCHS' => 'i386 x86_64 armv7 arm64',
|
||||
}
|
||||
|
||||
s.test_spec 'Tests' do |ts|
|
||||
|
@ -33,7 +33,7 @@ Pod::Spec.new do |s|
|
||||
'HEADER_SEARCH_PATHS' =>
|
||||
'"${PODS_TARGET_SRCROOT}" ' +
|
||||
'"${PODS_TARGET_SRCROOT}/' + objc_dir + 'apis"',
|
||||
'VALID_ARCHS' => 'x86_64 armv7 arm64',
|
||||
'VALID_ARCHS' => 'i386 x86_64 armv7 arm64',
|
||||
}
|
||||
|
||||
s.test_spec 'Tests' do |ts|
|
||||
|
@ -101,7 +101,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install git+git://github.com/tensorflow/examples.git#egg=tensorflow-examples[model_maker]"
|
||||
"!pip install git+https://github.com/tensorflow/examples.git#egg=tensorflow-examples[model_maker]"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -101,7 +101,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install git+git://github.com/tensorflow/examples.git#egg=tensorflow-examples[model_maker]"
|
||||
"!pip install git+https://github.com/tensorflow/examples.git#egg=tensorflow-examples[model_maker]"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -157,6 +157,7 @@ py_test(
|
||||
name = "lite_v2_test",
|
||||
srcs = ["lite_v2_test.py"],
|
||||
python_version = "PY3",
|
||||
shard_count = 4,
|
||||
srcs_version = "PY2AND3",
|
||||
tags = [
|
||||
"no_windows",
|
||||
|
@ -20,6 +20,8 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import enum
|
||||
import shutil
|
||||
import tempfile
|
||||
import warnings
|
||||
|
||||
from absl import logging
|
||||
@ -413,7 +415,7 @@ class TFLiteConverterBase(object):
|
||||
class TFLiteConverterBaseV2(TFLiteConverterBase):
|
||||
"""Converter subclass to share functionality between V2 converters."""
|
||||
|
||||
def _convert(self, graph_def, input_tensors, output_tensors):
|
||||
def convert(self, graph_def, input_tensors, output_tensors):
|
||||
"""Converts a TensorFlow GraphDef based on instance variables.
|
||||
|
||||
Args:
|
||||
@ -570,7 +572,115 @@ class TFLiteSavedModelConverterV2(TFLiteConverterBaseV2):
|
||||
graph.get_tensor_by_name(signature_def.outputs[key].name)
|
||||
for key in signature_def.outputs
|
||||
]
|
||||
return self._convert(meta_graph.graph_def, input_tensors, output_tensors)
|
||||
return super(TFLiteSavedModelConverterV2,
|
||||
self).convert(meta_graph.graph_def, input_tensors,
|
||||
output_tensors)
|
||||
|
||||
|
||||
class TFLiteKerasModelConverterV2(TFLiteConverterBaseV2):
|
||||
"""Converts the given Keras model into TensorFlow Lite model."""
|
||||
|
||||
def __init__(self, keras_model, trackable_obj=None):
|
||||
"""Constructor for TFLiteConverter.
|
||||
|
||||
Args:
|
||||
keras_model: tf.Keras.Model.
|
||||
trackable_obj: tf.AutoTrackable object associated with `funcs`. A
|
||||
reference to this object needs to be maintained so that Variables do not
|
||||
get garbage collected since functions have a weak reference to
|
||||
Variables. This is only required when the tf.AutoTrackable object is not
|
||||
maintained by the user (e.g. `from_saved_model`).
|
||||
"""
|
||||
super(TFLiteKerasModelConverterV2, self).__init__()
|
||||
self._keras_model = keras_model
|
||||
self._trackable_obj = trackable_obj
|
||||
|
||||
def convert(self):
|
||||
"""Converts a keras model based on instance variables.
|
||||
|
||||
Returns:
|
||||
The converted data in serialized format.
|
||||
|
||||
Raises:
|
||||
ValueError:
|
||||
Multiple concrete functions are specified.
|
||||
Input shape is not specified.
|
||||
Invalid quantization parameters.
|
||||
"""
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
try:
|
||||
self._keras_model.save(temp_dir, save_format="tf")
|
||||
self.saved_model_dir = temp_dir
|
||||
self._saved_model_tags = set([_tag_constants.SERVING])
|
||||
self._saved_model_exported_names = [
|
||||
_signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
|
||||
]
|
||||
self._parse_saved_model_args()
|
||||
if self.saved_model_dir:
|
||||
graph = _ops.Graph()
|
||||
saved_model = _loader_impl.SavedModelLoader(self.saved_model_dir)
|
||||
saved_model.load_graph(graph, tags=self._saved_model_tags)
|
||||
meta_graph = saved_model.get_meta_graph_def_from_tags(
|
||||
self._saved_model_tags)
|
||||
signature_def = meta_graph.signature_def[
|
||||
_signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
|
||||
input_tensors = [
|
||||
graph.get_tensor_by_name(signature_def.inputs[key].name)
|
||||
for key in signature_def.inputs
|
||||
]
|
||||
output_tensors = [
|
||||
graph.get_tensor_by_name(signature_def.outputs[key].name)
|
||||
for key in signature_def.outputs
|
||||
]
|
||||
self._trackable_obj = _load(self.saved_model_dir,
|
||||
self._saved_model_tags)
|
||||
return super(TFLiteKerasModelConverterV2,
|
||||
self).convert(meta_graph.graph_def, input_tensors,
|
||||
output_tensors)
|
||||
finally:
|
||||
shutil.rmtree(temp_dir, True)
|
||||
|
||||
input_signature = None
|
||||
# If the model's call is not a `tf.function`, then we need to first get its
|
||||
# input signature from `model_input_signature` method. We can't directly
|
||||
# call `trace_model_call` because otherwise the batch dimension is set
|
||||
# to None.
|
||||
# Once we have better support for dynamic shapes, we can remove this.
|
||||
if not isinstance(self._keras_model.call, _def_function.Function):
|
||||
# Pass `keep_original_batch_size=True` will ensure that we get an input
|
||||
# signature including the batch dimension specified by the user.
|
||||
input_signature = _saving_utils.model_input_signature(
|
||||
self._keras_model, keep_original_batch_size=True)
|
||||
|
||||
func = _saving_utils.trace_model_call(self._keras_model, input_signature)
|
||||
concrete_func = func.get_concrete_function()
|
||||
self._funcs = [concrete_func]
|
||||
|
||||
frozen_func, graph_def = (
|
||||
_convert_to_constants.convert_variables_to_constants_v2_as_graph(
|
||||
self._funcs[0], lower_control_flow=False))
|
||||
|
||||
input_tensors = [
|
||||
tensor for tensor in frozen_func.inputs
|
||||
if tensor.dtype != _dtypes.resource
|
||||
]
|
||||
output_tensors = frozen_func.outputs
|
||||
|
||||
# Run a Grappler pass.
|
||||
grappler_config = self._grappler_config()
|
||||
# Skip running grappler when there are no optimizers to run. If not,
|
||||
# grappler will run with the default optimizer set and it will lead to
|
||||
# causing an unexpected behavior.
|
||||
if grappler_config.graph_options.rewrite_options.optimizers:
|
||||
graph_def = _run_graph_optimizations(
|
||||
graph_def,
|
||||
input_tensors,
|
||||
output_tensors,
|
||||
config=grappler_config,
|
||||
graph=frozen_func.graph)
|
||||
|
||||
return super(TFLiteKerasModelConverterV2,
|
||||
self).convert(graph_def, input_tensors, output_tensors)
|
||||
|
||||
|
||||
class TFLiteFrozenGraphConverterV2(TFLiteConverterBaseV2):
|
||||
@ -638,7 +748,8 @@ class TFLiteFrozenGraphConverterV2(TFLiteConverterBaseV2):
|
||||
config=grappler_config,
|
||||
graph=frozen_func.graph)
|
||||
|
||||
return self._convert(graph_def, input_tensors, output_tensors)
|
||||
return super(TFLiteFrozenGraphConverterV2,
|
||||
self).convert(graph_def, input_tensors, output_tensors)
|
||||
|
||||
|
||||
@_tf_export("lite.TFLiteConverter", v1=[])
|
||||
@ -790,21 +901,7 @@ class TFLiteConverterV2(TFLiteFrozenGraphConverterV2):
|
||||
Returns:
|
||||
TFLiteConverter object.
|
||||
"""
|
||||
input_signature = None
|
||||
# If the model's call is not a `tf.function`, then we need to first get its
|
||||
# input signature from `model_input_signature` method. We can't directly
|
||||
# call `trace_model_call` because otherwise the batch dimension is set
|
||||
# to None.
|
||||
# Once we have better support for dynamic shapes, we can remove this.
|
||||
if not isinstance(model.call, _def_function.Function):
|
||||
# Pass `keep_original_batch_size=True` will ensure that we get an input
|
||||
# signature including the batch dimension specified by the user.
|
||||
input_signature = _saving_utils.model_input_signature(
|
||||
model, keep_original_batch_size=True)
|
||||
|
||||
func = _saving_utils.trace_model_call(model, input_signature)
|
||||
concrete_func = func.get_concrete_function()
|
||||
return cls([concrete_func])
|
||||
return TFLiteKerasModelConverterV2(model)
|
||||
|
||||
# pylint: disable=useless-super-delegation
|
||||
def convert(self):
|
||||
@ -964,7 +1061,7 @@ class TFLiteConverterBaseV1(TFLiteConverterBase):
|
||||
raise ValueError("std_dev and mean must be defined when inference_type "
|
||||
"or inference_input_type is QUANTIZED_UINT8 or INT8.")
|
||||
|
||||
def _convert(self):
|
||||
def convert(self):
|
||||
"""Converts a TensorFlow GraphDef based on instance variables.
|
||||
|
||||
Returns:
|
||||
@ -1247,8 +1344,86 @@ class TFLiteSavedModelConverter(TFLiteConverterBaseV1):
|
||||
self._output_tensors = result[2]
|
||||
self._parse_saved_model_args()
|
||||
|
||||
|
||||
class TFLiteKerasModelConverter(TFLiteConverterBaseV1):
|
||||
"""Converts the given SavedModel into TensorFlow Lite model."""
|
||||
|
||||
def __init__(self,
|
||||
model_file,
|
||||
input_arrays=None,
|
||||
input_shapes=None,
|
||||
output_arrays=None,
|
||||
custom_objects=None):
|
||||
"""Constructor for TFLiteConverter.
|
||||
|
||||
Args:
|
||||
model_file: Full filepath of HDF5 file containing the tf.keras model.
|
||||
input_arrays: List of input tensors to freeze graph with. Uses input
|
||||
arrays from SignatureDef when none are provided. (default None)
|
||||
input_shapes: Dict of strings representing input tensor names to list of
|
||||
integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}).
|
||||
Automatically determined when input shapes is None (e.g., {"foo" :
|
||||
None}). (default None)
|
||||
output_arrays: List of output tensors to freeze graph with. Uses output
|
||||
arrays from SignatureDef when none are provided. (default None)
|
||||
custom_objects: Dict mapping names (strings) to custom classes or
|
||||
functions to be considered during model deserialization. (default None)
|
||||
|
||||
Raises:
|
||||
ValueError: Invalid arguments.
|
||||
"""
|
||||
super(TFLiteKerasModelConverter,
|
||||
self).__init__(experimental_debug_info_func=None)
|
||||
# Handles Keras when Eager mode is enabled.
|
||||
if context.executing_eagerly():
|
||||
if input_arrays or output_arrays:
|
||||
raise ValueError("`input_arrays` and `output_arrays` are unsupported "
|
||||
"with Eager mode. If your model requires any of these "
|
||||
"parameters, please use disable_eager_execution().")
|
||||
|
||||
_keras.backend.set_learning_phase(False)
|
||||
keras_model = _keras.models.load_model(model_file, custom_objects)
|
||||
|
||||
function = _saving_utils.trace_model_call(keras_model)
|
||||
concrete_func = function.get_concrete_function()
|
||||
|
||||
frozen_func = _convert_to_constants.convert_variables_to_constants_v2(
|
||||
concrete_func, lower_control_flow=False)
|
||||
_set_tensor_shapes(frozen_func.inputs, input_shapes)
|
||||
self._keras_model = keras_model
|
||||
self._graph_def = frozen_func.graph.as_graph_def()
|
||||
self._input_tensors = frozen_func.inputs
|
||||
self._output_tensors = frozen_func.outputs
|
||||
self._debug_info_func = _build_debug_info_func(frozen_func.graph)
|
||||
return
|
||||
|
||||
# Handles Keras when Eager mode is disabled.
|
||||
_keras.backend.clear_session()
|
||||
_keras.backend.set_learning_phase(False)
|
||||
keras_model = _keras.models.load_model(model_file, custom_objects)
|
||||
sess = _keras.backend.get_session()
|
||||
|
||||
# Get input and output tensors.
|
||||
if input_arrays:
|
||||
input_tensors = _get_tensors_from_tensor_names(sess.graph, input_arrays)
|
||||
else:
|
||||
input_tensors = keras_model.inputs
|
||||
|
||||
if output_arrays:
|
||||
output_tensors = _get_tensors_from_tensor_names(sess.graph, output_arrays)
|
||||
else:
|
||||
output_tensors = keras_model.outputs
|
||||
_set_tensor_shapes(input_tensors, input_shapes)
|
||||
|
||||
graph_def = _freeze_graph(sess, input_tensors, output_tensors)
|
||||
self._keras_model = keras_model
|
||||
self._graph_def = graph_def
|
||||
self._input_tensors = input_tensors
|
||||
self._output_tensors = output_tensors
|
||||
self._debug_info_func = _build_debug_info_func(sess.graph)
|
||||
|
||||
def convert(self):
|
||||
"""Converts a TensorFlow GraphDef based on instance variables.
|
||||
"""Converts a Keras model based on instance variables.
|
||||
|
||||
Returns:
|
||||
The converted data in serialized format. Either a TFLite Flatbuffer or a
|
||||
@ -1259,7 +1434,28 @@ class TFLiteSavedModelConverter(TFLiteConverterBaseV1):
|
||||
Input shape is not specified.
|
||||
None value for dimension in input_tensor.
|
||||
"""
|
||||
return self._convert()
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
try:
|
||||
self._keras_model.save(temp_dir, save_format="tf")
|
||||
tag_set = set([_tag_constants.SERVING])
|
||||
signature_key = _signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
|
||||
result = _freeze_saved_model(temp_dir, None, None, None, tag_set,
|
||||
signature_key)
|
||||
|
||||
self.saved_model_dir = temp_dir
|
||||
self._saved_model_tags = tag_set
|
||||
self._saved_model_exported_names = [signature_key]
|
||||
self._parse_saved_model_args()
|
||||
if self.saved_model_dir:
|
||||
self._graph_def = result[0]
|
||||
self._input_tensors = result[1]
|
||||
self._output_tensors = result[2]
|
||||
self._debug_info_func = _build_debug_info_func(result[3])
|
||||
return super(TFLiteKerasModelConverter, self).convert()
|
||||
finally:
|
||||
shutil.rmtree(temp_dir, True)
|
||||
|
||||
return super(TFLiteKerasModelConverter, self).convert()
|
||||
|
||||
|
||||
class TFLiteFrozenGraphConverter(TFLiteConverterBaseV1):
|
||||
@ -1308,20 +1504,6 @@ class TFLiteFrozenGraphConverter(TFLiteConverterBaseV1):
|
||||
self._input_arrays_with_shape = input_arrays_with_shape
|
||||
self._output_arrays = output_arrays
|
||||
|
||||
def convert(self):
|
||||
"""Converts a TensorFlow GraphDef based on instance variables.
|
||||
|
||||
Returns:
|
||||
The converted data in serialized format. Either a TFLite Flatbuffer or a
|
||||
Graphviz graph depending on value in `output_format`.
|
||||
|
||||
Raises:
|
||||
ValueError:
|
||||
Input shape is not specified.
|
||||
None value for dimension in input_tensor.
|
||||
"""
|
||||
return self._convert()
|
||||
|
||||
|
||||
@_tf_export(v1=["lite.TFLiteConverter"])
|
||||
class TFLiteConverter(TFLiteFrozenGraphConverter):
|
||||
@ -1649,53 +1831,8 @@ class TFLiteConverter(TFLiteFrozenGraphConverter):
|
||||
Returns:
|
||||
TFLiteConverter class.
|
||||
"""
|
||||
# Handles Keras when Eager mode is enabled.
|
||||
if context.executing_eagerly():
|
||||
if input_arrays or output_arrays:
|
||||
raise ValueError("`input_arrays` and `output_arrays` are unsupported "
|
||||
"with Eager mode. If your model requires any of these "
|
||||
"parameters, please use disable_eager_execution().")
|
||||
|
||||
_keras.backend.set_learning_phase(False)
|
||||
keras_model = _keras.models.load_model(model_file, custom_objects)
|
||||
|
||||
function = _saving_utils.trace_model_call(keras_model)
|
||||
concrete_func = function.get_concrete_function()
|
||||
|
||||
frozen_func = _convert_to_constants.convert_variables_to_constants_v2(
|
||||
concrete_func, lower_control_flow=False)
|
||||
_set_tensor_shapes(frozen_func.inputs, input_shapes)
|
||||
return cls(
|
||||
frozen_func.graph.as_graph_def(),
|
||||
frozen_func.inputs,
|
||||
frozen_func.outputs,
|
||||
experimental_debug_info_func=_build_debug_info_func(
|
||||
frozen_func.graph))
|
||||
|
||||
# Handles Keras when Eager mode is disabled.
|
||||
_keras.backend.clear_session()
|
||||
_keras.backend.set_learning_phase(False)
|
||||
keras_model = _keras.models.load_model(model_file, custom_objects)
|
||||
sess = _keras.backend.get_session()
|
||||
|
||||
# Get input and output tensors.
|
||||
if input_arrays:
|
||||
input_tensors = _get_tensors_from_tensor_names(sess.graph, input_arrays)
|
||||
else:
|
||||
input_tensors = keras_model.inputs
|
||||
|
||||
if output_arrays:
|
||||
output_tensors = _get_tensors_from_tensor_names(sess.graph, output_arrays)
|
||||
else:
|
||||
output_tensors = keras_model.outputs
|
||||
_set_tensor_shapes(input_tensors, input_shapes)
|
||||
|
||||
graph_def = _freeze_graph(sess, input_tensors, output_tensors)
|
||||
return cls(
|
||||
graph_def,
|
||||
input_tensors,
|
||||
output_tensors,
|
||||
experimental_debug_info_func=_build_debug_info_func(sess.graph))
|
||||
return TFLiteKerasModelConverter(model_file, input_arrays, input_shapes,
|
||||
output_arrays, custom_objects)
|
||||
|
||||
# pylint: disable=useless-super-delegation
|
||||
def convert(self):
|
||||
|
@ -1895,7 +1895,7 @@ class FromKerasFile(TestModels, parameterized.TestCase):
|
||||
|
||||
input_details = interpreter.get_input_details()
|
||||
self.assertLen(input_details, 1)
|
||||
self.assertEqual('dense_input', input_details[0]['name'])
|
||||
self.assertEndsWith(input_details[0]['name'], 'dense_input')
|
||||
self.assertEqual(np.float32, input_details[0]['dtype'])
|
||||
self.assertTrue(([1, 3] == input_details[0]['shape']).all())
|
||||
self.assertEqual((0., 0.), input_details[0]['quantization'])
|
||||
@ -1990,7 +1990,7 @@ class FromKerasFile(TestModels, parameterized.TestCase):
|
||||
|
||||
input_details = interpreter.get_input_details()
|
||||
self.assertLen(input_details, 1)
|
||||
self.assertEqual('dense_input', input_details[0]['name'])
|
||||
self.assertEndsWith(input_details[0]['name'], 'dense_input')
|
||||
self.assertTrue(([2, 3] == input_details[0]['shape']).all())
|
||||
|
||||
def testSequentialModelOutputArray(self):
|
||||
@ -2109,12 +2109,12 @@ class FromKerasFile(TestModels, parameterized.TestCase):
|
||||
|
||||
input_details = interpreter.get_input_details()
|
||||
self.assertLen(input_details, 2)
|
||||
self.assertEqual('input_a', input_details[0]['name'])
|
||||
self.assertEndsWith(input_details[0]['name'], 'input_a')
|
||||
self.assertEqual(np.float32, input_details[0]['dtype'])
|
||||
self.assertTrue(([1, 3] == input_details[0]['shape']).all())
|
||||
self.assertEqual((0., 0.), input_details[0]['quantization'])
|
||||
|
||||
self.assertEqual('input_b', input_details[1]['name'])
|
||||
self.assertEndsWith(input_details[1]['name'], 'input_b')
|
||||
self.assertEqual(np.float32, input_details[1]['dtype'])
|
||||
self.assertTrue(([1, 3] == input_details[1]['shape']).all())
|
||||
self.assertEqual((0., 0.), input_details[1]['quantization'])
|
||||
@ -2165,7 +2165,7 @@ class FromKerasFile(TestModels, parameterized.TestCase):
|
||||
|
||||
input_details = interpreter.get_input_details()
|
||||
self.assertLen(input_details, 1)
|
||||
self.assertEqual('dense_input', input_details[0]['name'])
|
||||
self.assertEndsWith(input_details[0]['name'], 'dense_input')
|
||||
self.assertEqual(np.float32, input_details[0]['dtype'])
|
||||
self.assertTrue(([1, 3] == input_details[0]['shape']).all())
|
||||
self.assertEqual((0., 0.), input_details[0]['quantization'])
|
||||
|
@ -213,9 +213,11 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest):
|
||||
self.units = units
|
||||
|
||||
def build(self, input_shape):
|
||||
self.w = self.add_weight(shape=(input_shape[-1], self.units),
|
||||
initializer='random_normal',
|
||||
trainable=True)
|
||||
self.w = self.add_weight(
|
||||
'weight',
|
||||
shape=(input_shape[-1], self.units),
|
||||
initializer='random_normal',
|
||||
trainable=True)
|
||||
self.min_var = self.add_weight(
|
||||
'min',
|
||||
initializer=tf.keras.initializers.Constant(-6.0),
|
||||
@ -748,7 +750,10 @@ class ControlFlowTest(lite_v2_test_util.ModelTest):
|
||||
input_data = tf.constant(
|
||||
np.array(np.random.random_sample((1, 10, 10)), dtype=np.float32))
|
||||
rnn_obj = rnn_layer(units=10, input_shape=(10, 10))
|
||||
model = tf.keras.models.Sequential([rnn_obj])
|
||||
model = tf.keras.models.Sequential([
|
||||
tf.keras.layers.Input(batch_size=1, shape=(10, 10), name='input'),
|
||||
rnn_obj,
|
||||
])
|
||||
|
||||
# Convert model.
|
||||
converter = lite.TFLiteConverterV2.from_keras_model(model)
|
||||
@ -787,6 +792,7 @@ class ControlFlowTest(lite_v2_test_util.ModelTest):
|
||||
input_data = tf.constant(
|
||||
np.array(np.random.random_sample((1, 10, 10)), dtype=np.float32))
|
||||
model = tf.keras.models.Sequential()
|
||||
model.add(tf.keras.layers.Input(batch_size=1, shape=(10, 10), name='input'))
|
||||
model.add(
|
||||
tf.keras.layers.Bidirectional(
|
||||
recurrent_v2.LSTM(units=10, return_sequences=True),
|
||||
|
@ -33,7 +33,7 @@ from tensorflow.python.util.tf_export import tf_export
|
||||
# This value changes every day with an automatic CL. It can be modified in code
|
||||
# via `forward_compatibility_horizon()` or with the environment variable
|
||||
# TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date.
|
||||
_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2020, 5, 7)
|
||||
_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2020, 5, 8)
|
||||
_FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS"
|
||||
_FORWARD_COMPATIBILITY_DATE_NUMBER = None
|
||||
|
||||
|
@ -41,9 +41,9 @@ class CsvDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
def _setup_files(self, inputs, linebreak='\n', compression_type=None):
|
||||
filenames = []
|
||||
for i, ip in enumerate(inputs):
|
||||
for i, file_rows in enumerate(inputs):
|
||||
fn = os.path.join(self.get_temp_dir(), 'temp_%d.csv' % i)
|
||||
contents = linebreak.join(ip).encode('utf-8')
|
||||
contents = linebreak.join(file_rows).encode('utf-8')
|
||||
if compression_type is None:
|
||||
with open(fn, 'wb') as f:
|
||||
f.write(contents)
|
||||
@ -580,6 +580,13 @@ class CsvDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
inputs, [[0, 0, 0, 0], [1, 1, 1, 0], [0, 2, 2, 2]],
|
||||
record_defaults=record_defaults)
|
||||
|
||||
def testCsvDataset_immutableParams(self):
|
||||
inputs = [['a,b,c', '1,2,3', '4,5,6']]
|
||||
filenames = self._setup_files(inputs)
|
||||
select_cols = ['a', 'c']
|
||||
_ = readers.make_csv_dataset(
|
||||
filenames, batch_size=1, select_columns=select_cols)
|
||||
self.assertAllEqual(select_cols, ['a', 'c'])
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -183,24 +183,30 @@ def _get_sorted_col_indices(select_columns, column_names):
|
||||
"""Transforms select_columns argument into sorted column indices."""
|
||||
names_to_indices = {n: i for i, n in enumerate(column_names)}
|
||||
num_cols = len(column_names)
|
||||
for i, v in enumerate(select_columns):
|
||||
|
||||
results = []
|
||||
for v in select_columns:
|
||||
# If value is already an int, check if it's valid.
|
||||
if isinstance(v, int):
|
||||
if v < 0 or v >= num_cols:
|
||||
raise ValueError(
|
||||
"Column index %d specified in select_columns out of valid range." %
|
||||
v)
|
||||
continue
|
||||
if v not in names_to_indices:
|
||||
results.append(v)
|
||||
# Otherwise, check that it's a valid column name and convert to the
|
||||
# the relevant column index.
|
||||
elif v not in names_to_indices:
|
||||
raise ValueError(
|
||||
"Value '%s' specified in select_columns not a valid column index or "
|
||||
"name." % v)
|
||||
select_columns[i] = names_to_indices[v]
|
||||
else:
|
||||
results.append(names_to_indices[v])
|
||||
|
||||
# Sort and ensure there are no duplicates
|
||||
result = sorted(set(select_columns))
|
||||
if len(result) != len(select_columns):
|
||||
results = sorted(set(results))
|
||||
if len(results) != len(select_columns):
|
||||
raise ValueError("select_columns contains duplicate columns")
|
||||
return result
|
||||
return results
|
||||
|
||||
|
||||
def _maybe_shuffle_and_repeat(
|
||||
|
@ -1647,3 +1647,25 @@ py_test(
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "strategy_common_test",
|
||||
srcs = ["strategy_common_test.py"],
|
||||
tags = [
|
||||
"multi_and_single_gpu",
|
||||
# TODO(b/155301154): Enable this test on multi-gpu guitar once multi process
|
||||
# runner can run on guitar.
|
||||
"noguitar",
|
||||
],
|
||||
xla_enable_strict_auto_jit = True,
|
||||
deps = [
|
||||
":combinations",
|
||||
":reduce_util",
|
||||
":strategy_combinations",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python/eager:def_function",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
@ -19,6 +19,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import threading
|
||||
|
||||
import enum
|
||||
import six
|
||||
@ -31,7 +32,7 @@ from tensorflow.python.distribute import reduce_util
|
||||
from tensorflow.python.distribute import tpu_values
|
||||
from tensorflow.python.distribute import values as value_lib
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import executor
|
||||
from tensorflow.python.framework import kernels
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_util
|
||||
@ -948,6 +949,20 @@ class CollectiveAllReduce(CrossDeviceOps):
|
||||
self._collective_keys = (collective_keys or
|
||||
cross_device_utils.CollectiveKeys())
|
||||
self._communication = communication
|
||||
# In a multi threaded eager program we need to ensure different groups of
|
||||
# collectives don't interleave each other, otherwise there will be deadlock.
|
||||
self._lock = threading.Lock()
|
||||
|
||||
# Collective ops requires all devices to participate and is blocking. In
|
||||
# eager, we need one async executor for each device to be able to launch
|
||||
# them altogether. Note that async doesn't imply concurrency. Within an
|
||||
# async executor operations are still executed sequentially. In graph or
|
||||
# function building, the executors are not used.
|
||||
self._executors = []
|
||||
for _ in range(self._num_gpus_per_worker or 1):
|
||||
# If num_gpus_per_worker is zero, we assume there's only one device (CPU).
|
||||
self._executors.append(executor.new_executor(enable_async=True))
|
||||
|
||||
super(CollectiveAllReduce, self).__init__()
|
||||
|
||||
@property
|
||||
@ -1059,33 +1074,26 @@ class CollectiveAllReduce(CrossDeviceOps):
|
||||
"num_workers = %d, communication_hint = %s, num_packs = %d" %
|
||||
(batch_size, self._num_workers, communication, len(packs)), 10)
|
||||
|
||||
def batch_fn():
|
||||
"""Wrapper function around batched all-reduce calls."""
|
||||
reduced_values = []
|
||||
for pack in packs:
|
||||
# By placing all CollectiveReduce ops in a pack under single name scope,
|
||||
# we ensure they will be picked up by the `ScopedAllocator` grappler
|
||||
# optimizer and packed into a single all-reduce.
|
||||
with ops.name_scope("allreduce"):
|
||||
for per_replica in pack:
|
||||
# Add control dependencies per device from the last gradients to the
|
||||
# current set, in order to serialize NCCL launches.
|
||||
if (communication == CollectiveCommunication.NCCL.value and
|
||||
reduced_values):
|
||||
control_inputs = [g for g in reduced_values[-1]]
|
||||
else:
|
||||
control_inputs = None
|
||||
reduced_values.append(
|
||||
cross_device_utils.build_collective_reduce(
|
||||
per_replica.values, self._num_workers,
|
||||
self._collective_keys, "Add", "Id", communication,
|
||||
control_inputs))
|
||||
return reduced_values
|
||||
reduced_values = []
|
||||
for pack in packs:
|
||||
# By placing all CollectiveReduce ops in a pack under single name scope,
|
||||
# we ensure they will be picked up by the `ScopedAllocator` grappler
|
||||
# optimizer and packed into a single all-reduce.
|
||||
with self._lock, ops.name_scope("allreduce"):
|
||||
for per_replica in pack:
|
||||
# Add control dependencies per device from the last gradients to the
|
||||
# current set, in order to serialize NCCL launches.
|
||||
if (communication == CollectiveCommunication.NCCL.value and
|
||||
reduced_values):
|
||||
control_inputs = list(reduced_values[-1])
|
||||
else:
|
||||
control_inputs = None
|
||||
reduced_values.append(
|
||||
cross_device_utils.build_collective_reduce(
|
||||
per_replica.values, self._num_workers,
|
||||
self._collective_keys, "Add", "Id", communication,
|
||||
control_inputs, executors=self._executors))
|
||||
|
||||
if context.executing_eagerly():
|
||||
batch_fn = def_function.function(batch_fn)
|
||||
|
||||
reduced_values = batch_fn()
|
||||
mirrored = []
|
||||
# Reverse the order of reduced value to recover the order in the input.
|
||||
for value in reversed(reduced_values):
|
||||
@ -1134,6 +1142,12 @@ class CollectiveAllReduce(CrossDeviceOps):
|
||||
mirrored.append(value_lib.regroup(value, wrap_class=value_lib.Mirrored))
|
||||
return mirrored
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
# distribute_coordinator deep-copies the strategy object, so
|
||||
# CollectiveAllReduce needs to support deep copy as well.
|
||||
return CollectiveAllReduce(self._num_workers, self._num_gpus_per_worker,
|
||||
self._collective_keys, self._communication)
|
||||
|
||||
|
||||
def choose_the_best(devices, session_config=None):
|
||||
"""Find the best CrossDeviceOps locally given a `tf.compat.v1.ConfigProto`.
|
||||
|
@ -19,6 +19,9 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import itertools
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
@ -39,6 +42,7 @@ from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import kernels
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import collective_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import variables
|
||||
|
||||
@ -835,6 +839,64 @@ class CollectiveAllReduceTest(multi_worker_test_base.MultiWorkerTestBase,
|
||||
variable_length=variable_length,
|
||||
local_mode=True)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
required_gpus=2,
|
||||
mode="eager",
|
||||
communication=[
|
||||
CollectiveCommunication.NCCL, CollectiveCommunication.RING
|
||||
]))
|
||||
def testEagerMultiThread(self, communication):
|
||||
collective, devices, _ = self._get_test_objects(
|
||||
None,
|
||||
None,
|
||||
num_gpus=2,
|
||||
communication=communication,
|
||||
use_strategy_object=False,
|
||||
local_mode=True)
|
||||
|
||||
# We would like to simulate the following sequence:
|
||||
# thread-0 device0 device1
|
||||
# thread-1 device0 device1
|
||||
# If the kernel launch sequence is as-is the program will deadlock since
|
||||
# NCCL requires the launch order to be same on each device.
|
||||
v0 = _make_per_replica([1.0 for _ in devices], devices)
|
||||
v1 = _make_per_replica([2.0 for _ in devices], devices)
|
||||
|
||||
# Add a delay to collective_ops.all_reduce according to the input tensors
|
||||
# index in `sequence.`
|
||||
sequence = [v0.values[0], v1.values[0], v1.values[1], v0.values[1]]
|
||||
all_reduce = collective_ops.all_reduce
|
||||
|
||||
def delayed_all_reduce(input_tensor, *args, **kwargs):
|
||||
for idx, v in enumerate(sequence):
|
||||
if input_tensor is v:
|
||||
time.sleep(idx)
|
||||
break
|
||||
return all_reduce(input_tensor, *args, **kwargs)
|
||||
|
||||
with test.mock.patch.object(collective_ops, "all_reduce",
|
||||
delayed_all_reduce):
|
||||
# We only use NCCL for batch reduce with two or more values, so we use two
|
||||
# values here.
|
||||
|
||||
def thread_fn():
|
||||
reduced = collective.batch_reduce(reduce_util.ReduceOp.SUM, [(v0, v0),
|
||||
(v0, v0)])
|
||||
self.assertAllEqual(reduced[0].values, [2.0, 2.0])
|
||||
self.assertAllEqual(reduced[1].values, [2.0, 2.0])
|
||||
|
||||
t = threading.Thread(target=thread_fn)
|
||||
t.start()
|
||||
reduced = collective.batch_reduce(reduce_util.ReduceOp.SUM, [(v1, v1),
|
||||
(v1, v1)])
|
||||
self.assertAllEqual(reduced[0].values, [4.0, 4.0])
|
||||
self.assertAllEqual(reduced[1].values, [4.0, 4.0])
|
||||
t.join()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Set default inter op thread pool size to one to ensure we don't exhaust the
|
||||
# thread pool with the additional executors to run collectives in eager.
|
||||
os.environ["TF_NUM_INTEROP_THREADS"] = "1"
|
||||
test.main()
|
||||
|
@ -337,10 +337,12 @@ def build_collective_reduce(input_tensors,
|
||||
reduction_op='Add',
|
||||
unary_op='Id',
|
||||
communication_hint='AUTO',
|
||||
control_inputs=None):
|
||||
control_inputs=None,
|
||||
executors=None):
|
||||
"""Build a subgraph that does one full all-reduce, using the collective Op.
|
||||
|
||||
This method must be called in graph mode or inside a tf.function.
|
||||
If called in eager mode, it's required to supply a list of async executors for
|
||||
each input Tensor.
|
||||
|
||||
Args:
|
||||
input_tensors: tensors within a single worker graph that are to be reduced
|
||||
@ -355,6 +357,7 @@ def build_collective_reduce(input_tensors,
|
||||
implementation.
|
||||
control_inputs: if not None, add control edges between control_inputs and
|
||||
(index-wise) corresponding collective_reduce tensors
|
||||
executors: a list of async executor. Required for eager execution.
|
||||
|
||||
Returns:
|
||||
An array of final tensors, one per device, computed by the full reduction.
|
||||
@ -362,9 +365,11 @@ def build_collective_reduce(input_tensors,
|
||||
Raises:
|
||||
ValueError: There must be at least two tensors over all the workers.
|
||||
"""
|
||||
assert not context.executing_eagerly(), (
|
||||
'build_collective_reduce can only be called in graph mode or inside '
|
||||
'tf.function')
|
||||
if context.executing_eagerly():
|
||||
if (not executors or len(executors) != len(input_tensors) or
|
||||
not all(e.is_async() for e in executors)):
|
||||
raise ValueError(
|
||||
'collectives requires async executors for each device in eager mode')
|
||||
|
||||
group_size = len(input_tensors) * num_workers
|
||||
if group_size < 2:
|
||||
@ -375,15 +380,19 @@ def build_collective_reduce(input_tensors,
|
||||
|
||||
out_tensors = []
|
||||
for idx, input_tensor in enumerate(input_tensors):
|
||||
with ops.device(input_tensor.device):
|
||||
with ops.control_dependencies(
|
||||
_control_input(input_tensors, control_inputs, idx)):
|
||||
out_tensor = collective_ops.all_reduce(input_tensor, group_size,
|
||||
group_key, instance_key,
|
||||
reduction_op, unary_op,
|
||||
subdiv_offsets,
|
||||
communication_hint)
|
||||
out_tensors.append(out_tensor)
|
||||
if context.executing_eagerly():
|
||||
executor_scope = context.executor_scope(executors[idx])
|
||||
else:
|
||||
executor_scope = ops.NullContextmanager()
|
||||
with executor_scope, \
|
||||
ops.device(input_tensor.device), \
|
||||
ops.control_dependencies(
|
||||
_control_input(input_tensors, control_inputs, idx)):
|
||||
out_tensor = collective_ops.all_reduce(input_tensor, group_size,
|
||||
group_key, instance_key,
|
||||
reduction_op, unary_op,
|
||||
subdiv_offsets, communication_hint)
|
||||
out_tensors.append(out_tensor)
|
||||
return out_tensors
|
||||
|
||||
|
||||
|
@ -1912,9 +1912,8 @@ class StrategyExtendedV2(object):
|
||||
|
||||
def _reduce(self, reduce_op, value):
|
||||
# Default implementation until we have an implementation for each strategy.
|
||||
return self._local_results(
|
||||
self.reduce_to(reduce_op, value,
|
||||
device_util.current() or "/device:CPU:0"))[0]
|
||||
dst = device_util.current() or self._default_device or "/device:CPU:0"
|
||||
return self._local_results(self.reduce_to(reduce_op, value, dst))[0]
|
||||
|
||||
def reduce_to(self, reduce_op, value, destinations, experimental_hints=None):
|
||||
"""Combine (via e.g. sum or mean) values across replicas.
|
||||
|
65
tensorflow/python/distribute/strategy_common_test.py
Normal file
65
tensorflow/python/distribute/strategy_common_test.py
Normal file
@ -0,0 +1,65 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for common methods in strategy classes."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python.distribute import combinations
|
||||
from tensorflow.python.distribute import reduce_util
|
||||
from tensorflow.python.distribute import strategy_combinations
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class StrategyReduceTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
strategy=[strategy_combinations.multi_worker_mirrored_two_workers] +
|
||||
strategy_combinations.strategies_minus_tpu,
|
||||
mode=['eager']))
|
||||
def testSimpleReduce(self, strategy):
|
||||
|
||||
def fn_eager():
|
||||
|
||||
def replica_fn():
|
||||
return array_ops.ones((), dtypes.float32)
|
||||
|
||||
per_replica_value = strategy.run(replica_fn)
|
||||
return strategy.reduce(
|
||||
reduce_util.ReduceOp.SUM, value=per_replica_value, axis=None)
|
||||
|
||||
fn_graph = def_function.function(fn_eager)
|
||||
|
||||
# Run reduce under the strategy scope to explicitly enter
|
||||
# strategy default_device scope.
|
||||
with strategy.scope():
|
||||
self.assertEqual(fn_eager().numpy(), 1.0 * strategy.num_replicas_in_sync)
|
||||
self.assertEqual(fn_graph().numpy(), 1.0 * strategy.num_replicas_in_sync)
|
||||
|
||||
# Run reduce without a strategy scope to implicitly enter
|
||||
# strategy default_device scope.
|
||||
self.assertEqual(fn_eager().numpy(), 1.0 * strategy.num_replicas_in_sync)
|
||||
self.assertEqual(fn_graph().numpy(), 1.0 * strategy.num_replicas_in_sync)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
combinations.main()
|
@ -110,14 +110,12 @@ def run_benchmark(func, num_iters, execution_mode=None):
|
||||
class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase):
|
||||
|
||||
def __init__(self):
|
||||
# TODO(b/153054118): Add tf.RandomUniform
|
||||
if not context.is_tfrt_enabled():
|
||||
# used for multiply benchmarks
|
||||
self._m_2 = random_ops.random_uniform([2])
|
||||
# used for multiply benchmarks
|
||||
self._m_2 = random_ops.random_uniform([2])
|
||||
|
||||
# used for matmul benchmarks
|
||||
self._m_2_by_2 = random_ops.random_uniform((2, 2))
|
||||
self._m_100_by_784 = random_ops.random_uniform((100, 784))
|
||||
# used for matmul benchmarks
|
||||
self._m_2_by_2 = random_ops.random_uniform((2, 2))
|
||||
self._m_100_by_784 = random_ops.random_uniform((100, 784))
|
||||
|
||||
self._num_iters_2_by_2 = 30000
|
||||
self._num_iters_100_by_784 = 30000
|
||||
@ -319,17 +317,16 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase):
|
||||
func = lambda: math_ops.multiply(m, m)
|
||||
self._run(func, num_iters)
|
||||
|
||||
@test_util.disable_tfrt("random ops not supported")
|
||||
@test_util.disable_tfrt("numpy() not supported")
|
||||
def benchmark_np_multiply(self):
|
||||
self._benchmark_np_multiply(self._m_2, 30000)
|
||||
|
||||
@test_util.disable_tfrt("random ops not supported")
|
||||
def benchmark_tf_multiply_CPU(self):
|
||||
with context.device(CPU):
|
||||
m = self._m_2.cpu()
|
||||
self._benchmark_tf_multiply(m, 30000)
|
||||
|
||||
@test_util.disable_tfrt("random ops not supported")
|
||||
@test_util.disable_tfrt("copy to GPU not supported")
|
||||
def benchmark_tf_multiply_GPU(self):
|
||||
if not context.num_gpus():
|
||||
return
|
||||
@ -337,13 +334,12 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase):
|
||||
m = self._m_2.gpu()
|
||||
self._benchmark_tf_multiply(m, 30000)
|
||||
|
||||
@test_util.disable_tfrt("random ops not supported")
|
||||
def benchmark_tf_multiply_op_CPU(self):
|
||||
with context.device(CPU):
|
||||
m = self._m_2.cpu()
|
||||
self._benchmark_tf_multiply_op(m, 30000)
|
||||
|
||||
@test_util.disable_tfrt("random ops not supported")
|
||||
@test_util.disable_tfrt("copy to GPU not supported")
|
||||
def benchmark_tf_multiply_op_GPU(self):
|
||||
if not context.num_gpus():
|
||||
return
|
||||
@ -351,7 +347,6 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase):
|
||||
m = self._m_2.gpu()
|
||||
self._benchmark_tf_multiply_op(m, 30000)
|
||||
|
||||
@test_util.disable_tfrt("random ops not supported")
|
||||
def benchmark_tf_identity(self):
|
||||
m = self._m_2
|
||||
self._run(lambda: gen_array_ops.identity(m), 30000)
|
||||
@ -360,7 +355,6 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase):
|
||||
def benchmark_slowpath_tf_identity(self):
|
||||
self._run(lambda: gen_array_ops.identity(1), 30000)
|
||||
|
||||
@test_util.disable_tfrt("random ops not supported")
|
||||
def benchmark_tfe_py_execute_identity(self):
|
||||
m = self._m_2
|
||||
ctx_handle = context.context()._handle
|
||||
@ -498,19 +492,17 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase):
|
||||
self._run(m.value, num_iters)
|
||||
|
||||
# Benchmarks for A^2, A of dimension 2 by 2.
|
||||
@test_util.disable_tfrt("random ops not supported")
|
||||
def benchmark_np_matmul_2_by_2(self):
|
||||
self._benchmark_np_matmul(
|
||||
self._m_2_by_2, transpose_b=False, num_iters=self._num_iters_2_by_2)
|
||||
|
||||
@test_util.disable_tfrt("random ops not supported")
|
||||
def benchmark_tf_matmul_2_by_2_CPU(self):
|
||||
with context.device(CPU):
|
||||
m = self._m_2_by_2.cpu()
|
||||
self._benchmark_tf_matmul(
|
||||
m, transpose_b=False, num_iters=self._num_iters_2_by_2)
|
||||
|
||||
@test_util.disable_tfrt("random ops not supported")
|
||||
@test_util.disable_tfrt("async not supported")
|
||||
def benchmark_tf_matmul_2_by_2_CPU_async(self):
|
||||
with context.device(CPU):
|
||||
m = self._m_2_by_2.cpu()
|
||||
@ -520,35 +512,32 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase):
|
||||
num_iters=self._num_iters_2_by_2,
|
||||
execution_mode=context.ASYNC)
|
||||
|
||||
@test_util.disable_tfrt("random ops not supported")
|
||||
def benchmark_gen_math_ops_matmul_2_by_2_CPU(self):
|
||||
with context.device(CPU):
|
||||
m = self._m_2_by_2.cpu()
|
||||
self._benchmark_gen_math_ops_matmul(
|
||||
m, transpose_b=False, num_iters=self._num_iters_2_by_2)
|
||||
|
||||
@test_util.disable_tfrt("random ops not supported")
|
||||
def benchmark_tfe_py_fastpath_execute_matmul_2_by_2_CPU(self):
|
||||
with context.device(CPU):
|
||||
m = self._m_2_by_2.cpu()
|
||||
self._benchmark_tfe_py_fastpath_execute_matmul(
|
||||
m, transpose_b=False, num_iters=self._num_iters_2_by_2)
|
||||
|
||||
@test_util.disable_tfrt("random ops not supported")
|
||||
def benchmark_tfe_py_execute_matmul_2_by_2_CPU(self):
|
||||
with context.device(CPU):
|
||||
m = self._m_2_by_2.cpu()
|
||||
self._benchmark_tfe_py_execute_matmul(
|
||||
m, transpose_b=False, num_iters=self._num_iters_2_by_2)
|
||||
|
||||
@test_util.disable_tfrt("random ops not supported")
|
||||
@test_util.disable_tfrt("Mutex corrupt: waiting writer with no waiters")
|
||||
def benchmark_defun_matmul_2_by_2_CPU(self):
|
||||
with context.device(CPU):
|
||||
m = self._m_2_by_2.cpu()
|
||||
self._benchmark_defun_matmul(
|
||||
m, transpose_b=False, num_iters=self._num_iters_2_by_2)
|
||||
|
||||
@test_util.disable_tfrt("random ops not supported")
|
||||
@test_util.disable_tfrt("async not supported")
|
||||
def benchmark_defun_matmul_2_by_2_CPU_async(self):
|
||||
with context.device(CPU):
|
||||
m = self._m_2_by_2.cpu()
|
||||
@ -558,14 +547,14 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase):
|
||||
num_iters=self._num_iters_2_by_2,
|
||||
execution_mode=context.ASYNC)
|
||||
|
||||
@test_util.disable_tfrt("random ops not supported")
|
||||
@test_util.disable_tfrt("Mutex corrupt: waiting writer with no waiters")
|
||||
def benchmark_defun_matmul_forward_backward_2_by_2_CPU(self):
|
||||
with context.device(CPU):
|
||||
m = self._m_2_by_2.cpu()
|
||||
self._benchmark_defun_matmul_forward_backward(
|
||||
m, transpose_b=False, num_iters=self._num_iters_2_by_2)
|
||||
|
||||
@test_util.disable_tfrt("random ops not supported")
|
||||
@test_util.disable_tfrt("async not supported")
|
||||
def benchmark_defun_matmul_forward_backward_2_by_2_CPU_async(self):
|
||||
with context.device(CPU):
|
||||
m = self._m_2_by_2.cpu()
|
||||
@ -575,7 +564,7 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase):
|
||||
num_iters=self._num_iters_2_by_2,
|
||||
execution_mode=context.ASYNC)
|
||||
|
||||
@test_util.disable_tfrt("random ops not supported")
|
||||
@test_util.disable_tfrt("copy to GPU not supported")
|
||||
def benchmark_tf_matmul_2_by_2_GPU(self):
|
||||
if not context.num_gpus():
|
||||
return
|
||||
@ -584,7 +573,7 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase):
|
||||
self._benchmark_tf_matmul(
|
||||
m, transpose_b=False, num_iters=self._num_iters_2_by_2)
|
||||
|
||||
@test_util.disable_tfrt("random ops not supported")
|
||||
@test_util.disable_tfrt("async not supported")
|
||||
def benchmark_tf_matmul_2_by_2_GPU_async(self):
|
||||
if not context.num_gpus():
|
||||
return
|
||||
@ -596,7 +585,7 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase):
|
||||
num_iters=self._num_iters_2_by_2,
|
||||
execution_mode=context.ASYNC)
|
||||
|
||||
@test_util.disable_tfrt("random ops not supported")
|
||||
@test_util.disable_tfrt("copy to GPU not supported")
|
||||
def benchmark_gen_math_ops_matmul_2_by_2_GPU(self):
|
||||
if not context.num_gpus():
|
||||
return
|
||||
@ -605,7 +594,7 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase):
|
||||
self._benchmark_gen_math_ops_matmul(
|
||||
m, transpose_b=False, num_iters=self._num_iters_2_by_2)
|
||||
|
||||
@test_util.disable_tfrt("random ops not supported")
|
||||
@test_util.disable_tfrt("copy to GPU not supported")
|
||||
def benchmark_tfe_py_execute_matmul_2_by_2_GPU(self):
|
||||
if not context.num_gpus():
|
||||
return
|
||||
@ -614,7 +603,7 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase):
|
||||
self._benchmark_tfe_py_execute_matmul(
|
||||
m, transpose_b=False, num_iters=self._num_iters_2_by_2)
|
||||
|
||||
@test_util.disable_tfrt("random ops not supported")
|
||||
@test_util.disable_tfrt("defun not supported")
|
||||
def benchmark_defun_matmul_2_by_2_GPU(self):
|
||||
if not context.num_gpus():
|
||||
return
|
||||
@ -623,7 +612,7 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase):
|
||||
self._benchmark_defun_matmul(
|
||||
m, transpose_b=False, num_iters=self._num_iters_2_by_2)
|
||||
|
||||
@test_util.disable_tfrt("random ops not supported")
|
||||
@test_util.disable_tfrt("async not supported")
|
||||
def benchmark_defun_matmul_2_by_2_GPU_async(self):
|
||||
if not context.num_gpus():
|
||||
return
|
||||
@ -635,28 +624,26 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase):
|
||||
num_iters=self._num_iters_2_by_2,
|
||||
execution_mode=context.ASYNC)
|
||||
|
||||
@test_util.disable_tfrt("random ops not supported")
|
||||
@test_util.disable_tfrt("function not supported")
|
||||
def benchmark_nested_defun_matmul_2_by_2(self):
|
||||
m = self._m_2_by_2.cpu()
|
||||
self._benchmark_nested_defun_matmul(
|
||||
m, transpose_b=False, num_iters=self._num_iters_2_by_2)
|
||||
|
||||
# Benchmarks for AA.T, A of dimension 100 by 784.
|
||||
@test_util.disable_tfrt("random ops not supported")
|
||||
def benchmark_np_matmul_100_by_784(self):
|
||||
self._benchmark_np_matmul(
|
||||
self._m_100_by_784,
|
||||
transpose_b=True,
|
||||
num_iters=self._num_iters_100_by_784)
|
||||
|
||||
@test_util.disable_tfrt("random ops not supported")
|
||||
def benchmark_tf_matmul_100_by_784_CPU(self):
|
||||
with context.device(CPU):
|
||||
m = self._m_100_by_784.cpu()
|
||||
self._benchmark_tf_matmul(
|
||||
m, transpose_b=True, num_iters=self._num_iters_100_by_784)
|
||||
|
||||
@test_util.disable_tfrt("random ops not supported")
|
||||
@test_util.disable_tfrt("async not supported")
|
||||
def benchmark_tf_matmul_100_by_784_CPU_async(self):
|
||||
with context.device(CPU):
|
||||
m = self._m_100_by_784.cpu()
|
||||
@ -666,35 +653,33 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase):
|
||||
num_iters=self._num_iters_100_by_784,
|
||||
execution_mode=context.ASYNC)
|
||||
|
||||
@test_util.disable_tfrt("random ops not supported")
|
||||
def benchmark_gen_math_ops_matmul_100_by_784_CPU(self):
|
||||
with context.device(CPU):
|
||||
m = self._m_100_by_784.cpu()
|
||||
self._benchmark_gen_math_ops_matmul(
|
||||
m, transpose_b=True, num_iters=self._num_iters_100_by_784)
|
||||
|
||||
@test_util.disable_tfrt("random ops not supported")
|
||||
def benchmark_tfe_py_fastpath_execute_matmul_100_by_784_CPU(self):
|
||||
with context.device(CPU):
|
||||
m = self._m_100_by_784.cpu()
|
||||
self._benchmark_tfe_py_fastpath_execute_matmul(
|
||||
m, transpose_b=True, num_iters=self._num_iters_100_by_784)
|
||||
|
||||
@test_util.disable_tfrt("random ops not supported")
|
||||
@test_util.disable_tfrt("copy to GPU not supported")
|
||||
def benchmark_tfe_py_execute_matmul_100_by_784_CPU(self):
|
||||
with context.device(CPU):
|
||||
m = self._m_100_by_784.cpu()
|
||||
self._benchmark_tfe_py_execute_matmul(
|
||||
m, transpose_b=True, num_iters=self._num_iters_100_by_784)
|
||||
|
||||
@test_util.disable_tfrt("random ops not supported")
|
||||
@test_util.disable_tfrt("function not supported")
|
||||
def benchmark_defun_matmul_100_by_784_CPU(self):
|
||||
with context.device(CPU):
|
||||
m = self._m_100_by_784.cpu()
|
||||
self._benchmark_defun_matmul(
|
||||
m, transpose_b=True, num_iters=self._num_iters_100_by_784)
|
||||
|
||||
@test_util.disable_tfrt("random ops not supported")
|
||||
@test_util.disable_tfrt("copy to GPU not supported")
|
||||
def benchmark_tf_matmul_100_by_784_GPU(self):
|
||||
if not context.num_gpus():
|
||||
return
|
||||
@ -703,7 +688,7 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase):
|
||||
self._benchmark_tf_matmul(
|
||||
m, transpose_b=True, num_iters=self._num_iters_100_by_784)
|
||||
|
||||
@test_util.disable_tfrt("random ops not supported")
|
||||
@test_util.disable_tfrt("async not supported")
|
||||
def benchmark_tf_matmul_100_by_784_GPU_async(self):
|
||||
if not context.num_gpus():
|
||||
return
|
||||
@ -715,7 +700,7 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase):
|
||||
num_iters=self._num_iters_100_by_784,
|
||||
execution_mode=context.ASYNC)
|
||||
|
||||
@test_util.disable_tfrt("random ops not supported")
|
||||
@test_util.disable_tfrt("copy to GPU not supported")
|
||||
def benchmark_gen_math_ops_matmul_100_by_784_GPU(self):
|
||||
if not context.num_gpus():
|
||||
return
|
||||
@ -724,7 +709,7 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase):
|
||||
self._benchmark_gen_math_ops_matmul(
|
||||
m, transpose_b=True, num_iters=self._num_iters_100_by_784)
|
||||
|
||||
@test_util.disable_tfrt("random ops not supported")
|
||||
@test_util.disable_tfrt("copy to GPU not supported")
|
||||
def benchmark_tfe_py_execute_matmul_100_by_784_GPU(self):
|
||||
if not context.num_gpus():
|
||||
return
|
||||
@ -733,7 +718,7 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase):
|
||||
self._benchmark_tfe_py_execute_matmul(
|
||||
m, transpose_b=True, num_iters=self._num_iters_100_by_784)
|
||||
|
||||
@test_util.disable_tfrt("random ops not supported")
|
||||
@test_util.disable_tfrt("defun not supported")
|
||||
def benchmark_defun_matmul_100_by_784_GPU(self):
|
||||
if not context.num_gpus():
|
||||
return
|
||||
@ -742,7 +727,7 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase):
|
||||
self._benchmark_defun_matmul(
|
||||
m, transpose_b=True, num_iters=self._num_iters_100_by_784)
|
||||
|
||||
@test_util.disable_tfrt("random ops not supported")
|
||||
@test_util.disable_tfrt("defun not supported")
|
||||
def benchmark_nested_defun_matmul_100_by_784(self):
|
||||
m = self._m_100_by_784.gpu()
|
||||
self._benchmark_nested_defun_matmul(
|
||||
@ -815,35 +800,35 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase):
|
||||
func()
|
||||
self._run(func, 3000)
|
||||
|
||||
@test_util.disable_tfrt("random ops not supported")
|
||||
@test_util.disable_tfrt("defun not supported")
|
||||
def benchmark_forwardprop_matmul_256_by_2096_CPU(self):
|
||||
self._benchmark_forwardprop_matmul_CPU(shape=(256, 2096))
|
||||
|
||||
@test_util.disable_tfrt("random ops not supported")
|
||||
@test_util.disable_tfrt("defun not supported")
|
||||
def benchmark_forwardprop_in_defun_matmul_256_by_2096_CPU(self):
|
||||
self._benchmark_forwardprop_in_defun_matmul_CPU(shape=(256, 2096))
|
||||
|
||||
@test_util.disable_tfrt("random ops not supported")
|
||||
@test_util.disable_tfrt("defun not supported")
|
||||
def benchmark_forwardprop_in_defun_of_defun_matmul_256_by_2096_CPU(self):
|
||||
self._benchmark_forwardprop_in_defun_of_defun_matmul_CPU(shape=(256, 2096))
|
||||
|
||||
@test_util.disable_tfrt("random ops not supported")
|
||||
@test_util.disable_tfrt("defun not supported")
|
||||
def benchmark_forwardprop_of_defun_matmul_256_by_2096_CPU(self):
|
||||
self._benchmark_forwardprop_of_defun_matmul_CPU(shape=(256, 2096))
|
||||
|
||||
@test_util.disable_tfrt("random ops not supported")
|
||||
@test_util.disable_tfrt("defun not supported")
|
||||
def benchmark_forwardprop_matmul_100_by_784_CPU(self):
|
||||
self._benchmark_forwardprop_matmul_CPU(shape=(100, 784))
|
||||
|
||||
@test_util.disable_tfrt("random ops not supported")
|
||||
@test_util.disable_tfrt("defun not supported")
|
||||
def benchmark_forwardprop_in_defun_matmul_100_by_784_CPU(self):
|
||||
self._benchmark_forwardprop_in_defun_matmul_CPU(shape=(100, 784))
|
||||
|
||||
@test_util.disable_tfrt("random ops not supported")
|
||||
@test_util.disable_tfrt("defun not supported")
|
||||
def benchmark_forwardprop_in_defun_of_defun_matmul_100_by_784_CPU(self):
|
||||
self._benchmark_forwardprop_in_defun_of_defun_matmul_CPU(shape=(100, 784))
|
||||
|
||||
@test_util.disable_tfrt("random ops not supported")
|
||||
@test_util.disable_tfrt("defun not supported")
|
||||
def benchmark_forwardprop_of_defun_matmul_100_by_784_CPU(self):
|
||||
self._benchmark_forwardprop_of_defun_matmul_CPU(shape=(100, 784))
|
||||
|
||||
@ -988,25 +973,20 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase):
|
||||
func = lambda: array_ops.zeros_like(m)
|
||||
self._run(func, 3000)
|
||||
|
||||
@test_util.disable_tfrt("random ops not supported")
|
||||
def benchmark_tf_zeros_like_CPU(self):
|
||||
self._benchmark_tf_zeros_like(self._m_2_by_2)
|
||||
|
||||
@test_util.disable_tfrt("random ops not supported")
|
||||
def benchmark_tf_zeros_like_GPU(self):
|
||||
self._benchmark_tf_zeros_like(self._m_2_by_2, device=GPU)
|
||||
|
||||
@test_util.disable_tfrt("random ops not supported")
|
||||
def benchmark_tf_zeros_like_variable_CPU(self):
|
||||
m = resource_variable_ops.ResourceVariable(self._m_2_by_2)
|
||||
self._benchmark_tf_zeros_like(m)
|
||||
|
||||
@test_util.disable_tfrt("random ops not supported")
|
||||
def benchmark_tf_zeros_like_variable_GPU(self):
|
||||
m = resource_variable_ops.ResourceVariable(self._m_2_by_2)
|
||||
self._benchmark_tf_zeros_like(m, device=GPU)
|
||||
|
||||
@test_util.disable_tfrt("random ops not supported")
|
||||
def _benchmark_tf_random_uniform_2_by_2(self,
|
||||
shape=(2, 2),
|
||||
dtype=dtypes.int32,
|
||||
@ -1018,30 +998,24 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase):
|
||||
|
||||
self._run(func, num_iters=self._num_iters_2_by_2)
|
||||
|
||||
@test_util.disable_tfrt("random ops not supported")
|
||||
def benchmark_tf_random_uniform_2_by_2_integer_CPU(self):
|
||||
self._benchmark_tf_random_uniform_2_by_2()
|
||||
|
||||
@test_util.disable_tfrt("random ops not supported")
|
||||
def benchmark_tf_random_uniform_2_by_2_integer_GPU(self):
|
||||
self._benchmark_tf_random_uniform_2_by_2(device=GPU)
|
||||
|
||||
@test_util.disable_tfrt("random ops not supported")
|
||||
def benchmark_tf_random_uniform_2_by_2_float_CPU(self):
|
||||
self._benchmark_tf_random_uniform_2_by_2(dtype=dtypes.float32)
|
||||
|
||||
@test_util.disable_tfrt("random ops not supported")
|
||||
def benchmark_tf_random_uniform_2_by_2_float_GPU(self):
|
||||
self._benchmark_tf_random_uniform_2_by_2(
|
||||
dtype=dtypes.float32, device=GPU)
|
||||
|
||||
@test_util.disable_tfrt("random ops not supported")
|
||||
def benchmark_tf_random_uniform_2_by_2_default_setting_CPU(self):
|
||||
with context.device(CPU):
|
||||
func = lambda: random_ops.random_uniform((2, 2))
|
||||
self._run(func, num_iters=self._num_iters_2_by_2)
|
||||
|
||||
@test_util.disable_tfrt("random ops not supported")
|
||||
def benchmark_tf_random_uniform_2_by_2_default_setting_GPU(self):
|
||||
with context.device(GPU):
|
||||
func = lambda: random_ops.random_uniform((2, 2))
|
||||
@ -1063,19 +1037,15 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase):
|
||||
|
||||
self._run(func, num_iters=self._num_iters_2_by_2)
|
||||
|
||||
@test_util.disable_tfrt("random ops not supported")
|
||||
def benchmark_tf_dropout_scalar_rate_2_by_2_CPU(self):
|
||||
self._benchmark_tf_dropout_2_by_2(is_rate_tensor=False)
|
||||
|
||||
@test_util.disable_tfrt("random ops not supported")
|
||||
def benchmark_tf_dropout_scalar_rate_2_by_2_GPU(self):
|
||||
self._benchmark_tf_dropout_2_by_2(is_rate_tensor=False, device=GPU)
|
||||
|
||||
@test_util.disable_tfrt("random ops not supported")
|
||||
def benchmark_tf_dropout_2_by_2_CPU(self):
|
||||
self._benchmark_tf_dropout_2_by_2()
|
||||
|
||||
@test_util.disable_tfrt("random ops not supported")
|
||||
def benchmark_tf_dropout_2_by_2_GPU(self):
|
||||
self._benchmark_tf_dropout_2_by_2(device=GPU)
|
||||
|
||||
@ -1088,25 +1058,25 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase):
|
||||
func = lambda: array_ops.transpose(m, perm, conjugate)
|
||||
self._run(func, num_iters, execution_mode=execution_mode)
|
||||
|
||||
@test_util.disable_tfrt("random ops not supported")
|
||||
@test_util.disable_tfrt("ConvertToEagerTensorUncached error")
|
||||
def benchmark_tf_transpose_2_by_2_CPU(self):
|
||||
with context.device(CPU):
|
||||
m = self._m_2_by_2.cpu()
|
||||
self._benchmark_transpose(m, num_iters=self._num_iters_2_by_2)
|
||||
|
||||
@test_util.disable_tfrt("random ops not supported")
|
||||
@test_util.disable_tfrt("copy to GPU not supported")
|
||||
def benchmark_tf_transpose_2_by_2_GPU(self):
|
||||
with context.device(GPU):
|
||||
m = self._m_2_by_2.gpu()
|
||||
self._benchmark_transpose(m, num_iters=self._num_iters_2_by_2)
|
||||
|
||||
@test_util.disable_tfrt("random ops not supported")
|
||||
@test_util.disable_tfrt("ConvertToEagerTensorUncached error")
|
||||
def benchmark_tf_transpose_variable_2_by_2_CPU(self):
|
||||
with context.device(CPU):
|
||||
m = resource_variable_ops.ResourceVariable(self._m_2_by_2)
|
||||
self._benchmark_transpose(m, num_iters=self._num_iters_2_by_2)
|
||||
|
||||
@test_util.disable_tfrt("random ops not supported")
|
||||
@test_util.disable_tfrt("Cannot convert array to EagerTensor of dtype int32")
|
||||
def benchmark_tf_transpose_variable_2_by_2_GPU(self):
|
||||
with context.device(GPU):
|
||||
m = resource_variable_ops.ResourceVariable(self._m_2_by_2)
|
||||
@ -1164,26 +1134,23 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase):
|
||||
return defined(t1=t, t2=t, t3=t, t4=t, t5=t, t6=t, t7=t, t8=t)
|
||||
self._run(signature_computation, 30000)
|
||||
|
||||
@test_util.disable_tfrt("random ops not supported")
|
||||
def benchmark_matmul_read_variable_op_2_by_2_CPU(self):
|
||||
with context.device(CPU):
|
||||
m = resource_variable_ops.ResourceVariable(self._m_2_by_2)
|
||||
self._benchmark_matmul_read_variable(m, num_iters=self._num_iters_2_by_2)
|
||||
|
||||
@test_util.disable_tfrt("random ops not supported")
|
||||
def benchmark_matmul_read_variable_op_with_tape_2_by_2_CPU(self):
|
||||
with context.device(CPU):
|
||||
m = resource_variable_ops.ResourceVariable(self._m_2_by_2)
|
||||
self._benchmark_matmul_read_variable_with_tape(
|
||||
m, num_iters=self._num_iters_2_by_2)
|
||||
|
||||
@test_util.disable_tfrt("random ops not supported")
|
||||
def benchmark_read_variable_op_2_by_2_CPU(self):
|
||||
with context.device(CPU):
|
||||
m = resource_variable_ops.ResourceVariable(self._m_2_by_2)
|
||||
self._benchmark_read_variable(m, num_iters=self._num_iters_2_by_2)
|
||||
|
||||
@test_util.disable_tfrt("random ops not supported")
|
||||
@test_util.disable_tfrt("copy to GPU not supported")
|
||||
def benchmark_read_variable_op_2_by_2_GPU(self):
|
||||
if not context.num_gpus():
|
||||
return
|
||||
@ -1191,14 +1158,13 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase):
|
||||
m = resource_variable_ops.ResourceVariable(self._m_2_by_2.gpu())
|
||||
self._benchmark_read_variable(m, num_iters=self._num_iters_2_by_2)
|
||||
|
||||
@test_util.disable_tfrt("random ops not supported")
|
||||
def benchmark_read_variable_op_with_tape_2_by_2_CPU(self):
|
||||
with context.device(CPU):
|
||||
m = resource_variable_ops.ResourceVariable(self._m_2_by_2)
|
||||
self._benchmark_read_variable_with_tape(
|
||||
m, num_iters=self._num_iters_2_by_2)
|
||||
|
||||
@test_util.disable_tfrt("random ops not supported")
|
||||
@test_util.disable_tfrt("copy to GPU not supported")
|
||||
def benchmark_read_variable_op_with_tape_2_by_2_GPU(self):
|
||||
if not context.num_gpus():
|
||||
return
|
||||
@ -1228,7 +1194,6 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase):
|
||||
|
||||
self._run(scan, 100)
|
||||
|
||||
@test_util.disable_tfrt("add not supported, only add_v2")
|
||||
def benchmark_fastpath_conversion_type_inference(self):
|
||||
c = constant_op.constant(1., dtype=dtypes.float32)
|
||||
|
||||
@ -1268,7 +1233,6 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase):
|
||||
xs = [[[np.linspace(0, 1, 21).tolist()] * 20] * 20]
|
||||
self._run(lambda: constant_op.constant(xs, dtype=dtypes.float64), 10000)
|
||||
|
||||
@test_util.disable_tfrt("tf.fill not supported")
|
||||
def benchmark_list_of_zeros_to_np_array(self):
|
||||
values = []
|
||||
for _ in range(1000):
|
||||
@ -1286,11 +1250,11 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase):
|
||||
resources.append(resource_variable_ops.ResourceVariable(self._m_2))
|
||||
self._run(lambda: add_all(resources), num_iters)
|
||||
|
||||
@test_util.disable_tfrt("Random uniform needs fallback")
|
||||
@test_util.disable_tfrt("funtion not supported")
|
||||
def benchmarkFunctionWithFiveResourceInputs(self):
|
||||
self._benchmarkFunctionWithResourceInputs(5, 1000)
|
||||
|
||||
@test_util.disable_tfrt("Random uniform needs fallback")
|
||||
@test_util.disable_tfrt("funtion not supported")
|
||||
def benchmarkFunctionWithFiveHundredResourceInputs(self):
|
||||
self._benchmarkFunctionWithResourceInputs(500, 100)
|
||||
|
||||
@ -1325,15 +1289,15 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase):
|
||||
with context.device(CPU):
|
||||
self._run(benchmark_fn, 10)
|
||||
|
||||
@test_util.disable_tfrt("VarHandleOp needs fallback")
|
||||
@test_util.disable_tfrt("funtion not supported")
|
||||
def benchmarkTenThousandResourceReadsInCondInInnerFunc(self):
|
||||
self._benchmarkResourceReadsInCondInInnerFunc(10000)
|
||||
|
||||
@test_util.disable_tfrt("VarHandleOp needs fallback")
|
||||
@test_util.disable_tfrt("funtion not supported")
|
||||
def benchmarkHundredResourceReadsInCondInInnerFunc(self):
|
||||
self._benchmarkResourceReadsInCondInInnerFunc(100)
|
||||
|
||||
@test_util.disable_tfrt("VarHandleOp needs fallback")
|
||||
@test_util.disable_tfrt("funtion not supported")
|
||||
def benchmarkTenResourceReadsInCondInInnerFunc(self):
|
||||
self._benchmarkResourceReadsInCondInInnerFunc(10)
|
||||
|
||||
|
@ -6261,10 +6261,12 @@ def add_to_collection(name, value):
|
||||
Args:
|
||||
name: The key for the collection. For example, the `GraphKeys` class
|
||||
contains many standard names for collections.
|
||||
value: The value to add to the collection. @compatibility(eager)
|
||||
Collections are only supported in eager when variables are created inside
|
||||
an EagerVariableStore (e.g. as part of a layer or template).
|
||||
@end_compatibility
|
||||
value: The value to add to the collection.
|
||||
|
||||
@compatibility(eager)
|
||||
Collections are only supported in eager when variables are created inside
|
||||
an EagerVariableStore (e.g. as part of a layer or template).
|
||||
@end_compatibility
|
||||
"""
|
||||
get_default_graph().add_to_collection(name, value)
|
||||
|
||||
@ -6279,10 +6281,12 @@ def add_to_collections(names, value):
|
||||
Args:
|
||||
names: The key for the collections. The `GraphKeys` class contains many
|
||||
standard names for collections.
|
||||
value: The value to add to the collections. @compatibility(eager)
|
||||
Collections are only supported in eager when variables are created inside
|
||||
an EagerVariableStore (e.g. as part of a layer or template).
|
||||
@end_compatibility
|
||||
value: The value to add to the collections.
|
||||
|
||||
@compatibility(eager)
|
||||
Collections are only supported in eager when variables are created inside
|
||||
an EagerVariableStore (e.g. as part of a layer or template).
|
||||
@end_compatibility
|
||||
"""
|
||||
get_default_graph().add_to_collections(names, value)
|
||||
|
||||
|
@ -362,30 +362,6 @@ cuda_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "multi_worker_callback_tf1_test",
|
||||
srcs = ["multi_worker_callback_tf1_test.py"],
|
||||
# TODO(b/132384649): Enable for guitar and oss tests.
|
||||
shard_count = 24,
|
||||
tags = [
|
||||
"manual",
|
||||
"no_oss",
|
||||
"noguitar",
|
||||
"notap",
|
||||
],
|
||||
deps = [
|
||||
":distribute",
|
||||
":multi_worker_testing_utils",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python/distribute:collective_all_reduce_strategy",
|
||||
"//tensorflow/python/distribute:combinations",
|
||||
"//tensorflow/python/distribute:distribute_config",
|
||||
"//tensorflow/python/distribute:distribute_coordinator",
|
||||
"//tensorflow/python/distribute:multi_worker_test_base",
|
||||
"//tensorflow/python/keras",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "multi_worker_callback_tf2_test",
|
||||
srcs = ["multi_worker_callback_tf2_test.py"],
|
||||
@ -454,6 +430,11 @@ py_test(
|
||||
srcs = ["multi_worker_tutorial_test.py"],
|
||||
python_version = "PY3",
|
||||
shard_count = 5,
|
||||
tags = [
|
||||
"noasan",
|
||||
"nomsan",
|
||||
"notsan",
|
||||
], # TODO(b/156029134)
|
||||
deps = [
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
|
@ -1,597 +0,0 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for Keras callbacks in multi-worker training with TF1."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import threading
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python import keras
|
||||
from tensorflow.python.distribute import collective_all_reduce_strategy as collective_strategy
|
||||
from tensorflow.python.distribute import combinations
|
||||
from tensorflow.python.distribute import distribute_coordinator as dc
|
||||
from tensorflow.python.distribute import mirrored_strategy
|
||||
from tensorflow.python.distribute import multi_worker_test_base as test_base
|
||||
from tensorflow.python.distribute import multi_worker_util
|
||||
from tensorflow.python.keras import backend as K
|
||||
from tensorflow.python.keras import callbacks
|
||||
from tensorflow.python.keras import testing_utils
|
||||
from tensorflow.python.keras.distribute import multi_worker_testing_utils
|
||||
from tensorflow.python.keras.distribute import multi_worker_training_state as training_state
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
def get_strategy_object(strategy_cls):
|
||||
if strategy_cls == mirrored_strategy.MirroredStrategy:
|
||||
return strategy_cls(mirrored_strategy.all_local_devices())
|
||||
else:
|
||||
# CollectiveAllReduceStrategy and ParameterServerStrategy.
|
||||
return strategy_cls()
|
||||
|
||||
|
||||
def generate_callback_test_function(custom_callable):
|
||||
"""Generic template for callback tests using mnist synthetic dataset."""
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
mode=['graph'],
|
||||
strategy_cls=[collective_strategy.CollectiveAllReduceStrategy],
|
||||
required_gpus=[0, 1],
|
||||
file_format=['h5', 'tf']))
|
||||
def test_template(self, strategy_cls, file_format):
|
||||
num_workers = 2
|
||||
num_epoch = 2
|
||||
|
||||
cluster_spec = test_base.create_cluster_spec(num_workers=num_workers)
|
||||
self._barrier = dc._Barrier(2)
|
||||
|
||||
def _independent_worker_fn(*args, **kwargs): # pylint: disable=unused-argument
|
||||
"""Simulates an Independent Worker inside of a thread."""
|
||||
with test.mock.patch.object(dc, '_run_std_server',
|
||||
self._make_mock_run_std_server()):
|
||||
strategy = get_strategy_object(strategy_cls)
|
||||
batch_size = 64
|
||||
steps = 2
|
||||
train_ds, _ = multi_worker_testing_utils.mnist_synthetic_dataset(
|
||||
batch_size, steps)
|
||||
with strategy.scope():
|
||||
model = multi_worker_testing_utils.get_mnist_model((28, 28, 1))
|
||||
|
||||
custom_callable(
|
||||
model,
|
||||
self,
|
||||
train_ds,
|
||||
num_epoch,
|
||||
steps,
|
||||
strategy,
|
||||
saving_filepath=kwargs['saving_filepath'],
|
||||
barrier=kwargs['barrier'],
|
||||
threading_local=kwargs['threading_local'])
|
||||
|
||||
# Pass saving_filepath from the parent thread to ensure every worker has the
|
||||
# same filepath to save.
|
||||
saving_filepath = os.path.join(self.get_temp_dir(),
|
||||
'checkpoint.' + file_format)
|
||||
barrier = dc._Barrier(2)
|
||||
threading_local = threading.local()
|
||||
threads = self.run_multiple_tasks_in_threads(
|
||||
_independent_worker_fn,
|
||||
cluster_spec,
|
||||
saving_filepath=saving_filepath,
|
||||
barrier=barrier,
|
||||
threading_local=threading_local)
|
||||
self.assertFalse(training_state.checkpoint_exists(saving_filepath))
|
||||
|
||||
threads_to_join = []
|
||||
strategy = get_strategy_object(strategy_cls)
|
||||
if strategy.extended.experimental_between_graph:
|
||||
for ts in threads.values():
|
||||
threads_to_join.extend(ts)
|
||||
else:
|
||||
threads_to_join = [threads['worker'][0]]
|
||||
self.join_independent_workers(threads_to_join)
|
||||
|
||||
return test_template
|
||||
|
||||
|
||||
class KerasMultiWorkerCallbackTest(test_base.IndependentWorkerTestBase,
|
||||
parameterized.TestCase):
|
||||
"""KerasMultiWorkerCallbackTest for TF1.
|
||||
|
||||
TODO(rchao): Migrate all tests in this class to
|
||||
`multi_worker_callback_tf2_test`.
|
||||
"""
|
||||
|
||||
# The callables of the actual testing content to be run go below.
|
||||
@staticmethod
|
||||
def callableForTestChiefOnlyCallback(model, test_obj, train_ds, num_epoch,
|
||||
steps, strategy, saving_filepath,
|
||||
**kwargs):
|
||||
|
||||
class ChiefOnly(keras.callbacks.Callback):
|
||||
|
||||
def __init__(self):
|
||||
self._chief_worker_only = True
|
||||
self.filtered_correctly = True
|
||||
|
||||
def on_train_begin(self, logs):
|
||||
if not multi_worker_util.is_chief():
|
||||
# Non-chief workers shouldn't run this callback.
|
||||
self.filtered_correctly = False
|
||||
|
||||
cb = ChiefOnly()
|
||||
model.fit(
|
||||
x=train_ds, epochs=num_epoch, steps_per_epoch=steps, callbacks=[cb])
|
||||
|
||||
test_obj.assertTrue(cb.filtered_correctly)
|
||||
|
||||
@staticmethod
|
||||
def callableForTestModelCheckpointSavesOnChiefButNotOtherwise(
|
||||
model, test_obj, train_ds, num_epoch, steps, strategy, saving_filepath,
|
||||
**kwargs):
|
||||
|
||||
extension = os.path.splitext(saving_filepath)[1]
|
||||
|
||||
# Incorporate type/index information and thread id in saving_filepath to
|
||||
# ensure every worker has a unique path. Note that in normal use case the
|
||||
# saving_filepath will be the same for all workers, but we use different
|
||||
# ones here just to test out chief saves checkpoint but non-chief doesn't.
|
||||
|
||||
saving_filepath = os.path.join(
|
||||
test_obj.get_temp_dir(), 'checkpoint_%s_%d%s' %
|
||||
(test_base.get_task_type(), test_base.get_task_index(), extension))
|
||||
|
||||
# The saving_filepath shouldn't exist at the beginning (as it's unique).
|
||||
test_obj.assertFalse(training_state.checkpoint_exists(saving_filepath))
|
||||
|
||||
model.fit(
|
||||
x=train_ds,
|
||||
epochs=num_epoch,
|
||||
steps_per_epoch=steps,
|
||||
callbacks=[callbacks.ModelCheckpoint(filepath=saving_filepath)])
|
||||
|
||||
# If it's chief, the model should be saved; if not, the model shouldn't.
|
||||
test_obj.assertEqual(
|
||||
training_state.checkpoint_exists(saving_filepath), test_base.is_chief())
|
||||
|
||||
@staticmethod
|
||||
def initialFitting(test_obj, model, train_ds, num_epoch, steps,
|
||||
saving_filepath):
|
||||
# The saving_filepath shouldn't exist at the beginning.
|
||||
test_obj.assertFalse(training_state.checkpoint_exists(saving_filepath))
|
||||
|
||||
model.fit(
|
||||
x=train_ds,
|
||||
epochs=num_epoch,
|
||||
steps_per_epoch=steps,
|
||||
callbacks=[
|
||||
callbacks.ModelCheckpoint(
|
||||
filepath=saving_filepath, save_weights_only=True)
|
||||
])
|
||||
|
||||
# The saving_filepath should exist after fitting with callback. Both chief
|
||||
# and non-chief worker should both see it exists (which was saved only by
|
||||
# chief).
|
||||
test_obj.assertTrue(training_state.checkpoint_exists(saving_filepath))
|
||||
|
||||
history_after_one_more_epoch = model.fit(
|
||||
x=train_ds, epochs=1, steps_per_epoch=steps)
|
||||
|
||||
# The saving_filepath should continue to exist (if it did) after fitting
|
||||
# without callback.
|
||||
test_obj.assertTrue(training_state.checkpoint_exists(saving_filepath))
|
||||
|
||||
return saving_filepath, history_after_one_more_epoch
|
||||
|
||||
@staticmethod
|
||||
def callableForTestLoadWeightFromModelCheckpoint(model, test_obj, train_ds,
|
||||
num_epoch, steps, strategy,
|
||||
saving_filepath, **kwargs):
|
||||
filepaths = []
|
||||
real_mkstemp = tempfile.mkstemp
|
||||
def mocked_mkstemp():
|
||||
# Only non-chief should call tempfile.mkstemp() inside fit() in sync
|
||||
# training.
|
||||
assert not test_base.is_chief()
|
||||
file_handle, temp_file_name = real_mkstemp()
|
||||
extension = os.path.splitext(saving_filepath)[1]
|
||||
temp_filepath = temp_file_name + extension
|
||||
filepaths.append(temp_filepath)
|
||||
return file_handle, temp_file_name
|
||||
|
||||
# Mock tempfile.mkstemp() so the filepaths can be stored and verified later.
|
||||
with test.mock.patch.object(tempfile, 'mkstemp', mocked_mkstemp):
|
||||
saving_filepath, history_after_one_more_epoch = \
|
||||
KerasMultiWorkerCallbackTest.initialFitting(
|
||||
test_obj, model, train_ds, num_epoch, steps, saving_filepath)
|
||||
|
||||
with strategy.scope():
|
||||
model.load_weights(saving_filepath)
|
||||
|
||||
history_after_loading_weight_and_one_more_epoch = model.fit(
|
||||
x=train_ds, epochs=1, steps_per_epoch=steps)
|
||||
|
||||
test_obj.assertAllClose(
|
||||
history_after_one_more_epoch.history,
|
||||
history_after_loading_weight_and_one_more_epoch.history,
|
||||
rtol=5e-5)
|
||||
|
||||
# Verify the temp files are indeed removed (no trace left behind).
|
||||
for filepath in filepaths:
|
||||
assert not training_state.checkpoint_exists(filepath)
|
||||
|
||||
@staticmethod
|
||||
def callableForTestModelRestoreCallback(model, test_obj, train_ds, num_epoch,
|
||||
steps, strategy, saving_filepath,
|
||||
**kwargs):
|
||||
|
||||
saving_filepath, history_after_one_more_epoch = \
|
||||
KerasMultiWorkerCallbackTest.initialFitting(
|
||||
test_obj, model, train_ds, num_epoch, steps, saving_filepath)
|
||||
|
||||
# The model should get restored to the weights previously saved, by
|
||||
# adding a ModelCheckpoint callback (which results in a
|
||||
# _ModelRestoreCallback being added), with load_weights_on_restart=True.
|
||||
history_after_model_restoring_and_one_more_epoch = model.fit(
|
||||
x=train_ds,
|
||||
epochs=1,
|
||||
steps_per_epoch=steps,
|
||||
callbacks=[
|
||||
callbacks.ModelCheckpoint(
|
||||
filepath=saving_filepath,
|
||||
save_weights_only=True,
|
||||
load_weights_on_restart=True)
|
||||
])
|
||||
|
||||
# Asserting the history one epoch after initial fitting and one epoch after
|
||||
# restoring are closed.
|
||||
test_obj.assertAllClose(
|
||||
history_after_one_more_epoch.history,
|
||||
history_after_model_restoring_and_one_more_epoch.history,
|
||||
rtol=5e-5)
|
||||
|
||||
history_one_more_epoch_without_model_restoring = model.fit(
|
||||
x=train_ds, epochs=1, steps_per_epoch=steps)
|
||||
|
||||
# Ensuring training for another epoch gives different result.
|
||||
test_obj.assertNotAllClose(
|
||||
history_after_model_restoring_and_one_more_epoch.history,
|
||||
history_one_more_epoch_without_model_restoring.history,
|
||||
rtol=5e-5)
|
||||
|
||||
@staticmethod
|
||||
def callableForTestBackupModelRemoved(model, test_obj, train_ds, num_epoch,
|
||||
steps, strategy, saving_filepath,
|
||||
**kwargs):
|
||||
|
||||
# `barrier` object needs to be passed in from parent
|
||||
# thread so both threads refer to the same object.
|
||||
barrier = kwargs['barrier']
|
||||
|
||||
num_epoch = 3
|
||||
|
||||
# Testing the backup filepath `multi_worker_training_state` uses.
|
||||
_, backup_filepath = training_state._get_backup_filepath(saving_filepath)
|
||||
|
||||
# The backup_filepath shouldn't exist at the beginning.
|
||||
test_obj.assertFalse(training_state.checkpoint_exists(backup_filepath))
|
||||
|
||||
# Callback to verify that the backup file exists in the middle of training.
|
||||
class BackupFilepathVerifyingCallback(callbacks.Callback):
|
||||
|
||||
def on_epoch_begin(self, epoch, logs=None):
|
||||
if epoch > 1:
|
||||
# Asserting that after the first two epochs, the backup file should
|
||||
# exist.
|
||||
test_obj.assertTrue(training_state.checkpoint_exists(backup_filepath))
|
||||
|
||||
model.fit(
|
||||
x=train_ds,
|
||||
epochs=num_epoch,
|
||||
steps_per_epoch=steps,
|
||||
callbacks=[
|
||||
callbacks.ModelCheckpoint(
|
||||
filepath=saving_filepath, save_weights_only=True),
|
||||
BackupFilepathVerifyingCallback()
|
||||
])
|
||||
|
||||
# Sync on the two threads so we make sure the backup file is removed before
|
||||
# we move on.
|
||||
barrier.wait()
|
||||
|
||||
# The back up file should not exist at successful exit of `model.fit()`.
|
||||
test_obj.assertFalse(training_state.checkpoint_exists(backup_filepath))
|
||||
|
||||
@staticmethod
|
||||
def callableForTestBackupModelNotRemovedIfInterrupted(model, test_obj,
|
||||
train_ds, num_epoch,
|
||||
steps, strategy,
|
||||
saving_filepath,
|
||||
**kwargs):
|
||||
|
||||
# `barrier` object needs to be passed in from parent
|
||||
# thread so both threads refer to the same object.
|
||||
barrier = kwargs['barrier']
|
||||
|
||||
num_epoch = 4
|
||||
|
||||
# Testing the backup filepath `multi_worker_training_state` uses.
|
||||
_, backup_filepath = training_state._get_backup_filepath(saving_filepath)
|
||||
|
||||
# The backup_filepath shouldn't exist at the beginning.
|
||||
test_obj.assertFalse(training_state.checkpoint_exists(backup_filepath))
|
||||
|
||||
# Callback to interrupt in the middle of training.
|
||||
class InterruptingCallback(callbacks.Callback):
|
||||
|
||||
def on_epoch_begin(self, epoch, logs=None):
|
||||
if epoch == 2:
|
||||
raise RuntimeError('Interrupting!')
|
||||
|
||||
try:
|
||||
model.fit(
|
||||
x=train_ds,
|
||||
epochs=num_epoch,
|
||||
steps_per_epoch=steps,
|
||||
callbacks=[
|
||||
callbacks.ModelCheckpoint(
|
||||
filepath=saving_filepath, save_weights_only=True),
|
||||
InterruptingCallback()
|
||||
])
|
||||
except RuntimeError as e:
|
||||
if 'Interrupting!' not in e.message:
|
||||
raise
|
||||
|
||||
# Sync on the two threads.
|
||||
barrier.wait()
|
||||
|
||||
# The back up file should exist after interruption of `model.fit()`.
|
||||
test_obj.assertTrue(training_state.checkpoint_exists(backup_filepath))
|
||||
|
||||
@staticmethod
|
||||
def callableForTestUnmatchedModelFile(model, test_obj, train_ds, num_epoch,
|
||||
steps, strategy, saving_filepath,
|
||||
**kwargs):
|
||||
|
||||
# The saving_filepath shouldn't exist at the beginning.
|
||||
test_obj.assertFalse(training_state.checkpoint_exists(saving_filepath))
|
||||
|
||||
model.fit(
|
||||
x=train_ds,
|
||||
epochs=num_epoch,
|
||||
steps_per_epoch=steps,
|
||||
callbacks=[
|
||||
callbacks.ModelCheckpoint(
|
||||
filepath=saving_filepath, save_weights_only=True)
|
||||
])
|
||||
|
||||
(train_ds, _), (_, _) = testing_utils.get_test_data(
|
||||
train_samples=10, test_samples=10, input_shape=(3,), num_classes=2)
|
||||
|
||||
# Switch to a model of different structure.
|
||||
with strategy.scope():
|
||||
model = keras.models.Sequential()
|
||||
model.add(keras.layers.Dense(5, input_dim=3, activation='relu'))
|
||||
model.add(keras.layers.Dense(2, activation='softmax'))
|
||||
model.compile(
|
||||
loss='categorical_crossentropy', optimizer='rmsprop', metrics=['acc'])
|
||||
|
||||
test_obj.assertTrue(training_state.checkpoint_exists(saving_filepath))
|
||||
|
||||
if saving_filepath.endswith('.tf'):
|
||||
test_obj.skipTest('Loading mismatched TF checkpoint would cause Fatal '
|
||||
'Python error: Aborted. Skipping.')
|
||||
|
||||
# Unmatched format. Should raise ValueError.
|
||||
with test_obj.assertRaisesRegexp(ValueError, 'Error loading file from'):
|
||||
model.fit(
|
||||
x=train_ds,
|
||||
epochs=num_epoch,
|
||||
batch_size=8,
|
||||
callbacks=[
|
||||
callbacks.ModelCheckpoint(
|
||||
filepath=saving_filepath,
|
||||
save_weights_only=True,
|
||||
load_weights_on_restart=True)
|
||||
])
|
||||
|
||||
@staticmethod
|
||||
def callableForTestReduceLROnPlateau(model, test_obj, train_ds, num_epoch,
|
||||
steps, strategy, saving_filepath,
|
||||
**kwargs):
|
||||
|
||||
cbks = [
|
||||
callbacks.ReduceLROnPlateau(
|
||||
monitor='loss',
|
||||
factor=0.1,
|
||||
min_delta=1,
|
||||
patience=1,
|
||||
cooldown=5,
|
||||
verbose=1)
|
||||
]
|
||||
|
||||
# It is expected that the learning rate would drop by `factor` within
|
||||
# 3 epochs with `min_delta=1`.
|
||||
model.fit(x=train_ds, epochs=3, steps_per_epoch=steps, callbacks=cbks)
|
||||
test_obj.assertAllClose(
|
||||
float(K.get_value(model.optimizer.lr)), 0.0001, atol=1e-8)
|
||||
|
||||
# It is expected that the learning rate would drop by another `factor`
|
||||
# within 3 epochs with `min_delta=1`.
|
||||
model.fit(x=train_ds, epochs=3, steps_per_epoch=steps, callbacks=cbks)
|
||||
test_obj.assertAllClose(
|
||||
float(K.get_value(model.optimizer.lr)), 0.00001, atol=1e-8)
|
||||
|
||||
@staticmethod
|
||||
def callableForTestEarlyStopping(model, test_obj, train_ds, num_epoch, steps,
|
||||
strategy, saving_filepath, **kwargs):
|
||||
|
||||
class EpochCounterCallback(callbacks.Callback):
|
||||
|
||||
def on_epoch_begin(self, epoch, logs):
|
||||
self.last_epoch = epoch
|
||||
|
||||
epoch_counter_cbk = EpochCounterCallback()
|
||||
cbks = [
|
||||
callbacks.EarlyStopping(
|
||||
monitor='loss', min_delta=0.05, patience=1, verbose=1),
|
||||
epoch_counter_cbk
|
||||
]
|
||||
|
||||
# Empirically, it is expected that `model.fit()` would terminate around the
|
||||
# 22th epoch. Asserting that it should have been stopped before the 50th
|
||||
# epoch to avoid flakiness and be more predictable.
|
||||
model.fit(x=train_ds, epochs=100, steps_per_epoch=steps, callbacks=cbks)
|
||||
test_obj.assertLess(epoch_counter_cbk.last_epoch, 50)
|
||||
|
||||
@staticmethod
|
||||
def callableForTestLearningRateScheduler(model, test_obj, train_ds, num_epoch,
|
||||
steps, strategy, saving_filepath,
|
||||
**kwargs):
|
||||
|
||||
cbks = [
|
||||
callbacks.LearningRateScheduler(
|
||||
schedule=lambda x: 1. / (1. + x), verbose=1)
|
||||
]
|
||||
|
||||
# It is expected that with `epochs=2`, the learning rate would drop to
|
||||
# 1 / (1 + 2) = 0.5.
|
||||
model.fit(x=train_ds, epochs=2, steps_per_epoch=steps, callbacks=cbks)
|
||||
test_obj.assertAllClose(
|
||||
float(K.get_value(model.optimizer.lr)), 0.5, atol=1e-8)
|
||||
|
||||
# It is expected that with `epochs=4`, the learning rate would drop to
|
||||
# 1 / (1 + 4) = 0.25.
|
||||
model.fit(x=train_ds, epochs=4, steps_per_epoch=steps, callbacks=cbks)
|
||||
test_obj.assertAllClose(
|
||||
float(K.get_value(model.optimizer.lr)), 0.25, atol=1e-8)
|
||||
|
||||
# pylint: disable=g-doc-args
|
||||
@staticmethod
|
||||
def callableForTestIntermediateDirForFTAreRemoved(model, test_obj, train_ds,
|
||||
num_epoch, steps, strategy,
|
||||
saving_filepath, **kwargs):
|
||||
"""Testing that the temporary directory are removed.
|
||||
|
||||
Some temporary directories are created for the purpose of fault tolerance.
|
||||
This test ensures that such directories should have been removed at the time
|
||||
`model.fit()` finishes successfully.
|
||||
"""
|
||||
|
||||
# `threading_local` and `barrier` objects have to be passed in from parent
|
||||
# thread so both threads refer to the same object.
|
||||
threading_local = kwargs['threading_local']
|
||||
barrier = kwargs['barrier']
|
||||
|
||||
# Two threads will each has one copy of `temp_dirs_supposed_to_be_removed`
|
||||
# list.
|
||||
threading_local.temp_dirs_supposed_to_be_removed = []
|
||||
|
||||
callbacks_list = [
|
||||
callbacks.ModelCheckpoint(
|
||||
filepath=saving_filepath,
|
||||
save_weights_only=True,
|
||||
load_weights_on_restart=True),
|
||||
]
|
||||
|
||||
# Keep the references to the real function objects.
|
||||
real_os_path_join = os.path.join
|
||||
real_tempfile_mkdtemp = tempfile.mkdtemp
|
||||
|
||||
# Make a `os.path.join` wrapper, which will be patched onto the real
|
||||
# function, so the temporary directories can be tracked.
|
||||
def wrapper_os_path_join(path, *paths):
|
||||
join_result = real_os_path_join(path, *paths)
|
||||
if len(paths) == 1 and paths[0] == 'backup':
|
||||
threading_local.temp_dirs_supposed_to_be_removed.append(join_result)
|
||||
return join_result
|
||||
|
||||
# Likewise for `tempfile.mkdtemp`.
|
||||
def wrapper_tempfile_mkdtemp():
|
||||
result = real_tempfile_mkdtemp()
|
||||
threading_local.temp_dirs_supposed_to_be_removed.append(result)
|
||||
return result
|
||||
|
||||
# Now the two threads must sync here: if they are out of sync, one thread
|
||||
# can go ahead and patch `os.path.join` while the other has not even
|
||||
# assigned the real `os.path.join` to `real_os_path_join`. If this happened,
|
||||
# the "real" `os.path.join` the slower thread would see is actually the
|
||||
# wrapper of the other.
|
||||
barrier.wait()
|
||||
|
||||
# Note that `os.path.join` will respect the second patch (there are two
|
||||
# patches because of the two threads). Both threads will refer to the same
|
||||
# copy of `wrapper_os_path_join` because of the `barrier` preceding
|
||||
# `model.fit()`. Likewise for `wrapper_tempfile_mkdtemp`.
|
||||
os.path.join = wrapper_os_path_join
|
||||
tempfile.mkdtemp = wrapper_tempfile_mkdtemp
|
||||
|
||||
barrier.wait()
|
||||
model.fit(
|
||||
x=train_ds,
|
||||
epochs=num_epoch,
|
||||
steps_per_epoch=steps,
|
||||
callbacks=callbacks_list)
|
||||
|
||||
# Sync before un-patching to prevent either thread from accessing the real
|
||||
# functions. Also to make sure `model.fit()` is done on both threads (so we
|
||||
# can safely assert the directories are removed).
|
||||
barrier.wait()
|
||||
os.path.join = real_os_path_join
|
||||
tempfile.mkdtemp = real_tempfile_mkdtemp
|
||||
|
||||
# There should be directory (names) that are supposed to be removed.
|
||||
test_obj.assertTrue(threading_local.temp_dirs_supposed_to_be_removed)
|
||||
for temp_dir_supposed_to_be_removed in (
|
||||
threading_local.temp_dirs_supposed_to_be_removed):
|
||||
# They should have been removed and thus don't exist.
|
||||
test_obj.assertFalse(os.path.exists(temp_dir_supposed_to_be_removed))
|
||||
|
||||
# The actual testing methods go here.
|
||||
test_chief_only_callback = generate_callback_test_function(
|
||||
callableForTestChiefOnlyCallback.__func__)
|
||||
test_model_checkpoint_saves_on_chief_but_not_otherwise = \
|
||||
generate_callback_test_function(
|
||||
callableForTestModelCheckpointSavesOnChiefButNotOtherwise.__func__)
|
||||
test_load_weight_from_model_checkpoint = generate_callback_test_function(
|
||||
callableForTestLoadWeightFromModelCheckpoint.__func__)
|
||||
test_model_restore_callback = generate_callback_test_function(
|
||||
callableForTestModelRestoreCallback.__func__)
|
||||
test_unmatched_model_file = generate_callback_test_function(
|
||||
callableForTestUnmatchedModelFile.__func__)
|
||||
test_reduce_lr_on_plateau = generate_callback_test_function(
|
||||
callableForTestReduceLROnPlateau.__func__)
|
||||
test_early_stopping = generate_callback_test_function(
|
||||
callableForTestEarlyStopping.__func__)
|
||||
test_learning_rate_scheduler = generate_callback_test_function(
|
||||
callableForTestLearningRateScheduler.__func__)
|
||||
test_intermediate_dir_for_ft_are_removed = generate_callback_test_function(
|
||||
callableForTestIntermediateDirForFTAreRemoved.__func__)
|
||||
test_backup_model_removed = generate_callback_test_function(
|
||||
callableForTestBackupModelRemoved.__func__)
|
||||
test_backup_model_not_removed_if_interrupted = \
|
||||
generate_callback_test_function(
|
||||
callableForTestBackupModelNotRemovedIfInterrupted.__func__)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
with test.mock.patch.object(sys, 'exit', os._exit):
|
||||
test.main()
|
@ -676,16 +676,8 @@ def enable_v2_dtype_behavior():
|
||||
float32) instead of None. In addition, layers will automatically cast
|
||||
floating-point inputs to the layer's dtype.
|
||||
|
||||
>>> tf.compat.v1.keras.layers.disable_v2_dtype_behavior()
|
||||
>>> x = tf.ones((4, 4, 4, 4), dtype='float64')
|
||||
>>> layer = tf.keras.layers.Conv2D(filters=4, kernel_size=2)
|
||||
>>> print(layer.dtype) # None since V2 behavior is disabled
|
||||
None
|
||||
>>> y = layer(x) # Doesn't cast inputs since V2 dtype behavior is disabled
|
||||
>>> print(y.dtype.name)
|
||||
float64
|
||||
>>> tf.compat.v1.keras.layers.enable_v2_dtype_behavior()
|
||||
>>> layer = tf.keras.layers.Conv2D(filters=4, kernel_size=2)
|
||||
>>> print(layer.dtype) # float32 since V2 dtype behavior is enabled
|
||||
float32
|
||||
>>> y = layer(x) # Layer casts inputs since V2 dtype behavior is enabled
|
||||
|
@ -41,7 +41,7 @@ _VARIANCE_NAME = 'variance'
|
||||
class Normalization(CombinerPreprocessingLayer):
|
||||
"""Feature-wise normalization of the data.
|
||||
|
||||
This layer will coerce its inputs into a normal distribution centered around
|
||||
This layer will coerce its inputs into a distribution centered around
|
||||
0 with standard deviation 1. It accomplishes this by precomputing the mean and
|
||||
variance of the data, and calling (input-mean)/sqrt(var) at runtime.
|
||||
|
||||
|
@ -121,16 +121,19 @@ class RMSprop(optimizer_v2.OptimizerV2):
|
||||
Setting this to `True` may help with training, but is slightly more
|
||||
expensive in terms of computation and memory. Defaults to `False`.
|
||||
name: Optional name prefix for the operations created when applying
|
||||
gradients. Defaults to "RMSprop". @compatibility(eager) When eager
|
||||
execution is enabled, `learning_rate`, `decay`, `momentum`, and
|
||||
`epsilon` can each be a callable that takes no arguments and returns the
|
||||
actual value to use. This can be useful for changing these values across
|
||||
different invocations of optimizer functions. @end_compatibility
|
||||
gradients. Defaults to "RMSprop".
|
||||
**kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`, `lr`,
|
||||
`decay`}. `clipnorm` is clip gradients by norm; `clipvalue` is clip
|
||||
gradients by value, `decay` is included for backward compatibility to
|
||||
allow time inverse decay of learning rate. `lr` is included for backward
|
||||
compatibility, recommended to use `learning_rate` instead.
|
||||
|
||||
@compatibility(eager)
|
||||
When eager execution is enabled, `learning_rate`, `decay`, `momentum`, and
|
||||
`epsilon` can each be a callable that takes no arguments and returns the
|
||||
actual value to use. This can be useful for changing these values across
|
||||
different invocations of optimizer functions.
|
||||
@end_compatibility
|
||||
"""
|
||||
super(RMSprop, self).__init__(name, **kwargs)
|
||||
self._set_hyper("learning_rate", kwargs.get("lr", learning_rate))
|
||||
|
@ -1353,7 +1353,7 @@ def split_compile_and_replicate(computation,
|
||||
|
||||
def custom_getter(getter, name, *args, **kwargs):
|
||||
"""Variables on TPU have a few restrictions."""
|
||||
partitioner = kwargs["partitioner"]
|
||||
partitioner = kwargs.get("partitioner", None)
|
||||
if partitioner is not None:
|
||||
kwargs["partitioner"] = None
|
||||
logging.warning(
|
||||
|
@ -92,11 +92,14 @@ class AdamOptimizer(optimizer.Optimizer):
|
||||
Section 2.1), not the epsilon in Algorithm 1 of the paper.
|
||||
use_locking: If True use locks for update operations.
|
||||
name: Optional name for the operations created when applying gradients.
|
||||
Defaults to "Adam". @compatibility(eager) When eager execution is
|
||||
enabled, `learning_rate`, `beta1`, `beta2`, and `epsilon` can each be a
|
||||
callable that takes no arguments and returns the actual value to use.
|
||||
This can be useful for changing these values across different
|
||||
invocations of optimizer functions. @end_compatibility
|
||||
Defaults to "Adam".
|
||||
|
||||
@compatibility(eager)
|
||||
When eager execution is enabled, `learning_rate`, `beta1`, `beta2`, and
|
||||
`epsilon` can each be a callable that takes no arguments and returns the
|
||||
actual value to use. This can be useful for changing these values across
|
||||
different invocations of optimizer functions.
|
||||
@end_compatibility
|
||||
"""
|
||||
super(AdamOptimizer, self).__init__(use_locking, name)
|
||||
self._lr = learning_rate
|
||||
|
@ -222,11 +222,11 @@ cc_library(
|
||||
hdrs = if_gpu_is_configured(["asm_compiler.h"]),
|
||||
copts = tf_copts(),
|
||||
visibility = [
|
||||
"//tensorflow/compiler/mlir/tools/kernel_gen:__subpackages__",
|
||||
"//tensorflow/compiler/xla/service/gpu:__subpackages__",
|
||||
"//tensorflow/compiler/xla/service/mlir_gpu:__subpackages__",
|
||||
"//tensorflow/core/kernels:__subpackages__",
|
||||
"//tensorflow/stream_executor:__subpackages__",
|
||||
"//third_party/tf_runtime/tools/tf_kernel_gen:__subpackages__",
|
||||
],
|
||||
deps = if_gpu_is_configured([
|
||||
":gpu_asm_opts",
|
||||
|
@ -250,11 +250,11 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""):
|
||||
name = "eigen_archive",
|
||||
build_file = clean_dep("//third_party:eigen.BUILD"),
|
||||
patch_file = clean_dep("//third_party/eigen3:gpu_packet_math.patch"),
|
||||
sha256 = "d96aa8eda6dbf80e313c992a59e9e9451f420a6b9f58ef30aa41bffdc9df2f1b", # SHARED_EIGEN_SHA
|
||||
strip_prefix = "eigen-1e41406c362788057b3adcd9a25b73f43e6e6492",
|
||||
sha256 = "2c7c0aec4271dfca6b8a7707e2112f67c4cb3bdf7c89c0e98d3fcd39707c4468", # SHARED_EIGEN_SHA
|
||||
strip_prefix = "eigen-49f1aeb60d9f759859fce0d16aa5d1ecc7168d51",
|
||||
urls = [
|
||||
"https://storage.googleapis.com/mirror.tensorflow.org/gitlab.com/libeigen/eigen/-/archive/1e41406c362788057b3adcd9a25b73f43e6e6492/eigen-1e41406c362788057b3adcd9a25b73f43e6e6492.tar.gz",
|
||||
"https://gitlab.com/libeigen/eigen/-/archive/1e41406c362788057b3adcd9a25b73f43e6e6492/eigen-1e41406c362788057b3adcd9a25b73f43e6e6492.tar.gz",
|
||||
"https://storage.googleapis.com/mirror.tensorflow.org/gitlab.com/libeigen/eigen/-/archive/49f1aeb60d9f759859fce0d16aa5d1ecc7168d51/eigen-49f1aeb60d9f759859fce0d16aa5d1ecc7168d51.tar.gz",
|
||||
"https://gitlab.com/libeigen/eigen/-/archive/49f1aeb60d9f759859fce0d16aa5d1ecc7168d51/eigen-49f1aeb60d9f759859fce0d16aa5d1ecc7168d51.tar.gz",
|
||||
],
|
||||
)
|
||||
|
||||
@ -679,8 +679,8 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""):
|
||||
)
|
||||
|
||||
# Check out LLVM and MLIR from llvm-project.
|
||||
LLVM_COMMIT = "307cfdf5338641e3a895857ef02dc9da35cd0eb6"
|
||||
LLVM_SHA256 = "5e75125ecadee4f91e07c20bf6612d740913a677348fd33c7264ee8fe7d12b17"
|
||||
LLVM_COMMIT = "91087153210132a4c2d3cf19a4526d8f395cb5a4"
|
||||
LLVM_SHA256 = "b2e2314ce2d4a7f0da436063c922d716171415d1b5e85889235d9eab1ecb98c1"
|
||||
LLVM_URLS = [
|
||||
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT),
|
||||
"https://github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT),
|
||||
|
31
third_party/mlir/BUILD
vendored
31
third_party/mlir/BUILD
vendored
@ -657,6 +657,7 @@ gentbl(
|
||||
td_file = "include/mlir/Dialect/Shape/IR/ShapeOps.td",
|
||||
td_srcs = [
|
||||
":StdOpsTdFiles",
|
||||
"include/mlir/Dialect/Shape/IR/ShapeBase.td",
|
||||
"include/mlir/Interfaces/InferTypeOpInterface.td",
|
||||
],
|
||||
)
|
||||
@ -715,24 +716,35 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
gentbl(
|
||||
name = "StandardOpsTransformsPassIncGen",
|
||||
strip_include_prefix = "include",
|
||||
tbl_outs = [(
|
||||
"-gen-pass-decls",
|
||||
"include/mlir/Dialect/StandardOps/Transforms/Passes.h.inc",
|
||||
)],
|
||||
tblgen = ":mlir-tblgen",
|
||||
td_file = "include/mlir/Dialect/StandardOps/Transforms/Passes.td",
|
||||
td_srcs = [":PassBaseTdFiles"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "StandardOpsTransforms",
|
||||
srcs = glob(
|
||||
[
|
||||
"lib/Dialect/StandardOps/Transforms/*.cpp",
|
||||
"lib/Dialect/StandardOps/Transforms/*.h",
|
||||
],
|
||||
),
|
||||
hdrs = glob([
|
||||
"include/mlir/Dialect/StandardOps/Transforms/*.h",
|
||||
srcs = glob([
|
||||
"lib/Dialect/StandardOps/Transforms/*.cpp",
|
||||
"lib/Dialect/StandardOps/Transforms/*.h",
|
||||
]),
|
||||
hdrs = glob(["include/mlir/Dialect/StandardOps/Transforms/*.h"]),
|
||||
includes = ["include"],
|
||||
deps = [
|
||||
":Analysis",
|
||||
":ControlFlowInterfaces",
|
||||
":IR",
|
||||
":Pass",
|
||||
":StandardOps",
|
||||
":StandardOpsTransformsPassIncGen",
|
||||
":Support",
|
||||
":Transforms",
|
||||
"@llvm-project//llvm:support",
|
||||
],
|
||||
)
|
||||
@ -2471,6 +2483,7 @@ cc_library(
|
||||
":NVVMDialect",
|
||||
":Parser",
|
||||
":Pass",
|
||||
":StandardOpsTransforms",
|
||||
":StandardToSPIRVConversions",
|
||||
":StandardToStandard",
|
||||
":Support",
|
||||
@ -2568,6 +2581,8 @@ cc_library(
|
||||
":SPIRVPassIncGen",
|
||||
":Shape",
|
||||
":StandardOps",
|
||||
":StandardOpsTransforms",
|
||||
":StandardOpsTransformsPassIncGen",
|
||||
":StandardToSPIRVConversions",
|
||||
":StandardToStandard",
|
||||
":Transforms",
|
||||
|
Loading…
Reference in New Issue
Block a user