sync to master
This commit is contained in:
parent
ed158956df
commit
81adaff8a6
@ -137,6 +137,10 @@ This release contains contributions from many people at Google, as well as:
|
|||||||
|
|
||||||
<INSERT>, <NAME>, <HERE>, <USING>, <GITHUB>, <HANDLE>
|
<INSERT>, <NAME>, <HERE>, <USING>, <GITHUB>, <HANDLE>
|
||||||
|
|
||||||
|
# Release 2.4.1
|
||||||
|
|
||||||
|
* This release removes the AVX2 requirement from TF 2.4.0.
|
||||||
|
|
||||||
# Release 2.3.2
|
# Release 2.3.2
|
||||||
|
|
||||||
## Bug Fixes and Other Changes
|
## Bug Fixes and Other Changes
|
||||||
|
|||||||
@ -185,6 +185,9 @@ class ImmediateExecutionContext : public AbstractContext {
|
|||||||
|
|
||||||
virtual std::vector<std::string> GetLoggedOpsTestonly() { return {}; }
|
virtual std::vector<std::string> GetLoggedOpsTestonly() { return {}; }
|
||||||
|
|
||||||
|
// Get a list of the names of functions that have been registered.
|
||||||
|
virtual std::vector<string> ListFunctionNames() = 0;
|
||||||
|
|
||||||
//===--------------------------------------------------------------------===//
|
//===--------------------------------------------------------------------===//
|
||||||
// Distributed runtime related functions.
|
// Distributed runtime related functions.
|
||||||
//===--------------------------------------------------------------------===//
|
//===--------------------------------------------------------------------===//
|
||||||
|
|||||||
@ -127,6 +127,7 @@ add_mlir_library(MhloLhloToLinalg
|
|||||||
|
|
||||||
LINK_LIBS PUBLIC
|
LINK_LIBS PUBLIC
|
||||||
MhloDialect
|
MhloDialect
|
||||||
|
MLIRComplex
|
||||||
MLIRIR
|
MLIRIR
|
||||||
MLIRPass
|
MLIRPass
|
||||||
)
|
)
|
||||||
|
|||||||
@ -372,3 +372,112 @@ func @testNoDilatedConvWhenGivenInputIsNonFloatType(%arg0: tensor<1x128x128x3xi3
|
|||||||
// CHECK-NEXT: [[RESULT:%.*]] = "tf.BatchToSpaceND"
|
// CHECK-NEXT: [[RESULT:%.*]] = "tf.BatchToSpaceND"
|
||||||
// CHECK-NEXT: return [[RESULT]]
|
// CHECK-NEXT: return [[RESULT]]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func @testDilatedConv1DExpandH(%arg0: tensor<1x128x3xf32>, %arg1: tensor<1x5x3x8xf32>) -> tensor<1x128x8xf32> {
|
||||||
|
%cst = "tf.Const"() {value = dense<0> : tensor<1x2xi32>} : () -> tensor<1x2xi32>
|
||||||
|
%cst_0 = "tf.Const"() {value = dense<-3> : tensor<i32>} : () -> tensor<i32>
|
||||||
|
%cst_1 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||||
|
%cst_2 = "tf.Const"() {value = dense<4> : tensor<1x2xi32>} : () -> tensor<1x2xi32>
|
||||||
|
%0 = "tf.SpaceToBatchND"(%arg0, %cst_1, %cst_2) : (tensor<1x128x3xf32>, tensor<1xi32>, tensor<1x2xi32>) -> tensor<2x68x3xf32>
|
||||||
|
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<2x68x3xf32>, tensor<i32>) -> tensor<2x1x68x3xf32>
|
||||||
|
%2 = "tf.Conv2D"(%1, %arg1) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<2x1x68x3xf32>, tensor<1x5x3x8xf32>) -> tensor<2x1x64x8xf32>
|
||||||
|
%3 = "tf.Squeeze"(%2) {squeeze_dims = [-3]} : (tensor<2x1x64x8xf32>) -> tensor<2x64x8xf32>
|
||||||
|
%4 = "tf.BatchToSpaceND"(%3, %cst_1, %cst) : (tensor<2x64x8xf32>, tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x128x8xf32>
|
||||||
|
return %4 : tensor<1x128x8xf32>
|
||||||
|
|
||||||
|
// CHECK-LABEL: testDilatedConv1DExpandH
|
||||||
|
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x3xf32>, [[FILTER:%.*]]: tensor<1x5x3x8xf32>)
|
||||||
|
// CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<-3> : tensor<i32>} : () -> tensor<i32>
|
||||||
|
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x3xf32>, tensor<i32>) -> tensor<1x1x128x3xf32>
|
||||||
|
// CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) {dilations = [1, 1, 2, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x1x128x3xf32>, tensor<1x5x3x8xf32>) -> tensor<1x1x128x8xf32>
|
||||||
|
// CHECK-NEXT: [[RESULT:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [-3]} : (tensor<1x1x128x8xf32>) -> tensor<1x128x8xf32>
|
||||||
|
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x8xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
func @testDilatedConv1DExpandHWithBiasAdd(%arg0: tensor<1x128x3xf32>, %arg1: tensor<1x5x3x8xf32>, %arg2: tensor<8xf32>) -> tensor<1x128x8xf32> {
|
||||||
|
%cst = "tf.Const"() {value = dense<0> : tensor<1x2xi32>} : () -> tensor<1x2xi32>
|
||||||
|
%cst_0 = "tf.Const"() {value = dense<-3> : tensor<i32>} : () -> tensor<i32>
|
||||||
|
%cst_1 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||||
|
%cst_2 = "tf.Const"() {value = dense<4> : tensor<1x2xi32>} : () -> tensor<1x2xi32>
|
||||||
|
%0 = "tf.SpaceToBatchND"(%arg0, %cst_1, %cst_2) : (tensor<1x128x3xf32>, tensor<1xi32>, tensor<1x2xi32>) -> tensor<2x68x3xf32>
|
||||||
|
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<2x68x3xf32>, tensor<i32>) -> tensor<2x1x68x3xf32>
|
||||||
|
%2 = "tf.Conv2D"(%1, %arg1) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<2x1x68x3xf32>, tensor<1x5x3x8xf32>) -> tensor<2x1x64x8xf32>
|
||||||
|
%3 = "tf.Squeeze"(%2) {squeeze_dims = [-3]} : (tensor<2x1x64x8xf32>) -> tensor<2x64x8xf32>
|
||||||
|
%4 = "tf.BatchToSpaceND"(%3, %cst_1, %cst) : (tensor<2x64x8xf32>, tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x128x8xf32>
|
||||||
|
%5 = "tf.BiasAdd"(%4, %arg2) : (tensor<1x128x8xf32>, tensor<8xf32>) -> tensor<1x128x8xf32>
|
||||||
|
return %5 : tensor<1x128x8xf32>
|
||||||
|
|
||||||
|
// CHECK-LABEL: testDilatedConv1DExpandHWithBiasAdd
|
||||||
|
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x3xf32>, [[FILTER:%.*]]: tensor<1x5x3x8xf32>, [[BIAS:%.*]]: tensor<8xf32>)
|
||||||
|
// CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<-3> : tensor<i32>} : () -> tensor<i32>
|
||||||
|
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x3xf32>, tensor<i32>) -> tensor<1x1x128x3xf32>
|
||||||
|
// CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) {dilations = [1, 1, 2, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x1x128x3xf32>, tensor<1x5x3x8xf32>) -> tensor<1x1x128x8xf32>
|
||||||
|
// CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [-3]} : (tensor<1x1x128x8xf32>) -> tensor<1x128x8xf32>
|
||||||
|
// CHECK-NEXT: [[RESULT:%.*]] = "tf.BiasAdd"([[SQUEEZE]], [[BIAS]]) : (tensor<1x128x8xf32>, tensor<8xf32>) -> tensor<1x128x8xf32>
|
||||||
|
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x8xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
func @testDilatedConv1DExpandW(%arg0: tensor<1x128x3xf32>, %arg1: tensor<5x1x3x8xf32>) -> tensor<1x128x8xf32> {
|
||||||
|
%cst = "tf.Const"() {value = dense<0> : tensor<1x2xi32>} : () -> tensor<1x2xi32>
|
||||||
|
%cst_0 = "tf.Const"() {value = dense<-2> : tensor<i32>} : () -> tensor<i32>
|
||||||
|
%cst_1 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||||
|
%cst_2 = "tf.Const"() {value = dense<4> : tensor<1x2xi32>} : () -> tensor<1x2xi32>
|
||||||
|
%0 = "tf.SpaceToBatchND"(%arg0, %cst_1, %cst_2) : (tensor<1x128x3xf32>, tensor<1xi32>, tensor<1x2xi32>) -> tensor<2x68x3xf32>
|
||||||
|
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<2x68x3xf32>, tensor<i32>) -> tensor<2x68x1x3xf32>
|
||||||
|
%2 = "tf.Conv2D"(%1, %arg1) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<2x68x1x3xf32>, tensor<5x1x3x8xf32>) -> tensor<2x64x1x8xf32>
|
||||||
|
%3 = "tf.Squeeze"(%2) {squeeze_dims = [-2]} : (tensor<2x64x1x8xf32>) -> tensor<2x64x8xf32>
|
||||||
|
%4 = "tf.BatchToSpaceND"(%3, %cst_1, %cst) : (tensor<2x64x8xf32>, tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x128x8xf32>
|
||||||
|
return %4 : tensor<1x128x8xf32>
|
||||||
|
|
||||||
|
// CHECK-LABEL: testDilatedConv1DExpandW
|
||||||
|
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x3xf32>, [[FILTER:%.*]]: tensor<5x1x3x8xf32>)
|
||||||
|
// CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<-2> : tensor<i32>} : () -> tensor<i32>
|
||||||
|
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x3xf32>, tensor<i32>) -> tensor<1x128x1x3xf32>
|
||||||
|
// CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 1, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x128x1x3xf32>, tensor<5x1x3x8xf32>) -> tensor<1x128x1x8xf32>
|
||||||
|
// CHECK-NEXT: [[RESULT:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [-2]} : (tensor<1x128x1x8xf32>) -> tensor<1x128x8xf32>
|
||||||
|
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x8xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
func @testDilatedConv1DExpandWWithBiasAdd(%arg0: tensor<1x128x3xf32>, %arg1: tensor<5x1x3x8xf32>, %arg2: tensor<8xf32>) -> tensor<1x128x8xf32> {
|
||||||
|
%cst = "tf.Const"() {value = dense<0> : tensor<1x2xi32>} : () -> tensor<1x2xi32>
|
||||||
|
%cst_0 = "tf.Const"() {value = dense<-2> : tensor<i32>} : () -> tensor<i32>
|
||||||
|
%cst_1 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||||
|
%cst_2 = "tf.Const"() {value = dense<4> : tensor<1x2xi32>} : () -> tensor<1x2xi32>
|
||||||
|
%0 = "tf.SpaceToBatchND"(%arg0, %cst_1, %cst_2) : (tensor<1x128x3xf32>, tensor<1xi32>, tensor<1x2xi32>) -> tensor<2x68x3xf32>
|
||||||
|
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<2x68x3xf32>, tensor<i32>) -> tensor<2x68x1x3xf32>
|
||||||
|
%2 = "tf.Conv2D"(%1, %arg1) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<2x68x1x3xf32>, tensor<5x1x3x8xf32>) -> tensor<2x64x1x8xf32>
|
||||||
|
%3 = "tf.Squeeze"(%2) {squeeze_dims = [-2]} : (tensor<2x64x1x8xf32>) -> tensor<2x64x8xf32>
|
||||||
|
%4 = "tf.BatchToSpaceND"(%3, %cst_1, %cst) : (tensor<2x64x8xf32>, tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x128x8xf32>
|
||||||
|
%5 = "tf.BiasAdd"(%4, %arg2) : (tensor<1x128x8xf32>, tensor<8xf32>) -> tensor<1x128x8xf32>
|
||||||
|
return %5 : tensor<1x128x8xf32>
|
||||||
|
|
||||||
|
// CHECK-LABEL: testDilatedConv1DExpandWWithBiasAdd
|
||||||
|
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x3xf32>, [[FILTER:%.*]]: tensor<5x1x3x8xf32>, [[BIAS:%.*]]: tensor<8xf32>)
|
||||||
|
// CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<-2> : tensor<i32>} : () -> tensor<i32>
|
||||||
|
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x3xf32>, tensor<i32>) -> tensor<1x128x1x3xf32>
|
||||||
|
// CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 1, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x128x1x3xf32>, tensor<5x1x3x8xf32>) -> tensor<1x128x1x8xf32>
|
||||||
|
// CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [-2]} : (tensor<1x128x1x8xf32>) -> tensor<1x128x8xf32>
|
||||||
|
// CHECK-NEXT: [[RESULT:%.*]] = "tf.BiasAdd"([[SQUEEZE]], [[BIAS]]) : (tensor<1x128x8xf32>, tensor<8xf32>) -> tensor<1x128x8xf32>
|
||||||
|
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x8xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
func @testDilatedConv1DWithMixedPostiveAndNegativeAxis(%arg0: tensor<1x128x3xf32>, %arg1: tensor<1x5x3x8xf32>) -> tensor<1x128x8xf32> {
|
||||||
|
%cst = "tf.Const"() {value = dense<0> : tensor<1x2xi32>} : () -> tensor<1x2xi32>
|
||||||
|
%cst_0 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||||
|
%cst_1 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||||
|
%cst_2 = "tf.Const"() {value = dense<4> : tensor<1x2xi32>} : () -> tensor<1x2xi32>
|
||||||
|
%0 = "tf.SpaceToBatchND"(%arg0, %cst_1, %cst_2) : (tensor<1x128x3xf32>, tensor<1xi32>, tensor<1x2xi32>) -> tensor<2x68x3xf32>
|
||||||
|
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<2x68x3xf32>, tensor<i32>) -> tensor<2x1x68x3xf32>
|
||||||
|
%2 = "tf.Conv2D"(%1, %arg1) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<2x1x68x3xf32>, tensor<1x5x3x8xf32>) -> tensor<2x1x64x8xf32>
|
||||||
|
%3 = "tf.Squeeze"(%2) {squeeze_dims = [-3]} : (tensor<2x1x64x8xf32>) -> tensor<2x64x8xf32>
|
||||||
|
%4 = "tf.BatchToSpaceND"(%3, %cst_1, %cst) : (tensor<2x64x8xf32>, tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x128x8xf32>
|
||||||
|
return %4 : tensor<1x128x8xf32>
|
||||||
|
|
||||||
|
// CHECK-LABEL: testDilatedConv1DWithMixedPostiveAndNegativeAxis
|
||||||
|
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x3xf32>, [[FILTER:%.*]]: tensor<1x5x3x8xf32>)
|
||||||
|
// CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||||
|
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x3xf32>, tensor<i32>) -> tensor<1x1x128x3xf32>
|
||||||
|
// CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) {dilations = [1, 1, 2, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x1x128x3xf32>, tensor<1x5x3x8xf32>) -> tensor<1x1x128x8xf32>
|
||||||
|
// CHECK-NEXT: [[RESULT:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [-3]} : (tensor<1x1x128x8xf32>) -> tensor<1x128x8xf32>
|
||||||
|
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x8xf32>
|
||||||
|
}
|
||||||
|
|||||||
@ -70,7 +70,7 @@ class ConvertTFDilatedConvOp : public OpRewritePattern<Conv2dOpTy> {
|
|||||||
|
|
||||||
// Extract the dilation factor from `block_shape` and pack it in an ArrayAttr.
|
// Extract the dilation factor from `block_shape` and pack it in an ArrayAttr.
|
||||||
llvm::Optional<ArrayAttr> ExtractDilationsAttrFromBlockShape(
|
llvm::Optional<ArrayAttr> ExtractDilationsAttrFromBlockShape(
|
||||||
Value stb_block_shape, Value bts_block_shape,
|
Value stb_block_shape, Value bts_block_shape, int64_t expand_axis,
|
||||||
PatternRewriter& rewriter) const;
|
PatternRewriter& rewriter) const;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
@ -111,7 +111,7 @@ LogicalResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
|
|||||||
|
|
||||||
TF::ExpandDimsOp expand_op;
|
TF::ExpandDimsOp expand_op;
|
||||||
TF::SqueezeOp squeeze_op;
|
TF::SqueezeOp squeeze_op;
|
||||||
int64_t expand_axis;
|
int64_t expand_axis = -1;
|
||||||
// Expand + Squeeze op.
|
// Expand + Squeeze op.
|
||||||
if (llvm::isa<TF::ExpandDimsOp>(prev_op)) {
|
if (llvm::isa<TF::ExpandDimsOp>(prev_op)) {
|
||||||
if (!llvm::isa<TF::SqueezeOp>(next_op)) {
|
if (!llvm::isa<TF::SqueezeOp>(next_op)) {
|
||||||
@ -127,13 +127,26 @@ LogicalResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
|
|||||||
expand_axis =
|
expand_axis =
|
||||||
(*const_op.value().cast<DenseElementsAttr>().getIntValues().begin())
|
(*const_op.value().cast<DenseElementsAttr>().getIntValues().begin())
|
||||||
.getSExtValue();
|
.getSExtValue();
|
||||||
|
// Canonicalize axis. Some TF python functions, such as
|
||||||
|
// `tf.nn.convolution`, use negative axis.
|
||||||
|
if (expand_axis < 0) {
|
||||||
|
// Always expand 3D input to 4D input.
|
||||||
|
expand_axis += 4;
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
// Make sure that the `squeeze_dims` is equal to `expand_axis`.
|
// Make sure that the `squeeze_dims` is equal to `expand_axis`.
|
||||||
auto squeeze_dims = squeeze_op.squeeze_dims();
|
auto squeeze_dims = squeeze_op.squeeze_dims();
|
||||||
if (squeeze_dims.size() != 1 ||
|
if (squeeze_dims.size() != 1) {
|
||||||
squeeze_dims[0].cast<IntegerAttr>().getInt() != expand_axis) {
|
return failure();
|
||||||
|
}
|
||||||
|
int64_t squeeze_axis = squeeze_dims[0].cast<IntegerAttr>().getInt();
|
||||||
|
if (squeeze_axis < 0) {
|
||||||
|
// Always squeeze 4D input to 3D input.
|
||||||
|
squeeze_axis += 4;
|
||||||
|
}
|
||||||
|
if (squeeze_axis != expand_axis) {
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -183,7 +196,7 @@ LogicalResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
|
|||||||
}
|
}
|
||||||
|
|
||||||
llvm::Optional<ArrayAttr> dilations_attr = ExtractDilationsAttrFromBlockShape(
|
llvm::Optional<ArrayAttr> dilations_attr = ExtractDilationsAttrFromBlockShape(
|
||||||
stb_op.block_shape(), bts_op.block_shape(), rewriter);
|
stb_op.block_shape(), bts_op.block_shape(), expand_axis, rewriter);
|
||||||
if (!dilations_attr.hasValue()) return failure();
|
if (!dilations_attr.hasValue()) return failure();
|
||||||
|
|
||||||
if (expand_op) {
|
if (expand_op) {
|
||||||
@ -259,13 +272,24 @@ LogicalResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
|
|||||||
auto expand_result_type = RankedTensorType::get(
|
auto expand_result_type = RankedTensorType::get(
|
||||||
expand_shape, getElementTypeOrSelf(stb_op.input()));
|
expand_shape, getElementTypeOrSelf(stb_op.input()));
|
||||||
expand_op.getResult().setType(expand_result_type);
|
expand_op.getResult().setType(expand_result_type);
|
||||||
op.getResult().setType(expand_result_type);
|
|
||||||
|
// Update the conv op's output shape.
|
||||||
|
auto bts_output_shape =
|
||||||
|
bts_op.output().getType().cast<ShapedType>().getShape();
|
||||||
|
SmallVector<int64_t, 4> conv_result_shape(bts_output_shape.begin(),
|
||||||
|
bts_output_shape.end());
|
||||||
|
conv_result_shape.insert(conv_result_shape.begin() + expand_axis, 1);
|
||||||
|
auto conv_result_type = RankedTensorType::get(
|
||||||
|
conv_result_shape, getElementTypeOrSelf(stb_op.input()));
|
||||||
|
op.getResult().setType(conv_result_type);
|
||||||
|
|
||||||
squeeze_op.getResult().setType(bts_op.output().getType());
|
squeeze_op.getResult().setType(bts_op.output().getType());
|
||||||
|
|
||||||
// Connect `biasadd_op` with the output of `squeeze_op`.
|
// Connect `biasadd_op` with the output of `squeeze_op`.
|
||||||
|
if (biasadd_op) {
|
||||||
biasadd_op.setOperand(0, squeeze_op.output());
|
biasadd_op.setOperand(0, squeeze_op.output());
|
||||||
biasadd_op.output().setType(squeeze_op.output().getType());
|
biasadd_op.output().setType(squeeze_op.output().getType());
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
if (biasadd_op) biasadd_op.setOperand(0, op.output());
|
if (biasadd_op) biasadd_op.setOperand(0, op.output());
|
||||||
op.setOperand(0, stb_op.input());
|
op.setOperand(0, stb_op.input());
|
||||||
@ -283,7 +307,7 @@ LogicalResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
|
|||||||
template <typename Conv2dOpTy>
|
template <typename Conv2dOpTy>
|
||||||
llvm::Optional<ArrayAttr>
|
llvm::Optional<ArrayAttr>
|
||||||
ConvertTFDilatedConvOp<Conv2dOpTy>::ExtractDilationsAttrFromBlockShape(
|
ConvertTFDilatedConvOp<Conv2dOpTy>::ExtractDilationsAttrFromBlockShape(
|
||||||
Value stb_block_shape, Value bts_block_shape,
|
Value stb_block_shape, Value bts_block_shape, int64_t expand_axis,
|
||||||
PatternRewriter& rewriter) const {
|
PatternRewriter& rewriter) const {
|
||||||
ElementsAttr stb_bs_attr, bts_bs_attr;
|
ElementsAttr stb_bs_attr, bts_bs_attr;
|
||||||
if (!matchPattern(stb_block_shape, m_Constant(&stb_bs_attr)) ||
|
if (!matchPattern(stb_block_shape, m_Constant(&stb_bs_attr)) ||
|
||||||
@ -297,12 +321,31 @@ ConvertTFDilatedConvOp<Conv2dOpTy>::ExtractDilationsAttrFromBlockShape(
|
|||||||
if (stb_bs_attr.getValue({i}) != bts_bs_attr.getValue({i})) return {};
|
if (stb_bs_attr.getValue({i}) != bts_bs_attr.getValue({i})) return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int dilation_h_factor = -1, dilation_w_factor = -1;
|
||||||
// Set dilation factor.
|
// Set dilation factor.
|
||||||
if (stb_bs_attr.getNumElements() < 2) return {};
|
if (stb_bs_attr.getNumElements() >= 2) {
|
||||||
int dilation_h_factor =
|
dilation_h_factor = stb_bs_attr.getValue({0}).cast<IntegerAttr>().getInt();
|
||||||
|
dilation_w_factor = stb_bs_attr.getValue({1}).cast<IntegerAttr>().getInt();
|
||||||
|
} else if (stb_bs_attr.getNumElements() == 1) {
|
||||||
|
// For 1d conv, `tf.nn.convolution` expands NWC to NHWC format after
|
||||||
|
// `SpaceToBatchND`. Therefore, `block_shape` of `stb_op` only has one
|
||||||
|
// dilation factor of W dim, and dilation factor of H dim is set to 1.
|
||||||
|
if (expand_axis == 1) {
|
||||||
|
// NWC -> NHWC
|
||||||
|
dilation_h_factor = 1;
|
||||||
|
dilation_w_factor =
|
||||||
stb_bs_attr.getValue({0}).cast<IntegerAttr>().getInt();
|
stb_bs_attr.getValue({0}).cast<IntegerAttr>().getInt();
|
||||||
int dilation_w_factor =
|
} else if (expand_axis == 2) {
|
||||||
stb_bs_attr.getValue({1}).cast<IntegerAttr>().getInt();
|
// NHC -> NHWC
|
||||||
|
dilation_h_factor =
|
||||||
|
stb_bs_attr.getValue({0}).cast<IntegerAttr>().getInt();
|
||||||
|
dilation_w_factor = 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (dilation_h_factor == -1 || dilation_w_factor == -1) {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
return rewriter.getI64ArrayAttr({1, dilation_h_factor, dilation_w_factor, 1});
|
return rewriter.getI64ArrayAttr({1, dilation_h_factor, dilation_w_factor, 1});
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1049,7 +1049,6 @@ tf_xla_py_test(
|
|||||||
shard_count = 5,
|
shard_count = 5,
|
||||||
tags = [
|
tags = [
|
||||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||||
"no_rocm",
|
|
||||||
"optonly",
|
"optonly",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
|
|||||||
@ -609,7 +609,6 @@ xla_test(
|
|||||||
name = "logdet_test",
|
name = "logdet_test",
|
||||||
srcs = ["logdet_test.cc"],
|
srcs = ["logdet_test.cc"],
|
||||||
tags = [
|
tags = [
|
||||||
"no_rocm",
|
|
||||||
"optonly",
|
"optonly",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
|
|||||||
@ -1787,7 +1787,7 @@ cc_library(
|
|||||||
tf_cc_test(
|
tf_cc_test(
|
||||||
name = "buffer_comparator_test",
|
name = "buffer_comparator_test",
|
||||||
srcs = if_cuda_is_configured(["buffer_comparator_test.cc"]),
|
srcs = if_cuda_is_configured(["buffer_comparator_test.cc"]),
|
||||||
tags = ["no_rocm"] + tf_cuda_tests_tags(),
|
tags = tf_cuda_tests_tags(),
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/core:test_main",
|
"//tensorflow/core:test_main",
|
||||||
"//tensorflow/compiler/xla:shape_util",
|
"//tensorflow/compiler/xla:shape_util",
|
||||||
|
|||||||
@ -65,6 +65,12 @@ class ReductionDegenerateDimRemoverVisitor : public DfsHloRewriteVisitor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (updated_reduced_dimensions.empty()) {
|
||||||
|
std::unique_ptr<HloInstruction> reshape =
|
||||||
|
HloInstruction::CreateBitcast(reduce_shape, reduced_op);
|
||||||
|
return ReplaceWithNewInstruction(instr, std::move(reshape));
|
||||||
|
}
|
||||||
|
|
||||||
HloInstruction *input_reshape = instr->parent()->AddInstruction(
|
HloInstruction *input_reshape = instr->parent()->AddInstruction(
|
||||||
HloInstruction::CreateBitcast(canonical_input_shape, reduced_op));
|
HloInstruction::CreateBitcast(canonical_input_shape, reduced_op));
|
||||||
|
|
||||||
|
|||||||
@ -177,7 +177,7 @@ tf_cc_test(
|
|||||||
srcs = [
|
srcs = [
|
||||||
"tree_reduction_rewriter_test.cc",
|
"tree_reduction_rewriter_test.cc",
|
||||||
],
|
],
|
||||||
tags = tf_cuda_tests_tags() + ["no_rocm"],
|
tags = tf_cuda_tests_tags(),
|
||||||
deps = [
|
deps = [
|
||||||
":gpu_codegen_test",
|
":gpu_codegen_test",
|
||||||
"//tensorflow/compiler/xla:debug_options_flags",
|
"//tensorflow/compiler/xla:debug_options_flags",
|
||||||
@ -258,7 +258,7 @@ tf_cc_test(
|
|||||||
srcs = [
|
srcs = [
|
||||||
"parallel_reduction_test.cc",
|
"parallel_reduction_test.cc",
|
||||||
],
|
],
|
||||||
tags = tf_cuda_tests_tags() + ["no_rocm"],
|
tags = tf_cuda_tests_tags(),
|
||||||
deps = [
|
deps = [
|
||||||
":gpu_codegen_test",
|
":gpu_codegen_test",
|
||||||
"//tensorflow/compiler/xla/service:gpu_plugin",
|
"//tensorflow/compiler/xla/service:gpu_plugin",
|
||||||
@ -297,7 +297,7 @@ tf_cc_test(
|
|||||||
srcs = [
|
srcs = [
|
||||||
"gpu_copy_alone_test.cc",
|
"gpu_copy_alone_test.cc",
|
||||||
],
|
],
|
||||||
tags = tf_cuda_tests_tags() + ["no_rocm"],
|
tags = tf_cuda_tests_tags(),
|
||||||
deps = [
|
deps = [
|
||||||
":gpu_codegen_test",
|
":gpu_codegen_test",
|
||||||
"//tensorflow/compiler/xla/service:hlo",
|
"//tensorflow/compiler/xla/service:hlo",
|
||||||
@ -521,9 +521,7 @@ tf_cc_test(
|
|||||||
srcs = [
|
srcs = [
|
||||||
"sorting_test.cc",
|
"sorting_test.cc",
|
||||||
],
|
],
|
||||||
tags = tf_cuda_tests_tags() + [
|
tags = tf_cuda_tests_tags(),
|
||||||
"no_rocm",
|
|
||||||
],
|
|
||||||
deps = [
|
deps = [
|
||||||
":gpu_codegen_test",
|
":gpu_codegen_test",
|
||||||
"//tensorflow/compiler/xla:debug_options_flags",
|
"//tensorflow/compiler/xla:debug_options_flags",
|
||||||
|
|||||||
@ -69,6 +69,38 @@ ENTRY main {
|
|||||||
)");
|
)");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(ReductionDegenerateDimRemoverTest, DegenerateWithEmptyDimension) {
|
||||||
|
const char* hlo_text = R"(
|
||||||
|
HloModule ReduceWithDegenerateDimensions
|
||||||
|
|
||||||
|
add {
|
||||||
|
accum = f32[] parameter(0)
|
||||||
|
op = f32[] parameter(1)
|
||||||
|
ROOT out = f32[] add(accum, op)
|
||||||
|
}
|
||||||
|
|
||||||
|
ENTRY main {
|
||||||
|
input = f32[1,3,1,4,1,5,1] parameter(0)
|
||||||
|
zero = f32[] constant(0)
|
||||||
|
|
||||||
|
ROOT out = f32[3,4,5,1] reduce(input, zero), dimensions={0,2,4}, to_apply=add
|
||||||
|
}
|
||||||
|
|
||||||
|
)";
|
||||||
|
|
||||||
|
EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
|
||||||
|
// Copy instruction is added after bitcast because of copy-insertion pass,
|
||||||
|
// so we check the entire hlo module to verify there is no reduce instruction
|
||||||
|
// in this case.
|
||||||
|
MatchOptimizedHloWithShapes(hlo_text,
|
||||||
|
R"(
|
||||||
|
// CHECK: ENTRY %main (input: f32[1,3,1,4,1,5,1]) -> f32[3,4,5,1] {
|
||||||
|
// CHECK: %input = f32[1,3,1,4,1,5,1]{6,5,4,3,2,1,0} parameter(0)
|
||||||
|
// CHECK: %bitcast{{.+}} = f32[3,4,5,1]{3,2,1,0} bitcast(f32[1,3,1,4,1,5,1]{6,5,4,3,2,1,0} %input)
|
||||||
|
// CHECK: ROOT %copy{{.+}} = f32[3,4,5,1]{3,2,1,0} copy(f32[3,4,5,1]{3,2,1,0} %bitcast{{.+}})
|
||||||
|
)");
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace gpu
|
} // namespace gpu
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
|||||||
@ -1159,7 +1159,6 @@ xla_test(
|
|||||||
],
|
],
|
||||||
shard_count = 50,
|
shard_count = 50,
|
||||||
tags = [
|
tags = [
|
||||||
"no_rocm",
|
|
||||||
"optonly",
|
"optonly",
|
||||||
],
|
],
|
||||||
deps = CONVOLUTION_TEST_DEPS + [
|
deps = CONVOLUTION_TEST_DEPS + [
|
||||||
@ -1212,9 +1211,6 @@ xla_test(
|
|||||||
backend_args = {"gpu": ["--xla_backend_extra_options=xla_gpu_experimental_conv_disable_layout_heuristic"]},
|
backend_args = {"gpu": ["--xla_backend_extra_options=xla_gpu_experimental_conv_disable_layout_heuristic"]},
|
||||||
backends = ["gpu"],
|
backends = ["gpu"],
|
||||||
shard_count = 25,
|
shard_count = 25,
|
||||||
tags = [
|
|
||||||
"no_rocm",
|
|
||||||
],
|
|
||||||
deps = CONVOLUTION_TEST_DEPS + [
|
deps = CONVOLUTION_TEST_DEPS + [
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
@ -1228,9 +1224,6 @@ xla_test(
|
|||||||
backend_args = {"gpu": ["--xla_backend_extra_options=xla_gpu_experimental_conv_disable_layout_heuristic"]},
|
backend_args = {"gpu": ["--xla_backend_extra_options=xla_gpu_experimental_conv_disable_layout_heuristic"]},
|
||||||
backends = ["gpu"],
|
backends = ["gpu"],
|
||||||
shard_count = 25,
|
shard_count = 25,
|
||||||
tags = [
|
|
||||||
"no_rocm",
|
|
||||||
],
|
|
||||||
deps = CONVOLUTION_TEST_DEPS + [
|
deps = CONVOLUTION_TEST_DEPS + [
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
|
|||||||
@ -760,6 +760,10 @@ const FunctionDef* EagerContext::GetFunctionDef(const string& function_name) {
|
|||||||
return func_lib_def_.Find(function_name);
|
return func_lib_def_.Find(function_name);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<string> EagerContext::ListFunctionNames() {
|
||||||
|
return func_lib_def_.ListFunctionNames();
|
||||||
|
}
|
||||||
|
|
||||||
Status EagerContext::RemoveFunction(const string& func) {
|
Status EagerContext::RemoveFunction(const string& func) {
|
||||||
bool is_last_ref = false;
|
bool is_last_ref = false;
|
||||||
{
|
{
|
||||||
|
|||||||
@ -226,6 +226,8 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted {
|
|||||||
|
|
||||||
const FunctionDef* GetFunctionDef(const string& function_name);
|
const FunctionDef* GetFunctionDef(const string& function_name);
|
||||||
|
|
||||||
|
std::vector<string> ListFunctionNames() override;
|
||||||
|
|
||||||
Status RemoveFunction(const string& func) override;
|
Status RemoveFunction(const string& func) override;
|
||||||
|
|
||||||
// Wait for pending nodes to be finished in local executors (including context
|
// Wait for pending nodes to be finished in local executors (including context
|
||||||
|
|||||||
@ -1871,9 +1871,6 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item,
|
|||||||
// it for MatMul as well, but in practice this pattern does not appear in
|
// it for MatMul as well, but in practice this pattern does not appear in
|
||||||
// real Tensorflow graphs.
|
// real Tensorflow graphs.
|
||||||
|
|
||||||
// TODO(penporn):
|
|
||||||
// Remove this once TF-MKL supports _FusedConv2D with these operations.
|
|
||||||
#ifndef INTEL_MKL
|
|
||||||
// Remap Conv2D+Squeeze+BiasAdd into the _FusedConv2D+Squeeze.
|
// Remap Conv2D+Squeeze+BiasAdd into the _FusedConv2D+Squeeze.
|
||||||
ContractionWithSqueezeAndBiasAdd contract_with_squeeze_and_bias;
|
ContractionWithSqueezeAndBiasAdd contract_with_squeeze_and_bias;
|
||||||
if (allow_non_differentiable_rewrites &&
|
if (allow_non_differentiable_rewrites &&
|
||||||
@ -1884,6 +1881,9 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item,
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO(intel-tf):
|
||||||
|
// Remove this once TF-MKL supports _FusedConv2D with these operations.
|
||||||
|
#ifndef INTEL_MKL
|
||||||
// Remap Conv2D+FusedBatchNorm into the _FusedConv2D;
|
// Remap Conv2D+FusedBatchNorm into the _FusedConv2D;
|
||||||
ContractionWithBatchNorm contract_with_batch_norm;
|
ContractionWithBatchNorm contract_with_batch_norm;
|
||||||
if (allow_non_differentiable_rewrites &&
|
if (allow_non_differentiable_rewrites &&
|
||||||
|
|||||||
@ -932,6 +932,7 @@ TEST_F(RemapperTest, FuseConv2DWithBatchNormAndActivation) {
|
|||||||
test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
|
test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
#endif // !INTEL_MKL
|
||||||
|
|
||||||
TEST_F(RemapperTest, FuseConv2DWithSqueezeAndBias) {
|
TEST_F(RemapperTest, FuseConv2DWithSqueezeAndBias) {
|
||||||
using ops::Placeholder;
|
using ops::Placeholder;
|
||||||
@ -1003,7 +1004,6 @@ TEST_F(RemapperTest, FuseConv2DWithSqueezeAndBias) {
|
|||||||
ASSERT_EQ(tensors.size(), 1);
|
ASSERT_EQ(tensors.size(), 1);
|
||||||
test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
|
test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
|
||||||
}
|
}
|
||||||
#endif // !INTEL_MKL
|
|
||||||
|
|
||||||
} // namespace grappler
|
} // namespace grappler
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|||||||
@ -20,7 +20,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/kernels/data/dataset_utils.h"
|
#include "tensorflow/core/kernels/data/dataset_utils.h"
|
||||||
#include "tensorflow/core/lib/core/refcount.h"
|
#include "tensorflow/core/lib/core/refcount.h"
|
||||||
#include "tensorflow/core/lib/core/threadpool.h"
|
#include "tensorflow/core/lib/core/threadpool.h"
|
||||||
#include "tensorflow/core/platform/mutex.h"
|
|
||||||
#include "tensorflow/core/platform/thread_annotations.h"
|
#include "tensorflow/core/platform/thread_annotations.h"
|
||||||
#include "tensorflow/core/util/work_sharder.h"
|
#include "tensorflow/core/util/work_sharder.h"
|
||||||
|
|
||||||
@ -211,7 +210,6 @@ class ThreadPoolDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
Status GetNextInternal(IteratorContext* ctx,
|
Status GetNextInternal(IteratorContext* ctx,
|
||||||
std::vector<Tensor>* out_tensors,
|
std::vector<Tensor>* out_tensors,
|
||||||
bool* end_of_sequence) override {
|
bool* end_of_sequence) override {
|
||||||
tf_shared_lock l(mu_);
|
|
||||||
return input_impl_->GetNext(IteratorContext(CreateParams(ctx)),
|
return input_impl_->GetNext(IteratorContext(CreateParams(ctx)),
|
||||||
out_tensors, end_of_sequence);
|
out_tensors, end_of_sequence);
|
||||||
}
|
}
|
||||||
@ -225,7 +223,6 @@ class ThreadPoolDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
|
|
||||||
Status SaveInternal(SerializationContext* ctx,
|
Status SaveInternal(SerializationContext* ctx,
|
||||||
IteratorStateWriter* writer) override {
|
IteratorStateWriter* writer) override {
|
||||||
mutex_lock l(mu_);
|
|
||||||
DCHECK(input_impl_ != nullptr);
|
DCHECK(input_impl_ != nullptr);
|
||||||
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
|
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
@ -233,7 +230,6 @@ class ThreadPoolDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
|
|
||||||
Status RestoreInternal(IteratorContext* ctx,
|
Status RestoreInternal(IteratorContext* ctx,
|
||||||
IteratorStateReader* reader) override {
|
IteratorStateReader* reader) override {
|
||||||
mutex_lock l(mu_);
|
|
||||||
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
|
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
@ -249,8 +245,7 @@ class ThreadPoolDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
return params;
|
return params;
|
||||||
}
|
}
|
||||||
|
|
||||||
mutex mu_;
|
std::unique_ptr<IteratorBase> input_impl_;
|
||||||
std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
const DatasetBase* const input_;
|
const DatasetBase* const input_;
|
||||||
@ -351,7 +346,6 @@ class MaxIntraOpParallelismDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
auto max_parallelism = dataset()->max_intra_op_parallelism_;
|
auto max_parallelism = dataset()->max_intra_op_parallelism_;
|
||||||
params.runner =
|
params.runner =
|
||||||
RunnerWithMaxParallelism(*ctx->runner(), max_parallelism);
|
RunnerWithMaxParallelism(*ctx->runner(), max_parallelism);
|
||||||
tf_shared_lock l(mu_);
|
|
||||||
return input_impl_->GetNext(IteratorContext{std::move(params)},
|
return input_impl_->GetNext(IteratorContext{std::move(params)},
|
||||||
out_tensors, end_of_sequence);
|
out_tensors, end_of_sequence);
|
||||||
}
|
}
|
||||||
@ -365,7 +359,6 @@ class MaxIntraOpParallelismDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
|
|
||||||
Status SaveInternal(SerializationContext* ctx,
|
Status SaveInternal(SerializationContext* ctx,
|
||||||
IteratorStateWriter* writer) override {
|
IteratorStateWriter* writer) override {
|
||||||
mutex_lock l(mu_);
|
|
||||||
DCHECK(input_impl_ != nullptr);
|
DCHECK(input_impl_ != nullptr);
|
||||||
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
|
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
@ -373,14 +366,12 @@ class MaxIntraOpParallelismDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
|
|
||||||
Status RestoreInternal(IteratorContext* ctx,
|
Status RestoreInternal(IteratorContext* ctx,
|
||||||
IteratorStateReader* reader) override {
|
IteratorStateReader* reader) override {
|
||||||
mutex_lock l(mu_);
|
|
||||||
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
|
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
mutex mu_;
|
std::unique_ptr<IteratorBase> input_impl_;
|
||||||
std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
const DatasetBase* const input_;
|
const DatasetBase* const input_;
|
||||||
@ -481,7 +472,6 @@ class PrivateThreadPoolDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
pool->Schedule(std::move(c));
|
pool->Schedule(std::move(c));
|
||||||
};
|
};
|
||||||
params.runner_threadpool_size = dataset()->num_threads_;
|
params.runner_threadpool_size = dataset()->num_threads_;
|
||||||
tf_shared_lock l(mu_);
|
|
||||||
return input_impl_->GetNext(IteratorContext{std::move(params)},
|
return input_impl_->GetNext(IteratorContext{std::move(params)},
|
||||||
out_tensors, end_of_sequence);
|
out_tensors, end_of_sequence);
|
||||||
}
|
}
|
||||||
@ -495,7 +485,6 @@ class PrivateThreadPoolDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
|
|
||||||
Status SaveInternal(SerializationContext* ctx,
|
Status SaveInternal(SerializationContext* ctx,
|
||||||
IteratorStateWriter* writer) override {
|
IteratorStateWriter* writer) override {
|
||||||
mutex_lock l(mu_);
|
|
||||||
DCHECK(input_impl_ != nullptr);
|
DCHECK(input_impl_ != nullptr);
|
||||||
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
|
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
@ -503,14 +492,12 @@ class PrivateThreadPoolDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
|
|
||||||
Status RestoreInternal(IteratorContext* ctx,
|
Status RestoreInternal(IteratorContext* ctx,
|
||||||
IteratorStateReader* reader) override {
|
IteratorStateReader* reader) override {
|
||||||
mutex_lock l(mu_);
|
|
||||||
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
|
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
mutex mu_;
|
std::unique_ptr<IteratorBase> input_impl_;
|
||||||
std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
const DatasetBase* const input_;
|
const DatasetBase* const input_;
|
||||||
|
|||||||
@ -10,148 +10,52 @@ just a few minutes. All you need is a TensorFlow model [converted to TensorFlow
|
|||||||
Lite](../convert/). (If you don't have a model converted yet, you can experiment
|
Lite](../convert/). (If you don't have a model converted yet, you can experiment
|
||||||
using the model provided with the example linked below.)
|
using the model provided with the example linked below.)
|
||||||
|
|
||||||
## Install just the TensorFlow Lite interpreter
|
## About the TensorFlow Lite runtime package
|
||||||
|
|
||||||
To quickly run TensorFlow Lite models with Python, you can install just the
|
To quickly start executing TensorFlow Lite models with Python, you can install
|
||||||
TensorFlow Lite interpreter, instead of all TensorFlow packages.
|
just the TensorFlow Lite interpreter, instead of all TensorFlow packages. We
|
||||||
|
call this simplified Python package `tflite_runtime`.
|
||||||
|
|
||||||
This interpreter-only package is a fraction the size of the full TensorFlow
|
The `tflite_runtime` package is a fraction the size of the full `tensorflow`
|
||||||
package and includes the bare minimum code required to run inferences with
|
package and includes the bare minimum code required to run inferences with
|
||||||
TensorFlow Lite—it includes only the
|
TensorFlow Lite—primarily the
|
||||||
[`tf.lite.Interpreter`](https://www.tensorflow.org/api_docs/python/tf/lite/Interpreter)
|
[`Interpreter`](https://www.tensorflow.org/api_docs/python/tf/lite/Interpreter)
|
||||||
Python class. This small package is ideal when all you want to do is execute
|
Python class. This small package is ideal when all you want to do is execute
|
||||||
`.tflite` models and avoid wasting disk space with the large TensorFlow library.
|
`.tflite` models and avoid wasting disk space with the large TensorFlow library.
|
||||||
|
|
||||||
Note: If you need access to other Python APIs, such as the [TensorFlow Lite
|
Note: If you need access to other Python APIs, such as the
|
||||||
Converter](../convert/python_api.md), you must install the [full TensorFlow
|
[TensorFlow Lite Converter](../convert/), you must install the
|
||||||
package](https://www.tensorflow.org/install/).
|
[full TensorFlow package](https://www.tensorflow.org/install/).
|
||||||
|
|
||||||
To install, run `pip3 install` and pass it the appropriate Python wheel URL from
|
## Install TensorFlow Lite for Python
|
||||||
the following table.
|
|
||||||
|
|
||||||
For example, if you have a Raspberry Pi that's running Raspberry Pi OS 10 (which
|
To install the TensorFlow Lite runtime package, run this command:
|
||||||
has Python 3.7), install the Python wheel as follows:
|
|
||||||
|
<pre class="devsite-terminal devsite-click-to-copy">
|
||||||
|
pip3 install --extra-index-url https://google-coral.github.io/py-repo/ tflite_runtime
|
||||||
|
</pre>
|
||||||
|
|
||||||
|
If you're on a Raspberry Pi, this command might fail due to a known issue with
|
||||||
|
the `extra-index-url` option
|
||||||
|
([#4011](https://github.com/raspberrypi/linux/issues/4011)). So we suggest you
|
||||||
|
specify one of the
|
||||||
|
[`tflite_runtime` wheels](https://github.com/google-coral/pycoral/releases/)
|
||||||
|
that matches your system. For example, if you're running Raspberry Pi OS 10
|
||||||
|
(which has Python 3.7), instead use this command:
|
||||||
|
|
||||||
<pre class="devsite-terminal devsite-click-to-copy">
|
<pre class="devsite-terminal devsite-click-to-copy">
|
||||||
pip3 install https://github.com/google-coral/pycoral/releases/download/release-frogfish/tflite_runtime-2.5.0-cp37-cp37m-linux_armv7l.whl
|
pip3 install https://github.com/google-coral/pycoral/releases/download/release-frogfish/tflite_runtime-2.5.0-cp37-cp37m-linux_armv7l.whl
|
||||||
</pre>
|
</pre>
|
||||||
|
|
||||||
<table>
|
Note: If you're on Debian Linux and using TensorFlow Lite with a Coral ML
|
||||||
<tr><th>Platform</th><th>Python</th><th>URL</th></tr>
|
accelerator, using pip to install `tflite_runtime` may not be compatible with
|
||||||
<tr>
|
other Coral libraries. To ensure all your libraries are compatible, instead
|
||||||
<td style="white-space:nowrap" rowspan="4">Linux (ARM 32)</td>
|
install `tflite_runtime` as a
|
||||||
<td style="white-space:nowrap">3.5</td>
|
[Debian package from Coral](https://coral.ai/software/#debian-packages).
|
||||||
<td>https://github.com/google-coral/pycoral/releases/download/release-frogfish/tflite_runtime-2.5.0-cp35-cp35m-linux_armv7l.whl</td>
|
|
||||||
</tr>
|
|
||||||
<tr>
|
|
||||||
<!-- ARM 32 -->
|
|
||||||
<td style="white-space:nowrap">3.6</td>
|
|
||||||
<td>https://github.com/google-coral/pycoral/releases/download/release-frogfish/tflite_runtime-2.5.0-cp36-cp36m-linux_armv7l.whl</td>
|
|
||||||
</tr>
|
|
||||||
<tr>
|
|
||||||
<!-- ARM 32 -->
|
|
||||||
<td style="white-space:nowrap">3.7</td>
|
|
||||||
<td>https://github.com/google-coral/pycoral/releases/download/release-frogfish/tflite_runtime-2.5.0-cp37-cp37m-linux_armv7l.whl</td>
|
|
||||||
</tr>
|
|
||||||
<tr>
|
|
||||||
<!-- ARM 32 -->
|
|
||||||
<td style="white-space:nowrap">3.8</td>
|
|
||||||
<td>https://github.com/google-coral/pycoral/releases/download/release-frogfish/tflite_runtime-2.5.0-cp38-cp38-linux_armv7l.whl</td>
|
|
||||||
</tr>
|
|
||||||
|
|
||||||
<tr>
|
|
||||||
<td style="white-space:nowrap" rowspan="4">Linux (ARM 64)</td>
|
|
||||||
<td style="white-space:nowrap">3.5</td>
|
|
||||||
<td>https://github.com/google-coral/pycoral/releases/download/release-frogfish/tflite_runtime-2.5.0-cp35-cp35m-linux_aarch64.whl</td>
|
|
||||||
</tr>
|
|
||||||
<tr>
|
|
||||||
<!-- ARM 64 -->
|
|
||||||
<td style="white-space:nowrap">3.6</td>
|
|
||||||
<td>https://github.com/google-coral/pycoral/releases/download/release-frogfish/tflite_runtime-2.5.0-cp36-cp36m-linux_aarch64.whl</td>
|
|
||||||
</tr>
|
|
||||||
<tr>
|
|
||||||
<!-- ARM 64 -->
|
|
||||||
<td style="white-space:nowrap">3.7</td>
|
|
||||||
<td>https://github.com/google-coral/pycoral/releases/download/release-frogfish/tflite_runtime-2.5.0-cp37-cp37m-linux_aarch64.whl</td>
|
|
||||||
</tr>
|
|
||||||
<tr>
|
|
||||||
<!-- ARM 64 -->
|
|
||||||
<td style="white-space:nowrap">3.8</td>
|
|
||||||
<td>https://github.com/google-coral/pycoral/releases/download/release-frogfish/tflite_runtime-2.5.0-cp38-cp38-linux_aarch64.whl</td>
|
|
||||||
</tr>
|
|
||||||
|
|
||||||
<tr>
|
|
||||||
<td style="white-space:nowrap" rowspan="4">Linux (x86-64)</td>
|
|
||||||
<td style="white-space:nowrap">3.5</td>
|
|
||||||
<td>https://github.com/google-coral/pycoral/releases/download/release-frogfish/tflite_runtime-2.5.0-cp35-cp35m-linux_x86_64.whl</td>
|
|
||||||
</tr>
|
|
||||||
<tr>
|
|
||||||
<!-- x86-64 -->
|
|
||||||
<td style="white-space:nowrap">3.6</td>
|
|
||||||
<td>https://github.com/google-coral/pycoral/releases/download/release-frogfish/tflite_runtime-2.5.0-cp36-cp36m-linux_x86_64.whl</td>
|
|
||||||
</tr>
|
|
||||||
<tr>
|
|
||||||
<!-- x86-64 -->
|
|
||||||
<td style="white-space:nowrap">3.7</td>
|
|
||||||
<td>https://github.com/google-coral/pycoral/releases/download/release-frogfish/tflite_runtime-2.5.0-cp37-cp37m-linux_x86_64.whl</td>
|
|
||||||
</tr>
|
|
||||||
<tr>
|
|
||||||
<!-- x86-64 -->
|
|
||||||
<td style="white-space:nowrap">3.8</td>
|
|
||||||
<td>https://github.com/google-coral/pycoral/releases/download/release-frogfish/tflite_runtime-2.5.0-cp38-cp38-linux_x86_64.whl</td>
|
|
||||||
</tr>
|
|
||||||
|
|
||||||
<tr>
|
|
||||||
<td style="white-space:nowrap" rowspan="4">macOS 10.15</td>
|
|
||||||
<td style="white-space:nowrap">3.5</td>
|
|
||||||
<td>https://github.com/google-coral/pycoral/releases/download/release-frogfish/tflite_runtime-2.5.0-cp35-cp35m-macosx_10_15_x86_64.whl</td>
|
|
||||||
</tr>
|
|
||||||
<tr>
|
|
||||||
<!-- Mac -->
|
|
||||||
<td style="white-space:nowrap">3.6</td>
|
|
||||||
<td>https://github.com/google-coral/pycoral/releases/download/release-frogfish/tflite_runtime-2.5.0-cp36-cp36m-macosx_10_15_x86_64.whl</td>
|
|
||||||
</tr>
|
|
||||||
<tr>
|
|
||||||
<!-- Mac -->
|
|
||||||
<td style="white-space:nowrap">3.7</td>
|
|
||||||
<td>https://github.com/google-coral/pycoral/releases/download/release-frogfish/tflite_runtime-2.5.0-cp37-cp37m-macosx_10_15_x86_64.whl</td>
|
|
||||||
</tr>
|
|
||||||
<tr>
|
|
||||||
<!-- Mac -->
|
|
||||||
<td style="white-space:nowrap">3.8</td>
|
|
||||||
<td>https://github.com/google-coral/pycoral/releases/download/release-frogfish/tflite_runtime-2.5.0-cp38-cp38-macosx_10_15_x86_64.whl</td>
|
|
||||||
</tr>
|
|
||||||
|
|
||||||
<tr>
|
|
||||||
<td style="white-space:nowrap" rowspan="4">Windows 10</td>
|
|
||||||
<td style="white-space:nowrap">3.5</td>
|
|
||||||
<td>https://github.com/google-coral/pycoral/releases/download/release-frogfish/tflite_runtime-2.5.0-cp35-cp35m-win_amd64.whl</td>
|
|
||||||
</tr>
|
|
||||||
<tr>
|
|
||||||
<!-- Win -->
|
|
||||||
<td style="white-space:nowrap">3.6</td>
|
|
||||||
<td>https://github.com/google-coral/pycoral/releases/download/release-frogfish/tflite_runtime-2.5.0-cp36-cp36m-win_amd64.whl</td>
|
|
||||||
</tr>
|
|
||||||
<tr>
|
|
||||||
<!-- Win -->
|
|
||||||
<td style="white-space:nowrap">3.7</td>
|
|
||||||
<td>https://github.com/google-coral/pycoral/releases/download/release-frogfish/tflite_runtime-2.5.0-cp37-cp37m-win_amd64.whl</td>
|
|
||||||
</tr>
|
|
||||||
<tr>
|
|
||||||
<!-- Win -->
|
|
||||||
<td style="white-space:nowrap">3.8</td>
|
|
||||||
<td>https://github.com/google-coral/pycoral/releases/download/release-frogfish/tflite_runtime-2.5.0-cp38-cp38-win_amd64.whl</td>
|
|
||||||
</tr>
|
|
||||||
|
|
||||||
</table>
|
|
||||||
|
|
||||||
## Run an inference using tflite_runtime
|
## Run an inference using tflite_runtime
|
||||||
|
|
||||||
To distinguish this interpreter-only package from the full TensorFlow package
|
Instead of importing `Interpreter` from the `tensorflow` module, you now need to
|
||||||
(allowing both to be installed, if you choose), the Python module provided in
|
|
||||||
the above wheel is named `tflite_runtime`.
|
|
||||||
|
|
||||||
So instead of importing `Interpreter` from the `tensorflow` module, you need to
|
|
||||||
import it from `tflite_runtime`.
|
import it from `tflite_runtime`.
|
||||||
|
|
||||||
For example, after you install the package above, copy and run the
|
For example, after you install the package above, copy and run the
|
||||||
|
|||||||
@ -3103,7 +3103,6 @@ cuda_py_test(
|
|||||||
tags = [
|
tags = [
|
||||||
"guitar",
|
"guitar",
|
||||||
"multi_gpu",
|
"multi_gpu",
|
||||||
"no_rocm",
|
|
||||||
"no_windows",
|
"no_windows",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
|
|||||||
@ -1078,6 +1078,7 @@ cuda_py_test(
|
|||||||
tags = [
|
tags = [
|
||||||
"multi_and_single_gpu",
|
"multi_and_single_gpu",
|
||||||
"no_cuda_asan", # times out
|
"no_cuda_asan", # times out
|
||||||
|
"no_rocm",
|
||||||
"notsan", # b/173031470
|
"notsan", # b/173031470
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
@ -1741,6 +1742,7 @@ distribute_py_test(
|
|||||||
shard_count = 2,
|
shard_count = 2,
|
||||||
tags = [
|
tags = [
|
||||||
"multi_and_single_gpu",
|
"multi_and_single_gpu",
|
||||||
|
"no_rocm",
|
||||||
"notsan", # TODO(b/160006974)
|
"notsan", # TODO(b/160006974)
|
||||||
],
|
],
|
||||||
xla_enable_strict_auto_jit = True,
|
xla_enable_strict_auto_jit = True,
|
||||||
@ -1773,6 +1775,7 @@ distribute_py_test(
|
|||||||
tags = [
|
tags = [
|
||||||
"multi_and_single_gpu",
|
"multi_and_single_gpu",
|
||||||
"no_cuda_asan", # times out
|
"no_cuda_asan", # times out
|
||||||
|
"no_rocm",
|
||||||
"notsan", # TODO(b/160006974)
|
"notsan", # TODO(b/160006974)
|
||||||
],
|
],
|
||||||
xla_enable_strict_auto_jit = True,
|
xla_enable_strict_auto_jit = True,
|
||||||
@ -1846,6 +1849,7 @@ distribute_py_test(
|
|||||||
disable_mlir_bridge = False,
|
disable_mlir_bridge = False,
|
||||||
tags = [
|
tags = [
|
||||||
"multi_and_single_gpu",
|
"multi_and_single_gpu",
|
||||||
|
"no_rocm",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":combinations",
|
":combinations",
|
||||||
|
|||||||
@ -1186,6 +1186,15 @@ class Context(object):
|
|||||||
self.ensure_initialized()
|
self.ensure_initialized()
|
||||||
return pywrap_tfe.TFE_Py_PackEagerTensors(self._handle, tensors)
|
return pywrap_tfe.TFE_Py_PackEagerTensors(self._handle, tensors)
|
||||||
|
|
||||||
|
def list_function_names(self):
|
||||||
|
"""Get a list of names of registered functions.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A set of names of all registered functions for the context.
|
||||||
|
"""
|
||||||
|
self.ensure_initialized()
|
||||||
|
return set(pywrap_tfe.TFE_ContextListFunctionNames(self._handle))
|
||||||
|
|
||||||
def remove_function(self, name):
|
def remove_function(self, name):
|
||||||
"""Remove a function from the context.
|
"""Remove a function from the context.
|
||||||
|
|
||||||
|
|||||||
@ -151,6 +151,16 @@ class ContextTest(test.TestCase):
|
|||||||
with self.assertRaisesRegex(ValueError, 'Multiple devices'):
|
with self.assertRaisesRegex(ValueError, 'Multiple devices'):
|
||||||
context.context().get_total_memory_usage('GPU')
|
context.context().get_total_memory_usage('GPU')
|
||||||
|
|
||||||
|
def testListFunctionNames(self):
|
||||||
|
|
||||||
|
@def_function.function
|
||||||
|
def f():
|
||||||
|
return constant_op.constant(1.)
|
||||||
|
|
||||||
|
concrete = f.get_concrete_function()
|
||||||
|
self.assertIn(concrete.name.decode(),
|
||||||
|
context.context().list_function_names())
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
ops.enable_eager_execution()
|
ops.enable_eager_execution()
|
||||||
|
|||||||
@ -498,9 +498,17 @@ class _EagerDefinedFunction(object):
|
|||||||
function_callback(self)
|
function_callback(self)
|
||||||
|
|
||||||
def add_to_graph(self, g=None):
|
def add_to_graph(self, g=None):
|
||||||
|
"""Add the function to the current context or a graph, if supplied.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
g: the graph to add the function to. If not supplied, the function will
|
||||||
|
be added to the current context.
|
||||||
|
"""
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
if not g and context.executing_eagerly():
|
if not g and context.executing_eagerly():
|
||||||
context.context().add_function_def(self.definition)
|
ctx = context.context()
|
||||||
|
if not ctx.has_function(self.name):
|
||||||
|
ctx.add_function_def(self.definition)
|
||||||
else:
|
else:
|
||||||
if not g._is_function(self.name):
|
if not g._is_function(self.name):
|
||||||
g._add_function(self)
|
g._add_function(self)
|
||||||
|
|||||||
@ -4334,6 +4334,7 @@ EagerContextThreadLocalData* GetEagerContextThreadLocalData(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (eager_context_thread_local_data_map == nullptr) {
|
if (eager_context_thread_local_data_map == nullptr) {
|
||||||
|
absl::LeakCheckDisabler disabler;
|
||||||
eager_context_thread_local_data_map = new EagerContextThreadLocalDataMap();
|
eager_context_thread_local_data_map = new EagerContextThreadLocalDataMap();
|
||||||
}
|
}
|
||||||
auto& thread_local_data =
|
auto& thread_local_data =
|
||||||
|
|||||||
@ -660,7 +660,7 @@ def assert_no_new_pyobjects_executing_eagerly(func=None, warmup_iters=2):
|
|||||||
# versions of python2.7.x.
|
# versions of python2.7.x.
|
||||||
for _ in range(warmup_iters):
|
for _ in range(warmup_iters):
|
||||||
f(self, *args, **kwargs)
|
f(self, *args, **kwargs)
|
||||||
# Since we aren't in the normal test lifecylce, we need to manually run
|
# Since we aren't in the normal test lifecycle, we need to manually run
|
||||||
# cleanups to clear out their object references.
|
# cleanups to clear out their object references.
|
||||||
self.doCleanups()
|
self.doCleanups()
|
||||||
|
|
||||||
@ -668,6 +668,10 @@ def assert_no_new_pyobjects_executing_eagerly(func=None, warmup_iters=2):
|
|||||||
# create and save as a dummy variable to include it as a baseline.
|
# create and save as a dummy variable to include it as a baseline.
|
||||||
obj_count_by_type = _get_object_count_by_type()
|
obj_count_by_type = _get_object_count_by_type()
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
|
# Make sure any registered functions are cleaned up in the C++ runtime.
|
||||||
|
registered_function_names = context.context().list_function_names()
|
||||||
|
|
||||||
# unittest.doCleanups adds to self._outcome with each unwound call.
|
# unittest.doCleanups adds to self._outcome with each unwound call.
|
||||||
# These objects are retained across gc collections so we exclude them
|
# These objects are retained across gc collections so we exclude them
|
||||||
# from the object count calculation.
|
# from the object count calculation.
|
||||||
@ -682,7 +686,7 @@ def assert_no_new_pyobjects_executing_eagerly(func=None, warmup_iters=2):
|
|||||||
}
|
}
|
||||||
for _ in range(3):
|
for _ in range(3):
|
||||||
f(self, *args, **kwargs)
|
f(self, *args, **kwargs)
|
||||||
# Since we aren't in the normal test lifecylce, we need to manually run
|
# Since we aren't in the normal test lifecycle, we need to manually run
|
||||||
# cleanups to clear out their object references.
|
# cleanups to clear out their object references.
|
||||||
self.doCleanups()
|
self.doCleanups()
|
||||||
# Note that gc.get_objects misses anything that isn't subject to garbage
|
# Note that gc.get_objects misses anything that isn't subject to garbage
|
||||||
@ -711,6 +715,14 @@ def assert_no_new_pyobjects_executing_eagerly(func=None, warmup_iters=2):
|
|||||||
exclude=gc.get_referents(self._outcome.errors,
|
exclude=gc.get_referents(self._outcome.errors,
|
||||||
self._outcome.skipped)) -
|
self._outcome.skipped)) -
|
||||||
obj_count_by_type)
|
obj_count_by_type)
|
||||||
|
|
||||||
|
# There should be no newly registered functions hanging around.
|
||||||
|
leftover_functions = (
|
||||||
|
context.context().list_function_names() - registered_function_names)
|
||||||
|
assert not leftover_functions, (
|
||||||
|
"The following functions were newly created: %s" %
|
||||||
|
leftover_functions)
|
||||||
|
|
||||||
# In some cases (specifically on MacOS), new_count is somehow
|
# In some cases (specifically on MacOS), new_count is somehow
|
||||||
# smaller than previous_count.
|
# smaller than previous_count.
|
||||||
# Using plain assert because not all classes using this decorator
|
# Using plain assert because not all classes using this decorator
|
||||||
|
|||||||
@ -249,6 +249,7 @@ distribute_py_test(
|
|||||||
main = "custom_training_loop_metrics_test.py",
|
main = "custom_training_loop_metrics_test.py",
|
||||||
tags = [
|
tags = [
|
||||||
"multi_and_single_gpu",
|
"multi_and_single_gpu",
|
||||||
|
"no_rocm",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":strategy_combinations",
|
":strategy_combinations",
|
||||||
@ -270,6 +271,7 @@ distribute_py_test(
|
|||||||
tags = [
|
tags = [
|
||||||
"multi_and_single_gpu",
|
"multi_and_single_gpu",
|
||||||
"no_cuda_asan", # times out
|
"no_cuda_asan", # times out
|
||||||
|
"no_rocm",
|
||||||
"notsan", # TODO(b/170954243)
|
"notsan", # TODO(b/170954243)
|
||||||
],
|
],
|
||||||
tpu_tags = [
|
tpu_tags = [
|
||||||
@ -543,6 +545,7 @@ distribute_py_test(
|
|||||||
shard_count = 31,
|
shard_count = 31,
|
||||||
tags = [
|
tags = [
|
||||||
"multi_and_single_gpu",
|
"multi_and_single_gpu",
|
||||||
|
"no_rocm",
|
||||||
"no_windows_gpu",
|
"no_windows_gpu",
|
||||||
"noasan", # TODO(b/337374867) fails with -fsanitize=null
|
"noasan", # TODO(b/337374867) fails with -fsanitize=null
|
||||||
"notpu", # TODO(b/153672562)
|
"notpu", # TODO(b/153672562)
|
||||||
@ -562,6 +565,7 @@ distribute_py_test(
|
|||||||
shard_count = 7,
|
shard_count = 7,
|
||||||
tags = [
|
tags = [
|
||||||
"multi_and_single_gpu",
|
"multi_and_single_gpu",
|
||||||
|
"no_rocm",
|
||||||
],
|
],
|
||||||
xla_tags = [
|
xla_tags = [
|
||||||
"no_cuda_asan", # times out
|
"no_cuda_asan", # times out
|
||||||
|
|||||||
@ -671,6 +671,7 @@ class Functional(training_lib.Model):
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError: In case of improperly formatted config dict.
|
ValueError: In case of improperly formatted config dict.
|
||||||
"""
|
"""
|
||||||
|
with generic_utils.SharedObjectLoadingScope():
|
||||||
input_tensors, output_tensors, created_layers = reconstruct_from_config(
|
input_tensors, output_tensors, created_layers = reconstruct_from_config(
|
||||||
config, custom_objects)
|
config, custom_objects)
|
||||||
model = cls(inputs=input_tensors, outputs=output_tensors,
|
model = cls(inputs=input_tensors, outputs=output_tensors,
|
||||||
@ -1346,6 +1347,8 @@ def get_network_config(network, serialize_layer_fn=None):
|
|||||||
node_conversion_map[node_key] = kept_nodes
|
node_conversion_map[node_key] = kept_nodes
|
||||||
kept_nodes += 1
|
kept_nodes += 1
|
||||||
layer_configs = []
|
layer_configs = []
|
||||||
|
|
||||||
|
with generic_utils.SharedObjectSavingScope():
|
||||||
for layer in network.layers: # From the earliest layers on.
|
for layer in network.layers: # From the earliest layers on.
|
||||||
filtered_inbound_nodes = []
|
filtered_inbound_nodes = []
|
||||||
for original_node_index, node in enumerate(layer._inbound_nodes):
|
for original_node_index, node in enumerate(layer._inbound_nodes):
|
||||||
|
|||||||
@ -80,7 +80,6 @@ cuda_py_test(
|
|||||||
name = "gradient_checkpoint_test",
|
name = "gradient_checkpoint_test",
|
||||||
srcs = ["gradient_checkpoint_test.py"],
|
srcs = ["gradient_checkpoint_test.py"],
|
||||||
python_version = "PY3",
|
python_version = "PY3",
|
||||||
tags = ["no_rocm"],
|
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow:tensorflow_py_no_contrib",
|
"//tensorflow:tensorflow_py_no_contrib",
|
||||||
],
|
],
|
||||||
|
|||||||
@ -12,6 +12,7 @@ package(
|
|||||||
"//tensorflow/python/keras:__subpackages__",
|
"//tensorflow/python/keras:__subpackages__",
|
||||||
"//tensorflow/python/training/tracking:__pkg__",
|
"//tensorflow/python/training/tracking:__pkg__",
|
||||||
"//tensorflow/tools/pip_package:__pkg__",
|
"//tensorflow/tools/pip_package:__pkg__",
|
||||||
|
"//tensorflow_models/official/vision/beta/projects/residual_mobilenet/modeling/backbones:__pkg__",
|
||||||
],
|
],
|
||||||
licenses = ["notice"], # Apache 2.0
|
licenses = ["notice"], # Apache 2.0
|
||||||
)
|
)
|
||||||
@ -853,6 +854,7 @@ cuda_py_test(
|
|||||||
srcs = ["gru_v2_test.py"],
|
srcs = ["gru_v2_test.py"],
|
||||||
python_version = "PY3",
|
python_version = "PY3",
|
||||||
shard_count = 12,
|
shard_count = 12,
|
||||||
|
tags = ["no_rocm"],
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/python:client_testlib",
|
"//tensorflow/python:client_testlib",
|
||||||
"//tensorflow/python/keras",
|
"//tensorflow/python/keras",
|
||||||
|
|||||||
@ -114,13 +114,11 @@ class CategoryCrossing(base_preprocessing_layer.PreprocessingLayer):
|
|||||||
`[[b'1_X_2_X_3'], [b'4_X_5_X_6']]`
|
`[[b'1_X_2_X_3'], [b'4_X_5_X_6']]`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, depth=None, name=None, separator=None, **kwargs):
|
def __init__(self, depth=None, name=None, separator='_X_', **kwargs):
|
||||||
super(CategoryCrossing, self).__init__(name=name, **kwargs)
|
super(CategoryCrossing, self).__init__(name=name, **kwargs)
|
||||||
base_preprocessing_layer.keras_kpl_gauge.get_cell(
|
base_preprocessing_layer.keras_kpl_gauge.get_cell(
|
||||||
'CategoryCrossing').set(True)
|
'CategoryCrossing').set(True)
|
||||||
self.depth = depth
|
self.depth = depth
|
||||||
if separator is None:
|
|
||||||
separator = '_X_'
|
|
||||||
self.separator = separator
|
self.separator = separator
|
||||||
if isinstance(depth, (tuple, list)):
|
if isinstance(depth, (tuple, list)):
|
||||||
self._depth_tuple = depth
|
self._depth_tuple = depth
|
||||||
|
|||||||
@ -393,6 +393,10 @@ def clone_model(model, input_tensors=None, clone_function=None):
|
|||||||
except that it creates new layers (and thus new weights) instead
|
except that it creates new layers (and thus new weights) instead
|
||||||
of sharing the weights of the existing layers.
|
of sharing the weights of the existing layers.
|
||||||
|
|
||||||
|
`clone_model` will not preserve the uniqueness of shared objects within the
|
||||||
|
model (e.g. a single variable attached to two distinct layers will be
|
||||||
|
restored as two separate variables).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model: Instance of `Model`
|
model: Instance of `Model`
|
||||||
(could be a functional model or a Sequential model).
|
(could be a functional model or a Sequential model).
|
||||||
|
|||||||
@ -158,7 +158,6 @@ cuda_py_test(
|
|||||||
size = "medium",
|
size = "medium",
|
||||||
srcs = ["adadelta_test.py"],
|
srcs = ["adadelta_test.py"],
|
||||||
shard_count = 4,
|
shard_count = 4,
|
||||||
tags = ["no_rocm"],
|
|
||||||
# TODO(b/168527439): invalid resource variable reference on GPU for TFRT.
|
# TODO(b/168527439): invalid resource variable reference on GPU for TFRT.
|
||||||
deps = [
|
deps = [
|
||||||
":optimizer_v2",
|
":optimizer_v2",
|
||||||
@ -239,7 +238,6 @@ cuda_py_test(
|
|||||||
srcs = ["optimizer_v2_test.py"],
|
srcs = ["optimizer_v2_test.py"],
|
||||||
shard_count = 8,
|
shard_count = 8,
|
||||||
tags = [
|
tags = [
|
||||||
"no_rocm",
|
|
||||||
"no_windows",
|
"no_windows",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
@ -297,7 +295,6 @@ cuda_py_test(
|
|||||||
size = "medium",
|
size = "medium",
|
||||||
srcs = ["rmsprop_test.py"],
|
srcs = ["rmsprop_test.py"],
|
||||||
shard_count = 2,
|
shard_count = 2,
|
||||||
tags = ["no_rocm"],
|
|
||||||
xla_tags = [
|
xla_tags = [
|
||||||
"no_cuda_asan", # times out
|
"no_cuda_asan", # times out
|
||||||
],
|
],
|
||||||
|
|||||||
@ -148,6 +148,7 @@ def save_model(model,
|
|||||||
hdf5_format.save_model_to_hdf5(
|
hdf5_format.save_model_to_hdf5(
|
||||||
model, filepath, overwrite, include_optimizer)
|
model, filepath, overwrite, include_optimizer)
|
||||||
else:
|
else:
|
||||||
|
with generic_utils.SharedObjectSavingScope():
|
||||||
saved_model_save.save(model, filepath, overwrite, include_optimizer,
|
saved_model_save.save(model, filepath, overwrite, include_optimizer,
|
||||||
signatures, options, save_traces)
|
signatures, options, save_traces)
|
||||||
|
|
||||||
@ -194,6 +195,7 @@ def load_model(filepath, custom_objects=None, compile=True, options=None): # py
|
|||||||
ImportError: if loading from an hdf5 file and h5py is not available.
|
ImportError: if loading from an hdf5 file and h5py is not available.
|
||||||
IOError: In case of an invalid savefile.
|
IOError: In case of an invalid savefile.
|
||||||
"""
|
"""
|
||||||
|
with generic_utils.SharedObjectLoadingScope():
|
||||||
with generic_utils.CustomObjectScope(custom_objects or {}):
|
with generic_utils.CustomObjectScope(custom_objects or {}):
|
||||||
with load_context.load_context(options):
|
with load_context.load_context(options):
|
||||||
if (h5py is not None and
|
if (h5py is not None and
|
||||||
|
|||||||
@ -18,6 +18,7 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import collections
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import sys
|
import sys
|
||||||
@ -25,12 +26,14 @@ import tempfile
|
|||||||
|
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from six import string_types
|
||||||
|
|
||||||
from tensorflow.python import keras
|
from tensorflow.python import keras
|
||||||
from tensorflow.python import tf2
|
from tensorflow.python import tf2
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.feature_column import feature_column_lib
|
from tensorflow.python.feature_column import feature_column_lib
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import sparse_tensor
|
from tensorflow.python.framework import sparse_tensor
|
||||||
from tensorflow.python.keras import combinations
|
from tensorflow.python.keras import combinations
|
||||||
@ -859,6 +862,125 @@ class TestWholeModelSaving(keras_parameterized.TestCase):
|
|||||||
self.assertAllEqual(loaded_model.predict(args, batch_size=batch_size),
|
self.assertAllEqual(loaded_model.predict(args, batch_size=batch_size),
|
||||||
expected)
|
expected)
|
||||||
|
|
||||||
|
@combinations.generate(combinations.combine(mode=['eager']))
|
||||||
|
def test_shared_objects(self):
|
||||||
|
class OuterLayer(keras.layers.Layer):
|
||||||
|
|
||||||
|
def __init__(self, inner_layer):
|
||||||
|
super(OuterLayer, self).__init__()
|
||||||
|
self.inner_layer = inner_layer
|
||||||
|
|
||||||
|
def call(self, inputs):
|
||||||
|
return self.inner_layer(inputs)
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
return {
|
||||||
|
'inner_layer': generic_utils.serialize_keras_object(
|
||||||
|
self.inner_layer)
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, config):
|
||||||
|
return cls(generic_utils.deserialize_keras_object(
|
||||||
|
config['inner_layer']))
|
||||||
|
|
||||||
|
class InnerLayer(keras.layers.Layer):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(InnerLayer, self).__init__()
|
||||||
|
self.v = self.add_weight(name='v', shape=[], dtype=dtypes.float32)
|
||||||
|
|
||||||
|
def call(self, inputs):
|
||||||
|
return self.v + inputs
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, config):
|
||||||
|
return cls()
|
||||||
|
|
||||||
|
# Create a model with 2 output layers that share the same inner layer.
|
||||||
|
inner_layer = InnerLayer()
|
||||||
|
outer_layer_1 = OuterLayer(inner_layer)
|
||||||
|
outer_layer_2 = OuterLayer(inner_layer)
|
||||||
|
input_ = keras.Input(shape=(1,))
|
||||||
|
model = keras.Model(
|
||||||
|
inputs=input_, outputs=[outer_layer_1(input_), outer_layer_2(input_)])
|
||||||
|
|
||||||
|
# Changes to the shared layer should affect both outputs.
|
||||||
|
model.layers[1].inner_layer.v.assign(5)
|
||||||
|
self.assertAllEqual(model(1), [6.0, 6.0])
|
||||||
|
model.layers[1].inner_layer.v.assign(3)
|
||||||
|
self.assertAllEqual(model(1), [4.0, 4.0])
|
||||||
|
|
||||||
|
# After loading, changes to the shared layer should still affect both
|
||||||
|
# outputs.
|
||||||
|
def _do_assertions(loaded):
|
||||||
|
loaded.layers[1].inner_layer.v.assign(5)
|
||||||
|
self.assertAllEqual(loaded(1), [6.0, 6.0])
|
||||||
|
loaded.layers[1].inner_layer.v.assign(3)
|
||||||
|
self.assertAllEqual(loaded(1), [4.0, 4.0])
|
||||||
|
loaded.layers[2].inner_layer.v.assign(5)
|
||||||
|
self.assertAllEqual(loaded(1), [6.0, 6.0])
|
||||||
|
loaded.layers[2].inner_layer.v.assign(3)
|
||||||
|
self.assertAllEqual(loaded(1), [4.0, 4.0])
|
||||||
|
|
||||||
|
# We'd like to make sure we only attach shared object IDs when strictly
|
||||||
|
# necessary, so we'll recursively traverse the generated config to count
|
||||||
|
# whether we have the exact number we expect.
|
||||||
|
def _get_all_keys_recursive(dict_or_iterable):
|
||||||
|
if isinstance(dict_or_iterable, dict):
|
||||||
|
for key in dict_or_iterable.keys():
|
||||||
|
yield key
|
||||||
|
for key in _get_all_keys_recursive(dict_or_iterable.values()):
|
||||||
|
yield key
|
||||||
|
elif isinstance(dict_or_iterable, string_types):
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
for item in dict_or_iterable:
|
||||||
|
for key in _get_all_keys_recursive(item):
|
||||||
|
yield key
|
||||||
|
# Not an iterable or dictionary
|
||||||
|
except TypeError:
|
||||||
|
return
|
||||||
|
|
||||||
|
with generic_utils.CustomObjectScope({
|
||||||
|
'OuterLayer': OuterLayer, 'InnerLayer': InnerLayer}):
|
||||||
|
|
||||||
|
# Test saving and loading to disk
|
||||||
|
save_format = testing_utils.get_save_format()
|
||||||
|
saved_model_dir = self._save_model_dir()
|
||||||
|
keras.models.save_model(model, saved_model_dir, save_format=save_format)
|
||||||
|
loaded = keras.models.load_model(saved_model_dir)
|
||||||
|
_do_assertions(loaded)
|
||||||
|
|
||||||
|
# Test recreating directly from config
|
||||||
|
config = model.get_config()
|
||||||
|
key_count = collections.Counter(_get_all_keys_recursive(config))
|
||||||
|
self.assertEqual(key_count[generic_utils.SHARED_OBJECT_KEY], 2)
|
||||||
|
loaded = keras.Model.from_config(config)
|
||||||
|
_do_assertions(loaded)
|
||||||
|
|
||||||
|
@combinations.generate(combinations.combine(mode=['eager']))
|
||||||
|
def test_shared_objects_wrapper(self):
|
||||||
|
"""Tests that shared layers wrapped with `Wrapper` restore correctly."""
|
||||||
|
input_ = keras.Input(shape=(1,))
|
||||||
|
unwrapped = keras.layers.Layer(name='unwrapped')
|
||||||
|
wrapped = keras.layers.Wrapper(unwrapped, name='wrapped')
|
||||||
|
model = keras.Model(inputs=input_,
|
||||||
|
outputs=[unwrapped(input_), wrapped(input_)])
|
||||||
|
|
||||||
|
# Test recreating directly from config
|
||||||
|
config = model.get_config()
|
||||||
|
loaded = keras.Model.from_config(config)
|
||||||
|
self.assertIs(loaded.layers[1], loaded.layers[2].layer)
|
||||||
|
|
||||||
|
# Test saving and loading to disk
|
||||||
|
save_format = testing_utils.get_save_format()
|
||||||
|
saved_model_dir = self._save_model_dir()
|
||||||
|
keras.models.save_model(model, saved_model_dir, save_format=save_format)
|
||||||
|
loaded = keras.models.load_model(saved_model_dir)
|
||||||
|
self.assertIs(loaded.layers[1], loaded.layers[2].layer)
|
||||||
|
|
||||||
|
|
||||||
# Factory functions to create models that will be serialized inside a Network.
|
# Factory functions to create models that will be serialized inside a Network.
|
||||||
def _make_graph_network(input_size, output_size):
|
def _make_graph_network(input_size, output_size):
|
||||||
|
|||||||
@ -46,7 +46,6 @@ class LayerSavedModelSaver(base_serialization.SavedModelSaver):
|
|||||||
# TODO(kathywu): Synchronize with the keras spec (go/keras-json-spec) once
|
# TODO(kathywu): Synchronize with the keras spec (go/keras-json-spec) once
|
||||||
# the python config serialization has caught up.
|
# the python config serialization has caught up.
|
||||||
metadata = dict(
|
metadata = dict(
|
||||||
class_name=generic_utils.get_registered_name(type(self.obj)),
|
|
||||||
name=self.obj.name,
|
name=self.obj.name,
|
||||||
trainable=self.obj.trainable,
|
trainable=self.obj.trainable,
|
||||||
expects_training_arg=self.obj._expects_training_arg, # pylint: disable=protected-access
|
expects_training_arg=self.obj._expects_training_arg, # pylint: disable=protected-access
|
||||||
@ -56,7 +55,7 @@ class LayerSavedModelSaver(base_serialization.SavedModelSaver):
|
|||||||
must_restore_from_config=self.obj._must_restore_from_config, # pylint: disable=protected-access
|
must_restore_from_config=self.obj._must_restore_from_config, # pylint: disable=protected-access
|
||||||
)
|
)
|
||||||
|
|
||||||
metadata.update(get_config(self.obj))
|
metadata.update(get_serialized(self.obj))
|
||||||
if self.obj.input_spec is not None:
|
if self.obj.input_spec is not None:
|
||||||
# Layer's input_spec has already been type-checked in the property setter.
|
# Layer's input_spec has already been type-checked in the property setter.
|
||||||
metadata['input_spec'] = nest.map_structure(
|
metadata['input_spec'] = nest.map_structure(
|
||||||
@ -110,16 +109,12 @@ class LayerSavedModelSaver(base_serialization.SavedModelSaver):
|
|||||||
|
|
||||||
# TODO(kathywu): Move serialization utils (and related utils from
|
# TODO(kathywu): Move serialization utils (and related utils from
|
||||||
# generic_utils.py) to a separate file.
|
# generic_utils.py) to a separate file.
|
||||||
def get_config(obj):
|
def get_serialized(obj):
|
||||||
with generic_utils.skip_failed_serialization():
|
with generic_utils.skip_failed_serialization():
|
||||||
# Store the config dictionary, which may be used when reviving the object.
|
# Store the config dictionary, which may be used when reviving the object.
|
||||||
# When loading, the program will attempt to revive the object from config,
|
# When loading, the program will attempt to revive the object from config,
|
||||||
# and if that fails, the object will be revived from the SavedModel.
|
# and if that fails, the object will be revived from the SavedModel.
|
||||||
config = generic_utils.serialize_keras_object(obj)['config']
|
return generic_utils.serialize_keras_object(obj)
|
||||||
|
|
||||||
if config is not None:
|
|
||||||
return {'config': config}
|
|
||||||
return {}
|
|
||||||
|
|
||||||
|
|
||||||
class InputLayerSavedModelSaver(base_serialization.SavedModelSaver):
|
class InputLayerSavedModelSaver(base_serialization.SavedModelSaver):
|
||||||
|
|||||||
@ -492,13 +492,15 @@ class KerasObjectLoader(object):
|
|||||||
# found.
|
# found.
|
||||||
class_name = metadata.get('class_name')
|
class_name = metadata.get('class_name')
|
||||||
config = metadata.get('config')
|
config = metadata.get('config')
|
||||||
|
shared_object_id = metadata.get('shared_object_id')
|
||||||
must_restore_from_config = metadata.get('must_restore_from_config')
|
must_restore_from_config = metadata.get('must_restore_from_config')
|
||||||
if not generic_utils.validate_config(config):
|
if not generic_utils.validate_config(config):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
obj = layers_module.deserialize(
|
obj = layers_module.deserialize(
|
||||||
generic_utils.serialize_keras_class_and_config(class_name, config))
|
generic_utils.serialize_keras_class_and_config(
|
||||||
|
class_name, config, shared_object_id=shared_object_id))
|
||||||
except ValueError:
|
except ValueError:
|
||||||
if must_restore_from_config:
|
if must_restore_from_config:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
|
|||||||
@ -36,7 +36,7 @@ class MetricSavedModelSaver(layer_serialization.LayerSavedModelSaver):
|
|||||||
class_name=generic_utils.get_registered_name(type(self.obj)),
|
class_name=generic_utils.get_registered_name(type(self.obj)),
|
||||||
name=self.obj.name,
|
name=self.obj.name,
|
||||||
dtype=self.obj.dtype)
|
dtype=self.obj.dtype)
|
||||||
metadata.update(layer_serialization.get_config(self.obj))
|
metadata.update(layer_serialization.get_serialized(self.obj))
|
||||||
if self.obj._build_input_shape is not None: # pylint: disable=protected-access
|
if self.obj._build_input_shape is not None: # pylint: disable=protected-access
|
||||||
metadata['build_input_shape'] = self.obj._build_input_shape # pylint: disable=protected-access
|
metadata['build_input_shape'] = self.obj._build_input_shape # pylint: disable=protected-access
|
||||||
return metadata
|
return metadata
|
||||||
|
|||||||
@ -24,8 +24,10 @@ import marshal
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
|
import threading
|
||||||
import time
|
import time
|
||||||
import types as python_types
|
import types as python_types
|
||||||
|
import weakref
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import six
|
import six
|
||||||
@ -110,9 +112,205 @@ def get_custom_objects():
|
|||||||
return _GLOBAL_CUSTOM_OBJECTS
|
return _GLOBAL_CUSTOM_OBJECTS
|
||||||
|
|
||||||
|
|
||||||
def serialize_keras_class_and_config(cls_name, cls_config):
|
# Store a unique, per-object ID for shared objects.
|
||||||
|
#
|
||||||
|
# We store a unique ID for each object so that we may, at loading time,
|
||||||
|
# re-create the network properly. Without this ID, we would have no way of
|
||||||
|
# determining whether a config is a description of a new object that
|
||||||
|
# should be created or is merely a reference to an already-created object.
|
||||||
|
SHARED_OBJECT_KEY = 'shared_object_id'
|
||||||
|
|
||||||
|
|
||||||
|
class NoopLoadingScope(object):
|
||||||
|
"""The default shared object loading scope. It does nothing.
|
||||||
|
|
||||||
|
Created to simplify serialization code that doesn't care about shared objects
|
||||||
|
(e.g. when serializing a single object).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get(self, unused_object_id):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def set(self, object_id, obj):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
SHARED_OBJECT_LOADING = threading.local()
|
||||||
|
|
||||||
|
|
||||||
|
def _shared_object_loading_scope():
|
||||||
|
"""Get the current shared object saving scope in a threadsafe manner.
|
||||||
|
|
||||||
|
Attributes on the threadlocal variable must be set per-thread, thus we
|
||||||
|
cannot initialize these globally.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A SharedObjectLoadingScope or NoopLoadingScope object.
|
||||||
|
"""
|
||||||
|
return getattr(SHARED_OBJECT_LOADING, 'scope', NoopLoadingScope())
|
||||||
|
|
||||||
|
|
||||||
|
class SharedObjectLoadingScope(object):
|
||||||
|
"""A context manager for keeping track of loaded objects.
|
||||||
|
|
||||||
|
During the deserialization process, we may come across objects that are
|
||||||
|
shared across multiple layers. In order to accurately restore the network
|
||||||
|
structure to its original state, `SharedObjectLoadingScope` allows us to
|
||||||
|
re-use shared objects rather than cloning them.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
global SHARED_OBJECT_LOADING
|
||||||
|
|
||||||
|
SHARED_OBJECT_LOADING.scope = self
|
||||||
|
self._obj_ids_to_obj = {}
|
||||||
|
return self
|
||||||
|
|
||||||
|
def get(self, object_id):
|
||||||
|
"""Given a shared object ID, returns a previously instantiated object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
object_id: shared object ID to use when attempting to find already-loaded
|
||||||
|
object.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The object, if we've seen this ID before. Else, `None`.
|
||||||
|
"""
|
||||||
|
# Explicitly check for `None` internally to make external calling code a
|
||||||
|
# bit cleaner.
|
||||||
|
if object_id is None:
|
||||||
|
return
|
||||||
|
return self._obj_ids_to_obj.get(object_id)
|
||||||
|
|
||||||
|
def set(self, object_id, obj):
|
||||||
|
"""Stores an instantiated object for future lookup and sharing."""
|
||||||
|
if object_id is None:
|
||||||
|
return
|
||||||
|
self._obj_ids_to_obj[object_id] = obj
|
||||||
|
|
||||||
|
def __exit__(self, *args, **kwargs):
|
||||||
|
global SHARED_OBJECT_LOADING
|
||||||
|
SHARED_OBJECT_LOADING.scope = NoopLoadingScope()
|
||||||
|
|
||||||
|
|
||||||
|
SHARED_OBJECT_SAVING = threading.local()
|
||||||
|
|
||||||
|
|
||||||
|
def _shared_object_saving_scope():
|
||||||
|
"""Get the current shared object saving scope in a threadsafe manner.
|
||||||
|
|
||||||
|
Attributes on the threadlocal variable must be set per-thread, thus we
|
||||||
|
cannot initialize these globally.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A SharedObjectSavingScope object or None.
|
||||||
|
"""
|
||||||
|
return getattr(SHARED_OBJECT_SAVING, 'scope', None)
|
||||||
|
|
||||||
|
|
||||||
|
class SharedObjectConfig(dict):
|
||||||
|
"""A configuration container that keeps track of references.
|
||||||
|
|
||||||
|
`SharedObjectConfig` will automatically attach a shared object ID to any
|
||||||
|
configs which are referenced more than once, allowing for proper shared
|
||||||
|
object reconstruction at load time.
|
||||||
|
|
||||||
|
In most cases, it would be more proper to subclass something like
|
||||||
|
`collections.UserDict` or `collections.Mapping` rather than `dict` directly.
|
||||||
|
Unfortunately, python's json encoder does not support `Mapping`s. This is
|
||||||
|
important functionality to retain, since we are dealing with serialization.
|
||||||
|
|
||||||
|
We should be safe to subclass `dict` here, since we aren't actually
|
||||||
|
overriding any core methods, only augmenting with a new one for reference
|
||||||
|
counting.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, base_config, object_id, **kwargs):
|
||||||
|
self.ref_count = 1
|
||||||
|
self.object_id = object_id
|
||||||
|
super(SharedObjectConfig, self).__init__(base_config, **kwargs)
|
||||||
|
|
||||||
|
def increment_ref_count(self):
|
||||||
|
# As soon as we've seen the object more than once, we want to attach the
|
||||||
|
# shared object ID. This allows us to only attach the shared object ID when
|
||||||
|
# it's strictly necessary, making backwards compatibility breakage less
|
||||||
|
# likely.
|
||||||
|
if self.ref_count == 1:
|
||||||
|
self[SHARED_OBJECT_KEY] = self.object_id
|
||||||
|
self.ref_count += 1
|
||||||
|
|
||||||
|
|
||||||
|
class SharedObjectSavingScope(object):
|
||||||
|
"""Keeps track of shared object configs when serializing."""
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
global SHARED_OBJECT_SAVING
|
||||||
|
|
||||||
|
# Serialization can happen at a number of layers for a number of reasons.
|
||||||
|
# We may end up with a case where we're opening a saving scope within
|
||||||
|
# another saving scope. In that case, we'd like to use the outermost scope
|
||||||
|
# available and ignore inner scopes, since there is not (yet) a reasonable
|
||||||
|
# use case for having these nested and distinct.
|
||||||
|
if _shared_object_saving_scope() is not None:
|
||||||
|
self._passthrough = True
|
||||||
|
return _shared_object_saving_scope()
|
||||||
|
else:
|
||||||
|
self._passthrough = False
|
||||||
|
|
||||||
|
SHARED_OBJECT_SAVING.scope = self
|
||||||
|
self._shared_objects_config = weakref.WeakKeyDictionary()
|
||||||
|
self._next_id = 0
|
||||||
|
return self
|
||||||
|
|
||||||
|
def get_config(self, obj):
|
||||||
|
"""Gets a `SharedObjectConfig` if one has already been seen for `obj`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
obj: The object for which to retrieve the `SharedObjectConfig`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The SharedObjectConfig for a given object, if already seen. Else,
|
||||||
|
`None`.
|
||||||
|
"""
|
||||||
|
if obj in self._shared_objects_config:
|
||||||
|
shared_object_config = self._shared_objects_config[obj]
|
||||||
|
shared_object_config.increment_ref_count()
|
||||||
|
return shared_object_config
|
||||||
|
|
||||||
|
def create_config(self, base_config, obj):
|
||||||
|
shared_object_config = SharedObjectConfig(base_config, self._next_id)
|
||||||
|
self._next_id += 1
|
||||||
|
self._shared_objects_config[obj] = shared_object_config
|
||||||
|
return shared_object_config
|
||||||
|
|
||||||
|
def __exit__(self, *args, **kwargs):
|
||||||
|
if not self._passthrough:
|
||||||
|
global SHARED_OBJECT_SAVING
|
||||||
|
SHARED_OBJECT_SAVING.scope = None
|
||||||
|
|
||||||
|
|
||||||
|
def serialize_keras_class_and_config(
|
||||||
|
cls_name, cls_config, obj=None, shared_object_id=None):
|
||||||
"""Returns the serialization of the class with the given config."""
|
"""Returns the serialization of the class with the given config."""
|
||||||
return {'class_name': cls_name, 'config': cls_config}
|
base_config = {'class_name': cls_name, 'config': cls_config}
|
||||||
|
|
||||||
|
# We call `serialize_keras_class_and_config` for some branches of the load
|
||||||
|
# path. In that case, we may already have a shared object ID we'd like to
|
||||||
|
# retain.
|
||||||
|
if shared_object_id is not None:
|
||||||
|
base_config[SHARED_OBJECT_KEY] = shared_object_id
|
||||||
|
|
||||||
|
# If we have an active `SharedObjectSavingScope`, check whether we've already
|
||||||
|
# serialized this config. If so, just use that config. This will store an
|
||||||
|
# extra ID field in the config, allowing us to re-create the shared object
|
||||||
|
# relationship at load time.
|
||||||
|
if _shared_object_saving_scope() is not None and obj is not None:
|
||||||
|
shared_object_config = _shared_object_saving_scope().get_config(obj)
|
||||||
|
if shared_object_config is None:
|
||||||
|
return _shared_object_saving_scope().create_config(base_config, obj)
|
||||||
|
return shared_object_config
|
||||||
|
|
||||||
|
return base_config
|
||||||
|
|
||||||
|
|
||||||
@keras_export('keras.utils.register_keras_serializable')
|
@keras_export('keras.utils.register_keras_serializable')
|
||||||
@ -234,7 +432,19 @@ def get_registered_object(name, custom_objects=None, module_objects=None):
|
|||||||
|
|
||||||
@keras_export('keras.utils.serialize_keras_object')
|
@keras_export('keras.utils.serialize_keras_object')
|
||||||
def serialize_keras_object(instance):
|
def serialize_keras_object(instance):
|
||||||
"""Serialize a Keras object into a JSON-compatible representation."""
|
"""Serialize a Keras object into a JSON-compatible representation.
|
||||||
|
|
||||||
|
Calls to `serialize_keras_object` while underneath the
|
||||||
|
`SharedObjectSavingScope` context manager will cause any objects re-used
|
||||||
|
across multiple layers to be saved with a special shared object ID. This
|
||||||
|
allows the network to be re-created properly during deserialization.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
instance: The object to serialize.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dict-like, JSON-compatible representation of the object's config.
|
||||||
|
"""
|
||||||
_, instance = tf_decorator.unwrap(instance)
|
_, instance = tf_decorator.unwrap(instance)
|
||||||
if instance is None:
|
if instance is None:
|
||||||
return None
|
return None
|
||||||
@ -265,7 +475,8 @@ def serialize_keras_object(instance):
|
|||||||
serialization_config[key] = item
|
serialization_config[key] = item
|
||||||
|
|
||||||
name = get_registered_name(instance.__class__)
|
name = get_registered_name(instance.__class__)
|
||||||
return serialize_keras_class_and_config(name, serialization_config)
|
return serialize_keras_class_and_config(
|
||||||
|
name, serialization_config, instance)
|
||||||
if hasattr(instance, '__name__'):
|
if hasattr(instance, '__name__'):
|
||||||
return get_registered_name(instance)
|
return get_registered_name(instance)
|
||||||
raise ValueError('Cannot serialize', instance)
|
raise ValueError('Cannot serialize', instance)
|
||||||
@ -286,8 +497,9 @@ def class_and_config_for_serialized_keras_object(
|
|||||||
custom_objects=None,
|
custom_objects=None,
|
||||||
printable_module_name='object'):
|
printable_module_name='object'):
|
||||||
"""Returns the class name and config for a serialized keras object."""
|
"""Returns the class name and config for a serialized keras object."""
|
||||||
if (not isinstance(config, dict) or 'class_name' not in config or
|
if (not isinstance(config, dict)
|
||||||
'config' not in config):
|
or 'class_name' not in config
|
||||||
|
or 'config' not in config):
|
||||||
raise ValueError('Improper config format: ' + str(config))
|
raise ValueError('Improper config format: ' + str(config))
|
||||||
|
|
||||||
class_name = config['class_name']
|
class_name = config['class_name']
|
||||||
@ -341,7 +553,24 @@ def deserialize_keras_object(identifier,
|
|||||||
module_objects=None,
|
module_objects=None,
|
||||||
custom_objects=None,
|
custom_objects=None,
|
||||||
printable_module_name='object'):
|
printable_module_name='object'):
|
||||||
"""Turns the serialized form of a Keras object back into an actual object."""
|
"""Turns the serialized form of a Keras object back into an actual object.
|
||||||
|
|
||||||
|
Calls to `deserialize_keras_object` while underneath the
|
||||||
|
`SharedObjectLoadingScope` context manager will cause any already-seen shared
|
||||||
|
objects to be returned as-is rather than creating a new object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
identifier: the serialized form of the object.
|
||||||
|
module_objects: A dictionary of custom objects to look the name up in.
|
||||||
|
Generally, module_objects is provided by midlevel library implementers.
|
||||||
|
custom_objects: A dictionary of custom objects to look the name up in.
|
||||||
|
Generally, custom_objects is provided by the user.
|
||||||
|
printable_module_name: A human-readable string representing the type of the
|
||||||
|
object. Printed in case of exception.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The deserialized object.
|
||||||
|
"""
|
||||||
if identifier is None:
|
if identifier is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -351,25 +580,39 @@ def deserialize_keras_object(identifier,
|
|||||||
(cls, cls_config) = class_and_config_for_serialized_keras_object(
|
(cls, cls_config) = class_and_config_for_serialized_keras_object(
|
||||||
config, module_objects, custom_objects, printable_module_name)
|
config, module_objects, custom_objects, printable_module_name)
|
||||||
|
|
||||||
|
# If this object has already been loaded (i.e. it's shared between multiple
|
||||||
|
# objects), return the already-loaded object.
|
||||||
|
shared_object_id = config.get(SHARED_OBJECT_KEY)
|
||||||
|
shared_object = _shared_object_loading_scope().get(shared_object_id) # pylint: disable=assignment-from-none
|
||||||
|
if shared_object is not None:
|
||||||
|
return shared_object
|
||||||
|
|
||||||
if hasattr(cls, 'from_config'):
|
if hasattr(cls, 'from_config'):
|
||||||
arg_spec = tf_inspect.getfullargspec(cls.from_config)
|
arg_spec = tf_inspect.getfullargspec(cls.from_config)
|
||||||
custom_objects = custom_objects or {}
|
custom_objects = custom_objects or {}
|
||||||
|
|
||||||
if 'custom_objects' in arg_spec.args:
|
if 'custom_objects' in arg_spec.args:
|
||||||
return cls.from_config(
|
deserialized_obj = cls.from_config(
|
||||||
cls_config,
|
cls_config,
|
||||||
custom_objects=dict(
|
custom_objects=dict(
|
||||||
list(_GLOBAL_CUSTOM_OBJECTS.items()) +
|
list(_GLOBAL_CUSTOM_OBJECTS.items()) +
|
||||||
list(custom_objects.items())))
|
list(custom_objects.items())))
|
||||||
|
else:
|
||||||
with CustomObjectScope(custom_objects):
|
with CustomObjectScope(custom_objects):
|
||||||
return cls.from_config(cls_config)
|
deserialized_obj = cls.from_config(cls_config)
|
||||||
else:
|
else:
|
||||||
# Then `cls` may be a function returning a class.
|
# Then `cls` may be a function returning a class.
|
||||||
# in this case by convention `config` holds
|
# in this case by convention `config` holds
|
||||||
# the kwargs of the function.
|
# the kwargs of the function.
|
||||||
custom_objects = custom_objects or {}
|
custom_objects = custom_objects or {}
|
||||||
with CustomObjectScope(custom_objects):
|
with CustomObjectScope(custom_objects):
|
||||||
return cls(**cls_config)
|
deserialized_obj = cls(**cls_config)
|
||||||
|
|
||||||
|
# Add object to shared objects, in case we find it referenced again.
|
||||||
|
_shared_object_loading_scope().set(shared_object_id, deserialized_obj)
|
||||||
|
|
||||||
|
return deserialized_obj
|
||||||
|
|
||||||
elif isinstance(identifier, six.string_types):
|
elif isinstance(identifier, six.string_types):
|
||||||
object_name = identifier
|
object_name = identifier
|
||||||
if custom_objects and object_name in custom_objects:
|
if custom_objects and object_name in custom_objects:
|
||||||
|
|||||||
@ -23,6 +23,7 @@ from functools import partial
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.python import keras
|
from tensorflow.python import keras
|
||||||
|
from tensorflow.python.keras.utils import generic_utils
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
@ -384,5 +385,63 @@ class SliceArraysTest(test.TestCase):
|
|||||||
[None, None, None])
|
[None, None, None])
|
||||||
|
|
||||||
|
|
||||||
|
# object() alone isn't compatible with WeakKeyDictionary, which we use to
|
||||||
|
# track shared configs.
|
||||||
|
class MaybeSharedObject(object):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class SharedObjectScopeTest(test.TestCase):
|
||||||
|
|
||||||
|
def test_shared_object_saving_scope_single_object_doesnt_export_id(self):
|
||||||
|
with generic_utils.SharedObjectSavingScope() as scope:
|
||||||
|
single_object = MaybeSharedObject()
|
||||||
|
self.assertIsNone(scope.get_config(single_object))
|
||||||
|
single_object_config = scope.create_config({}, single_object)
|
||||||
|
self.assertIsNotNone(single_object_config)
|
||||||
|
self.assertNotIn(generic_utils.SHARED_OBJECT_KEY,
|
||||||
|
single_object_config)
|
||||||
|
|
||||||
|
def test_shared_object_saving_scope_shared_object_exports_id(self):
|
||||||
|
with generic_utils.SharedObjectSavingScope() as scope:
|
||||||
|
shared_object = MaybeSharedObject()
|
||||||
|
self.assertIsNone(scope.get_config(shared_object))
|
||||||
|
scope.create_config({}, shared_object)
|
||||||
|
first_object_config = scope.get_config(shared_object)
|
||||||
|
second_object_config = scope.get_config(shared_object)
|
||||||
|
self.assertIn(generic_utils.SHARED_OBJECT_KEY,
|
||||||
|
first_object_config)
|
||||||
|
self.assertIn(generic_utils.SHARED_OBJECT_KEY,
|
||||||
|
second_object_config)
|
||||||
|
self.assertIs(first_object_config, second_object_config)
|
||||||
|
|
||||||
|
def test_shared_object_loading_scope_noop(self):
|
||||||
|
# Test that, without a context manager scope, adding configs will do
|
||||||
|
# nothing.
|
||||||
|
obj_id = 1
|
||||||
|
obj = MaybeSharedObject()
|
||||||
|
generic_utils._shared_object_loading_scope().set(obj_id, obj)
|
||||||
|
self.assertIsNone(generic_utils._shared_object_loading_scope().get(obj_id))
|
||||||
|
|
||||||
|
def test_shared_object_loading_scope_returns_shared_obj(self):
|
||||||
|
obj_id = 1
|
||||||
|
obj = MaybeSharedObject()
|
||||||
|
with generic_utils.SharedObjectLoadingScope() as scope:
|
||||||
|
scope.set(obj_id, obj)
|
||||||
|
self.assertIs(scope.get(obj_id), obj)
|
||||||
|
|
||||||
|
def test_nested_shared_object_saving_scopes(self):
|
||||||
|
my_obj = MaybeSharedObject()
|
||||||
|
with generic_utils.SharedObjectSavingScope() as scope_1:
|
||||||
|
scope_1.create_config({}, my_obj)
|
||||||
|
with generic_utils.SharedObjectSavingScope() as scope_2:
|
||||||
|
# Nesting saving scopes should return the original scope and should
|
||||||
|
# not clear any objects we're tracking.
|
||||||
|
self.assertIs(scope_1, scope_2)
|
||||||
|
self.assertIsNotNone(scope_2.get_config(my_obj))
|
||||||
|
self.assertIsNotNone(scope_1.get_config(my_obj))
|
||||||
|
self.assertIsNone(generic_utils._shared_object_saving_scope())
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test.main()
|
test.main()
|
||||||
|
|||||||
@ -21,7 +21,6 @@ cuda_py_test(
|
|||||||
python_version = "PY3",
|
python_version = "PY3",
|
||||||
tags = [
|
tags = [
|
||||||
"no_pip",
|
"no_pip",
|
||||||
"no_rocm",
|
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":mnist_testing_utils",
|
":mnist_testing_utils",
|
||||||
|
|||||||
@ -2118,7 +2118,6 @@ class SingleCycleTests(test.TestCase, parameterized.TestCase):
|
|||||||
# allocations at a lower level.
|
# allocations at a lower level.
|
||||||
@test_util.assert_no_new_pyobjects_executing_eagerly
|
@test_util.assert_no_new_pyobjects_executing_eagerly
|
||||||
def test_functions_cleaned(self):
|
def test_functions_cleaned(self):
|
||||||
self.skipTest("TODO(b/175152958): The test is leaking function definitions")
|
|
||||||
if sys.version_info.major < 3:
|
if sys.version_info.major < 3:
|
||||||
self.skipTest("Not working in Python 2")
|
self.skipTest("Not working in Python 2")
|
||||||
root = module.Module()
|
root = module.Module()
|
||||||
|
|||||||
@ -28,6 +28,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||||
#include "tensorflow/c/eager/c_api_internal.h"
|
#include "tensorflow/c/eager/c_api_internal.h"
|
||||||
#include "tensorflow/c/eager/dlpack.h"
|
#include "tensorflow/c/eager/dlpack.h"
|
||||||
|
#include "tensorflow/c/eager/tfe_context_internal.h"
|
||||||
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
||||||
#include "tensorflow/c/tf_status.h"
|
#include "tensorflow/c/tf_status.h"
|
||||||
#include "tensorflow/c/tf_status_helper.h"
|
#include "tensorflow/c/tf_status_helper.h"
|
||||||
@ -670,6 +671,10 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
|
|||||||
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
|
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
|
||||||
return output;
|
return output;
|
||||||
});
|
});
|
||||||
|
m.def("TFE_ContextListFunctionNames", [](py::handle& ctx) {
|
||||||
|
return tensorflow::unwrap(tensorflow::InputTFE_Context(ctx))
|
||||||
|
->ListFunctionNames();
|
||||||
|
});
|
||||||
m.def("TFE_ContextEnableRunMetadata", [](py::handle& ctx) {
|
m.def("TFE_ContextEnableRunMetadata", [](py::handle& ctx) {
|
||||||
TFE_ContextEnableRunMetadata(tensorflow::InputTFE_Context(ctx));
|
TFE_ContextEnableRunMetadata(tensorflow::InputTFE_Context(ctx));
|
||||||
});
|
});
|
||||||
|
|||||||
@ -25,14 +25,18 @@ from __future__ import print_function
|
|||||||
import os
|
import os
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
|
from typing import Any, List, Optional, Text
|
||||||
|
|
||||||
from tensorflow.core.util import event_pb2
|
from tensorflow.core.util import event_pb2
|
||||||
|
from tensorflow.python.client import session as session_lib
|
||||||
from tensorflow.python.framework import meta_graph
|
from tensorflow.python.framework import meta_graph
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
from tensorflow.python.training import basic_session_run_hooks
|
from tensorflow.python.training import basic_session_run_hooks
|
||||||
|
from tensorflow.python.training import monitored_session
|
||||||
|
from tensorflow.python.training import saver as saver_lib
|
||||||
|
from tensorflow.python.training import session_run_hook
|
||||||
from tensorflow.python.training import training_util
|
from tensorflow.python.training import training_util
|
||||||
from tensorflow.python.training.session_run_hook import SessionRunArgs
|
|
||||||
from tensorflow.python.training.summary_io import SummaryWriterCache
|
from tensorflow.python.training.summary_io import SummaryWriterCache
|
||||||
|
|
||||||
|
|
||||||
@ -40,13 +44,14 @@ class AsyncCheckpointSaverHook(basic_session_run_hooks.CheckpointSaverHook):
|
|||||||
"""Saves checkpoints every N steps or seconds."""
|
"""Saves checkpoints every N steps or seconds."""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
checkpoint_dir,
|
checkpoint_dir: Text,
|
||||||
save_secs=None,
|
save_secs: Optional[int] = None,
|
||||||
save_steps=None,
|
save_steps: Optional[int] = None,
|
||||||
saver=None,
|
saver: Optional[saver_lib.Saver] = None,
|
||||||
checkpoint_basename="model.ckpt",
|
checkpoint_basename: Text = "model.ckpt",
|
||||||
scaffold=None,
|
scaffold: Optional[monitored_session.Scaffold] = None,
|
||||||
listeners=None):
|
listeners: Optional[List[
|
||||||
|
basic_session_run_hooks.CheckpointSaverListener]] = None):
|
||||||
"""Initializes a `CheckpointSaverHook`.
|
"""Initializes a `CheckpointSaverHook`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -98,7 +103,7 @@ class AsyncCheckpointSaverHook(basic_session_run_hooks.CheckpointSaverHook):
|
|||||||
for l in self._listeners:
|
for l in self._listeners:
|
||||||
l.begin()
|
l.begin()
|
||||||
|
|
||||||
def after_create_session(self, session, coord):
|
def after_create_session(self, session: session_lib.Session, coord: Any):
|
||||||
global_step = session.run(self._global_step_tensor)
|
global_step = session.run(self._global_step_tensor)
|
||||||
|
|
||||||
# We do write graph and saver_def at the first call of before_run.
|
# We do write graph and saver_def at the first call of before_run.
|
||||||
@ -122,10 +127,11 @@ class AsyncCheckpointSaverHook(basic_session_run_hooks.CheckpointSaverHook):
|
|||||||
self._save(session, global_step)
|
self._save(session, global_step)
|
||||||
self._timer.update_last_triggered_step(global_step)
|
self._timer.update_last_triggered_step(global_step)
|
||||||
|
|
||||||
def before_run(self, run_context): # pylint: disable=unused-argument
|
def before_run(self, run_context: Any): # pylint: disable=unused-argument
|
||||||
return SessionRunArgs(self._global_step_tensor)
|
return session_run_hook.SessionRunArgs(self._global_step_tensor)
|
||||||
|
|
||||||
def after_run(self, run_context, run_values):
|
def after_run(self, run_context: session_run_hook.SessionRunContext,
|
||||||
|
run_values: Any):
|
||||||
global_step = run_context.session.run(self._global_step_tensor)
|
global_step = run_context.session.run(self._global_step_tensor)
|
||||||
if self._timer.should_trigger_for_step(global_step):
|
if self._timer.should_trigger_for_step(global_step):
|
||||||
self._timer.update_last_triggered_step(global_step)
|
self._timer.update_last_triggered_step(global_step)
|
||||||
@ -133,7 +139,7 @@ class AsyncCheckpointSaverHook(basic_session_run_hooks.CheckpointSaverHook):
|
|||||||
if self._save(run_context.session, global_step):
|
if self._save(run_context.session, global_step):
|
||||||
run_context.request_stop()
|
run_context.request_stop()
|
||||||
|
|
||||||
def end(self, session):
|
def end(self, session: session_lib.Session):
|
||||||
if self._save_thread:
|
if self._save_thread:
|
||||||
logging.info("Waiting for any pending checkpoints to finish.")
|
logging.info("Waiting for any pending checkpoints to finish.")
|
||||||
self._save_thread.join()
|
self._save_thread.join()
|
||||||
|
|||||||
@ -19,6 +19,8 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from typing import Generator, Optional, Text
|
||||||
|
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import variable_scope
|
from tensorflow.python.ops import variable_scope
|
||||||
@ -70,10 +72,18 @@ def _get_custom_getter():
|
|||||||
|
|
||||||
@tf_export(v1=['tpu.bfloat16_scope'])
|
@tf_export(v1=['tpu.bfloat16_scope'])
|
||||||
@tf_contextlib.contextmanager
|
@tf_contextlib.contextmanager
|
||||||
def bfloat16_scope(name=None):
|
def bfloat16_scope(
|
||||||
|
name: Optional[Text] = None
|
||||||
|
) -> Generator[variable_scope.variable_scope, None, None]:
|
||||||
"""Scope class for bfloat16 variables so that the model uses custom getter.
|
"""Scope class for bfloat16 variables so that the model uses custom getter.
|
||||||
|
|
||||||
This enables variables to be read as bfloat16 type when using get_variable.
|
This enables variables to be read as bfloat16 type when using get_variable.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
name: Name to use for scope.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
a variable scope.
|
||||||
"""
|
"""
|
||||||
if name is None:
|
if name is None:
|
||||||
name = ''
|
name = ''
|
||||||
|
|||||||
@ -18,6 +18,8 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from typing import Callable, Optional, Text, Union
|
||||||
|
|
||||||
from tensorflow.python.data.experimental.ops import interleave_ops
|
from tensorflow.python.data.experimental.ops import interleave_ops
|
||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
from tensorflow.python.data.ops import iterator_ops
|
from tensorflow.python.data.ops import iterator_ops
|
||||||
@ -28,13 +30,13 @@ from tensorflow.python.framework import ops
|
|||||||
from tensorflow.python.ops import functional_ops
|
from tensorflow.python.ops import functional_ops
|
||||||
|
|
||||||
|
|
||||||
def _TextLineDataset(filename):
|
def _TextLineDataset(filename: Text) -> dataset_ops.Dataset:
|
||||||
buffer_size = 8 * 1024 * 1024 # 8 MiB per file
|
buffer_size = 8 * 1024 * 1024 # 8 MiB per file
|
||||||
dataset = readers.TextLineDataset(filename, buffer_size=buffer_size)
|
dataset = readers.TextLineDataset(filename, buffer_size=buffer_size)
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
def _TFRecordDataset(filename):
|
def _TFRecordDataset(filename: Text) -> dataset_ops.Dataset:
|
||||||
buffer_size = 8 * 1024 * 1024 # 8 MiB per file
|
buffer_size = 8 * 1024 * 1024 # 8 MiB per file
|
||||||
dataset = readers.TFRecordDataset(filename, buffer_size=buffer_size)
|
dataset = readers.TFRecordDataset(filename, buffer_size=buffer_size)
|
||||||
return dataset
|
return dataset
|
||||||
@ -47,15 +49,17 @@ _FILETYPE_MAP = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def StreamingFilesDataset(files,
|
def StreamingFilesDataset(
|
||||||
filetype=None,
|
files: Union[Text, dataset_ops.Dataset],
|
||||||
file_reader_job=None,
|
filetype: Optional[Union[Text, Callable[[Text],
|
||||||
worker_job=None,
|
dataset_ops.Dataset]]] = None,
|
||||||
num_epochs=None,
|
file_reader_job: Optional[Text] = None,
|
||||||
filename_shuffle_buffer_size=None,
|
worker_job: Optional[Text] = None,
|
||||||
num_parallel_reads=None,
|
num_epochs: Optional[int] = None,
|
||||||
batch_transfer_size=None,
|
filename_shuffle_buffer_size: Optional[Union[int, bool]] = None,
|
||||||
sloppy=None):
|
num_parallel_reads: Optional[int] = None,
|
||||||
|
batch_transfer_size: Optional[Union[int, bool]] = None,
|
||||||
|
sloppy: bool = True) -> dataset_ops.Dataset:
|
||||||
"""StreamingFilesDataset constructs a dataset to stream from workers (GCE VM).
|
"""StreamingFilesDataset constructs a dataset to stream from workers (GCE VM).
|
||||||
|
|
||||||
Because Cloud TPUs are allocated over the network, a Cloud TPU cannot read
|
Because Cloud TPUs are allocated over the network, a Cloud TPU cannot read
|
||||||
@ -126,9 +130,6 @@ def StreamingFilesDataset(files,
|
|||||||
if batch_transfer_size is None:
|
if batch_transfer_size is None:
|
||||||
batch_transfer_size = 256
|
batch_transfer_size = 256
|
||||||
|
|
||||||
if sloppy is None:
|
|
||||||
sloppy = True
|
|
||||||
|
|
||||||
if file_reader_job == 'coordinator':
|
if file_reader_job == 'coordinator':
|
||||||
file_reader_device = '/job:coordinator/task:0'
|
file_reader_device = '/job:coordinator/task:0'
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -20,6 +20,7 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import enum
|
import enum
|
||||||
import math
|
import math
|
||||||
|
from typing import List, Optional, Text, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||||
@ -66,7 +67,7 @@ class DeviceAssignment(object):
|
|||||||
`DeviceAssignment` directly.
|
`DeviceAssignment` directly.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, topology, core_assignment):
|
def __init__(self, topology: Topology, core_assignment: np.ndarray):
|
||||||
"""Constructs a `DeviceAssignment` object.
|
"""Constructs a `DeviceAssignment` object.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -104,22 +105,22 @@ class DeviceAssignment(object):
|
|||||||
self._core_assignment, topology)
|
self._core_assignment, topology)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def topology(self):
|
def topology(self) -> Topology:
|
||||||
"""A `Topology` that describes the TPU topology."""
|
"""A `Topology` that describes the TPU topology."""
|
||||||
return self._topology
|
return self._topology
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_cores_per_replica(self):
|
def num_cores_per_replica(self) -> int:
|
||||||
"""The number of cores per replica."""
|
"""The number of cores per replica."""
|
||||||
return self._num_cores_per_replica
|
return self._num_cores_per_replica
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_replicas(self):
|
def num_replicas(self) -> int:
|
||||||
"""The number of replicas of the computation."""
|
"""The number of replicas of the computation."""
|
||||||
return self._num_replicas
|
return self._num_replicas
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def core_assignment(self):
|
def core_assignment(self) -> np.ndarray:
|
||||||
"""The logical to physical core mapping.
|
"""The logical to physical core mapping.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -129,11 +130,11 @@ class DeviceAssignment(object):
|
|||||||
"""
|
"""
|
||||||
return self._core_assignment
|
return self._core_assignment
|
||||||
|
|
||||||
def coordinates(self, replica, logical_core):
|
def coordinates(self, replica: int, logical_core: int) -> Tuple: # pylint:disable=g-bare-generic
|
||||||
"""Returns the physical topology coordinates of a logical core."""
|
"""Returns the physical topology coordinates of a logical core."""
|
||||||
return tuple(self.core_assignment[replica, logical_core, :])
|
return tuple(self.core_assignment[replica, logical_core, :])
|
||||||
|
|
||||||
def lookup_replicas(self, task_id, logical_core):
|
def lookup_replicas(self, task_id: int, logical_core: int) -> List[int]:
|
||||||
"""Lookup replica ids by task number and logical core.
|
"""Lookup replica ids by task number and logical core.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -153,31 +154,38 @@ class DeviceAssignment(object):
|
|||||||
"Can not find any replica in task: {} contains logical_core: {} ".
|
"Can not find any replica in task: {} contains logical_core: {} ".
|
||||||
format(task_id, logical_core))
|
format(task_id, logical_core))
|
||||||
|
|
||||||
def tpu_ordinal(self, replica=0, logical_core=0):
|
def tpu_ordinal(self, replica: int = 0, logical_core: int = 0) -> int:
|
||||||
"""Returns the ordinal of the TPU device assigned to a logical core."""
|
"""Returns the ordinal of the TPU device assigned to a logical core."""
|
||||||
coordinates = self.coordinates(replica, logical_core)
|
coordinates = self.coordinates(replica, logical_core)
|
||||||
return self._topology.tpu_device_ordinal_at_coordinates(coordinates)
|
return self._topology.tpu_device_ordinal_at_coordinates(coordinates)
|
||||||
|
|
||||||
def host_device(self, replica=0, logical_core=0, job=None):
|
def host_device(self,
|
||||||
|
replica: int = 0,
|
||||||
|
logical_core: int = 0,
|
||||||
|
job: Optional[Text] = None) -> Text:
|
||||||
"""Returns the CPU device attached to a logical core."""
|
"""Returns the CPU device attached to a logical core."""
|
||||||
coordinates = self.coordinates(replica, logical_core)
|
coordinates = self.coordinates(replica, logical_core)
|
||||||
return self._topology.cpu_device_name_at_coordinates(coordinates, job=job)
|
return self._topology.cpu_device_name_at_coordinates(coordinates, job=job)
|
||||||
|
|
||||||
def tpu_device(self, replica=0, logical_core=0, job=None):
|
def tpu_device(self,
|
||||||
|
replica: int = 0,
|
||||||
|
logical_core: int = 0,
|
||||||
|
job: Optional[Text] = None) -> Text:
|
||||||
"""Returns the name of the TPU device assigned to a logical core."""
|
"""Returns the name of the TPU device assigned to a logical core."""
|
||||||
coordinates = self.coordinates(replica, logical_core)
|
coordinates = self.coordinates(replica, logical_core)
|
||||||
return self._topology.tpu_device_name_at_coordinates(coordinates, job=job)
|
return self._topology.tpu_device_name_at_coordinates(coordinates, job=job)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def build(topology,
|
def build(topology: Topology,
|
||||||
computation_shape=None,
|
computation_shape: Optional[np.ndarray] = None,
|
||||||
computation_stride=None,
|
computation_stride: Optional[np.ndarray] = None,
|
||||||
num_replicas=1):
|
num_replicas: int = 1) -> "DeviceAssignment":
|
||||||
return device_assignment(topology, computation_shape, computation_stride,
|
return device_assignment(topology, computation_shape, computation_stride,
|
||||||
num_replicas)
|
num_replicas)
|
||||||
|
|
||||||
|
|
||||||
def _open_ring_2d(x_size, y_size, z_coord):
|
def _open_ring_2d(x_size: int, y_size: int,
|
||||||
|
z_coord: int) -> List[Tuple[int, int, int]]:
|
||||||
"""Ring-order of a X by Y mesh, with a fixed Z coordinate.
|
"""Ring-order of a X by Y mesh, with a fixed Z coordinate.
|
||||||
|
|
||||||
For example, in a 4x4 mesh, this returns the following order.
|
For example, in a 4x4 mesh, this returns the following order.
|
||||||
@ -213,7 +221,8 @@ def _open_ring_2d(x_size, y_size, z_coord):
|
|||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
def _ring_3d(x_size, y_size, z_size):
|
def _ring_3d(x_size: int, y_size: int,
|
||||||
|
z_size: int) -> List[Tuple[int, int, int]]:
|
||||||
"""Ring-order of a X by Y by Z mesh.
|
"""Ring-order of a X by Y by Z mesh.
|
||||||
|
|
||||||
Constructs the 3d ring from 2d rings that are stacked in the Z dimension and
|
Constructs the 3d ring from 2d rings that are stacked in the Z dimension and
|
||||||
@ -325,11 +334,13 @@ class DeviceOrderMode(enum.IntEnum):
|
|||||||
MESH = 2
|
MESH = 2
|
||||||
|
|
||||||
|
|
||||||
def device_assignment(topology,
|
def device_assignment(
|
||||||
computation_shape=None,
|
topology: Topology,
|
||||||
computation_stride=None,
|
computation_shape: Optional[np.ndarray] = None,
|
||||||
num_replicas=1,
|
computation_stride: Optional[np.ndarray] = None,
|
||||||
device_order_mode=DeviceOrderMode.AUTO):
|
num_replicas: int = 1,
|
||||||
|
device_order_mode: DeviceOrderMode = DeviceOrderMode.AUTO
|
||||||
|
) -> DeviceAssignment:
|
||||||
"""Computes a device_assignment of a computation across a TPU topology.
|
"""Computes a device_assignment of a computation across a TPU topology.
|
||||||
|
|
||||||
Attempts to choose a compact grid of cores for locality.
|
Attempts to choose a compact grid of cores for locality.
|
||||||
@ -341,11 +352,12 @@ def device_assignment(topology,
|
|||||||
optimal packing.
|
optimal packing.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
topology: A `Topology` object that describes the TPU cluster topology.
|
topology: A `Topology` object that describes the TPU cluster topology. To
|
||||||
To obtain a TPU topology, evaluate the `Tensor` returned by
|
obtain a TPU topology, evaluate the `Tensor` returned by
|
||||||
`initialize_system` using `Session.run`. Either a serialized
|
`initialize_system` using `Session.run`. Either a serialized
|
||||||
`TopologyProto` or a `Topology` object may be passed. Note: you must
|
`TopologyProto` or a `Topology` object may be passed. Note: you must
|
||||||
evaluate the `Tensor` first; you cannot pass an unevaluated `Tensor` here.
|
evaluate the `Tensor` first; you cannot pass an unevaluated `Tensor`
|
||||||
|
here.
|
||||||
computation_shape: A rank 1 int32 numpy array with size equal to the
|
computation_shape: A rank 1 int32 numpy array with size equal to the
|
||||||
topology rank, describing the shape of the computation's block of cores.
|
topology rank, describing the shape of the computation's block of cores.
|
||||||
If None, the `computation_shape` is `[1] * topology_rank`.
|
If None, the `computation_shape` is `[1] * topology_rank`.
|
||||||
|
|||||||
@ -20,7 +20,7 @@ from __future__ import print_function
|
|||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
from typing import Any, Dict, Callable, List, Optional, Text, Tuple
|
from typing import Any, Dict, Callable, Iterable, List, Optional, Text, Tuple, Union
|
||||||
|
|
||||||
from absl import logging
|
from absl import logging
|
||||||
|
|
||||||
@ -229,7 +229,6 @@ class TPUEmbedding(tracking.AutoTrackable):
|
|||||||
model = model_fn(...)
|
model = model_fn(...)
|
||||||
embedding = tf.tpu.experimental.embedding.TPUEmbedding(
|
embedding = tf.tpu.experimental.embedding.TPUEmbedding(
|
||||||
feature_config=feature_config,
|
feature_config=feature_config,
|
||||||
batch_size=1024,
|
|
||||||
optimizer=tf.tpu.experimental.embedding.SGD(0.1))
|
optimizer=tf.tpu.experimental.embedding.SGD(0.1))
|
||||||
checkpoint = tf.train.Checkpoint(model=model, embedding=embedding)
|
checkpoint = tf.train.Checkpoint(model=model, embedding=embedding)
|
||||||
checkpoint.restore(...)
|
checkpoint.restore(...)
|
||||||
@ -244,7 +243,7 @@ class TPUEmbedding(tracking.AutoTrackable):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
feature_config: Any,
|
feature_config: Union[tpu_embedding_v2_utils.FeatureConfig, Iterable], # pylint:disable=g-bare-generic
|
||||||
optimizer: Optional[tpu_embedding_v2_utils._Optimizer], # pylint:disable=protected-access
|
optimizer: Optional[tpu_embedding_v2_utils._Optimizer], # pylint:disable=protected-access
|
||||||
pipeline_execution_with_tensor_core: bool = False):
|
pipeline_execution_with_tensor_core: bool = False):
|
||||||
"""Creates the TPUEmbedding mid level API object.
|
"""Creates the TPUEmbedding mid level API object.
|
||||||
|
|||||||
@ -19,15 +19,23 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from typing import Any, Callable, Iterable, List, Optional, Union
|
||||||
|
|
||||||
from tensorflow.python.compiler.xla import xla
|
from tensorflow.python.compiler.xla import xla
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
from tensorflow.python.tpu import tensor_tracer
|
from tensorflow.python.tpu import tensor_tracer
|
||||||
|
from tensorflow.python.tpu import tpu_feed
|
||||||
from tensorflow.python.tpu import tpu_function
|
from tensorflow.python.tpu import tpu_function
|
||||||
|
from tensorflow.python.types import core as core_types
|
||||||
|
|
||||||
|
|
||||||
def while_loop(condition, body, inputs=None, infeed_queue=None, name=None):
|
def while_loop(condition: Callable[..., Any],
|
||||||
|
body: Callable[..., Any],
|
||||||
|
inputs: Optional[List[Any]] = None,
|
||||||
|
infeed_queue: Optional[tpu_feed.InfeedQueue] = None,
|
||||||
|
name: Any = None) -> Any:
|
||||||
"""Builds a training loop for TPUs.
|
"""Builds a training loop for TPUs.
|
||||||
|
|
||||||
The set of loop-carried tensors corresponds to `inputs`. Both
|
The set of loop-carried tensors corresponds to `inputs`. Both
|
||||||
@ -41,10 +49,10 @@ def while_loop(condition, body, inputs=None, infeed_queue=None, name=None):
|
|||||||
Args:
|
Args:
|
||||||
condition: a Python function that builds the loop condition.
|
condition: a Python function that builds the loop condition.
|
||||||
body: a Python function that builds the loop body.
|
body: a Python function that builds the loop body.
|
||||||
inputs: a list of initial values passed into the training loop, or
|
inputs: a list of initial values passed into the training loop, or None
|
||||||
None (equivalent to an empty list).
|
(equivalent to an empty list).
|
||||||
infeed_queue: if not None, the infeed queue from which to append a tuple
|
infeed_queue: if not None, the infeed queue from which to append a tuple of
|
||||||
of arguments as inputs to condition.
|
arguments as inputs to condition.
|
||||||
name: (Deprecated) Does nothing.
|
name: (Deprecated) Does nothing.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -178,7 +186,12 @@ def while_loop(condition, body, inputs=None, infeed_queue=None, name=None):
|
|||||||
condition_wrapper, body_wrapper, inputs, name="", parallel_iterations=1)
|
condition_wrapper, body_wrapper, inputs, name="", parallel_iterations=1)
|
||||||
|
|
||||||
|
|
||||||
def repeat(n, body, inputs=None, infeed_queue=None, name=None):
|
def repeat(
|
||||||
|
n: int,
|
||||||
|
body: Callable[..., Union[core_types.TensorLike, Iterable]], # pylint:disable=g-bare-generic
|
||||||
|
inputs: Optional[List[core_types.TensorLike]] = None,
|
||||||
|
infeed_queue: Optional[tpu_feed.InfeedQueue] = None,
|
||||||
|
name: Any = None) -> List[core_types.TensorLike]:
|
||||||
"""Builds a training loop that executes a fixed number of iterations.
|
"""Builds a training loop that executes a fixed number of iterations.
|
||||||
|
|
||||||
The set of loop-carried tensors correspond to `inputs`.
|
The set of loop-carried tensors correspond to `inputs`.
|
||||||
@ -188,11 +201,12 @@ def repeat(n, body, inputs=None, infeed_queue=None, name=None):
|
|||||||
Args:
|
Args:
|
||||||
n: the number of loop iterations
|
n: the number of loop iterations
|
||||||
body: a Python function that builds the loop body.
|
body: a Python function that builds the loop body.
|
||||||
inputs: a list of initial values passed into the training loop or
|
inputs: a list of initial values passed into the training loop or None
|
||||||
None (equivalent to an empty list).
|
(equivalent to an empty list).
|
||||||
infeed_queue: if not None, the infeed queue from which to append a tuple
|
infeed_queue: if not None, the infeed queue from which to append a tuple of
|
||||||
of arguments as inputs to condition.
|
arguments as inputs to condition.
|
||||||
name: (Deprecated) Does nothing.
|
name: (Deprecated) Does nothing.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The final values of the loop-carried tensors.
|
The final values of the loop-carried tensors.
|
||||||
Raises:
|
Raises:
|
||||||
|
|||||||
@ -138,7 +138,7 @@ tf_class {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "__init__"
|
name: "__init__"
|
||||||
argspec: "args=[\'self\', \'depth\', \'name\', \'separator\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\'], "
|
argspec: "args=[\'self\', \'depth\', \'name\', \'separator\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'_X_\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "adapt"
|
name: "adapt"
|
||||||
|
|||||||
@ -138,7 +138,7 @@ tf_class {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "__init__"
|
name: "__init__"
|
||||||
argspec: "args=[\'self\', \'depth\', \'name\', \'separator\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\'], "
|
argspec: "args=[\'self\', \'depth\', \'name\', \'separator\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'_X_\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "adapt"
|
name: "adapt"
|
||||||
|
|||||||
@ -261,6 +261,7 @@ function install_macos_pip_deps {
|
|||||||
${PIP_CMD} install $USER_FLAG 'grpcio ~= 1.34.0'
|
${PIP_CMD} install $USER_FLAG 'grpcio ~= 1.34.0'
|
||||||
${PIP_CMD} install $USER_FLAG 'portpicker ~= 1.3.1'
|
${PIP_CMD} install $USER_FLAG 'portpicker ~= 1.3.1'
|
||||||
${PIP_CMD} install $USER_FLAG 'scipy ~= 1.5.2'
|
${PIP_CMD} install $USER_FLAG 'scipy ~= 1.5.2'
|
||||||
|
${PIP_CMD} install $USER_FLAG --upgrade certifi
|
||||||
|
|
||||||
# LINT.ThenChange(:linux_pip_installations_orig)
|
# LINT.ThenChange(:linux_pip_installations_orig)
|
||||||
# LINT.ThenChange(:linux_pip_installations)
|
# LINT.ThenChange(:linux_pip_installations)
|
||||||
|
|||||||
@ -46,7 +46,7 @@ py_test(
|
|||||||
tags = [
|
tags = [
|
||||||
"no_oss_py2",
|
"no_oss_py2",
|
||||||
"no_pip",
|
"no_pip",
|
||||||
"no_rocm",
|
"no_rocm", # No need to rerun this test for ROCm config.
|
||||||
"no_windows", # numpy prints differently on windows.
|
"no_windows", # numpy prints differently on windows.
|
||||||
"noasan",
|
"noasan",
|
||||||
"nomsan",
|
"nomsan",
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user