sync to master
This commit is contained in:
parent
ed158956df
commit
81adaff8a6
@ -137,6 +137,10 @@ This release contains contributions from many people at Google, as well as:
|
||||
|
||||
<INSERT>, <NAME>, <HERE>, <USING>, <GITHUB>, <HANDLE>
|
||||
|
||||
# Release 2.4.1
|
||||
|
||||
* This release removes the AVX2 requirement from TF 2.4.0.
|
||||
|
||||
# Release 2.3.2
|
||||
|
||||
## Bug Fixes and Other Changes
|
||||
|
@ -185,6 +185,9 @@ class ImmediateExecutionContext : public AbstractContext {
|
||||
|
||||
virtual std::vector<std::string> GetLoggedOpsTestonly() { return {}; }
|
||||
|
||||
// Get a list of the names of functions that have been registered.
|
||||
virtual std::vector<string> ListFunctionNames() = 0;
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Distributed runtime related functions.
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
@ -127,6 +127,7 @@ add_mlir_library(MhloLhloToLinalg
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MhloDialect
|
||||
MLIRComplex
|
||||
MLIRIR
|
||||
MLIRPass
|
||||
)
|
||||
|
@ -372,3 +372,112 @@ func @testNoDilatedConvWhenGivenInputIsNonFloatType(%arg0: tensor<1x128x128x3xi3
|
||||
// CHECK-NEXT: [[RESULT:%.*]] = "tf.BatchToSpaceND"
|
||||
// CHECK-NEXT: return [[RESULT]]
|
||||
}
|
||||
|
||||
func @testDilatedConv1DExpandH(%arg0: tensor<1x128x3xf32>, %arg1: tensor<1x5x3x8xf32>) -> tensor<1x128x8xf32> {
|
||||
%cst = "tf.Const"() {value = dense<0> : tensor<1x2xi32>} : () -> tensor<1x2xi32>
|
||||
%cst_0 = "tf.Const"() {value = dense<-3> : tensor<i32>} : () -> tensor<i32>
|
||||
%cst_1 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
%cst_2 = "tf.Const"() {value = dense<4> : tensor<1x2xi32>} : () -> tensor<1x2xi32>
|
||||
%0 = "tf.SpaceToBatchND"(%arg0, %cst_1, %cst_2) : (tensor<1x128x3xf32>, tensor<1xi32>, tensor<1x2xi32>) -> tensor<2x68x3xf32>
|
||||
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<2x68x3xf32>, tensor<i32>) -> tensor<2x1x68x3xf32>
|
||||
%2 = "tf.Conv2D"(%1, %arg1) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<2x1x68x3xf32>, tensor<1x5x3x8xf32>) -> tensor<2x1x64x8xf32>
|
||||
%3 = "tf.Squeeze"(%2) {squeeze_dims = [-3]} : (tensor<2x1x64x8xf32>) -> tensor<2x64x8xf32>
|
||||
%4 = "tf.BatchToSpaceND"(%3, %cst_1, %cst) : (tensor<2x64x8xf32>, tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x128x8xf32>
|
||||
return %4 : tensor<1x128x8xf32>
|
||||
|
||||
// CHECK-LABEL: testDilatedConv1DExpandH
|
||||
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x3xf32>, [[FILTER:%.*]]: tensor<1x5x3x8xf32>)
|
||||
// CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<-3> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x3xf32>, tensor<i32>) -> tensor<1x1x128x3xf32>
|
||||
// CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) {dilations = [1, 1, 2, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x1x128x3xf32>, tensor<1x5x3x8xf32>) -> tensor<1x1x128x8xf32>
|
||||
// CHECK-NEXT: [[RESULT:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [-3]} : (tensor<1x1x128x8xf32>) -> tensor<1x128x8xf32>
|
||||
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x8xf32>
|
||||
}
|
||||
|
||||
func @testDilatedConv1DExpandHWithBiasAdd(%arg0: tensor<1x128x3xf32>, %arg1: tensor<1x5x3x8xf32>, %arg2: tensor<8xf32>) -> tensor<1x128x8xf32> {
|
||||
%cst = "tf.Const"() {value = dense<0> : tensor<1x2xi32>} : () -> tensor<1x2xi32>
|
||||
%cst_0 = "tf.Const"() {value = dense<-3> : tensor<i32>} : () -> tensor<i32>
|
||||
%cst_1 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
%cst_2 = "tf.Const"() {value = dense<4> : tensor<1x2xi32>} : () -> tensor<1x2xi32>
|
||||
%0 = "tf.SpaceToBatchND"(%arg0, %cst_1, %cst_2) : (tensor<1x128x3xf32>, tensor<1xi32>, tensor<1x2xi32>) -> tensor<2x68x3xf32>
|
||||
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<2x68x3xf32>, tensor<i32>) -> tensor<2x1x68x3xf32>
|
||||
%2 = "tf.Conv2D"(%1, %arg1) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<2x1x68x3xf32>, tensor<1x5x3x8xf32>) -> tensor<2x1x64x8xf32>
|
||||
%3 = "tf.Squeeze"(%2) {squeeze_dims = [-3]} : (tensor<2x1x64x8xf32>) -> tensor<2x64x8xf32>
|
||||
%4 = "tf.BatchToSpaceND"(%3, %cst_1, %cst) : (tensor<2x64x8xf32>, tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x128x8xf32>
|
||||
%5 = "tf.BiasAdd"(%4, %arg2) : (tensor<1x128x8xf32>, tensor<8xf32>) -> tensor<1x128x8xf32>
|
||||
return %5 : tensor<1x128x8xf32>
|
||||
|
||||
// CHECK-LABEL: testDilatedConv1DExpandHWithBiasAdd
|
||||
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x3xf32>, [[FILTER:%.*]]: tensor<1x5x3x8xf32>, [[BIAS:%.*]]: tensor<8xf32>)
|
||||
// CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<-3> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x3xf32>, tensor<i32>) -> tensor<1x1x128x3xf32>
|
||||
// CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) {dilations = [1, 1, 2, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x1x128x3xf32>, tensor<1x5x3x8xf32>) -> tensor<1x1x128x8xf32>
|
||||
// CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [-3]} : (tensor<1x1x128x8xf32>) -> tensor<1x128x8xf32>
|
||||
// CHECK-NEXT: [[RESULT:%.*]] = "tf.BiasAdd"([[SQUEEZE]], [[BIAS]]) : (tensor<1x128x8xf32>, tensor<8xf32>) -> tensor<1x128x8xf32>
|
||||
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x8xf32>
|
||||
}
|
||||
|
||||
func @testDilatedConv1DExpandW(%arg0: tensor<1x128x3xf32>, %arg1: tensor<5x1x3x8xf32>) -> tensor<1x128x8xf32> {
|
||||
%cst = "tf.Const"() {value = dense<0> : tensor<1x2xi32>} : () -> tensor<1x2xi32>
|
||||
%cst_0 = "tf.Const"() {value = dense<-2> : tensor<i32>} : () -> tensor<i32>
|
||||
%cst_1 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
%cst_2 = "tf.Const"() {value = dense<4> : tensor<1x2xi32>} : () -> tensor<1x2xi32>
|
||||
%0 = "tf.SpaceToBatchND"(%arg0, %cst_1, %cst_2) : (tensor<1x128x3xf32>, tensor<1xi32>, tensor<1x2xi32>) -> tensor<2x68x3xf32>
|
||||
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<2x68x3xf32>, tensor<i32>) -> tensor<2x68x1x3xf32>
|
||||
%2 = "tf.Conv2D"(%1, %arg1) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<2x68x1x3xf32>, tensor<5x1x3x8xf32>) -> tensor<2x64x1x8xf32>
|
||||
%3 = "tf.Squeeze"(%2) {squeeze_dims = [-2]} : (tensor<2x64x1x8xf32>) -> tensor<2x64x8xf32>
|
||||
%4 = "tf.BatchToSpaceND"(%3, %cst_1, %cst) : (tensor<2x64x8xf32>, tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x128x8xf32>
|
||||
return %4 : tensor<1x128x8xf32>
|
||||
|
||||
// CHECK-LABEL: testDilatedConv1DExpandW
|
||||
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x3xf32>, [[FILTER:%.*]]: tensor<5x1x3x8xf32>)
|
||||
// CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<-2> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x3xf32>, tensor<i32>) -> tensor<1x128x1x3xf32>
|
||||
// CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 1, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x128x1x3xf32>, tensor<5x1x3x8xf32>) -> tensor<1x128x1x8xf32>
|
||||
// CHECK-NEXT: [[RESULT:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [-2]} : (tensor<1x128x1x8xf32>) -> tensor<1x128x8xf32>
|
||||
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x8xf32>
|
||||
}
|
||||
|
||||
func @testDilatedConv1DExpandWWithBiasAdd(%arg0: tensor<1x128x3xf32>, %arg1: tensor<5x1x3x8xf32>, %arg2: tensor<8xf32>) -> tensor<1x128x8xf32> {
|
||||
%cst = "tf.Const"() {value = dense<0> : tensor<1x2xi32>} : () -> tensor<1x2xi32>
|
||||
%cst_0 = "tf.Const"() {value = dense<-2> : tensor<i32>} : () -> tensor<i32>
|
||||
%cst_1 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
%cst_2 = "tf.Const"() {value = dense<4> : tensor<1x2xi32>} : () -> tensor<1x2xi32>
|
||||
%0 = "tf.SpaceToBatchND"(%arg0, %cst_1, %cst_2) : (tensor<1x128x3xf32>, tensor<1xi32>, tensor<1x2xi32>) -> tensor<2x68x3xf32>
|
||||
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<2x68x3xf32>, tensor<i32>) -> tensor<2x68x1x3xf32>
|
||||
%2 = "tf.Conv2D"(%1, %arg1) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<2x68x1x3xf32>, tensor<5x1x3x8xf32>) -> tensor<2x64x1x8xf32>
|
||||
%3 = "tf.Squeeze"(%2) {squeeze_dims = [-2]} : (tensor<2x64x1x8xf32>) -> tensor<2x64x8xf32>
|
||||
%4 = "tf.BatchToSpaceND"(%3, %cst_1, %cst) : (tensor<2x64x8xf32>, tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x128x8xf32>
|
||||
%5 = "tf.BiasAdd"(%4, %arg2) : (tensor<1x128x8xf32>, tensor<8xf32>) -> tensor<1x128x8xf32>
|
||||
return %5 : tensor<1x128x8xf32>
|
||||
|
||||
// CHECK-LABEL: testDilatedConv1DExpandWWithBiasAdd
|
||||
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x3xf32>, [[FILTER:%.*]]: tensor<5x1x3x8xf32>, [[BIAS:%.*]]: tensor<8xf32>)
|
||||
// CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<-2> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x3xf32>, tensor<i32>) -> tensor<1x128x1x3xf32>
|
||||
// CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 1, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x128x1x3xf32>, tensor<5x1x3x8xf32>) -> tensor<1x128x1x8xf32>
|
||||
// CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [-2]} : (tensor<1x128x1x8xf32>) -> tensor<1x128x8xf32>
|
||||
// CHECK-NEXT: [[RESULT:%.*]] = "tf.BiasAdd"([[SQUEEZE]], [[BIAS]]) : (tensor<1x128x8xf32>, tensor<8xf32>) -> tensor<1x128x8xf32>
|
||||
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x8xf32>
|
||||
}
|
||||
|
||||
func @testDilatedConv1DWithMixedPostiveAndNegativeAxis(%arg0: tensor<1x128x3xf32>, %arg1: tensor<1x5x3x8xf32>) -> tensor<1x128x8xf32> {
|
||||
%cst = "tf.Const"() {value = dense<0> : tensor<1x2xi32>} : () -> tensor<1x2xi32>
|
||||
%cst_0 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||
%cst_1 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
%cst_2 = "tf.Const"() {value = dense<4> : tensor<1x2xi32>} : () -> tensor<1x2xi32>
|
||||
%0 = "tf.SpaceToBatchND"(%arg0, %cst_1, %cst_2) : (tensor<1x128x3xf32>, tensor<1xi32>, tensor<1x2xi32>) -> tensor<2x68x3xf32>
|
||||
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<2x68x3xf32>, tensor<i32>) -> tensor<2x1x68x3xf32>
|
||||
%2 = "tf.Conv2D"(%1, %arg1) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<2x1x68x3xf32>, tensor<1x5x3x8xf32>) -> tensor<2x1x64x8xf32>
|
||||
%3 = "tf.Squeeze"(%2) {squeeze_dims = [-3]} : (tensor<2x1x64x8xf32>) -> tensor<2x64x8xf32>
|
||||
%4 = "tf.BatchToSpaceND"(%3, %cst_1, %cst) : (tensor<2x64x8xf32>, tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x128x8xf32>
|
||||
return %4 : tensor<1x128x8xf32>
|
||||
|
||||
// CHECK-LABEL: testDilatedConv1DWithMixedPostiveAndNegativeAxis
|
||||
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x3xf32>, [[FILTER:%.*]]: tensor<1x5x3x8xf32>)
|
||||
// CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x3xf32>, tensor<i32>) -> tensor<1x1x128x3xf32>
|
||||
// CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) {dilations = [1, 1, 2, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x1x128x3xf32>, tensor<1x5x3x8xf32>) -> tensor<1x1x128x8xf32>
|
||||
// CHECK-NEXT: [[RESULT:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [-3]} : (tensor<1x1x128x8xf32>) -> tensor<1x128x8xf32>
|
||||
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x8xf32>
|
||||
}
|
||||
|
@ -70,7 +70,7 @@ class ConvertTFDilatedConvOp : public OpRewritePattern<Conv2dOpTy> {
|
||||
|
||||
// Extract the dilation factor from `block_shape` and pack it in an ArrayAttr.
|
||||
llvm::Optional<ArrayAttr> ExtractDilationsAttrFromBlockShape(
|
||||
Value stb_block_shape, Value bts_block_shape,
|
||||
Value stb_block_shape, Value bts_block_shape, int64_t expand_axis,
|
||||
PatternRewriter& rewriter) const;
|
||||
|
||||
public:
|
||||
@ -111,7 +111,7 @@ LogicalResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
|
||||
|
||||
TF::ExpandDimsOp expand_op;
|
||||
TF::SqueezeOp squeeze_op;
|
||||
int64_t expand_axis;
|
||||
int64_t expand_axis = -1;
|
||||
// Expand + Squeeze op.
|
||||
if (llvm::isa<TF::ExpandDimsOp>(prev_op)) {
|
||||
if (!llvm::isa<TF::SqueezeOp>(next_op)) {
|
||||
@ -127,13 +127,26 @@ LogicalResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
|
||||
expand_axis =
|
||||
(*const_op.value().cast<DenseElementsAttr>().getIntValues().begin())
|
||||
.getSExtValue();
|
||||
// Canonicalize axis. Some TF python functions, such as
|
||||
// `tf.nn.convolution`, use negative axis.
|
||||
if (expand_axis < 0) {
|
||||
// Always expand 3D input to 4D input.
|
||||
expand_axis += 4;
|
||||
}
|
||||
} else {
|
||||
return failure();
|
||||
}
|
||||
// Make sure that the `squeeze_dims` is equal to `expand_axis`.
|
||||
auto squeeze_dims = squeeze_op.squeeze_dims();
|
||||
if (squeeze_dims.size() != 1 ||
|
||||
squeeze_dims[0].cast<IntegerAttr>().getInt() != expand_axis) {
|
||||
if (squeeze_dims.size() != 1) {
|
||||
return failure();
|
||||
}
|
||||
int64_t squeeze_axis = squeeze_dims[0].cast<IntegerAttr>().getInt();
|
||||
if (squeeze_axis < 0) {
|
||||
// Always squeeze 4D input to 3D input.
|
||||
squeeze_axis += 4;
|
||||
}
|
||||
if (squeeze_axis != expand_axis) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
@ -183,7 +196,7 @@ LogicalResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
|
||||
}
|
||||
|
||||
llvm::Optional<ArrayAttr> dilations_attr = ExtractDilationsAttrFromBlockShape(
|
||||
stb_op.block_shape(), bts_op.block_shape(), rewriter);
|
||||
stb_op.block_shape(), bts_op.block_shape(), expand_axis, rewriter);
|
||||
if (!dilations_attr.hasValue()) return failure();
|
||||
|
||||
if (expand_op) {
|
||||
@ -259,13 +272,24 @@ LogicalResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
|
||||
auto expand_result_type = RankedTensorType::get(
|
||||
expand_shape, getElementTypeOrSelf(stb_op.input()));
|
||||
expand_op.getResult().setType(expand_result_type);
|
||||
op.getResult().setType(expand_result_type);
|
||||
|
||||
// Update the conv op's output shape.
|
||||
auto bts_output_shape =
|
||||
bts_op.output().getType().cast<ShapedType>().getShape();
|
||||
SmallVector<int64_t, 4> conv_result_shape(bts_output_shape.begin(),
|
||||
bts_output_shape.end());
|
||||
conv_result_shape.insert(conv_result_shape.begin() + expand_axis, 1);
|
||||
auto conv_result_type = RankedTensorType::get(
|
||||
conv_result_shape, getElementTypeOrSelf(stb_op.input()));
|
||||
op.getResult().setType(conv_result_type);
|
||||
|
||||
squeeze_op.getResult().setType(bts_op.output().getType());
|
||||
|
||||
// Connect `biasadd_op` with the output of `squeeze_op`.
|
||||
biasadd_op.setOperand(0, squeeze_op.output());
|
||||
biasadd_op.output().setType(squeeze_op.output().getType());
|
||||
if (biasadd_op) {
|
||||
biasadd_op.setOperand(0, squeeze_op.output());
|
||||
biasadd_op.output().setType(squeeze_op.output().getType());
|
||||
}
|
||||
} else {
|
||||
if (biasadd_op) biasadd_op.setOperand(0, op.output());
|
||||
op.setOperand(0, stb_op.input());
|
||||
@ -283,7 +307,7 @@ LogicalResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
|
||||
template <typename Conv2dOpTy>
|
||||
llvm::Optional<ArrayAttr>
|
||||
ConvertTFDilatedConvOp<Conv2dOpTy>::ExtractDilationsAttrFromBlockShape(
|
||||
Value stb_block_shape, Value bts_block_shape,
|
||||
Value stb_block_shape, Value bts_block_shape, int64_t expand_axis,
|
||||
PatternRewriter& rewriter) const {
|
||||
ElementsAttr stb_bs_attr, bts_bs_attr;
|
||||
if (!matchPattern(stb_block_shape, m_Constant(&stb_bs_attr)) ||
|
||||
@ -297,12 +321,31 @@ ConvertTFDilatedConvOp<Conv2dOpTy>::ExtractDilationsAttrFromBlockShape(
|
||||
if (stb_bs_attr.getValue({i}) != bts_bs_attr.getValue({i})) return {};
|
||||
}
|
||||
|
||||
int dilation_h_factor = -1, dilation_w_factor = -1;
|
||||
// Set dilation factor.
|
||||
if (stb_bs_attr.getNumElements() < 2) return {};
|
||||
int dilation_h_factor =
|
||||
stb_bs_attr.getValue({0}).cast<IntegerAttr>().getInt();
|
||||
int dilation_w_factor =
|
||||
stb_bs_attr.getValue({1}).cast<IntegerAttr>().getInt();
|
||||
if (stb_bs_attr.getNumElements() >= 2) {
|
||||
dilation_h_factor = stb_bs_attr.getValue({0}).cast<IntegerAttr>().getInt();
|
||||
dilation_w_factor = stb_bs_attr.getValue({1}).cast<IntegerAttr>().getInt();
|
||||
} else if (stb_bs_attr.getNumElements() == 1) {
|
||||
// For 1d conv, `tf.nn.convolution` expands NWC to NHWC format after
|
||||
// `SpaceToBatchND`. Therefore, `block_shape` of `stb_op` only has one
|
||||
// dilation factor of W dim, and dilation factor of H dim is set to 1.
|
||||
if (expand_axis == 1) {
|
||||
// NWC -> NHWC
|
||||
dilation_h_factor = 1;
|
||||
dilation_w_factor =
|
||||
stb_bs_attr.getValue({0}).cast<IntegerAttr>().getInt();
|
||||
} else if (expand_axis == 2) {
|
||||
// NHC -> NHWC
|
||||
dilation_h_factor =
|
||||
stb_bs_attr.getValue({0}).cast<IntegerAttr>().getInt();
|
||||
dilation_w_factor = 1;
|
||||
}
|
||||
}
|
||||
|
||||
if (dilation_h_factor == -1 || dilation_w_factor == -1) {
|
||||
return {};
|
||||
}
|
||||
|
||||
return rewriter.getI64ArrayAttr({1, dilation_h_factor, dilation_w_factor, 1});
|
||||
}
|
||||
|
@ -1049,7 +1049,6 @@ tf_xla_py_test(
|
||||
shard_count = 5,
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
"no_rocm",
|
||||
"optonly",
|
||||
],
|
||||
deps = [
|
||||
|
@ -609,7 +609,6 @@ xla_test(
|
||||
name = "logdet_test",
|
||||
srcs = ["logdet_test.cc"],
|
||||
tags = [
|
||||
"no_rocm",
|
||||
"optonly",
|
||||
],
|
||||
deps = [
|
||||
|
@ -1787,7 +1787,7 @@ cc_library(
|
||||
tf_cc_test(
|
||||
name = "buffer_comparator_test",
|
||||
srcs = if_cuda_is_configured(["buffer_comparator_test.cc"]),
|
||||
tags = ["no_rocm"] + tf_cuda_tests_tags(),
|
||||
tags = tf_cuda_tests_tags(),
|
||||
deps = [
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
|
@ -65,6 +65,12 @@ class ReductionDegenerateDimRemoverVisitor : public DfsHloRewriteVisitor {
|
||||
}
|
||||
}
|
||||
|
||||
if (updated_reduced_dimensions.empty()) {
|
||||
std::unique_ptr<HloInstruction> reshape =
|
||||
HloInstruction::CreateBitcast(reduce_shape, reduced_op);
|
||||
return ReplaceWithNewInstruction(instr, std::move(reshape));
|
||||
}
|
||||
|
||||
HloInstruction *input_reshape = instr->parent()->AddInstruction(
|
||||
HloInstruction::CreateBitcast(canonical_input_shape, reduced_op));
|
||||
|
||||
|
@ -177,7 +177,7 @@ tf_cc_test(
|
||||
srcs = [
|
||||
"tree_reduction_rewriter_test.cc",
|
||||
],
|
||||
tags = tf_cuda_tests_tags() + ["no_rocm"],
|
||||
tags = tf_cuda_tests_tags(),
|
||||
deps = [
|
||||
":gpu_codegen_test",
|
||||
"//tensorflow/compiler/xla:debug_options_flags",
|
||||
@ -258,7 +258,7 @@ tf_cc_test(
|
||||
srcs = [
|
||||
"parallel_reduction_test.cc",
|
||||
],
|
||||
tags = tf_cuda_tests_tags() + ["no_rocm"],
|
||||
tags = tf_cuda_tests_tags(),
|
||||
deps = [
|
||||
":gpu_codegen_test",
|
||||
"//tensorflow/compiler/xla/service:gpu_plugin",
|
||||
@ -297,7 +297,7 @@ tf_cc_test(
|
||||
srcs = [
|
||||
"gpu_copy_alone_test.cc",
|
||||
],
|
||||
tags = tf_cuda_tests_tags() + ["no_rocm"],
|
||||
tags = tf_cuda_tests_tags(),
|
||||
deps = [
|
||||
":gpu_codegen_test",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
@ -521,9 +521,7 @@ tf_cc_test(
|
||||
srcs = [
|
||||
"sorting_test.cc",
|
||||
],
|
||||
tags = tf_cuda_tests_tags() + [
|
||||
"no_rocm",
|
||||
],
|
||||
tags = tf_cuda_tests_tags(),
|
||||
deps = [
|
||||
":gpu_codegen_test",
|
||||
"//tensorflow/compiler/xla:debug_options_flags",
|
||||
|
@ -69,6 +69,38 @@ ENTRY main {
|
||||
)");
|
||||
}
|
||||
|
||||
TEST_F(ReductionDegenerateDimRemoverTest, DegenerateWithEmptyDimension) {
|
||||
const char* hlo_text = R"(
|
||||
HloModule ReduceWithDegenerateDimensions
|
||||
|
||||
add {
|
||||
accum = f32[] parameter(0)
|
||||
op = f32[] parameter(1)
|
||||
ROOT out = f32[] add(accum, op)
|
||||
}
|
||||
|
||||
ENTRY main {
|
||||
input = f32[1,3,1,4,1,5,1] parameter(0)
|
||||
zero = f32[] constant(0)
|
||||
|
||||
ROOT out = f32[3,4,5,1] reduce(input, zero), dimensions={0,2,4}, to_apply=add
|
||||
}
|
||||
|
||||
)";
|
||||
|
||||
EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
|
||||
// Copy instruction is added after bitcast because of copy-insertion pass,
|
||||
// so we check the entire hlo module to verify there is no reduce instruction
|
||||
// in this case.
|
||||
MatchOptimizedHloWithShapes(hlo_text,
|
||||
R"(
|
||||
// CHECK: ENTRY %main (input: f32[1,3,1,4,1,5,1]) -> f32[3,4,5,1] {
|
||||
// CHECK: %input = f32[1,3,1,4,1,5,1]{6,5,4,3,2,1,0} parameter(0)
|
||||
// CHECK: %bitcast{{.+}} = f32[3,4,5,1]{3,2,1,0} bitcast(f32[1,3,1,4,1,5,1]{6,5,4,3,2,1,0} %input)
|
||||
// CHECK: ROOT %copy{{.+}} = f32[3,4,5,1]{3,2,1,0} copy(f32[3,4,5,1]{3,2,1,0} %bitcast{{.+}})
|
||||
)");
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
||||
|
@ -1159,7 +1159,6 @@ xla_test(
|
||||
],
|
||||
shard_count = 50,
|
||||
tags = [
|
||||
"no_rocm",
|
||||
"optonly",
|
||||
],
|
||||
deps = CONVOLUTION_TEST_DEPS + [
|
||||
@ -1212,9 +1211,6 @@ xla_test(
|
||||
backend_args = {"gpu": ["--xla_backend_extra_options=xla_gpu_experimental_conv_disable_layout_heuristic"]},
|
||||
backends = ["gpu"],
|
||||
shard_count = 25,
|
||||
tags = [
|
||||
"no_rocm",
|
||||
],
|
||||
deps = CONVOLUTION_TEST_DEPS + [
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
@ -1228,9 +1224,6 @@ xla_test(
|
||||
backend_args = {"gpu": ["--xla_backend_extra_options=xla_gpu_experimental_conv_disable_layout_heuristic"]},
|
||||
backends = ["gpu"],
|
||||
shard_count = 25,
|
||||
tags = [
|
||||
"no_rocm",
|
||||
],
|
||||
deps = CONVOLUTION_TEST_DEPS + [
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
|
@ -760,6 +760,10 @@ const FunctionDef* EagerContext::GetFunctionDef(const string& function_name) {
|
||||
return func_lib_def_.Find(function_name);
|
||||
}
|
||||
|
||||
std::vector<string> EagerContext::ListFunctionNames() {
|
||||
return func_lib_def_.ListFunctionNames();
|
||||
}
|
||||
|
||||
Status EagerContext::RemoveFunction(const string& func) {
|
||||
bool is_last_ref = false;
|
||||
{
|
||||
|
@ -226,6 +226,8 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted {
|
||||
|
||||
const FunctionDef* GetFunctionDef(const string& function_name);
|
||||
|
||||
std::vector<string> ListFunctionNames() override;
|
||||
|
||||
Status RemoveFunction(const string& func) override;
|
||||
|
||||
// Wait for pending nodes to be finished in local executors (including context
|
||||
|
@ -1867,13 +1867,10 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
continue;
|
||||
}
|
||||
|
||||
// NOTE: We can only fuse BatchNorm into Conv2D nodes. In theory we can do
|
||||
// it for MatMul as well, but in practice this pattern does not appear in
|
||||
// real Tensorflow graphs.
|
||||
// NOTE: We can only fuse BatchNorm into Conv2D nodes. In theory we can do
|
||||
// it for MatMul as well, but in practice this pattern does not appear in
|
||||
// real Tensorflow graphs.
|
||||
|
||||
// TODO(penporn):
|
||||
// Remove this once TF-MKL supports _FusedConv2D with these operations.
|
||||
#ifndef INTEL_MKL
|
||||
// Remap Conv2D+Squeeze+BiasAdd into the _FusedConv2D+Squeeze.
|
||||
ContractionWithSqueezeAndBiasAdd contract_with_squeeze_and_bias;
|
||||
if (allow_non_differentiable_rewrites &&
|
||||
@ -1884,6 +1881,9 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
continue;
|
||||
}
|
||||
|
||||
// TODO(intel-tf):
|
||||
// Remove this once TF-MKL supports _FusedConv2D with these operations.
|
||||
#ifndef INTEL_MKL
|
||||
// Remap Conv2D+FusedBatchNorm into the _FusedConv2D;
|
||||
ContractionWithBatchNorm contract_with_batch_norm;
|
||||
if (allow_non_differentiable_rewrites &&
|
||||
|
@ -932,6 +932,7 @@ TEST_F(RemapperTest, FuseConv2DWithBatchNormAndActivation) {
|
||||
test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
|
||||
}
|
||||
}
|
||||
#endif // !INTEL_MKL
|
||||
|
||||
TEST_F(RemapperTest, FuseConv2DWithSqueezeAndBias) {
|
||||
using ops::Placeholder;
|
||||
@ -1003,7 +1004,6 @@ TEST_F(RemapperTest, FuseConv2DWithSqueezeAndBias) {
|
||||
ASSERT_EQ(tensors.size(), 1);
|
||||
test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
|
||||
}
|
||||
#endif // !INTEL_MKL
|
||||
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
||||
|
@ -20,7 +20,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/kernels/data/dataset_utils.h"
|
||||
#include "tensorflow/core/lib/core/refcount.h"
|
||||
#include "tensorflow/core/lib/core/threadpool.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/thread_annotations.h"
|
||||
#include "tensorflow/core/util/work_sharder.h"
|
||||
|
||||
@ -211,7 +210,6 @@ class ThreadPoolDatasetOp : public UnaryDatasetOpKernel {
|
||||
Status GetNextInternal(IteratorContext* ctx,
|
||||
std::vector<Tensor>* out_tensors,
|
||||
bool* end_of_sequence) override {
|
||||
tf_shared_lock l(mu_);
|
||||
return input_impl_->GetNext(IteratorContext(CreateParams(ctx)),
|
||||
out_tensors, end_of_sequence);
|
||||
}
|
||||
@ -225,7 +223,6 @@ class ThreadPoolDatasetOp : public UnaryDatasetOpKernel {
|
||||
|
||||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override {
|
||||
mutex_lock l(mu_);
|
||||
DCHECK(input_impl_ != nullptr);
|
||||
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
|
||||
return Status::OK();
|
||||
@ -233,7 +230,6 @@ class ThreadPoolDatasetOp : public UnaryDatasetOpKernel {
|
||||
|
||||
Status RestoreInternal(IteratorContext* ctx,
|
||||
IteratorStateReader* reader) override {
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
|
||||
return Status::OK();
|
||||
}
|
||||
@ -249,8 +245,7 @@ class ThreadPoolDatasetOp : public UnaryDatasetOpKernel {
|
||||
return params;
|
||||
}
|
||||
|
||||
mutex mu_;
|
||||
std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
|
||||
std::unique_ptr<IteratorBase> input_impl_;
|
||||
};
|
||||
|
||||
const DatasetBase* const input_;
|
||||
@ -351,7 +346,6 @@ class MaxIntraOpParallelismDatasetOp : public UnaryDatasetOpKernel {
|
||||
auto max_parallelism = dataset()->max_intra_op_parallelism_;
|
||||
params.runner =
|
||||
RunnerWithMaxParallelism(*ctx->runner(), max_parallelism);
|
||||
tf_shared_lock l(mu_);
|
||||
return input_impl_->GetNext(IteratorContext{std::move(params)},
|
||||
out_tensors, end_of_sequence);
|
||||
}
|
||||
@ -365,7 +359,6 @@ class MaxIntraOpParallelismDatasetOp : public UnaryDatasetOpKernel {
|
||||
|
||||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override {
|
||||
mutex_lock l(mu_);
|
||||
DCHECK(input_impl_ != nullptr);
|
||||
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
|
||||
return Status::OK();
|
||||
@ -373,14 +366,12 @@ class MaxIntraOpParallelismDatasetOp : public UnaryDatasetOpKernel {
|
||||
|
||||
Status RestoreInternal(IteratorContext* ctx,
|
||||
IteratorStateReader* reader) override {
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
mutex mu_;
|
||||
std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
|
||||
std::unique_ptr<IteratorBase> input_impl_;
|
||||
};
|
||||
|
||||
const DatasetBase* const input_;
|
||||
@ -481,7 +472,6 @@ class PrivateThreadPoolDatasetOp : public UnaryDatasetOpKernel {
|
||||
pool->Schedule(std::move(c));
|
||||
};
|
||||
params.runner_threadpool_size = dataset()->num_threads_;
|
||||
tf_shared_lock l(mu_);
|
||||
return input_impl_->GetNext(IteratorContext{std::move(params)},
|
||||
out_tensors, end_of_sequence);
|
||||
}
|
||||
@ -495,7 +485,6 @@ class PrivateThreadPoolDatasetOp : public UnaryDatasetOpKernel {
|
||||
|
||||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override {
|
||||
mutex_lock l(mu_);
|
||||
DCHECK(input_impl_ != nullptr);
|
||||
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
|
||||
return Status::OK();
|
||||
@ -503,14 +492,12 @@ class PrivateThreadPoolDatasetOp : public UnaryDatasetOpKernel {
|
||||
|
||||
Status RestoreInternal(IteratorContext* ctx,
|
||||
IteratorStateReader* reader) override {
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
mutex mu_;
|
||||
std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
|
||||
std::unique_ptr<IteratorBase> input_impl_;
|
||||
};
|
||||
|
||||
const DatasetBase* const input_;
|
||||
|
@ -10,148 +10,52 @@ just a few minutes. All you need is a TensorFlow model [converted to TensorFlow
|
||||
Lite](../convert/). (If you don't have a model converted yet, you can experiment
|
||||
using the model provided with the example linked below.)
|
||||
|
||||
## Install just the TensorFlow Lite interpreter
|
||||
## About the TensorFlow Lite runtime package
|
||||
|
||||
To quickly run TensorFlow Lite models with Python, you can install just the
|
||||
TensorFlow Lite interpreter, instead of all TensorFlow packages.
|
||||
To quickly start executing TensorFlow Lite models with Python, you can install
|
||||
just the TensorFlow Lite interpreter, instead of all TensorFlow packages. We
|
||||
call this simplified Python package `tflite_runtime`.
|
||||
|
||||
This interpreter-only package is a fraction the size of the full TensorFlow
|
||||
The `tflite_runtime` package is a fraction the size of the full `tensorflow`
|
||||
package and includes the bare minimum code required to run inferences with
|
||||
TensorFlow Lite—it includes only the
|
||||
[`tf.lite.Interpreter`](https://www.tensorflow.org/api_docs/python/tf/lite/Interpreter)
|
||||
TensorFlow Lite—primarily the
|
||||
[`Interpreter`](https://www.tensorflow.org/api_docs/python/tf/lite/Interpreter)
|
||||
Python class. This small package is ideal when all you want to do is execute
|
||||
`.tflite` models and avoid wasting disk space with the large TensorFlow library.
|
||||
|
||||
Note: If you need access to other Python APIs, such as the [TensorFlow Lite
|
||||
Converter](../convert/python_api.md), you must install the [full TensorFlow
|
||||
package](https://www.tensorflow.org/install/).
|
||||
Note: If you need access to other Python APIs, such as the
|
||||
[TensorFlow Lite Converter](../convert/), you must install the
|
||||
[full TensorFlow package](https://www.tensorflow.org/install/).
|
||||
|
||||
To install, run `pip3 install` and pass it the appropriate Python wheel URL from
|
||||
the following table.
|
||||
## Install TensorFlow Lite for Python
|
||||
|
||||
For example, if you have a Raspberry Pi that's running Raspberry Pi OS 10 (which
|
||||
has Python 3.7), install the Python wheel as follows:
|
||||
To install the TensorFlow Lite runtime package, run this command:
|
||||
|
||||
<pre class="devsite-terminal devsite-click-to-copy">
|
||||
pip3 install --extra-index-url https://google-coral.github.io/py-repo/ tflite_runtime
|
||||
</pre>
|
||||
|
||||
If you're on a Raspberry Pi, this command might fail due to a known issue with
|
||||
the `extra-index-url` option
|
||||
([#4011](https://github.com/raspberrypi/linux/issues/4011)). So we suggest you
|
||||
specify one of the
|
||||
[`tflite_runtime` wheels](https://github.com/google-coral/pycoral/releases/)
|
||||
that matches your system. For example, if you're running Raspberry Pi OS 10
|
||||
(which has Python 3.7), instead use this command:
|
||||
|
||||
<pre class="devsite-terminal devsite-click-to-copy">
|
||||
pip3 install https://github.com/google-coral/pycoral/releases/download/release-frogfish/tflite_runtime-2.5.0-cp37-cp37m-linux_armv7l.whl
|
||||
</pre>
|
||||
|
||||
<table>
|
||||
<tr><th>Platform</th><th>Python</th><th>URL</th></tr>
|
||||
<tr>
|
||||
<td style="white-space:nowrap" rowspan="4">Linux (ARM 32)</td>
|
||||
<td style="white-space:nowrap">3.5</td>
|
||||
<td>https://github.com/google-coral/pycoral/releases/download/release-frogfish/tflite_runtime-2.5.0-cp35-cp35m-linux_armv7l.whl</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<!-- ARM 32 -->
|
||||
<td style="white-space:nowrap">3.6</td>
|
||||
<td>https://github.com/google-coral/pycoral/releases/download/release-frogfish/tflite_runtime-2.5.0-cp36-cp36m-linux_armv7l.whl</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<!-- ARM 32 -->
|
||||
<td style="white-space:nowrap">3.7</td>
|
||||
<td>https://github.com/google-coral/pycoral/releases/download/release-frogfish/tflite_runtime-2.5.0-cp37-cp37m-linux_armv7l.whl</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<!-- ARM 32 -->
|
||||
<td style="white-space:nowrap">3.8</td>
|
||||
<td>https://github.com/google-coral/pycoral/releases/download/release-frogfish/tflite_runtime-2.5.0-cp38-cp38-linux_armv7l.whl</td>
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
<td style="white-space:nowrap" rowspan="4">Linux (ARM 64)</td>
|
||||
<td style="white-space:nowrap">3.5</td>
|
||||
<td>https://github.com/google-coral/pycoral/releases/download/release-frogfish/tflite_runtime-2.5.0-cp35-cp35m-linux_aarch64.whl</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<!-- ARM 64 -->
|
||||
<td style="white-space:nowrap">3.6</td>
|
||||
<td>https://github.com/google-coral/pycoral/releases/download/release-frogfish/tflite_runtime-2.5.0-cp36-cp36m-linux_aarch64.whl</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<!-- ARM 64 -->
|
||||
<td style="white-space:nowrap">3.7</td>
|
||||
<td>https://github.com/google-coral/pycoral/releases/download/release-frogfish/tflite_runtime-2.5.0-cp37-cp37m-linux_aarch64.whl</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<!-- ARM 64 -->
|
||||
<td style="white-space:nowrap">3.8</td>
|
||||
<td>https://github.com/google-coral/pycoral/releases/download/release-frogfish/tflite_runtime-2.5.0-cp38-cp38-linux_aarch64.whl</td>
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
<td style="white-space:nowrap" rowspan="4">Linux (x86-64)</td>
|
||||
<td style="white-space:nowrap">3.5</td>
|
||||
<td>https://github.com/google-coral/pycoral/releases/download/release-frogfish/tflite_runtime-2.5.0-cp35-cp35m-linux_x86_64.whl</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<!-- x86-64 -->
|
||||
<td style="white-space:nowrap">3.6</td>
|
||||
<td>https://github.com/google-coral/pycoral/releases/download/release-frogfish/tflite_runtime-2.5.0-cp36-cp36m-linux_x86_64.whl</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<!-- x86-64 -->
|
||||
<td style="white-space:nowrap">3.7</td>
|
||||
<td>https://github.com/google-coral/pycoral/releases/download/release-frogfish/tflite_runtime-2.5.0-cp37-cp37m-linux_x86_64.whl</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<!-- x86-64 -->
|
||||
<td style="white-space:nowrap">3.8</td>
|
||||
<td>https://github.com/google-coral/pycoral/releases/download/release-frogfish/tflite_runtime-2.5.0-cp38-cp38-linux_x86_64.whl</td>
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
<td style="white-space:nowrap" rowspan="4">macOS 10.15</td>
|
||||
<td style="white-space:nowrap">3.5</td>
|
||||
<td>https://github.com/google-coral/pycoral/releases/download/release-frogfish/tflite_runtime-2.5.0-cp35-cp35m-macosx_10_15_x86_64.whl</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<!-- Mac -->
|
||||
<td style="white-space:nowrap">3.6</td>
|
||||
<td>https://github.com/google-coral/pycoral/releases/download/release-frogfish/tflite_runtime-2.5.0-cp36-cp36m-macosx_10_15_x86_64.whl</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<!-- Mac -->
|
||||
<td style="white-space:nowrap">3.7</td>
|
||||
<td>https://github.com/google-coral/pycoral/releases/download/release-frogfish/tflite_runtime-2.5.0-cp37-cp37m-macosx_10_15_x86_64.whl</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<!-- Mac -->
|
||||
<td style="white-space:nowrap">3.8</td>
|
||||
<td>https://github.com/google-coral/pycoral/releases/download/release-frogfish/tflite_runtime-2.5.0-cp38-cp38-macosx_10_15_x86_64.whl</td>
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
<td style="white-space:nowrap" rowspan="4">Windows 10</td>
|
||||
<td style="white-space:nowrap">3.5</td>
|
||||
<td>https://github.com/google-coral/pycoral/releases/download/release-frogfish/tflite_runtime-2.5.0-cp35-cp35m-win_amd64.whl</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<!-- Win -->
|
||||
<td style="white-space:nowrap">3.6</td>
|
||||
<td>https://github.com/google-coral/pycoral/releases/download/release-frogfish/tflite_runtime-2.5.0-cp36-cp36m-win_amd64.whl</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<!-- Win -->
|
||||
<td style="white-space:nowrap">3.7</td>
|
||||
<td>https://github.com/google-coral/pycoral/releases/download/release-frogfish/tflite_runtime-2.5.0-cp37-cp37m-win_amd64.whl</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<!-- Win -->
|
||||
<td style="white-space:nowrap">3.8</td>
|
||||
<td>https://github.com/google-coral/pycoral/releases/download/release-frogfish/tflite_runtime-2.5.0-cp38-cp38-win_amd64.whl</td>
|
||||
</tr>
|
||||
|
||||
</table>
|
||||
Note: If you're on Debian Linux and using TensorFlow Lite with a Coral ML
|
||||
accelerator, using pip to install `tflite_runtime` may not be compatible with
|
||||
other Coral libraries. To ensure all your libraries are compatible, instead
|
||||
install `tflite_runtime` as a
|
||||
[Debian package from Coral](https://coral.ai/software/#debian-packages).
|
||||
|
||||
## Run an inference using tflite_runtime
|
||||
|
||||
To distinguish this interpreter-only package from the full TensorFlow package
|
||||
(allowing both to be installed, if you choose), the Python module provided in
|
||||
the above wheel is named `tflite_runtime`.
|
||||
|
||||
So instead of importing `Interpreter` from the `tensorflow` module, you need to
|
||||
Instead of importing `Interpreter` from the `tensorflow` module, you now need to
|
||||
import it from `tflite_runtime`.
|
||||
|
||||
For example, after you install the package above, copy and run the
|
||||
|
@ -3103,7 +3103,6 @@ cuda_py_test(
|
||||
tags = [
|
||||
"guitar",
|
||||
"multi_gpu",
|
||||
"no_rocm",
|
||||
"no_windows",
|
||||
],
|
||||
deps = [
|
||||
|
@ -1078,6 +1078,7 @@ cuda_py_test(
|
||||
tags = [
|
||||
"multi_and_single_gpu",
|
||||
"no_cuda_asan", # times out
|
||||
"no_rocm",
|
||||
"notsan", # b/173031470
|
||||
],
|
||||
deps = [
|
||||
@ -1741,6 +1742,7 @@ distribute_py_test(
|
||||
shard_count = 2,
|
||||
tags = [
|
||||
"multi_and_single_gpu",
|
||||
"no_rocm",
|
||||
"notsan", # TODO(b/160006974)
|
||||
],
|
||||
xla_enable_strict_auto_jit = True,
|
||||
@ -1773,6 +1775,7 @@ distribute_py_test(
|
||||
tags = [
|
||||
"multi_and_single_gpu",
|
||||
"no_cuda_asan", # times out
|
||||
"no_rocm",
|
||||
"notsan", # TODO(b/160006974)
|
||||
],
|
||||
xla_enable_strict_auto_jit = True,
|
||||
@ -1846,6 +1849,7 @@ distribute_py_test(
|
||||
disable_mlir_bridge = False,
|
||||
tags = [
|
||||
"multi_and_single_gpu",
|
||||
"no_rocm",
|
||||
],
|
||||
deps = [
|
||||
":combinations",
|
||||
|
@ -1186,6 +1186,15 @@ class Context(object):
|
||||
self.ensure_initialized()
|
||||
return pywrap_tfe.TFE_Py_PackEagerTensors(self._handle, tensors)
|
||||
|
||||
def list_function_names(self):
|
||||
"""Get a list of names of registered functions.
|
||||
|
||||
Returns:
|
||||
A set of names of all registered functions for the context.
|
||||
"""
|
||||
self.ensure_initialized()
|
||||
return set(pywrap_tfe.TFE_ContextListFunctionNames(self._handle))
|
||||
|
||||
def remove_function(self, name):
|
||||
"""Remove a function from the context.
|
||||
|
||||
|
@ -151,6 +151,16 @@ class ContextTest(test.TestCase):
|
||||
with self.assertRaisesRegex(ValueError, 'Multiple devices'):
|
||||
context.context().get_total_memory_usage('GPU')
|
||||
|
||||
def testListFunctionNames(self):
|
||||
|
||||
@def_function.function
|
||||
def f():
|
||||
return constant_op.constant(1.)
|
||||
|
||||
concrete = f.get_concrete_function()
|
||||
self.assertIn(concrete.name.decode(),
|
||||
context.context().list_function_names())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
ops.enable_eager_execution()
|
||||
|
@ -498,9 +498,17 @@ class _EagerDefinedFunction(object):
|
||||
function_callback(self)
|
||||
|
||||
def add_to_graph(self, g=None):
|
||||
"""Add the function to the current context or a graph, if supplied.
|
||||
|
||||
Args:
|
||||
g: the graph to add the function to. If not supplied, the function will
|
||||
be added to the current context.
|
||||
"""
|
||||
# pylint: disable=protected-access
|
||||
if not g and context.executing_eagerly():
|
||||
context.context().add_function_def(self.definition)
|
||||
ctx = context.context()
|
||||
if not ctx.has_function(self.name):
|
||||
ctx.add_function_def(self.definition)
|
||||
else:
|
||||
if not g._is_function(self.name):
|
||||
g._add_function(self)
|
||||
|
@ -4334,6 +4334,7 @@ EagerContextThreadLocalData* GetEagerContextThreadLocalData(
|
||||
}
|
||||
|
||||
if (eager_context_thread_local_data_map == nullptr) {
|
||||
absl::LeakCheckDisabler disabler;
|
||||
eager_context_thread_local_data_map = new EagerContextThreadLocalDataMap();
|
||||
}
|
||||
auto& thread_local_data =
|
||||
|
@ -660,7 +660,7 @@ def assert_no_new_pyobjects_executing_eagerly(func=None, warmup_iters=2):
|
||||
# versions of python2.7.x.
|
||||
for _ in range(warmup_iters):
|
||||
f(self, *args, **kwargs)
|
||||
# Since we aren't in the normal test lifecylce, we need to manually run
|
||||
# Since we aren't in the normal test lifecycle, we need to manually run
|
||||
# cleanups to clear out their object references.
|
||||
self.doCleanups()
|
||||
|
||||
@ -668,6 +668,10 @@ def assert_no_new_pyobjects_executing_eagerly(func=None, warmup_iters=2):
|
||||
# create and save as a dummy variable to include it as a baseline.
|
||||
obj_count_by_type = _get_object_count_by_type()
|
||||
gc.collect()
|
||||
|
||||
# Make sure any registered functions are cleaned up in the C++ runtime.
|
||||
registered_function_names = context.context().list_function_names()
|
||||
|
||||
# unittest.doCleanups adds to self._outcome with each unwound call.
|
||||
# These objects are retained across gc collections so we exclude them
|
||||
# from the object count calculation.
|
||||
@ -682,7 +686,7 @@ def assert_no_new_pyobjects_executing_eagerly(func=None, warmup_iters=2):
|
||||
}
|
||||
for _ in range(3):
|
||||
f(self, *args, **kwargs)
|
||||
# Since we aren't in the normal test lifecylce, we need to manually run
|
||||
# Since we aren't in the normal test lifecycle, we need to manually run
|
||||
# cleanups to clear out their object references.
|
||||
self.doCleanups()
|
||||
# Note that gc.get_objects misses anything that isn't subject to garbage
|
||||
@ -711,6 +715,14 @@ def assert_no_new_pyobjects_executing_eagerly(func=None, warmup_iters=2):
|
||||
exclude=gc.get_referents(self._outcome.errors,
|
||||
self._outcome.skipped)) -
|
||||
obj_count_by_type)
|
||||
|
||||
# There should be no newly registered functions hanging around.
|
||||
leftover_functions = (
|
||||
context.context().list_function_names() - registered_function_names)
|
||||
assert not leftover_functions, (
|
||||
"The following functions were newly created: %s" %
|
||||
leftover_functions)
|
||||
|
||||
# In some cases (specifically on MacOS), new_count is somehow
|
||||
# smaller than previous_count.
|
||||
# Using plain assert because not all classes using this decorator
|
||||
|
@ -249,6 +249,7 @@ distribute_py_test(
|
||||
main = "custom_training_loop_metrics_test.py",
|
||||
tags = [
|
||||
"multi_and_single_gpu",
|
||||
"no_rocm",
|
||||
],
|
||||
deps = [
|
||||
":strategy_combinations",
|
||||
@ -270,6 +271,7 @@ distribute_py_test(
|
||||
tags = [
|
||||
"multi_and_single_gpu",
|
||||
"no_cuda_asan", # times out
|
||||
"no_rocm",
|
||||
"notsan", # TODO(b/170954243)
|
||||
],
|
||||
tpu_tags = [
|
||||
@ -543,6 +545,7 @@ distribute_py_test(
|
||||
shard_count = 31,
|
||||
tags = [
|
||||
"multi_and_single_gpu",
|
||||
"no_rocm",
|
||||
"no_windows_gpu",
|
||||
"noasan", # TODO(b/337374867) fails with -fsanitize=null
|
||||
"notpu", # TODO(b/153672562)
|
||||
@ -562,6 +565,7 @@ distribute_py_test(
|
||||
shard_count = 7,
|
||||
tags = [
|
||||
"multi_and_single_gpu",
|
||||
"no_rocm",
|
||||
],
|
||||
xla_tags = [
|
||||
"no_cuda_asan", # times out
|
||||
|
@ -671,12 +671,13 @@ class Functional(training_lib.Model):
|
||||
Raises:
|
||||
ValueError: In case of improperly formatted config dict.
|
||||
"""
|
||||
input_tensors, output_tensors, created_layers = reconstruct_from_config(
|
||||
config, custom_objects)
|
||||
model = cls(inputs=input_tensors, outputs=output_tensors,
|
||||
name=config.get('name'))
|
||||
connect_ancillary_layers(model, created_layers)
|
||||
return model
|
||||
with generic_utils.SharedObjectLoadingScope():
|
||||
input_tensors, output_tensors, created_layers = reconstruct_from_config(
|
||||
config, custom_objects)
|
||||
model = cls(inputs=input_tensors, outputs=output_tensors,
|
||||
name=config.get('name'))
|
||||
connect_ancillary_layers(model, created_layers)
|
||||
return model
|
||||
|
||||
def _validate_graph_inputs_and_outputs(self):
|
||||
"""Validates the inputs and outputs of a Graph Network."""
|
||||
@ -1346,21 +1347,23 @@ def get_network_config(network, serialize_layer_fn=None):
|
||||
node_conversion_map[node_key] = kept_nodes
|
||||
kept_nodes += 1
|
||||
layer_configs = []
|
||||
for layer in network.layers: # From the earliest layers on.
|
||||
filtered_inbound_nodes = []
|
||||
for original_node_index, node in enumerate(layer._inbound_nodes):
|
||||
node_key = _make_node_key(layer.name, original_node_index)
|
||||
if node_key in network._network_nodes and not node.is_input:
|
||||
# The node is relevant to the model:
|
||||
# add to filtered_inbound_nodes.
|
||||
node_data = node.serialize(_make_node_key, node_conversion_map)
|
||||
filtered_inbound_nodes.append(node_data)
|
||||
|
||||
layer_config = serialize_layer_fn(layer)
|
||||
layer_config['name'] = layer.name
|
||||
layer_config['inbound_nodes'] = filtered_inbound_nodes
|
||||
layer_configs.append(layer_config)
|
||||
config['layers'] = layer_configs
|
||||
with generic_utils.SharedObjectSavingScope():
|
||||
for layer in network.layers: # From the earliest layers on.
|
||||
filtered_inbound_nodes = []
|
||||
for original_node_index, node in enumerate(layer._inbound_nodes):
|
||||
node_key = _make_node_key(layer.name, original_node_index)
|
||||
if node_key in network._network_nodes and not node.is_input:
|
||||
# The node is relevant to the model:
|
||||
# add to filtered_inbound_nodes.
|
||||
node_data = node.serialize(_make_node_key, node_conversion_map)
|
||||
filtered_inbound_nodes.append(node_data)
|
||||
|
||||
layer_config = serialize_layer_fn(layer)
|
||||
layer_config['name'] = layer.name
|
||||
layer_config['inbound_nodes'] = filtered_inbound_nodes
|
||||
layer_configs.append(layer_config)
|
||||
config['layers'] = layer_configs
|
||||
|
||||
# Gather info about inputs and outputs.
|
||||
model_inputs = []
|
||||
|
@ -80,7 +80,6 @@ cuda_py_test(
|
||||
name = "gradient_checkpoint_test",
|
||||
srcs = ["gradient_checkpoint_test.py"],
|
||||
python_version = "PY3",
|
||||
tags = ["no_rocm"],
|
||||
deps = [
|
||||
"//tensorflow:tensorflow_py_no_contrib",
|
||||
],
|
||||
|
@ -12,6 +12,7 @@ package(
|
||||
"//tensorflow/python/keras:__subpackages__",
|
||||
"//tensorflow/python/training/tracking:__pkg__",
|
||||
"//tensorflow/tools/pip_package:__pkg__",
|
||||
"//tensorflow_models/official/vision/beta/projects/residual_mobilenet/modeling/backbones:__pkg__",
|
||||
],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
@ -853,6 +854,7 @@ cuda_py_test(
|
||||
srcs = ["gru_v2_test.py"],
|
||||
python_version = "PY3",
|
||||
shard_count = 12,
|
||||
tags = ["no_rocm"],
|
||||
deps = [
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python/keras",
|
||||
|
@ -114,13 +114,11 @@ class CategoryCrossing(base_preprocessing_layer.PreprocessingLayer):
|
||||
`[[b'1_X_2_X_3'], [b'4_X_5_X_6']]`
|
||||
"""
|
||||
|
||||
def __init__(self, depth=None, name=None, separator=None, **kwargs):
|
||||
def __init__(self, depth=None, name=None, separator='_X_', **kwargs):
|
||||
super(CategoryCrossing, self).__init__(name=name, **kwargs)
|
||||
base_preprocessing_layer.keras_kpl_gauge.get_cell(
|
||||
'CategoryCrossing').set(True)
|
||||
self.depth = depth
|
||||
if separator is None:
|
||||
separator = '_X_'
|
||||
self.separator = separator
|
||||
if isinstance(depth, (tuple, list)):
|
||||
self._depth_tuple = depth
|
||||
|
@ -393,6 +393,10 @@ def clone_model(model, input_tensors=None, clone_function=None):
|
||||
except that it creates new layers (and thus new weights) instead
|
||||
of sharing the weights of the existing layers.
|
||||
|
||||
`clone_model` will not preserve the uniqueness of shared objects within the
|
||||
model (e.g. a single variable attached to two distinct layers will be
|
||||
restored as two separate variables).
|
||||
|
||||
Args:
|
||||
model: Instance of `Model`
|
||||
(could be a functional model or a Sequential model).
|
||||
|
@ -158,7 +158,6 @@ cuda_py_test(
|
||||
size = "medium",
|
||||
srcs = ["adadelta_test.py"],
|
||||
shard_count = 4,
|
||||
tags = ["no_rocm"],
|
||||
# TODO(b/168527439): invalid resource variable reference on GPU for TFRT.
|
||||
deps = [
|
||||
":optimizer_v2",
|
||||
@ -239,7 +238,6 @@ cuda_py_test(
|
||||
srcs = ["optimizer_v2_test.py"],
|
||||
shard_count = 8,
|
||||
tags = [
|
||||
"no_rocm",
|
||||
"no_windows",
|
||||
],
|
||||
deps = [
|
||||
@ -297,7 +295,6 @@ cuda_py_test(
|
||||
size = "medium",
|
||||
srcs = ["rmsprop_test.py"],
|
||||
shard_count = 2,
|
||||
tags = ["no_rocm"],
|
||||
xla_tags = [
|
||||
"no_cuda_asan", # times out
|
||||
],
|
||||
|
@ -148,8 +148,9 @@ def save_model(model,
|
||||
hdf5_format.save_model_to_hdf5(
|
||||
model, filepath, overwrite, include_optimizer)
|
||||
else:
|
||||
saved_model_save.save(model, filepath, overwrite, include_optimizer,
|
||||
signatures, options, save_traces)
|
||||
with generic_utils.SharedObjectSavingScope():
|
||||
saved_model_save.save(model, filepath, overwrite, include_optimizer,
|
||||
signatures, options, save_traces)
|
||||
|
||||
|
||||
@keras_export('keras.models.load_model')
|
||||
@ -194,17 +195,18 @@ def load_model(filepath, custom_objects=None, compile=True, options=None): # py
|
||||
ImportError: if loading from an hdf5 file and h5py is not available.
|
||||
IOError: In case of an invalid savefile.
|
||||
"""
|
||||
with generic_utils.CustomObjectScope(custom_objects or {}):
|
||||
with load_context.load_context(options):
|
||||
if (h5py is not None and
|
||||
(isinstance(filepath, h5py.File) or h5py.is_hdf5(filepath))):
|
||||
return hdf5_format.load_model_from_hdf5(filepath, custom_objects,
|
||||
compile)
|
||||
with generic_utils.SharedObjectLoadingScope():
|
||||
with generic_utils.CustomObjectScope(custom_objects or {}):
|
||||
with load_context.load_context(options):
|
||||
if (h5py is not None and
|
||||
(isinstance(filepath, h5py.File) or h5py.is_hdf5(filepath))):
|
||||
return hdf5_format.load_model_from_hdf5(filepath, custom_objects,
|
||||
compile)
|
||||
|
||||
filepath = path_to_string(filepath)
|
||||
if isinstance(filepath, six.string_types):
|
||||
loader_impl.parse_saved_model(filepath)
|
||||
return saved_model_load.load(filepath, compile, options)
|
||||
filepath = path_to_string(filepath)
|
||||
if isinstance(filepath, six.string_types):
|
||||
loader_impl.parse_saved_model(filepath)
|
||||
return saved_model_load.load(filepath, compile, options)
|
||||
|
||||
raise IOError(
|
||||
'Unable to load model. Filepath is not an hdf5 file (or h5py is not '
|
||||
|
@ -18,6 +18,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
@ -25,12 +26,14 @@ import tempfile
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
from six import string_types
|
||||
|
||||
from tensorflow.python import keras
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.feature_column import feature_column_lib
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.keras import combinations
|
||||
@ -859,6 +862,125 @@ class TestWholeModelSaving(keras_parameterized.TestCase):
|
||||
self.assertAllEqual(loaded_model.predict(args, batch_size=batch_size),
|
||||
expected)
|
||||
|
||||
@combinations.generate(combinations.combine(mode=['eager']))
|
||||
def test_shared_objects(self):
|
||||
class OuterLayer(keras.layers.Layer):
|
||||
|
||||
def __init__(self, inner_layer):
|
||||
super(OuterLayer, self).__init__()
|
||||
self.inner_layer = inner_layer
|
||||
|
||||
def call(self, inputs):
|
||||
return self.inner_layer(inputs)
|
||||
|
||||
def get_config(self):
|
||||
return {
|
||||
'inner_layer': generic_utils.serialize_keras_object(
|
||||
self.inner_layer)
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config):
|
||||
return cls(generic_utils.deserialize_keras_object(
|
||||
config['inner_layer']))
|
||||
|
||||
class InnerLayer(keras.layers.Layer):
|
||||
|
||||
def __init__(self):
|
||||
super(InnerLayer, self).__init__()
|
||||
self.v = self.add_weight(name='v', shape=[], dtype=dtypes.float32)
|
||||
|
||||
def call(self, inputs):
|
||||
return self.v + inputs
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config):
|
||||
return cls()
|
||||
|
||||
# Create a model with 2 output layers that share the same inner layer.
|
||||
inner_layer = InnerLayer()
|
||||
outer_layer_1 = OuterLayer(inner_layer)
|
||||
outer_layer_2 = OuterLayer(inner_layer)
|
||||
input_ = keras.Input(shape=(1,))
|
||||
model = keras.Model(
|
||||
inputs=input_, outputs=[outer_layer_1(input_), outer_layer_2(input_)])
|
||||
|
||||
# Changes to the shared layer should affect both outputs.
|
||||
model.layers[1].inner_layer.v.assign(5)
|
||||
self.assertAllEqual(model(1), [6.0, 6.0])
|
||||
model.layers[1].inner_layer.v.assign(3)
|
||||
self.assertAllEqual(model(1), [4.0, 4.0])
|
||||
|
||||
# After loading, changes to the shared layer should still affect both
|
||||
# outputs.
|
||||
def _do_assertions(loaded):
|
||||
loaded.layers[1].inner_layer.v.assign(5)
|
||||
self.assertAllEqual(loaded(1), [6.0, 6.0])
|
||||
loaded.layers[1].inner_layer.v.assign(3)
|
||||
self.assertAllEqual(loaded(1), [4.0, 4.0])
|
||||
loaded.layers[2].inner_layer.v.assign(5)
|
||||
self.assertAllEqual(loaded(1), [6.0, 6.0])
|
||||
loaded.layers[2].inner_layer.v.assign(3)
|
||||
self.assertAllEqual(loaded(1), [4.0, 4.0])
|
||||
|
||||
# We'd like to make sure we only attach shared object IDs when strictly
|
||||
# necessary, so we'll recursively traverse the generated config to count
|
||||
# whether we have the exact number we expect.
|
||||
def _get_all_keys_recursive(dict_or_iterable):
|
||||
if isinstance(dict_or_iterable, dict):
|
||||
for key in dict_or_iterable.keys():
|
||||
yield key
|
||||
for key in _get_all_keys_recursive(dict_or_iterable.values()):
|
||||
yield key
|
||||
elif isinstance(dict_or_iterable, string_types):
|
||||
return
|
||||
else:
|
||||
try:
|
||||
for item in dict_or_iterable:
|
||||
for key in _get_all_keys_recursive(item):
|
||||
yield key
|
||||
# Not an iterable or dictionary
|
||||
except TypeError:
|
||||
return
|
||||
|
||||
with generic_utils.CustomObjectScope({
|
||||
'OuterLayer': OuterLayer, 'InnerLayer': InnerLayer}):
|
||||
|
||||
# Test saving and loading to disk
|
||||
save_format = testing_utils.get_save_format()
|
||||
saved_model_dir = self._save_model_dir()
|
||||
keras.models.save_model(model, saved_model_dir, save_format=save_format)
|
||||
loaded = keras.models.load_model(saved_model_dir)
|
||||
_do_assertions(loaded)
|
||||
|
||||
# Test recreating directly from config
|
||||
config = model.get_config()
|
||||
key_count = collections.Counter(_get_all_keys_recursive(config))
|
||||
self.assertEqual(key_count[generic_utils.SHARED_OBJECT_KEY], 2)
|
||||
loaded = keras.Model.from_config(config)
|
||||
_do_assertions(loaded)
|
||||
|
||||
@combinations.generate(combinations.combine(mode=['eager']))
|
||||
def test_shared_objects_wrapper(self):
|
||||
"""Tests that shared layers wrapped with `Wrapper` restore correctly."""
|
||||
input_ = keras.Input(shape=(1,))
|
||||
unwrapped = keras.layers.Layer(name='unwrapped')
|
||||
wrapped = keras.layers.Wrapper(unwrapped, name='wrapped')
|
||||
model = keras.Model(inputs=input_,
|
||||
outputs=[unwrapped(input_), wrapped(input_)])
|
||||
|
||||
# Test recreating directly from config
|
||||
config = model.get_config()
|
||||
loaded = keras.Model.from_config(config)
|
||||
self.assertIs(loaded.layers[1], loaded.layers[2].layer)
|
||||
|
||||
# Test saving and loading to disk
|
||||
save_format = testing_utils.get_save_format()
|
||||
saved_model_dir = self._save_model_dir()
|
||||
keras.models.save_model(model, saved_model_dir, save_format=save_format)
|
||||
loaded = keras.models.load_model(saved_model_dir)
|
||||
self.assertIs(loaded.layers[1], loaded.layers[2].layer)
|
||||
|
||||
|
||||
# Factory functions to create models that will be serialized inside a Network.
|
||||
def _make_graph_network(input_size, output_size):
|
||||
|
@ -46,7 +46,6 @@ class LayerSavedModelSaver(base_serialization.SavedModelSaver):
|
||||
# TODO(kathywu): Synchronize with the keras spec (go/keras-json-spec) once
|
||||
# the python config serialization has caught up.
|
||||
metadata = dict(
|
||||
class_name=generic_utils.get_registered_name(type(self.obj)),
|
||||
name=self.obj.name,
|
||||
trainable=self.obj.trainable,
|
||||
expects_training_arg=self.obj._expects_training_arg, # pylint: disable=protected-access
|
||||
@ -56,7 +55,7 @@ class LayerSavedModelSaver(base_serialization.SavedModelSaver):
|
||||
must_restore_from_config=self.obj._must_restore_from_config, # pylint: disable=protected-access
|
||||
)
|
||||
|
||||
metadata.update(get_config(self.obj))
|
||||
metadata.update(get_serialized(self.obj))
|
||||
if self.obj.input_spec is not None:
|
||||
# Layer's input_spec has already been type-checked in the property setter.
|
||||
metadata['input_spec'] = nest.map_structure(
|
||||
@ -110,16 +109,12 @@ class LayerSavedModelSaver(base_serialization.SavedModelSaver):
|
||||
|
||||
# TODO(kathywu): Move serialization utils (and related utils from
|
||||
# generic_utils.py) to a separate file.
|
||||
def get_config(obj):
|
||||
def get_serialized(obj):
|
||||
with generic_utils.skip_failed_serialization():
|
||||
# Store the config dictionary, which may be used when reviving the object.
|
||||
# When loading, the program will attempt to revive the object from config,
|
||||
# and if that fails, the object will be revived from the SavedModel.
|
||||
config = generic_utils.serialize_keras_object(obj)['config']
|
||||
|
||||
if config is not None:
|
||||
return {'config': config}
|
||||
return {}
|
||||
return generic_utils.serialize_keras_object(obj)
|
||||
|
||||
|
||||
class InputLayerSavedModelSaver(base_serialization.SavedModelSaver):
|
||||
|
@ -492,13 +492,15 @@ class KerasObjectLoader(object):
|
||||
# found.
|
||||
class_name = metadata.get('class_name')
|
||||
config = metadata.get('config')
|
||||
shared_object_id = metadata.get('shared_object_id')
|
||||
must_restore_from_config = metadata.get('must_restore_from_config')
|
||||
if not generic_utils.validate_config(config):
|
||||
return None
|
||||
|
||||
try:
|
||||
obj = layers_module.deserialize(
|
||||
generic_utils.serialize_keras_class_and_config(class_name, config))
|
||||
generic_utils.serialize_keras_class_and_config(
|
||||
class_name, config, shared_object_id=shared_object_id))
|
||||
except ValueError:
|
||||
if must_restore_from_config:
|
||||
raise RuntimeError(
|
||||
|
@ -36,7 +36,7 @@ class MetricSavedModelSaver(layer_serialization.LayerSavedModelSaver):
|
||||
class_name=generic_utils.get_registered_name(type(self.obj)),
|
||||
name=self.obj.name,
|
||||
dtype=self.obj.dtype)
|
||||
metadata.update(layer_serialization.get_config(self.obj))
|
||||
metadata.update(layer_serialization.get_serialized(self.obj))
|
||||
if self.obj._build_input_shape is not None: # pylint: disable=protected-access
|
||||
metadata['build_input_shape'] = self.obj._build_input_shape # pylint: disable=protected-access
|
||||
return metadata
|
||||
|
@ -24,8 +24,10 @@ import marshal
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import types as python_types
|
||||
import weakref
|
||||
|
||||
import numpy as np
|
||||
import six
|
||||
@ -110,9 +112,205 @@ def get_custom_objects():
|
||||
return _GLOBAL_CUSTOM_OBJECTS
|
||||
|
||||
|
||||
def serialize_keras_class_and_config(cls_name, cls_config):
|
||||
# Store a unique, per-object ID for shared objects.
|
||||
#
|
||||
# We store a unique ID for each object so that we may, at loading time,
|
||||
# re-create the network properly. Without this ID, we would have no way of
|
||||
# determining whether a config is a description of a new object that
|
||||
# should be created or is merely a reference to an already-created object.
|
||||
SHARED_OBJECT_KEY = 'shared_object_id'
|
||||
|
||||
|
||||
class NoopLoadingScope(object):
|
||||
"""The default shared object loading scope. It does nothing.
|
||||
|
||||
Created to simplify serialization code that doesn't care about shared objects
|
||||
(e.g. when serializing a single object).
|
||||
"""
|
||||
|
||||
def get(self, unused_object_id):
|
||||
return None
|
||||
|
||||
def set(self, object_id, obj):
|
||||
pass
|
||||
|
||||
|
||||
SHARED_OBJECT_LOADING = threading.local()
|
||||
|
||||
|
||||
def _shared_object_loading_scope():
|
||||
"""Get the current shared object saving scope in a threadsafe manner.
|
||||
|
||||
Attributes on the threadlocal variable must be set per-thread, thus we
|
||||
cannot initialize these globally.
|
||||
|
||||
Returns:
|
||||
A SharedObjectLoadingScope or NoopLoadingScope object.
|
||||
"""
|
||||
return getattr(SHARED_OBJECT_LOADING, 'scope', NoopLoadingScope())
|
||||
|
||||
|
||||
class SharedObjectLoadingScope(object):
|
||||
"""A context manager for keeping track of loaded objects.
|
||||
|
||||
During the deserialization process, we may come across objects that are
|
||||
shared across multiple layers. In order to accurately restore the network
|
||||
structure to its original state, `SharedObjectLoadingScope` allows us to
|
||||
re-use shared objects rather than cloning them.
|
||||
"""
|
||||
|
||||
def __enter__(self):
|
||||
global SHARED_OBJECT_LOADING
|
||||
|
||||
SHARED_OBJECT_LOADING.scope = self
|
||||
self._obj_ids_to_obj = {}
|
||||
return self
|
||||
|
||||
def get(self, object_id):
|
||||
"""Given a shared object ID, returns a previously instantiated object.
|
||||
|
||||
Args:
|
||||
object_id: shared object ID to use when attempting to find already-loaded
|
||||
object.
|
||||
|
||||
Returns:
|
||||
The object, if we've seen this ID before. Else, `None`.
|
||||
"""
|
||||
# Explicitly check for `None` internally to make external calling code a
|
||||
# bit cleaner.
|
||||
if object_id is None:
|
||||
return
|
||||
return self._obj_ids_to_obj.get(object_id)
|
||||
|
||||
def set(self, object_id, obj):
|
||||
"""Stores an instantiated object for future lookup and sharing."""
|
||||
if object_id is None:
|
||||
return
|
||||
self._obj_ids_to_obj[object_id] = obj
|
||||
|
||||
def __exit__(self, *args, **kwargs):
|
||||
global SHARED_OBJECT_LOADING
|
||||
SHARED_OBJECT_LOADING.scope = NoopLoadingScope()
|
||||
|
||||
|
||||
SHARED_OBJECT_SAVING = threading.local()
|
||||
|
||||
|
||||
def _shared_object_saving_scope():
|
||||
"""Get the current shared object saving scope in a threadsafe manner.
|
||||
|
||||
Attributes on the threadlocal variable must be set per-thread, thus we
|
||||
cannot initialize these globally.
|
||||
|
||||
Returns:
|
||||
A SharedObjectSavingScope object or None.
|
||||
"""
|
||||
return getattr(SHARED_OBJECT_SAVING, 'scope', None)
|
||||
|
||||
|
||||
class SharedObjectConfig(dict):
|
||||
"""A configuration container that keeps track of references.
|
||||
|
||||
`SharedObjectConfig` will automatically attach a shared object ID to any
|
||||
configs which are referenced more than once, allowing for proper shared
|
||||
object reconstruction at load time.
|
||||
|
||||
In most cases, it would be more proper to subclass something like
|
||||
`collections.UserDict` or `collections.Mapping` rather than `dict` directly.
|
||||
Unfortunately, python's json encoder does not support `Mapping`s. This is
|
||||
important functionality to retain, since we are dealing with serialization.
|
||||
|
||||
We should be safe to subclass `dict` here, since we aren't actually
|
||||
overriding any core methods, only augmenting with a new one for reference
|
||||
counting.
|
||||
"""
|
||||
|
||||
def __init__(self, base_config, object_id, **kwargs):
|
||||
self.ref_count = 1
|
||||
self.object_id = object_id
|
||||
super(SharedObjectConfig, self).__init__(base_config, **kwargs)
|
||||
|
||||
def increment_ref_count(self):
|
||||
# As soon as we've seen the object more than once, we want to attach the
|
||||
# shared object ID. This allows us to only attach the shared object ID when
|
||||
# it's strictly necessary, making backwards compatibility breakage less
|
||||
# likely.
|
||||
if self.ref_count == 1:
|
||||
self[SHARED_OBJECT_KEY] = self.object_id
|
||||
self.ref_count += 1
|
||||
|
||||
|
||||
class SharedObjectSavingScope(object):
|
||||
"""Keeps track of shared object configs when serializing."""
|
||||
|
||||
def __enter__(self):
|
||||
global SHARED_OBJECT_SAVING
|
||||
|
||||
# Serialization can happen at a number of layers for a number of reasons.
|
||||
# We may end up with a case where we're opening a saving scope within
|
||||
# another saving scope. In that case, we'd like to use the outermost scope
|
||||
# available and ignore inner scopes, since there is not (yet) a reasonable
|
||||
# use case for having these nested and distinct.
|
||||
if _shared_object_saving_scope() is not None:
|
||||
self._passthrough = True
|
||||
return _shared_object_saving_scope()
|
||||
else:
|
||||
self._passthrough = False
|
||||
|
||||
SHARED_OBJECT_SAVING.scope = self
|
||||
self._shared_objects_config = weakref.WeakKeyDictionary()
|
||||
self._next_id = 0
|
||||
return self
|
||||
|
||||
def get_config(self, obj):
|
||||
"""Gets a `SharedObjectConfig` if one has already been seen for `obj`.
|
||||
|
||||
Args:
|
||||
obj: The object for which to retrieve the `SharedObjectConfig`.
|
||||
|
||||
Returns:
|
||||
The SharedObjectConfig for a given object, if already seen. Else,
|
||||
`None`.
|
||||
"""
|
||||
if obj in self._shared_objects_config:
|
||||
shared_object_config = self._shared_objects_config[obj]
|
||||
shared_object_config.increment_ref_count()
|
||||
return shared_object_config
|
||||
|
||||
def create_config(self, base_config, obj):
|
||||
shared_object_config = SharedObjectConfig(base_config, self._next_id)
|
||||
self._next_id += 1
|
||||
self._shared_objects_config[obj] = shared_object_config
|
||||
return shared_object_config
|
||||
|
||||
def __exit__(self, *args, **kwargs):
|
||||
if not self._passthrough:
|
||||
global SHARED_OBJECT_SAVING
|
||||
SHARED_OBJECT_SAVING.scope = None
|
||||
|
||||
|
||||
def serialize_keras_class_and_config(
|
||||
cls_name, cls_config, obj=None, shared_object_id=None):
|
||||
"""Returns the serialization of the class with the given config."""
|
||||
return {'class_name': cls_name, 'config': cls_config}
|
||||
base_config = {'class_name': cls_name, 'config': cls_config}
|
||||
|
||||
# We call `serialize_keras_class_and_config` for some branches of the load
|
||||
# path. In that case, we may already have a shared object ID we'd like to
|
||||
# retain.
|
||||
if shared_object_id is not None:
|
||||
base_config[SHARED_OBJECT_KEY] = shared_object_id
|
||||
|
||||
# If we have an active `SharedObjectSavingScope`, check whether we've already
|
||||
# serialized this config. If so, just use that config. This will store an
|
||||
# extra ID field in the config, allowing us to re-create the shared object
|
||||
# relationship at load time.
|
||||
if _shared_object_saving_scope() is not None and obj is not None:
|
||||
shared_object_config = _shared_object_saving_scope().get_config(obj)
|
||||
if shared_object_config is None:
|
||||
return _shared_object_saving_scope().create_config(base_config, obj)
|
||||
return shared_object_config
|
||||
|
||||
return base_config
|
||||
|
||||
|
||||
@keras_export('keras.utils.register_keras_serializable')
|
||||
@ -234,7 +432,19 @@ def get_registered_object(name, custom_objects=None, module_objects=None):
|
||||
|
||||
@keras_export('keras.utils.serialize_keras_object')
|
||||
def serialize_keras_object(instance):
|
||||
"""Serialize a Keras object into a JSON-compatible representation."""
|
||||
"""Serialize a Keras object into a JSON-compatible representation.
|
||||
|
||||
Calls to `serialize_keras_object` while underneath the
|
||||
`SharedObjectSavingScope` context manager will cause any objects re-used
|
||||
across multiple layers to be saved with a special shared object ID. This
|
||||
allows the network to be re-created properly during deserialization.
|
||||
|
||||
Args:
|
||||
instance: The object to serialize.
|
||||
|
||||
Returns:
|
||||
A dict-like, JSON-compatible representation of the object's config.
|
||||
"""
|
||||
_, instance = tf_decorator.unwrap(instance)
|
||||
if instance is None:
|
||||
return None
|
||||
@ -265,7 +475,8 @@ def serialize_keras_object(instance):
|
||||
serialization_config[key] = item
|
||||
|
||||
name = get_registered_name(instance.__class__)
|
||||
return serialize_keras_class_and_config(name, serialization_config)
|
||||
return serialize_keras_class_and_config(
|
||||
name, serialization_config, instance)
|
||||
if hasattr(instance, '__name__'):
|
||||
return get_registered_name(instance)
|
||||
raise ValueError('Cannot serialize', instance)
|
||||
@ -286,8 +497,9 @@ def class_and_config_for_serialized_keras_object(
|
||||
custom_objects=None,
|
||||
printable_module_name='object'):
|
||||
"""Returns the class name and config for a serialized keras object."""
|
||||
if (not isinstance(config, dict) or 'class_name' not in config or
|
||||
'config' not in config):
|
||||
if (not isinstance(config, dict)
|
||||
or 'class_name' not in config
|
||||
or 'config' not in config):
|
||||
raise ValueError('Improper config format: ' + str(config))
|
||||
|
||||
class_name = config['class_name']
|
||||
@ -341,7 +553,24 @@ def deserialize_keras_object(identifier,
|
||||
module_objects=None,
|
||||
custom_objects=None,
|
||||
printable_module_name='object'):
|
||||
"""Turns the serialized form of a Keras object back into an actual object."""
|
||||
"""Turns the serialized form of a Keras object back into an actual object.
|
||||
|
||||
Calls to `deserialize_keras_object` while underneath the
|
||||
`SharedObjectLoadingScope` context manager will cause any already-seen shared
|
||||
objects to be returned as-is rather than creating a new object.
|
||||
|
||||
Args:
|
||||
identifier: the serialized form of the object.
|
||||
module_objects: A dictionary of custom objects to look the name up in.
|
||||
Generally, module_objects is provided by midlevel library implementers.
|
||||
custom_objects: A dictionary of custom objects to look the name up in.
|
||||
Generally, custom_objects is provided by the user.
|
||||
printable_module_name: A human-readable string representing the type of the
|
||||
object. Printed in case of exception.
|
||||
|
||||
Returns:
|
||||
The deserialized object.
|
||||
"""
|
||||
if identifier is None:
|
||||
return None
|
||||
|
||||
@ -351,25 +580,39 @@ def deserialize_keras_object(identifier,
|
||||
(cls, cls_config) = class_and_config_for_serialized_keras_object(
|
||||
config, module_objects, custom_objects, printable_module_name)
|
||||
|
||||
# If this object has already been loaded (i.e. it's shared between multiple
|
||||
# objects), return the already-loaded object.
|
||||
shared_object_id = config.get(SHARED_OBJECT_KEY)
|
||||
shared_object = _shared_object_loading_scope().get(shared_object_id) # pylint: disable=assignment-from-none
|
||||
if shared_object is not None:
|
||||
return shared_object
|
||||
|
||||
if hasattr(cls, 'from_config'):
|
||||
arg_spec = tf_inspect.getfullargspec(cls.from_config)
|
||||
custom_objects = custom_objects or {}
|
||||
|
||||
if 'custom_objects' in arg_spec.args:
|
||||
return cls.from_config(
|
||||
deserialized_obj = cls.from_config(
|
||||
cls_config,
|
||||
custom_objects=dict(
|
||||
list(_GLOBAL_CUSTOM_OBJECTS.items()) +
|
||||
list(custom_objects.items())))
|
||||
with CustomObjectScope(custom_objects):
|
||||
return cls.from_config(cls_config)
|
||||
else:
|
||||
with CustomObjectScope(custom_objects):
|
||||
deserialized_obj = cls.from_config(cls_config)
|
||||
else:
|
||||
# Then `cls` may be a function returning a class.
|
||||
# in this case by convention `config` holds
|
||||
# the kwargs of the function.
|
||||
custom_objects = custom_objects or {}
|
||||
with CustomObjectScope(custom_objects):
|
||||
return cls(**cls_config)
|
||||
deserialized_obj = cls(**cls_config)
|
||||
|
||||
# Add object to shared objects, in case we find it referenced again.
|
||||
_shared_object_loading_scope().set(shared_object_id, deserialized_obj)
|
||||
|
||||
return deserialized_obj
|
||||
|
||||
elif isinstance(identifier, six.string_types):
|
||||
object_name = identifier
|
||||
if custom_objects and object_name in custom_objects:
|
||||
|
@ -23,6 +23,7 @@ from functools import partial
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python import keras
|
||||
from tensorflow.python.keras.utils import generic_utils
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@ -384,5 +385,63 @@ class SliceArraysTest(test.TestCase):
|
||||
[None, None, None])
|
||||
|
||||
|
||||
# object() alone isn't compatible with WeakKeyDictionary, which we use to
|
||||
# track shared configs.
|
||||
class MaybeSharedObject(object):
|
||||
pass
|
||||
|
||||
|
||||
class SharedObjectScopeTest(test.TestCase):
|
||||
|
||||
def test_shared_object_saving_scope_single_object_doesnt_export_id(self):
|
||||
with generic_utils.SharedObjectSavingScope() as scope:
|
||||
single_object = MaybeSharedObject()
|
||||
self.assertIsNone(scope.get_config(single_object))
|
||||
single_object_config = scope.create_config({}, single_object)
|
||||
self.assertIsNotNone(single_object_config)
|
||||
self.assertNotIn(generic_utils.SHARED_OBJECT_KEY,
|
||||
single_object_config)
|
||||
|
||||
def test_shared_object_saving_scope_shared_object_exports_id(self):
|
||||
with generic_utils.SharedObjectSavingScope() as scope:
|
||||
shared_object = MaybeSharedObject()
|
||||
self.assertIsNone(scope.get_config(shared_object))
|
||||
scope.create_config({}, shared_object)
|
||||
first_object_config = scope.get_config(shared_object)
|
||||
second_object_config = scope.get_config(shared_object)
|
||||
self.assertIn(generic_utils.SHARED_OBJECT_KEY,
|
||||
first_object_config)
|
||||
self.assertIn(generic_utils.SHARED_OBJECT_KEY,
|
||||
second_object_config)
|
||||
self.assertIs(first_object_config, second_object_config)
|
||||
|
||||
def test_shared_object_loading_scope_noop(self):
|
||||
# Test that, without a context manager scope, adding configs will do
|
||||
# nothing.
|
||||
obj_id = 1
|
||||
obj = MaybeSharedObject()
|
||||
generic_utils._shared_object_loading_scope().set(obj_id, obj)
|
||||
self.assertIsNone(generic_utils._shared_object_loading_scope().get(obj_id))
|
||||
|
||||
def test_shared_object_loading_scope_returns_shared_obj(self):
|
||||
obj_id = 1
|
||||
obj = MaybeSharedObject()
|
||||
with generic_utils.SharedObjectLoadingScope() as scope:
|
||||
scope.set(obj_id, obj)
|
||||
self.assertIs(scope.get(obj_id), obj)
|
||||
|
||||
def test_nested_shared_object_saving_scopes(self):
|
||||
my_obj = MaybeSharedObject()
|
||||
with generic_utils.SharedObjectSavingScope() as scope_1:
|
||||
scope_1.create_config({}, my_obj)
|
||||
with generic_utils.SharedObjectSavingScope() as scope_2:
|
||||
# Nesting saving scopes should return the original scope and should
|
||||
# not clear any objects we're tracking.
|
||||
self.assertIs(scope_1, scope_2)
|
||||
self.assertIsNotNone(scope_2.get_config(my_obj))
|
||||
self.assertIsNotNone(scope_1.get_config(my_obj))
|
||||
self.assertIsNone(generic_utils._shared_object_saving_scope())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -21,7 +21,6 @@ cuda_py_test(
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"no_pip",
|
||||
"no_rocm",
|
||||
],
|
||||
deps = [
|
||||
":mnist_testing_utils",
|
||||
|
@ -2118,7 +2118,6 @@ class SingleCycleTests(test.TestCase, parameterized.TestCase):
|
||||
# allocations at a lower level.
|
||||
@test_util.assert_no_new_pyobjects_executing_eagerly
|
||||
def test_functions_cleaned(self):
|
||||
self.skipTest("TODO(b/175152958): The test is leaking function definitions")
|
||||
if sys.version_info.major < 3:
|
||||
self.skipTest("Not working in Python 2")
|
||||
root = module.Module()
|
||||
|
@ -28,6 +28,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/dlpack.h"
|
||||
#include "tensorflow/c/eager/tfe_context_internal.h"
|
||||
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
@ -670,6 +671,10 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
|
||||
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
|
||||
return output;
|
||||
});
|
||||
m.def("TFE_ContextListFunctionNames", [](py::handle& ctx) {
|
||||
return tensorflow::unwrap(tensorflow::InputTFE_Context(ctx))
|
||||
->ListFunctionNames();
|
||||
});
|
||||
m.def("TFE_ContextEnableRunMetadata", [](py::handle& ctx) {
|
||||
TFE_ContextEnableRunMetadata(tensorflow::InputTFE_Context(ctx));
|
||||
});
|
||||
|
@ -25,14 +25,18 @@ from __future__ import print_function
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, List, Optional, Text
|
||||
|
||||
from tensorflow.core.util import event_pb2
|
||||
from tensorflow.python.client import session as session_lib
|
||||
from tensorflow.python.framework import meta_graph
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training import basic_session_run_hooks
|
||||
from tensorflow.python.training import monitored_session
|
||||
from tensorflow.python.training import saver as saver_lib
|
||||
from tensorflow.python.training import session_run_hook
|
||||
from tensorflow.python.training import training_util
|
||||
from tensorflow.python.training.session_run_hook import SessionRunArgs
|
||||
from tensorflow.python.training.summary_io import SummaryWriterCache
|
||||
|
||||
|
||||
@ -40,13 +44,14 @@ class AsyncCheckpointSaverHook(basic_session_run_hooks.CheckpointSaverHook):
|
||||
"""Saves checkpoints every N steps or seconds."""
|
||||
|
||||
def __init__(self,
|
||||
checkpoint_dir,
|
||||
save_secs=None,
|
||||
save_steps=None,
|
||||
saver=None,
|
||||
checkpoint_basename="model.ckpt",
|
||||
scaffold=None,
|
||||
listeners=None):
|
||||
checkpoint_dir: Text,
|
||||
save_secs: Optional[int] = None,
|
||||
save_steps: Optional[int] = None,
|
||||
saver: Optional[saver_lib.Saver] = None,
|
||||
checkpoint_basename: Text = "model.ckpt",
|
||||
scaffold: Optional[monitored_session.Scaffold] = None,
|
||||
listeners: Optional[List[
|
||||
basic_session_run_hooks.CheckpointSaverListener]] = None):
|
||||
"""Initializes a `CheckpointSaverHook`.
|
||||
|
||||
Args:
|
||||
@ -98,7 +103,7 @@ class AsyncCheckpointSaverHook(basic_session_run_hooks.CheckpointSaverHook):
|
||||
for l in self._listeners:
|
||||
l.begin()
|
||||
|
||||
def after_create_session(self, session, coord):
|
||||
def after_create_session(self, session: session_lib.Session, coord: Any):
|
||||
global_step = session.run(self._global_step_tensor)
|
||||
|
||||
# We do write graph and saver_def at the first call of before_run.
|
||||
@ -122,10 +127,11 @@ class AsyncCheckpointSaverHook(basic_session_run_hooks.CheckpointSaverHook):
|
||||
self._save(session, global_step)
|
||||
self._timer.update_last_triggered_step(global_step)
|
||||
|
||||
def before_run(self, run_context): # pylint: disable=unused-argument
|
||||
return SessionRunArgs(self._global_step_tensor)
|
||||
def before_run(self, run_context: Any): # pylint: disable=unused-argument
|
||||
return session_run_hook.SessionRunArgs(self._global_step_tensor)
|
||||
|
||||
def after_run(self, run_context, run_values):
|
||||
def after_run(self, run_context: session_run_hook.SessionRunContext,
|
||||
run_values: Any):
|
||||
global_step = run_context.session.run(self._global_step_tensor)
|
||||
if self._timer.should_trigger_for_step(global_step):
|
||||
self._timer.update_last_triggered_step(global_step)
|
||||
@ -133,7 +139,7 @@ class AsyncCheckpointSaverHook(basic_session_run_hooks.CheckpointSaverHook):
|
||||
if self._save(run_context.session, global_step):
|
||||
run_context.request_stop()
|
||||
|
||||
def end(self, session):
|
||||
def end(self, session: session_lib.Session):
|
||||
if self._save_thread:
|
||||
logging.info("Waiting for any pending checkpoints to finish.")
|
||||
self._save_thread.join()
|
||||
|
@ -19,6 +19,8 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from typing import Generator, Optional, Text
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
@ -70,10 +72,18 @@ def _get_custom_getter():
|
||||
|
||||
@tf_export(v1=['tpu.bfloat16_scope'])
|
||||
@tf_contextlib.contextmanager
|
||||
def bfloat16_scope(name=None):
|
||||
def bfloat16_scope(
|
||||
name: Optional[Text] = None
|
||||
) -> Generator[variable_scope.variable_scope, None, None]:
|
||||
"""Scope class for bfloat16 variables so that the model uses custom getter.
|
||||
|
||||
This enables variables to be read as bfloat16 type when using get_variable.
|
||||
|
||||
Arguments:
|
||||
name: Name to use for scope.
|
||||
|
||||
Yields:
|
||||
a variable scope.
|
||||
"""
|
||||
if name is None:
|
||||
name = ''
|
||||
|
@ -18,6 +18,8 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from typing import Callable, Optional, Text, Union
|
||||
|
||||
from tensorflow.python.data.experimental.ops import interleave_ops
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.ops import iterator_ops
|
||||
@ -28,13 +30,13 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import functional_ops
|
||||
|
||||
|
||||
def _TextLineDataset(filename):
|
||||
def _TextLineDataset(filename: Text) -> dataset_ops.Dataset:
|
||||
buffer_size = 8 * 1024 * 1024 # 8 MiB per file
|
||||
dataset = readers.TextLineDataset(filename, buffer_size=buffer_size)
|
||||
return dataset
|
||||
|
||||
|
||||
def _TFRecordDataset(filename):
|
||||
def _TFRecordDataset(filename: Text) -> dataset_ops.Dataset:
|
||||
buffer_size = 8 * 1024 * 1024 # 8 MiB per file
|
||||
dataset = readers.TFRecordDataset(filename, buffer_size=buffer_size)
|
||||
return dataset
|
||||
@ -47,15 +49,17 @@ _FILETYPE_MAP = {
|
||||
}
|
||||
|
||||
|
||||
def StreamingFilesDataset(files,
|
||||
filetype=None,
|
||||
file_reader_job=None,
|
||||
worker_job=None,
|
||||
num_epochs=None,
|
||||
filename_shuffle_buffer_size=None,
|
||||
num_parallel_reads=None,
|
||||
batch_transfer_size=None,
|
||||
sloppy=None):
|
||||
def StreamingFilesDataset(
|
||||
files: Union[Text, dataset_ops.Dataset],
|
||||
filetype: Optional[Union[Text, Callable[[Text],
|
||||
dataset_ops.Dataset]]] = None,
|
||||
file_reader_job: Optional[Text] = None,
|
||||
worker_job: Optional[Text] = None,
|
||||
num_epochs: Optional[int] = None,
|
||||
filename_shuffle_buffer_size: Optional[Union[int, bool]] = None,
|
||||
num_parallel_reads: Optional[int] = None,
|
||||
batch_transfer_size: Optional[Union[int, bool]] = None,
|
||||
sloppy: bool = True) -> dataset_ops.Dataset:
|
||||
"""StreamingFilesDataset constructs a dataset to stream from workers (GCE VM).
|
||||
|
||||
Because Cloud TPUs are allocated over the network, a Cloud TPU cannot read
|
||||
@ -126,9 +130,6 @@ def StreamingFilesDataset(files,
|
||||
if batch_transfer_size is None:
|
||||
batch_transfer_size = 256
|
||||
|
||||
if sloppy is None:
|
||||
sloppy = True
|
||||
|
||||
if file_reader_job == 'coordinator':
|
||||
file_reader_device = '/job:coordinator/task:0'
|
||||
else:
|
||||
|
@ -20,6 +20,7 @@ from __future__ import print_function
|
||||
|
||||
import enum
|
||||
import math
|
||||
from typing import List, Optional, Text, Tuple
|
||||
|
||||
import numpy as np
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
@ -66,7 +67,7 @@ class DeviceAssignment(object):
|
||||
`DeviceAssignment` directly.
|
||||
"""
|
||||
|
||||
def __init__(self, topology, core_assignment):
|
||||
def __init__(self, topology: Topology, core_assignment: np.ndarray):
|
||||
"""Constructs a `DeviceAssignment` object.
|
||||
|
||||
Args:
|
||||
@ -104,22 +105,22 @@ class DeviceAssignment(object):
|
||||
self._core_assignment, topology)
|
||||
|
||||
@property
|
||||
def topology(self):
|
||||
def topology(self) -> Topology:
|
||||
"""A `Topology` that describes the TPU topology."""
|
||||
return self._topology
|
||||
|
||||
@property
|
||||
def num_cores_per_replica(self):
|
||||
def num_cores_per_replica(self) -> int:
|
||||
"""The number of cores per replica."""
|
||||
return self._num_cores_per_replica
|
||||
|
||||
@property
|
||||
def num_replicas(self):
|
||||
def num_replicas(self) -> int:
|
||||
"""The number of replicas of the computation."""
|
||||
return self._num_replicas
|
||||
|
||||
@property
|
||||
def core_assignment(self):
|
||||
def core_assignment(self) -> np.ndarray:
|
||||
"""The logical to physical core mapping.
|
||||
|
||||
Returns:
|
||||
@ -129,11 +130,11 @@ class DeviceAssignment(object):
|
||||
"""
|
||||
return self._core_assignment
|
||||
|
||||
def coordinates(self, replica, logical_core):
|
||||
def coordinates(self, replica: int, logical_core: int) -> Tuple: # pylint:disable=g-bare-generic
|
||||
"""Returns the physical topology coordinates of a logical core."""
|
||||
return tuple(self.core_assignment[replica, logical_core, :])
|
||||
|
||||
def lookup_replicas(self, task_id, logical_core):
|
||||
def lookup_replicas(self, task_id: int, logical_core: int) -> List[int]:
|
||||
"""Lookup replica ids by task number and logical core.
|
||||
|
||||
Args:
|
||||
@ -153,31 +154,38 @@ class DeviceAssignment(object):
|
||||
"Can not find any replica in task: {} contains logical_core: {} ".
|
||||
format(task_id, logical_core))
|
||||
|
||||
def tpu_ordinal(self, replica=0, logical_core=0):
|
||||
def tpu_ordinal(self, replica: int = 0, logical_core: int = 0) -> int:
|
||||
"""Returns the ordinal of the TPU device assigned to a logical core."""
|
||||
coordinates = self.coordinates(replica, logical_core)
|
||||
return self._topology.tpu_device_ordinal_at_coordinates(coordinates)
|
||||
|
||||
def host_device(self, replica=0, logical_core=0, job=None):
|
||||
def host_device(self,
|
||||
replica: int = 0,
|
||||
logical_core: int = 0,
|
||||
job: Optional[Text] = None) -> Text:
|
||||
"""Returns the CPU device attached to a logical core."""
|
||||
coordinates = self.coordinates(replica, logical_core)
|
||||
return self._topology.cpu_device_name_at_coordinates(coordinates, job=job)
|
||||
|
||||
def tpu_device(self, replica=0, logical_core=0, job=None):
|
||||
def tpu_device(self,
|
||||
replica: int = 0,
|
||||
logical_core: int = 0,
|
||||
job: Optional[Text] = None) -> Text:
|
||||
"""Returns the name of the TPU device assigned to a logical core."""
|
||||
coordinates = self.coordinates(replica, logical_core)
|
||||
return self._topology.tpu_device_name_at_coordinates(coordinates, job=job)
|
||||
|
||||
@staticmethod
|
||||
def build(topology,
|
||||
computation_shape=None,
|
||||
computation_stride=None,
|
||||
num_replicas=1):
|
||||
def build(topology: Topology,
|
||||
computation_shape: Optional[np.ndarray] = None,
|
||||
computation_stride: Optional[np.ndarray] = None,
|
||||
num_replicas: int = 1) -> "DeviceAssignment":
|
||||
return device_assignment(topology, computation_shape, computation_stride,
|
||||
num_replicas)
|
||||
|
||||
|
||||
def _open_ring_2d(x_size, y_size, z_coord):
|
||||
def _open_ring_2d(x_size: int, y_size: int,
|
||||
z_coord: int) -> List[Tuple[int, int, int]]:
|
||||
"""Ring-order of a X by Y mesh, with a fixed Z coordinate.
|
||||
|
||||
For example, in a 4x4 mesh, this returns the following order.
|
||||
@ -213,7 +221,8 @@ def _open_ring_2d(x_size, y_size, z_coord):
|
||||
return ret
|
||||
|
||||
|
||||
def _ring_3d(x_size, y_size, z_size):
|
||||
def _ring_3d(x_size: int, y_size: int,
|
||||
z_size: int) -> List[Tuple[int, int, int]]:
|
||||
"""Ring-order of a X by Y by Z mesh.
|
||||
|
||||
Constructs the 3d ring from 2d rings that are stacked in the Z dimension and
|
||||
@ -325,11 +334,13 @@ class DeviceOrderMode(enum.IntEnum):
|
||||
MESH = 2
|
||||
|
||||
|
||||
def device_assignment(topology,
|
||||
computation_shape=None,
|
||||
computation_stride=None,
|
||||
num_replicas=1,
|
||||
device_order_mode=DeviceOrderMode.AUTO):
|
||||
def device_assignment(
|
||||
topology: Topology,
|
||||
computation_shape: Optional[np.ndarray] = None,
|
||||
computation_stride: Optional[np.ndarray] = None,
|
||||
num_replicas: int = 1,
|
||||
device_order_mode: DeviceOrderMode = DeviceOrderMode.AUTO
|
||||
) -> DeviceAssignment:
|
||||
"""Computes a device_assignment of a computation across a TPU topology.
|
||||
|
||||
Attempts to choose a compact grid of cores for locality.
|
||||
@ -341,11 +352,12 @@ def device_assignment(topology,
|
||||
optimal packing.
|
||||
|
||||
Args:
|
||||
topology: A `Topology` object that describes the TPU cluster topology.
|
||||
To obtain a TPU topology, evaluate the `Tensor` returned by
|
||||
topology: A `Topology` object that describes the TPU cluster topology. To
|
||||
obtain a TPU topology, evaluate the `Tensor` returned by
|
||||
`initialize_system` using `Session.run`. Either a serialized
|
||||
`TopologyProto` or a `Topology` object may be passed. Note: you must
|
||||
evaluate the `Tensor` first; you cannot pass an unevaluated `Tensor` here.
|
||||
evaluate the `Tensor` first; you cannot pass an unevaluated `Tensor`
|
||||
here.
|
||||
computation_shape: A rank 1 int32 numpy array with size equal to the
|
||||
topology rank, describing the shape of the computation's block of cores.
|
||||
If None, the `computation_shape` is `[1] * topology_rank`.
|
||||
|
@ -20,7 +20,7 @@ from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import functools
|
||||
from typing import Any, Dict, Callable, List, Optional, Text, Tuple
|
||||
from typing import Any, Dict, Callable, Iterable, List, Optional, Text, Tuple, Union
|
||||
|
||||
from absl import logging
|
||||
|
||||
@ -229,7 +229,6 @@ class TPUEmbedding(tracking.AutoTrackable):
|
||||
model = model_fn(...)
|
||||
embedding = tf.tpu.experimental.embedding.TPUEmbedding(
|
||||
feature_config=feature_config,
|
||||
batch_size=1024,
|
||||
optimizer=tf.tpu.experimental.embedding.SGD(0.1))
|
||||
checkpoint = tf.train.Checkpoint(model=model, embedding=embedding)
|
||||
checkpoint.restore(...)
|
||||
@ -244,7 +243,7 @@ class TPUEmbedding(tracking.AutoTrackable):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
feature_config: Any,
|
||||
feature_config: Union[tpu_embedding_v2_utils.FeatureConfig, Iterable], # pylint:disable=g-bare-generic
|
||||
optimizer: Optional[tpu_embedding_v2_utils._Optimizer], # pylint:disable=protected-access
|
||||
pipeline_execution_with_tensor_core: bool = False):
|
||||
"""Creates the TPUEmbedding mid level API object.
|
||||
|
@ -19,15 +19,23 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from typing import Any, Callable, Iterable, List, Optional, Union
|
||||
|
||||
from tensorflow.python.compiler.xla import xla
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.tpu import tensor_tracer
|
||||
from tensorflow.python.tpu import tpu_feed
|
||||
from tensorflow.python.tpu import tpu_function
|
||||
from tensorflow.python.types import core as core_types
|
||||
|
||||
|
||||
def while_loop(condition, body, inputs=None, infeed_queue=None, name=None):
|
||||
def while_loop(condition: Callable[..., Any],
|
||||
body: Callable[..., Any],
|
||||
inputs: Optional[List[Any]] = None,
|
||||
infeed_queue: Optional[tpu_feed.InfeedQueue] = None,
|
||||
name: Any = None) -> Any:
|
||||
"""Builds a training loop for TPUs.
|
||||
|
||||
The set of loop-carried tensors corresponds to `inputs`. Both
|
||||
@ -41,10 +49,10 @@ def while_loop(condition, body, inputs=None, infeed_queue=None, name=None):
|
||||
Args:
|
||||
condition: a Python function that builds the loop condition.
|
||||
body: a Python function that builds the loop body.
|
||||
inputs: a list of initial values passed into the training loop, or
|
||||
None (equivalent to an empty list).
|
||||
infeed_queue: if not None, the infeed queue from which to append a tuple
|
||||
of arguments as inputs to condition.
|
||||
inputs: a list of initial values passed into the training loop, or None
|
||||
(equivalent to an empty list).
|
||||
infeed_queue: if not None, the infeed queue from which to append a tuple of
|
||||
arguments as inputs to condition.
|
||||
name: (Deprecated) Does nothing.
|
||||
|
||||
Returns:
|
||||
@ -178,7 +186,12 @@ def while_loop(condition, body, inputs=None, infeed_queue=None, name=None):
|
||||
condition_wrapper, body_wrapper, inputs, name="", parallel_iterations=1)
|
||||
|
||||
|
||||
def repeat(n, body, inputs=None, infeed_queue=None, name=None):
|
||||
def repeat(
|
||||
n: int,
|
||||
body: Callable[..., Union[core_types.TensorLike, Iterable]], # pylint:disable=g-bare-generic
|
||||
inputs: Optional[List[core_types.TensorLike]] = None,
|
||||
infeed_queue: Optional[tpu_feed.InfeedQueue] = None,
|
||||
name: Any = None) -> List[core_types.TensorLike]:
|
||||
"""Builds a training loop that executes a fixed number of iterations.
|
||||
|
||||
The set of loop-carried tensors correspond to `inputs`.
|
||||
@ -188,11 +201,12 @@ def repeat(n, body, inputs=None, infeed_queue=None, name=None):
|
||||
Args:
|
||||
n: the number of loop iterations
|
||||
body: a Python function that builds the loop body.
|
||||
inputs: a list of initial values passed into the training loop or
|
||||
None (equivalent to an empty list).
|
||||
infeed_queue: if not None, the infeed queue from which to append a tuple
|
||||
of arguments as inputs to condition.
|
||||
inputs: a list of initial values passed into the training loop or None
|
||||
(equivalent to an empty list).
|
||||
infeed_queue: if not None, the infeed queue from which to append a tuple of
|
||||
arguments as inputs to condition.
|
||||
name: (Deprecated) Does nothing.
|
||||
|
||||
Returns:
|
||||
The final values of the loop-carried tensors.
|
||||
Raises:
|
||||
|
@ -138,7 +138,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'depth\', \'name\', \'separator\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'depth\', \'name\', \'separator\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'_X_\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "adapt"
|
||||
|
@ -138,7 +138,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'depth\', \'name\', \'separator\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'depth\', \'name\', \'separator\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'_X_\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "adapt"
|
||||
|
@ -261,6 +261,7 @@ function install_macos_pip_deps {
|
||||
${PIP_CMD} install $USER_FLAG 'grpcio ~= 1.34.0'
|
||||
${PIP_CMD} install $USER_FLAG 'portpicker ~= 1.3.1'
|
||||
${PIP_CMD} install $USER_FLAG 'scipy ~= 1.5.2'
|
||||
${PIP_CMD} install $USER_FLAG --upgrade certifi
|
||||
|
||||
# LINT.ThenChange(:linux_pip_installations_orig)
|
||||
# LINT.ThenChange(:linux_pip_installations)
|
||||
|
@ -46,7 +46,7 @@ py_test(
|
||||
tags = [
|
||||
"no_oss_py2",
|
||||
"no_pip",
|
||||
"no_rocm",
|
||||
"no_rocm", # No need to rerun this test for ROCm config.
|
||||
"no_windows", # numpy prints differently on windows.
|
||||
"noasan",
|
||||
"nomsan",
|
||||
|
Loading…
Reference in New Issue
Block a user