sync to master

This commit is contained in:
Mihai Maruseac 2021-01-21 15:40:33 -08:00
parent ed158956df
commit 81adaff8a6
52 changed files with 916 additions and 313 deletions

View File

@ -137,6 +137,10 @@ This release contains contributions from many people at Google, as well as:
<INSERT>, <NAME>, <HERE>, <USING>, <GITHUB>, <HANDLE>
# Release 2.4.1
* This release removes the AVX2 requirement from TF 2.4.0.
# Release 2.3.2
## Bug Fixes and Other Changes

View File

@ -185,6 +185,9 @@ class ImmediateExecutionContext : public AbstractContext {
virtual std::vector<std::string> GetLoggedOpsTestonly() { return {}; }
// Get a list of the names of functions that have been registered.
virtual std::vector<string> ListFunctionNames() = 0;
//===--------------------------------------------------------------------===//
// Distributed runtime related functions.
//===--------------------------------------------------------------------===//

View File

@ -127,6 +127,7 @@ add_mlir_library(MhloLhloToLinalg
LINK_LIBS PUBLIC
MhloDialect
MLIRComplex
MLIRIR
MLIRPass
)

View File

@ -372,3 +372,112 @@ func @testNoDilatedConvWhenGivenInputIsNonFloatType(%arg0: tensor<1x128x128x3xi3
// CHECK-NEXT: [[RESULT:%.*]] = "tf.BatchToSpaceND"
// CHECK-NEXT: return [[RESULT]]
}
func @testDilatedConv1DExpandH(%arg0: tensor<1x128x3xf32>, %arg1: tensor<1x5x3x8xf32>) -> tensor<1x128x8xf32> {
%cst = "tf.Const"() {value = dense<0> : tensor<1x2xi32>} : () -> tensor<1x2xi32>
%cst_0 = "tf.Const"() {value = dense<-3> : tensor<i32>} : () -> tensor<i32>
%cst_1 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
%cst_2 = "tf.Const"() {value = dense<4> : tensor<1x2xi32>} : () -> tensor<1x2xi32>
%0 = "tf.SpaceToBatchND"(%arg0, %cst_1, %cst_2) : (tensor<1x128x3xf32>, tensor<1xi32>, tensor<1x2xi32>) -> tensor<2x68x3xf32>
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<2x68x3xf32>, tensor<i32>) -> tensor<2x1x68x3xf32>
%2 = "tf.Conv2D"(%1, %arg1) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<2x1x68x3xf32>, tensor<1x5x3x8xf32>) -> tensor<2x1x64x8xf32>
%3 = "tf.Squeeze"(%2) {squeeze_dims = [-3]} : (tensor<2x1x64x8xf32>) -> tensor<2x64x8xf32>
%4 = "tf.BatchToSpaceND"(%3, %cst_1, %cst) : (tensor<2x64x8xf32>, tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x128x8xf32>
return %4 : tensor<1x128x8xf32>
// CHECK-LABEL: testDilatedConv1DExpandH
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x3xf32>, [[FILTER:%.*]]: tensor<1x5x3x8xf32>)
// CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<-3> : tensor<i32>} : () -> tensor<i32>
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x3xf32>, tensor<i32>) -> tensor<1x1x128x3xf32>
// CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) {dilations = [1, 1, 2, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x1x128x3xf32>, tensor<1x5x3x8xf32>) -> tensor<1x1x128x8xf32>
// CHECK-NEXT: [[RESULT:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [-3]} : (tensor<1x1x128x8xf32>) -> tensor<1x128x8xf32>
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x8xf32>
}
func @testDilatedConv1DExpandHWithBiasAdd(%arg0: tensor<1x128x3xf32>, %arg1: tensor<1x5x3x8xf32>, %arg2: tensor<8xf32>) -> tensor<1x128x8xf32> {
%cst = "tf.Const"() {value = dense<0> : tensor<1x2xi32>} : () -> tensor<1x2xi32>
%cst_0 = "tf.Const"() {value = dense<-3> : tensor<i32>} : () -> tensor<i32>
%cst_1 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
%cst_2 = "tf.Const"() {value = dense<4> : tensor<1x2xi32>} : () -> tensor<1x2xi32>
%0 = "tf.SpaceToBatchND"(%arg0, %cst_1, %cst_2) : (tensor<1x128x3xf32>, tensor<1xi32>, tensor<1x2xi32>) -> tensor<2x68x3xf32>
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<2x68x3xf32>, tensor<i32>) -> tensor<2x1x68x3xf32>
%2 = "tf.Conv2D"(%1, %arg1) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<2x1x68x3xf32>, tensor<1x5x3x8xf32>) -> tensor<2x1x64x8xf32>
%3 = "tf.Squeeze"(%2) {squeeze_dims = [-3]} : (tensor<2x1x64x8xf32>) -> tensor<2x64x8xf32>
%4 = "tf.BatchToSpaceND"(%3, %cst_1, %cst) : (tensor<2x64x8xf32>, tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x128x8xf32>
%5 = "tf.BiasAdd"(%4, %arg2) : (tensor<1x128x8xf32>, tensor<8xf32>) -> tensor<1x128x8xf32>
return %5 : tensor<1x128x8xf32>
// CHECK-LABEL: testDilatedConv1DExpandHWithBiasAdd
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x3xf32>, [[FILTER:%.*]]: tensor<1x5x3x8xf32>, [[BIAS:%.*]]: tensor<8xf32>)
// CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<-3> : tensor<i32>} : () -> tensor<i32>
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x3xf32>, tensor<i32>) -> tensor<1x1x128x3xf32>
// CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) {dilations = [1, 1, 2, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x1x128x3xf32>, tensor<1x5x3x8xf32>) -> tensor<1x1x128x8xf32>
// CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [-3]} : (tensor<1x1x128x8xf32>) -> tensor<1x128x8xf32>
// CHECK-NEXT: [[RESULT:%.*]] = "tf.BiasAdd"([[SQUEEZE]], [[BIAS]]) : (tensor<1x128x8xf32>, tensor<8xf32>) -> tensor<1x128x8xf32>
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x8xf32>
}
func @testDilatedConv1DExpandW(%arg0: tensor<1x128x3xf32>, %arg1: tensor<5x1x3x8xf32>) -> tensor<1x128x8xf32> {
%cst = "tf.Const"() {value = dense<0> : tensor<1x2xi32>} : () -> tensor<1x2xi32>
%cst_0 = "tf.Const"() {value = dense<-2> : tensor<i32>} : () -> tensor<i32>
%cst_1 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
%cst_2 = "tf.Const"() {value = dense<4> : tensor<1x2xi32>} : () -> tensor<1x2xi32>
%0 = "tf.SpaceToBatchND"(%arg0, %cst_1, %cst_2) : (tensor<1x128x3xf32>, tensor<1xi32>, tensor<1x2xi32>) -> tensor<2x68x3xf32>
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<2x68x3xf32>, tensor<i32>) -> tensor<2x68x1x3xf32>
%2 = "tf.Conv2D"(%1, %arg1) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<2x68x1x3xf32>, tensor<5x1x3x8xf32>) -> tensor<2x64x1x8xf32>
%3 = "tf.Squeeze"(%2) {squeeze_dims = [-2]} : (tensor<2x64x1x8xf32>) -> tensor<2x64x8xf32>
%4 = "tf.BatchToSpaceND"(%3, %cst_1, %cst) : (tensor<2x64x8xf32>, tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x128x8xf32>
return %4 : tensor<1x128x8xf32>
// CHECK-LABEL: testDilatedConv1DExpandW
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x3xf32>, [[FILTER:%.*]]: tensor<5x1x3x8xf32>)
// CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<-2> : tensor<i32>} : () -> tensor<i32>
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x3xf32>, tensor<i32>) -> tensor<1x128x1x3xf32>
// CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 1, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x128x1x3xf32>, tensor<5x1x3x8xf32>) -> tensor<1x128x1x8xf32>
// CHECK-NEXT: [[RESULT:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [-2]} : (tensor<1x128x1x8xf32>) -> tensor<1x128x8xf32>
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x8xf32>
}
func @testDilatedConv1DExpandWWithBiasAdd(%arg0: tensor<1x128x3xf32>, %arg1: tensor<5x1x3x8xf32>, %arg2: tensor<8xf32>) -> tensor<1x128x8xf32> {
%cst = "tf.Const"() {value = dense<0> : tensor<1x2xi32>} : () -> tensor<1x2xi32>
%cst_0 = "tf.Const"() {value = dense<-2> : tensor<i32>} : () -> tensor<i32>
%cst_1 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
%cst_2 = "tf.Const"() {value = dense<4> : tensor<1x2xi32>} : () -> tensor<1x2xi32>
%0 = "tf.SpaceToBatchND"(%arg0, %cst_1, %cst_2) : (tensor<1x128x3xf32>, tensor<1xi32>, tensor<1x2xi32>) -> tensor<2x68x3xf32>
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<2x68x3xf32>, tensor<i32>) -> tensor<2x68x1x3xf32>
%2 = "tf.Conv2D"(%1, %arg1) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<2x68x1x3xf32>, tensor<5x1x3x8xf32>) -> tensor<2x64x1x8xf32>
%3 = "tf.Squeeze"(%2) {squeeze_dims = [-2]} : (tensor<2x64x1x8xf32>) -> tensor<2x64x8xf32>
%4 = "tf.BatchToSpaceND"(%3, %cst_1, %cst) : (tensor<2x64x8xf32>, tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x128x8xf32>
%5 = "tf.BiasAdd"(%4, %arg2) : (tensor<1x128x8xf32>, tensor<8xf32>) -> tensor<1x128x8xf32>
return %5 : tensor<1x128x8xf32>
// CHECK-LABEL: testDilatedConv1DExpandWWithBiasAdd
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x3xf32>, [[FILTER:%.*]]: tensor<5x1x3x8xf32>, [[BIAS:%.*]]: tensor<8xf32>)
// CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<-2> : tensor<i32>} : () -> tensor<i32>
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x3xf32>, tensor<i32>) -> tensor<1x128x1x3xf32>
// CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 1, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x128x1x3xf32>, tensor<5x1x3x8xf32>) -> tensor<1x128x1x8xf32>
// CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [-2]} : (tensor<1x128x1x8xf32>) -> tensor<1x128x8xf32>
// CHECK-NEXT: [[RESULT:%.*]] = "tf.BiasAdd"([[SQUEEZE]], [[BIAS]]) : (tensor<1x128x8xf32>, tensor<8xf32>) -> tensor<1x128x8xf32>
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x8xf32>
}
func @testDilatedConv1DWithMixedPostiveAndNegativeAxis(%arg0: tensor<1x128x3xf32>, %arg1: tensor<1x5x3x8xf32>) -> tensor<1x128x8xf32> {
%cst = "tf.Const"() {value = dense<0> : tensor<1x2xi32>} : () -> tensor<1x2xi32>
%cst_0 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
%cst_1 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
%cst_2 = "tf.Const"() {value = dense<4> : tensor<1x2xi32>} : () -> tensor<1x2xi32>
%0 = "tf.SpaceToBatchND"(%arg0, %cst_1, %cst_2) : (tensor<1x128x3xf32>, tensor<1xi32>, tensor<1x2xi32>) -> tensor<2x68x3xf32>
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<2x68x3xf32>, tensor<i32>) -> tensor<2x1x68x3xf32>
%2 = "tf.Conv2D"(%1, %arg1) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<2x1x68x3xf32>, tensor<1x5x3x8xf32>) -> tensor<2x1x64x8xf32>
%3 = "tf.Squeeze"(%2) {squeeze_dims = [-3]} : (tensor<2x1x64x8xf32>) -> tensor<2x64x8xf32>
%4 = "tf.BatchToSpaceND"(%3, %cst_1, %cst) : (tensor<2x64x8xf32>, tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x128x8xf32>
return %4 : tensor<1x128x8xf32>
// CHECK-LABEL: testDilatedConv1DWithMixedPostiveAndNegativeAxis
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x3xf32>, [[FILTER:%.*]]: tensor<1x5x3x8xf32>)
// CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x3xf32>, tensor<i32>) -> tensor<1x1x128x3xf32>
// CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) {dilations = [1, 1, 2, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x1x128x3xf32>, tensor<1x5x3x8xf32>) -> tensor<1x1x128x8xf32>
// CHECK-NEXT: [[RESULT:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [-3]} : (tensor<1x1x128x8xf32>) -> tensor<1x128x8xf32>
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x8xf32>
}

View File

@ -70,7 +70,7 @@ class ConvertTFDilatedConvOp : public OpRewritePattern<Conv2dOpTy> {
// Extract the dilation factor from `block_shape` and pack it in an ArrayAttr.
llvm::Optional<ArrayAttr> ExtractDilationsAttrFromBlockShape(
Value stb_block_shape, Value bts_block_shape,
Value stb_block_shape, Value bts_block_shape, int64_t expand_axis,
PatternRewriter& rewriter) const;
public:
@ -111,7 +111,7 @@ LogicalResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
TF::ExpandDimsOp expand_op;
TF::SqueezeOp squeeze_op;
int64_t expand_axis;
int64_t expand_axis = -1;
// Expand + Squeeze op.
if (llvm::isa<TF::ExpandDimsOp>(prev_op)) {
if (!llvm::isa<TF::SqueezeOp>(next_op)) {
@ -127,13 +127,26 @@ LogicalResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
expand_axis =
(*const_op.value().cast<DenseElementsAttr>().getIntValues().begin())
.getSExtValue();
// Canonicalize axis. Some TF python functions, such as
// `tf.nn.convolution`, use negative axis.
if (expand_axis < 0) {
// Always expand 3D input to 4D input.
expand_axis += 4;
}
} else {
return failure();
}
// Make sure that the `squeeze_dims` is equal to `expand_axis`.
auto squeeze_dims = squeeze_op.squeeze_dims();
if (squeeze_dims.size() != 1 ||
squeeze_dims[0].cast<IntegerAttr>().getInt() != expand_axis) {
if (squeeze_dims.size() != 1) {
return failure();
}
int64_t squeeze_axis = squeeze_dims[0].cast<IntegerAttr>().getInt();
if (squeeze_axis < 0) {
// Always squeeze 4D input to 3D input.
squeeze_axis += 4;
}
if (squeeze_axis != expand_axis) {
return failure();
}
@ -183,7 +196,7 @@ LogicalResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
}
llvm::Optional<ArrayAttr> dilations_attr = ExtractDilationsAttrFromBlockShape(
stb_op.block_shape(), bts_op.block_shape(), rewriter);
stb_op.block_shape(), bts_op.block_shape(), expand_axis, rewriter);
if (!dilations_attr.hasValue()) return failure();
if (expand_op) {
@ -259,13 +272,24 @@ LogicalResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
auto expand_result_type = RankedTensorType::get(
expand_shape, getElementTypeOrSelf(stb_op.input()));
expand_op.getResult().setType(expand_result_type);
op.getResult().setType(expand_result_type);
// Update the conv op's output shape.
auto bts_output_shape =
bts_op.output().getType().cast<ShapedType>().getShape();
SmallVector<int64_t, 4> conv_result_shape(bts_output_shape.begin(),
bts_output_shape.end());
conv_result_shape.insert(conv_result_shape.begin() + expand_axis, 1);
auto conv_result_type = RankedTensorType::get(
conv_result_shape, getElementTypeOrSelf(stb_op.input()));
op.getResult().setType(conv_result_type);
squeeze_op.getResult().setType(bts_op.output().getType());
// Connect `biasadd_op` with the output of `squeeze_op`.
biasadd_op.setOperand(0, squeeze_op.output());
biasadd_op.output().setType(squeeze_op.output().getType());
if (biasadd_op) {
biasadd_op.setOperand(0, squeeze_op.output());
biasadd_op.output().setType(squeeze_op.output().getType());
}
} else {
if (biasadd_op) biasadd_op.setOperand(0, op.output());
op.setOperand(0, stb_op.input());
@ -283,7 +307,7 @@ LogicalResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
template <typename Conv2dOpTy>
llvm::Optional<ArrayAttr>
ConvertTFDilatedConvOp<Conv2dOpTy>::ExtractDilationsAttrFromBlockShape(
Value stb_block_shape, Value bts_block_shape,
Value stb_block_shape, Value bts_block_shape, int64_t expand_axis,
PatternRewriter& rewriter) const {
ElementsAttr stb_bs_attr, bts_bs_attr;
if (!matchPattern(stb_block_shape, m_Constant(&stb_bs_attr)) ||
@ -297,12 +321,31 @@ ConvertTFDilatedConvOp<Conv2dOpTy>::ExtractDilationsAttrFromBlockShape(
if (stb_bs_attr.getValue({i}) != bts_bs_attr.getValue({i})) return {};
}
int dilation_h_factor = -1, dilation_w_factor = -1;
// Set dilation factor.
if (stb_bs_attr.getNumElements() < 2) return {};
int dilation_h_factor =
stb_bs_attr.getValue({0}).cast<IntegerAttr>().getInt();
int dilation_w_factor =
stb_bs_attr.getValue({1}).cast<IntegerAttr>().getInt();
if (stb_bs_attr.getNumElements() >= 2) {
dilation_h_factor = stb_bs_attr.getValue({0}).cast<IntegerAttr>().getInt();
dilation_w_factor = stb_bs_attr.getValue({1}).cast<IntegerAttr>().getInt();
} else if (stb_bs_attr.getNumElements() == 1) {
// For 1d conv, `tf.nn.convolution` expands NWC to NHWC format after
// `SpaceToBatchND`. Therefore, `block_shape` of `stb_op` only has one
// dilation factor of W dim, and dilation factor of H dim is set to 1.
if (expand_axis == 1) {
// NWC -> NHWC
dilation_h_factor = 1;
dilation_w_factor =
stb_bs_attr.getValue({0}).cast<IntegerAttr>().getInt();
} else if (expand_axis == 2) {
// NHC -> NHWC
dilation_h_factor =
stb_bs_attr.getValue({0}).cast<IntegerAttr>().getInt();
dilation_w_factor = 1;
}
}
if (dilation_h_factor == -1 || dilation_w_factor == -1) {
return {};
}
return rewriter.getI64ArrayAttr({1, dilation_h_factor, dilation_w_factor, 1});
}

View File

@ -1049,7 +1049,6 @@ tf_xla_py_test(
shard_count = 5,
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
"no_rocm",
"optonly",
],
deps = [

View File

@ -609,7 +609,6 @@ xla_test(
name = "logdet_test",
srcs = ["logdet_test.cc"],
tags = [
"no_rocm",
"optonly",
],
deps = [

View File

@ -1787,7 +1787,7 @@ cc_library(
tf_cc_test(
name = "buffer_comparator_test",
srcs = if_cuda_is_configured(["buffer_comparator_test.cc"]),
tags = ["no_rocm"] + tf_cuda_tests_tags(),
tags = tf_cuda_tests_tags(),
deps = [
"//tensorflow/core:test_main",
"//tensorflow/compiler/xla:shape_util",

View File

@ -65,6 +65,12 @@ class ReductionDegenerateDimRemoverVisitor : public DfsHloRewriteVisitor {
}
}
if (updated_reduced_dimensions.empty()) {
std::unique_ptr<HloInstruction> reshape =
HloInstruction::CreateBitcast(reduce_shape, reduced_op);
return ReplaceWithNewInstruction(instr, std::move(reshape));
}
HloInstruction *input_reshape = instr->parent()->AddInstruction(
HloInstruction::CreateBitcast(canonical_input_shape, reduced_op));

View File

@ -177,7 +177,7 @@ tf_cc_test(
srcs = [
"tree_reduction_rewriter_test.cc",
],
tags = tf_cuda_tests_tags() + ["no_rocm"],
tags = tf_cuda_tests_tags(),
deps = [
":gpu_codegen_test",
"//tensorflow/compiler/xla:debug_options_flags",
@ -258,7 +258,7 @@ tf_cc_test(
srcs = [
"parallel_reduction_test.cc",
],
tags = tf_cuda_tests_tags() + ["no_rocm"],
tags = tf_cuda_tests_tags(),
deps = [
":gpu_codegen_test",
"//tensorflow/compiler/xla/service:gpu_plugin",
@ -297,7 +297,7 @@ tf_cc_test(
srcs = [
"gpu_copy_alone_test.cc",
],
tags = tf_cuda_tests_tags() + ["no_rocm"],
tags = tf_cuda_tests_tags(),
deps = [
":gpu_codegen_test",
"//tensorflow/compiler/xla/service:hlo",
@ -521,9 +521,7 @@ tf_cc_test(
srcs = [
"sorting_test.cc",
],
tags = tf_cuda_tests_tags() + [
"no_rocm",
],
tags = tf_cuda_tests_tags(),
deps = [
":gpu_codegen_test",
"//tensorflow/compiler/xla:debug_options_flags",

View File

@ -69,6 +69,38 @@ ENTRY main {
)");
}
TEST_F(ReductionDegenerateDimRemoverTest, DegenerateWithEmptyDimension) {
const char* hlo_text = R"(
HloModule ReduceWithDegenerateDimensions
add {
accum = f32[] parameter(0)
op = f32[] parameter(1)
ROOT out = f32[] add(accum, op)
}
ENTRY main {
input = f32[1,3,1,4,1,5,1] parameter(0)
zero = f32[] constant(0)
ROOT out = f32[3,4,5,1] reduce(input, zero), dimensions={0,2,4}, to_apply=add
}
)";
EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
// Copy instruction is added after bitcast because of copy-insertion pass,
// so we check the entire hlo module to verify there is no reduce instruction
// in this case.
MatchOptimizedHloWithShapes(hlo_text,
R"(
// CHECK: ENTRY %main (input: f32[1,3,1,4,1,5,1]) -> f32[3,4,5,1] {
// CHECK: %input = f32[1,3,1,4,1,5,1]{6,5,4,3,2,1,0} parameter(0)
// CHECK: %bitcast{{.+}} = f32[3,4,5,1]{3,2,1,0} bitcast(f32[1,3,1,4,1,5,1]{6,5,4,3,2,1,0} %input)
// CHECK: ROOT %copy{{.+}} = f32[3,4,5,1]{3,2,1,0} copy(f32[3,4,5,1]{3,2,1,0} %bitcast{{.+}})
)");
}
} // namespace
} // namespace gpu
} // namespace xla

View File

@ -1159,7 +1159,6 @@ xla_test(
],
shard_count = 50,
tags = [
"no_rocm",
"optonly",
],
deps = CONVOLUTION_TEST_DEPS + [
@ -1212,9 +1211,6 @@ xla_test(
backend_args = {"gpu": ["--xla_backend_extra_options=xla_gpu_experimental_conv_disable_layout_heuristic"]},
backends = ["gpu"],
shard_count = 25,
tags = [
"no_rocm",
],
deps = CONVOLUTION_TEST_DEPS + [
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
@ -1228,9 +1224,6 @@ xla_test(
backend_args = {"gpu": ["--xla_backend_extra_options=xla_gpu_experimental_conv_disable_layout_heuristic"]},
backends = ["gpu"],
shard_count = 25,
tags = [
"no_rocm",
],
deps = CONVOLUTION_TEST_DEPS + [
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",

View File

@ -760,6 +760,10 @@ const FunctionDef* EagerContext::GetFunctionDef(const string& function_name) {
return func_lib_def_.Find(function_name);
}
std::vector<string> EagerContext::ListFunctionNames() {
return func_lib_def_.ListFunctionNames();
}
Status EagerContext::RemoveFunction(const string& func) {
bool is_last_ref = false;
{

View File

@ -226,6 +226,8 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted {
const FunctionDef* GetFunctionDef(const string& function_name);
std::vector<string> ListFunctionNames() override;
Status RemoveFunction(const string& func) override;
// Wait for pending nodes to be finished in local executors (including context

View File

@ -1867,13 +1867,10 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item,
continue;
}
// NOTE: We can only fuse BatchNorm into Conv2D nodes. In theory we can do
// it for MatMul as well, but in practice this pattern does not appear in
// real Tensorflow graphs.
// NOTE: We can only fuse BatchNorm into Conv2D nodes. In theory we can do
// it for MatMul as well, but in practice this pattern does not appear in
// real Tensorflow graphs.
// TODO(penporn):
// Remove this once TF-MKL supports _FusedConv2D with these operations.
#ifndef INTEL_MKL
// Remap Conv2D+Squeeze+BiasAdd into the _FusedConv2D+Squeeze.
ContractionWithSqueezeAndBiasAdd contract_with_squeeze_and_bias;
if (allow_non_differentiable_rewrites &&
@ -1884,6 +1881,9 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item,
continue;
}
// TODO(intel-tf):
// Remove this once TF-MKL supports _FusedConv2D with these operations.
#ifndef INTEL_MKL
// Remap Conv2D+FusedBatchNorm into the _FusedConv2D;
ContractionWithBatchNorm contract_with_batch_norm;
if (allow_non_differentiable_rewrites &&

View File

@ -932,6 +932,7 @@ TEST_F(RemapperTest, FuseConv2DWithBatchNormAndActivation) {
test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
}
}
#endif // !INTEL_MKL
TEST_F(RemapperTest, FuseConv2DWithSqueezeAndBias) {
using ops::Placeholder;
@ -1003,7 +1004,6 @@ TEST_F(RemapperTest, FuseConv2DWithSqueezeAndBias) {
ASSERT_EQ(tensors.size(), 1);
test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
}
#endif // !INTEL_MKL
} // namespace grappler
} // namespace tensorflow

View File

@ -20,7 +20,6 @@ limitations under the License.
#include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/util/work_sharder.h"
@ -211,7 +210,6 @@ class ThreadPoolDatasetOp : public UnaryDatasetOpKernel {
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
tf_shared_lock l(mu_);
return input_impl_->GetNext(IteratorContext(CreateParams(ctx)),
out_tensors, end_of_sequence);
}
@ -225,7 +223,6 @@ class ThreadPoolDatasetOp : public UnaryDatasetOpKernel {
Status SaveInternal(SerializationContext* ctx,
IteratorStateWriter* writer) override {
mutex_lock l(mu_);
DCHECK(input_impl_ != nullptr);
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
return Status::OK();
@ -233,7 +230,6 @@ class ThreadPoolDatasetOp : public UnaryDatasetOpKernel {
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
return Status::OK();
}
@ -249,8 +245,7 @@ class ThreadPoolDatasetOp : public UnaryDatasetOpKernel {
return params;
}
mutex mu_;
std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
std::unique_ptr<IteratorBase> input_impl_;
};
const DatasetBase* const input_;
@ -351,7 +346,6 @@ class MaxIntraOpParallelismDatasetOp : public UnaryDatasetOpKernel {
auto max_parallelism = dataset()->max_intra_op_parallelism_;
params.runner =
RunnerWithMaxParallelism(*ctx->runner(), max_parallelism);
tf_shared_lock l(mu_);
return input_impl_->GetNext(IteratorContext{std::move(params)},
out_tensors, end_of_sequence);
}
@ -365,7 +359,6 @@ class MaxIntraOpParallelismDatasetOp : public UnaryDatasetOpKernel {
Status SaveInternal(SerializationContext* ctx,
IteratorStateWriter* writer) override {
mutex_lock l(mu_);
DCHECK(input_impl_ != nullptr);
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
return Status::OK();
@ -373,14 +366,12 @@ class MaxIntraOpParallelismDatasetOp : public UnaryDatasetOpKernel {
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
return Status::OK();
}
private:
mutex mu_;
std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
std::unique_ptr<IteratorBase> input_impl_;
};
const DatasetBase* const input_;
@ -481,7 +472,6 @@ class PrivateThreadPoolDatasetOp : public UnaryDatasetOpKernel {
pool->Schedule(std::move(c));
};
params.runner_threadpool_size = dataset()->num_threads_;
tf_shared_lock l(mu_);
return input_impl_->GetNext(IteratorContext{std::move(params)},
out_tensors, end_of_sequence);
}
@ -495,7 +485,6 @@ class PrivateThreadPoolDatasetOp : public UnaryDatasetOpKernel {
Status SaveInternal(SerializationContext* ctx,
IteratorStateWriter* writer) override {
mutex_lock l(mu_);
DCHECK(input_impl_ != nullptr);
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
return Status::OK();
@ -503,14 +492,12 @@ class PrivateThreadPoolDatasetOp : public UnaryDatasetOpKernel {
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
return Status::OK();
}
private:
mutex mu_;
std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
std::unique_ptr<IteratorBase> input_impl_;
};
const DatasetBase* const input_;

View File

@ -10,148 +10,52 @@ just a few minutes. All you need is a TensorFlow model [converted to TensorFlow
Lite](../convert/). (If you don't have a model converted yet, you can experiment
using the model provided with the example linked below.)
## Install just the TensorFlow Lite interpreter
## About the TensorFlow Lite runtime package
To quickly run TensorFlow Lite models with Python, you can install just the
TensorFlow Lite interpreter, instead of all TensorFlow packages.
To quickly start executing TensorFlow Lite models with Python, you can install
just the TensorFlow Lite interpreter, instead of all TensorFlow packages. We
call this simplified Python package `tflite_runtime`.
This interpreter-only package is a fraction the size of the full TensorFlow
The `tflite_runtime` package is a fraction the size of the full `tensorflow`
package and includes the bare minimum code required to run inferences with
TensorFlow Lite—it includes only the
[`tf.lite.Interpreter`](https://www.tensorflow.org/api_docs/python/tf/lite/Interpreter)
TensorFlow Lite—primarily the
[`Interpreter`](https://www.tensorflow.org/api_docs/python/tf/lite/Interpreter)
Python class. This small package is ideal when all you want to do is execute
`.tflite` models and avoid wasting disk space with the large TensorFlow library.
Note: If you need access to other Python APIs, such as the [TensorFlow Lite
Converter](../convert/python_api.md), you must install the [full TensorFlow
package](https://www.tensorflow.org/install/).
Note: If you need access to other Python APIs, such as the
[TensorFlow Lite Converter](../convert/), you must install the
[full TensorFlow package](https://www.tensorflow.org/install/).
To install, run `pip3 install` and pass it the appropriate Python wheel URL from
the following table.
## Install TensorFlow Lite for Python
For example, if you have a Raspberry Pi that's running Raspberry Pi OS 10 (which
has Python 3.7), install the Python wheel as follows:
To install the TensorFlow Lite runtime package, run this command:
<pre class="devsite-terminal devsite-click-to-copy">
pip3 install --extra-index-url https://google-coral.github.io/py-repo/ tflite_runtime
</pre>
If you're on a Raspberry Pi, this command might fail due to a known issue with
the `extra-index-url` option
([#4011](https://github.com/raspberrypi/linux/issues/4011)). So we suggest you
specify one of the
[`tflite_runtime` wheels](https://github.com/google-coral/pycoral/releases/)
that matches your system. For example, if you're running Raspberry Pi OS 10
(which has Python 3.7), instead use this command:
<pre class="devsite-terminal devsite-click-to-copy">
pip3 install https://github.com/google-coral/pycoral/releases/download/release-frogfish/tflite_runtime-2.5.0-cp37-cp37m-linux_armv7l.whl
</pre>
<table>
<tr><th>Platform</th><th>Python</th><th>URL</th></tr>
<tr>
<td style="white-space:nowrap" rowspan="4">Linux (ARM 32)</td>
<td style="white-space:nowrap">3.5</td>
<td>https://github.com/google-coral/pycoral/releases/download/release-frogfish/tflite_runtime-2.5.0-cp35-cp35m-linux_armv7l.whl</td>
</tr>
<tr>
<!-- ARM 32 -->
<td style="white-space:nowrap">3.6</td>
<td>https://github.com/google-coral/pycoral/releases/download/release-frogfish/tflite_runtime-2.5.0-cp36-cp36m-linux_armv7l.whl</td>
</tr>
<tr>
<!-- ARM 32 -->
<td style="white-space:nowrap">3.7</td>
<td>https://github.com/google-coral/pycoral/releases/download/release-frogfish/tflite_runtime-2.5.0-cp37-cp37m-linux_armv7l.whl</td>
</tr>
<tr>
<!-- ARM 32 -->
<td style="white-space:nowrap">3.8</td>
<td>https://github.com/google-coral/pycoral/releases/download/release-frogfish/tflite_runtime-2.5.0-cp38-cp38-linux_armv7l.whl</td>
</tr>
<tr>
<td style="white-space:nowrap" rowspan="4">Linux (ARM 64)</td>
<td style="white-space:nowrap">3.5</td>
<td>https://github.com/google-coral/pycoral/releases/download/release-frogfish/tflite_runtime-2.5.0-cp35-cp35m-linux_aarch64.whl</td>
</tr>
<tr>
<!-- ARM 64 -->
<td style="white-space:nowrap">3.6</td>
<td>https://github.com/google-coral/pycoral/releases/download/release-frogfish/tflite_runtime-2.5.0-cp36-cp36m-linux_aarch64.whl</td>
</tr>
<tr>
<!-- ARM 64 -->
<td style="white-space:nowrap">3.7</td>
<td>https://github.com/google-coral/pycoral/releases/download/release-frogfish/tflite_runtime-2.5.0-cp37-cp37m-linux_aarch64.whl</td>
</tr>
<tr>
<!-- ARM 64 -->
<td style="white-space:nowrap">3.8</td>
<td>https://github.com/google-coral/pycoral/releases/download/release-frogfish/tflite_runtime-2.5.0-cp38-cp38-linux_aarch64.whl</td>
</tr>
<tr>
<td style="white-space:nowrap" rowspan="4">Linux (x86-64)</td>
<td style="white-space:nowrap">3.5</td>
<td>https://github.com/google-coral/pycoral/releases/download/release-frogfish/tflite_runtime-2.5.0-cp35-cp35m-linux_x86_64.whl</td>
</tr>
<tr>
<!-- x86-64 -->
<td style="white-space:nowrap">3.6</td>
<td>https://github.com/google-coral/pycoral/releases/download/release-frogfish/tflite_runtime-2.5.0-cp36-cp36m-linux_x86_64.whl</td>
</tr>
<tr>
<!-- x86-64 -->
<td style="white-space:nowrap">3.7</td>
<td>https://github.com/google-coral/pycoral/releases/download/release-frogfish/tflite_runtime-2.5.0-cp37-cp37m-linux_x86_64.whl</td>
</tr>
<tr>
<!-- x86-64 -->
<td style="white-space:nowrap">3.8</td>
<td>https://github.com/google-coral/pycoral/releases/download/release-frogfish/tflite_runtime-2.5.0-cp38-cp38-linux_x86_64.whl</td>
</tr>
<tr>
<td style="white-space:nowrap" rowspan="4">macOS 10.15</td>
<td style="white-space:nowrap">3.5</td>
<td>https://github.com/google-coral/pycoral/releases/download/release-frogfish/tflite_runtime-2.5.0-cp35-cp35m-macosx_10_15_x86_64.whl</td>
</tr>
<tr>
<!-- Mac -->
<td style="white-space:nowrap">3.6</td>
<td>https://github.com/google-coral/pycoral/releases/download/release-frogfish/tflite_runtime-2.5.0-cp36-cp36m-macosx_10_15_x86_64.whl</td>
</tr>
<tr>
<!-- Mac -->
<td style="white-space:nowrap">3.7</td>
<td>https://github.com/google-coral/pycoral/releases/download/release-frogfish/tflite_runtime-2.5.0-cp37-cp37m-macosx_10_15_x86_64.whl</td>
</tr>
<tr>
<!-- Mac -->
<td style="white-space:nowrap">3.8</td>
<td>https://github.com/google-coral/pycoral/releases/download/release-frogfish/tflite_runtime-2.5.0-cp38-cp38-macosx_10_15_x86_64.whl</td>
</tr>
<tr>
<td style="white-space:nowrap" rowspan="4">Windows 10</td>
<td style="white-space:nowrap">3.5</td>
<td>https://github.com/google-coral/pycoral/releases/download/release-frogfish/tflite_runtime-2.5.0-cp35-cp35m-win_amd64.whl</td>
</tr>
<tr>
<!-- Win -->
<td style="white-space:nowrap">3.6</td>
<td>https://github.com/google-coral/pycoral/releases/download/release-frogfish/tflite_runtime-2.5.0-cp36-cp36m-win_amd64.whl</td>
</tr>
<tr>
<!-- Win -->
<td style="white-space:nowrap">3.7</td>
<td>https://github.com/google-coral/pycoral/releases/download/release-frogfish/tflite_runtime-2.5.0-cp37-cp37m-win_amd64.whl</td>
</tr>
<tr>
<!-- Win -->
<td style="white-space:nowrap">3.8</td>
<td>https://github.com/google-coral/pycoral/releases/download/release-frogfish/tflite_runtime-2.5.0-cp38-cp38-win_amd64.whl</td>
</tr>
</table>
Note: If you're on Debian Linux and using TensorFlow Lite with a Coral ML
accelerator, using pip to install `tflite_runtime` may not be compatible with
other Coral libraries. To ensure all your libraries are compatible, instead
install `tflite_runtime` as a
[Debian package from Coral](https://coral.ai/software/#debian-packages).
## Run an inference using tflite_runtime
To distinguish this interpreter-only package from the full TensorFlow package
(allowing both to be installed, if you choose), the Python module provided in
the above wheel is named `tflite_runtime`.
So instead of importing `Interpreter` from the `tensorflow` module, you need to
Instead of importing `Interpreter` from the `tensorflow` module, you now need to
import it from `tflite_runtime`.
For example, after you install the package above, copy and run the

View File

@ -3103,7 +3103,6 @@ cuda_py_test(
tags = [
"guitar",
"multi_gpu",
"no_rocm",
"no_windows",
],
deps = [

View File

@ -1078,6 +1078,7 @@ cuda_py_test(
tags = [
"multi_and_single_gpu",
"no_cuda_asan", # times out
"no_rocm",
"notsan", # b/173031470
],
deps = [
@ -1741,6 +1742,7 @@ distribute_py_test(
shard_count = 2,
tags = [
"multi_and_single_gpu",
"no_rocm",
"notsan", # TODO(b/160006974)
],
xla_enable_strict_auto_jit = True,
@ -1773,6 +1775,7 @@ distribute_py_test(
tags = [
"multi_and_single_gpu",
"no_cuda_asan", # times out
"no_rocm",
"notsan", # TODO(b/160006974)
],
xla_enable_strict_auto_jit = True,
@ -1846,6 +1849,7 @@ distribute_py_test(
disable_mlir_bridge = False,
tags = [
"multi_and_single_gpu",
"no_rocm",
],
deps = [
":combinations",

View File

@ -1186,6 +1186,15 @@ class Context(object):
self.ensure_initialized()
return pywrap_tfe.TFE_Py_PackEagerTensors(self._handle, tensors)
def list_function_names(self):
"""Get a list of names of registered functions.
Returns:
A set of names of all registered functions for the context.
"""
self.ensure_initialized()
return set(pywrap_tfe.TFE_ContextListFunctionNames(self._handle))
def remove_function(self, name):
"""Remove a function from the context.

View File

@ -151,6 +151,16 @@ class ContextTest(test.TestCase):
with self.assertRaisesRegex(ValueError, 'Multiple devices'):
context.context().get_total_memory_usage('GPU')
def testListFunctionNames(self):
@def_function.function
def f():
return constant_op.constant(1.)
concrete = f.get_concrete_function()
self.assertIn(concrete.name.decode(),
context.context().list_function_names())
if __name__ == '__main__':
ops.enable_eager_execution()

View File

@ -498,9 +498,17 @@ class _EagerDefinedFunction(object):
function_callback(self)
def add_to_graph(self, g=None):
"""Add the function to the current context or a graph, if supplied.
Args:
g: the graph to add the function to. If not supplied, the function will
be added to the current context.
"""
# pylint: disable=protected-access
if not g and context.executing_eagerly():
context.context().add_function_def(self.definition)
ctx = context.context()
if not ctx.has_function(self.name):
ctx.add_function_def(self.definition)
else:
if not g._is_function(self.name):
g._add_function(self)

View File

@ -4334,6 +4334,7 @@ EagerContextThreadLocalData* GetEagerContextThreadLocalData(
}
if (eager_context_thread_local_data_map == nullptr) {
absl::LeakCheckDisabler disabler;
eager_context_thread_local_data_map = new EagerContextThreadLocalDataMap();
}
auto& thread_local_data =

View File

@ -660,7 +660,7 @@ def assert_no_new_pyobjects_executing_eagerly(func=None, warmup_iters=2):
# versions of python2.7.x.
for _ in range(warmup_iters):
f(self, *args, **kwargs)
# Since we aren't in the normal test lifecylce, we need to manually run
# Since we aren't in the normal test lifecycle, we need to manually run
# cleanups to clear out their object references.
self.doCleanups()
@ -668,6 +668,10 @@ def assert_no_new_pyobjects_executing_eagerly(func=None, warmup_iters=2):
# create and save as a dummy variable to include it as a baseline.
obj_count_by_type = _get_object_count_by_type()
gc.collect()
# Make sure any registered functions are cleaned up in the C++ runtime.
registered_function_names = context.context().list_function_names()
# unittest.doCleanups adds to self._outcome with each unwound call.
# These objects are retained across gc collections so we exclude them
# from the object count calculation.
@ -682,7 +686,7 @@ def assert_no_new_pyobjects_executing_eagerly(func=None, warmup_iters=2):
}
for _ in range(3):
f(self, *args, **kwargs)
# Since we aren't in the normal test lifecylce, we need to manually run
# Since we aren't in the normal test lifecycle, we need to manually run
# cleanups to clear out their object references.
self.doCleanups()
# Note that gc.get_objects misses anything that isn't subject to garbage
@ -711,6 +715,14 @@ def assert_no_new_pyobjects_executing_eagerly(func=None, warmup_iters=2):
exclude=gc.get_referents(self._outcome.errors,
self._outcome.skipped)) -
obj_count_by_type)
# There should be no newly registered functions hanging around.
leftover_functions = (
context.context().list_function_names() - registered_function_names)
assert not leftover_functions, (
"The following functions were newly created: %s" %
leftover_functions)
# In some cases (specifically on MacOS), new_count is somehow
# smaller than previous_count.
# Using plain assert because not all classes using this decorator

View File

@ -249,6 +249,7 @@ distribute_py_test(
main = "custom_training_loop_metrics_test.py",
tags = [
"multi_and_single_gpu",
"no_rocm",
],
deps = [
":strategy_combinations",
@ -270,6 +271,7 @@ distribute_py_test(
tags = [
"multi_and_single_gpu",
"no_cuda_asan", # times out
"no_rocm",
"notsan", # TODO(b/170954243)
],
tpu_tags = [
@ -543,6 +545,7 @@ distribute_py_test(
shard_count = 31,
tags = [
"multi_and_single_gpu",
"no_rocm",
"no_windows_gpu",
"noasan", # TODO(b/337374867) fails with -fsanitize=null
"notpu", # TODO(b/153672562)
@ -562,6 +565,7 @@ distribute_py_test(
shard_count = 7,
tags = [
"multi_and_single_gpu",
"no_rocm",
],
xla_tags = [
"no_cuda_asan", # times out

View File

@ -671,12 +671,13 @@ class Functional(training_lib.Model):
Raises:
ValueError: In case of improperly formatted config dict.
"""
input_tensors, output_tensors, created_layers = reconstruct_from_config(
config, custom_objects)
model = cls(inputs=input_tensors, outputs=output_tensors,
name=config.get('name'))
connect_ancillary_layers(model, created_layers)
return model
with generic_utils.SharedObjectLoadingScope():
input_tensors, output_tensors, created_layers = reconstruct_from_config(
config, custom_objects)
model = cls(inputs=input_tensors, outputs=output_tensors,
name=config.get('name'))
connect_ancillary_layers(model, created_layers)
return model
def _validate_graph_inputs_and_outputs(self):
"""Validates the inputs and outputs of a Graph Network."""
@ -1346,21 +1347,23 @@ def get_network_config(network, serialize_layer_fn=None):
node_conversion_map[node_key] = kept_nodes
kept_nodes += 1
layer_configs = []
for layer in network.layers: # From the earliest layers on.
filtered_inbound_nodes = []
for original_node_index, node in enumerate(layer._inbound_nodes):
node_key = _make_node_key(layer.name, original_node_index)
if node_key in network._network_nodes and not node.is_input:
# The node is relevant to the model:
# add to filtered_inbound_nodes.
node_data = node.serialize(_make_node_key, node_conversion_map)
filtered_inbound_nodes.append(node_data)
layer_config = serialize_layer_fn(layer)
layer_config['name'] = layer.name
layer_config['inbound_nodes'] = filtered_inbound_nodes
layer_configs.append(layer_config)
config['layers'] = layer_configs
with generic_utils.SharedObjectSavingScope():
for layer in network.layers: # From the earliest layers on.
filtered_inbound_nodes = []
for original_node_index, node in enumerate(layer._inbound_nodes):
node_key = _make_node_key(layer.name, original_node_index)
if node_key in network._network_nodes and not node.is_input:
# The node is relevant to the model:
# add to filtered_inbound_nodes.
node_data = node.serialize(_make_node_key, node_conversion_map)
filtered_inbound_nodes.append(node_data)
layer_config = serialize_layer_fn(layer)
layer_config['name'] = layer.name
layer_config['inbound_nodes'] = filtered_inbound_nodes
layer_configs.append(layer_config)
config['layers'] = layer_configs
# Gather info about inputs and outputs.
model_inputs = []

View File

@ -80,7 +80,6 @@ cuda_py_test(
name = "gradient_checkpoint_test",
srcs = ["gradient_checkpoint_test.py"],
python_version = "PY3",
tags = ["no_rocm"],
deps = [
"//tensorflow:tensorflow_py_no_contrib",
],

View File

@ -12,6 +12,7 @@ package(
"//tensorflow/python/keras:__subpackages__",
"//tensorflow/python/training/tracking:__pkg__",
"//tensorflow/tools/pip_package:__pkg__",
"//tensorflow_models/official/vision/beta/projects/residual_mobilenet/modeling/backbones:__pkg__",
],
licenses = ["notice"], # Apache 2.0
)
@ -853,6 +854,7 @@ cuda_py_test(
srcs = ["gru_v2_test.py"],
python_version = "PY3",
shard_count = 12,
tags = ["no_rocm"],
deps = [
"//tensorflow/python:client_testlib",
"//tensorflow/python/keras",

View File

@ -114,13 +114,11 @@ class CategoryCrossing(base_preprocessing_layer.PreprocessingLayer):
`[[b'1_X_2_X_3'], [b'4_X_5_X_6']]`
"""
def __init__(self, depth=None, name=None, separator=None, **kwargs):
def __init__(self, depth=None, name=None, separator='_X_', **kwargs):
super(CategoryCrossing, self).__init__(name=name, **kwargs)
base_preprocessing_layer.keras_kpl_gauge.get_cell(
'CategoryCrossing').set(True)
self.depth = depth
if separator is None:
separator = '_X_'
self.separator = separator
if isinstance(depth, (tuple, list)):
self._depth_tuple = depth

View File

@ -393,6 +393,10 @@ def clone_model(model, input_tensors=None, clone_function=None):
except that it creates new layers (and thus new weights) instead
of sharing the weights of the existing layers.
`clone_model` will not preserve the uniqueness of shared objects within the
model (e.g. a single variable attached to two distinct layers will be
restored as two separate variables).
Args:
model: Instance of `Model`
(could be a functional model or a Sequential model).

View File

@ -158,7 +158,6 @@ cuda_py_test(
size = "medium",
srcs = ["adadelta_test.py"],
shard_count = 4,
tags = ["no_rocm"],
# TODO(b/168527439): invalid resource variable reference on GPU for TFRT.
deps = [
":optimizer_v2",
@ -239,7 +238,6 @@ cuda_py_test(
srcs = ["optimizer_v2_test.py"],
shard_count = 8,
tags = [
"no_rocm",
"no_windows",
],
deps = [
@ -297,7 +295,6 @@ cuda_py_test(
size = "medium",
srcs = ["rmsprop_test.py"],
shard_count = 2,
tags = ["no_rocm"],
xla_tags = [
"no_cuda_asan", # times out
],

View File

@ -148,8 +148,9 @@ def save_model(model,
hdf5_format.save_model_to_hdf5(
model, filepath, overwrite, include_optimizer)
else:
saved_model_save.save(model, filepath, overwrite, include_optimizer,
signatures, options, save_traces)
with generic_utils.SharedObjectSavingScope():
saved_model_save.save(model, filepath, overwrite, include_optimizer,
signatures, options, save_traces)
@keras_export('keras.models.load_model')
@ -194,17 +195,18 @@ def load_model(filepath, custom_objects=None, compile=True, options=None): # py
ImportError: if loading from an hdf5 file and h5py is not available.
IOError: In case of an invalid savefile.
"""
with generic_utils.CustomObjectScope(custom_objects or {}):
with load_context.load_context(options):
if (h5py is not None and
(isinstance(filepath, h5py.File) or h5py.is_hdf5(filepath))):
return hdf5_format.load_model_from_hdf5(filepath, custom_objects,
compile)
with generic_utils.SharedObjectLoadingScope():
with generic_utils.CustomObjectScope(custom_objects or {}):
with load_context.load_context(options):
if (h5py is not None and
(isinstance(filepath, h5py.File) or h5py.is_hdf5(filepath))):
return hdf5_format.load_model_from_hdf5(filepath, custom_objects,
compile)
filepath = path_to_string(filepath)
if isinstance(filepath, six.string_types):
loader_impl.parse_saved_model(filepath)
return saved_model_load.load(filepath, compile, options)
filepath = path_to_string(filepath)
if isinstance(filepath, six.string_types):
loader_impl.parse_saved_model(filepath)
return saved_model_load.load(filepath, compile, options)
raise IOError(
'Unable to load model. Filepath is not an hdf5 file (or h5py is not '

View File

@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import os
import shutil
import sys
@ -25,12 +26,14 @@ import tempfile
from absl.testing import parameterized
import numpy as np
from six import string_types
from tensorflow.python import keras
from tensorflow.python import tf2
from tensorflow.python.eager import context
from tensorflow.python.feature_column import feature_column_lib
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.keras import combinations
@ -859,6 +862,125 @@ class TestWholeModelSaving(keras_parameterized.TestCase):
self.assertAllEqual(loaded_model.predict(args, batch_size=batch_size),
expected)
@combinations.generate(combinations.combine(mode=['eager']))
def test_shared_objects(self):
class OuterLayer(keras.layers.Layer):
def __init__(self, inner_layer):
super(OuterLayer, self).__init__()
self.inner_layer = inner_layer
def call(self, inputs):
return self.inner_layer(inputs)
def get_config(self):
return {
'inner_layer': generic_utils.serialize_keras_object(
self.inner_layer)
}
@classmethod
def from_config(cls, config):
return cls(generic_utils.deserialize_keras_object(
config['inner_layer']))
class InnerLayer(keras.layers.Layer):
def __init__(self):
super(InnerLayer, self).__init__()
self.v = self.add_weight(name='v', shape=[], dtype=dtypes.float32)
def call(self, inputs):
return self.v + inputs
@classmethod
def from_config(cls, config):
return cls()
# Create a model with 2 output layers that share the same inner layer.
inner_layer = InnerLayer()
outer_layer_1 = OuterLayer(inner_layer)
outer_layer_2 = OuterLayer(inner_layer)
input_ = keras.Input(shape=(1,))
model = keras.Model(
inputs=input_, outputs=[outer_layer_1(input_), outer_layer_2(input_)])
# Changes to the shared layer should affect both outputs.
model.layers[1].inner_layer.v.assign(5)
self.assertAllEqual(model(1), [6.0, 6.0])
model.layers[1].inner_layer.v.assign(3)
self.assertAllEqual(model(1), [4.0, 4.0])
# After loading, changes to the shared layer should still affect both
# outputs.
def _do_assertions(loaded):
loaded.layers[1].inner_layer.v.assign(5)
self.assertAllEqual(loaded(1), [6.0, 6.0])
loaded.layers[1].inner_layer.v.assign(3)
self.assertAllEqual(loaded(1), [4.0, 4.0])
loaded.layers[2].inner_layer.v.assign(5)
self.assertAllEqual(loaded(1), [6.0, 6.0])
loaded.layers[2].inner_layer.v.assign(3)
self.assertAllEqual(loaded(1), [4.0, 4.0])
# We'd like to make sure we only attach shared object IDs when strictly
# necessary, so we'll recursively traverse the generated config to count
# whether we have the exact number we expect.
def _get_all_keys_recursive(dict_or_iterable):
if isinstance(dict_or_iterable, dict):
for key in dict_or_iterable.keys():
yield key
for key in _get_all_keys_recursive(dict_or_iterable.values()):
yield key
elif isinstance(dict_or_iterable, string_types):
return
else:
try:
for item in dict_or_iterable:
for key in _get_all_keys_recursive(item):
yield key
# Not an iterable or dictionary
except TypeError:
return
with generic_utils.CustomObjectScope({
'OuterLayer': OuterLayer, 'InnerLayer': InnerLayer}):
# Test saving and loading to disk
save_format = testing_utils.get_save_format()
saved_model_dir = self._save_model_dir()
keras.models.save_model(model, saved_model_dir, save_format=save_format)
loaded = keras.models.load_model(saved_model_dir)
_do_assertions(loaded)
# Test recreating directly from config
config = model.get_config()
key_count = collections.Counter(_get_all_keys_recursive(config))
self.assertEqual(key_count[generic_utils.SHARED_OBJECT_KEY], 2)
loaded = keras.Model.from_config(config)
_do_assertions(loaded)
@combinations.generate(combinations.combine(mode=['eager']))
def test_shared_objects_wrapper(self):
"""Tests that shared layers wrapped with `Wrapper` restore correctly."""
input_ = keras.Input(shape=(1,))
unwrapped = keras.layers.Layer(name='unwrapped')
wrapped = keras.layers.Wrapper(unwrapped, name='wrapped')
model = keras.Model(inputs=input_,
outputs=[unwrapped(input_), wrapped(input_)])
# Test recreating directly from config
config = model.get_config()
loaded = keras.Model.from_config(config)
self.assertIs(loaded.layers[1], loaded.layers[2].layer)
# Test saving and loading to disk
save_format = testing_utils.get_save_format()
saved_model_dir = self._save_model_dir()
keras.models.save_model(model, saved_model_dir, save_format=save_format)
loaded = keras.models.load_model(saved_model_dir)
self.assertIs(loaded.layers[1], loaded.layers[2].layer)
# Factory functions to create models that will be serialized inside a Network.
def _make_graph_network(input_size, output_size):

View File

@ -46,7 +46,6 @@ class LayerSavedModelSaver(base_serialization.SavedModelSaver):
# TODO(kathywu): Synchronize with the keras spec (go/keras-json-spec) once
# the python config serialization has caught up.
metadata = dict(
class_name=generic_utils.get_registered_name(type(self.obj)),
name=self.obj.name,
trainable=self.obj.trainable,
expects_training_arg=self.obj._expects_training_arg, # pylint: disable=protected-access
@ -56,7 +55,7 @@ class LayerSavedModelSaver(base_serialization.SavedModelSaver):
must_restore_from_config=self.obj._must_restore_from_config, # pylint: disable=protected-access
)
metadata.update(get_config(self.obj))
metadata.update(get_serialized(self.obj))
if self.obj.input_spec is not None:
# Layer's input_spec has already been type-checked in the property setter.
metadata['input_spec'] = nest.map_structure(
@ -110,16 +109,12 @@ class LayerSavedModelSaver(base_serialization.SavedModelSaver):
# TODO(kathywu): Move serialization utils (and related utils from
# generic_utils.py) to a separate file.
def get_config(obj):
def get_serialized(obj):
with generic_utils.skip_failed_serialization():
# Store the config dictionary, which may be used when reviving the object.
# When loading, the program will attempt to revive the object from config,
# and if that fails, the object will be revived from the SavedModel.
config = generic_utils.serialize_keras_object(obj)['config']
if config is not None:
return {'config': config}
return {}
return generic_utils.serialize_keras_object(obj)
class InputLayerSavedModelSaver(base_serialization.SavedModelSaver):

View File

@ -492,13 +492,15 @@ class KerasObjectLoader(object):
# found.
class_name = metadata.get('class_name')
config = metadata.get('config')
shared_object_id = metadata.get('shared_object_id')
must_restore_from_config = metadata.get('must_restore_from_config')
if not generic_utils.validate_config(config):
return None
try:
obj = layers_module.deserialize(
generic_utils.serialize_keras_class_and_config(class_name, config))
generic_utils.serialize_keras_class_and_config(
class_name, config, shared_object_id=shared_object_id))
except ValueError:
if must_restore_from_config:
raise RuntimeError(

View File

@ -36,7 +36,7 @@ class MetricSavedModelSaver(layer_serialization.LayerSavedModelSaver):
class_name=generic_utils.get_registered_name(type(self.obj)),
name=self.obj.name,
dtype=self.obj.dtype)
metadata.update(layer_serialization.get_config(self.obj))
metadata.update(layer_serialization.get_serialized(self.obj))
if self.obj._build_input_shape is not None: # pylint: disable=protected-access
metadata['build_input_shape'] = self.obj._build_input_shape # pylint: disable=protected-access
return metadata

View File

@ -24,8 +24,10 @@ import marshal
import os
import re
import sys
import threading
import time
import types as python_types
import weakref
import numpy as np
import six
@ -110,9 +112,205 @@ def get_custom_objects():
return _GLOBAL_CUSTOM_OBJECTS
def serialize_keras_class_and_config(cls_name, cls_config):
# Store a unique, per-object ID for shared objects.
#
# We store a unique ID for each object so that we may, at loading time,
# re-create the network properly. Without this ID, we would have no way of
# determining whether a config is a description of a new object that
# should be created or is merely a reference to an already-created object.
SHARED_OBJECT_KEY = 'shared_object_id'
class NoopLoadingScope(object):
"""The default shared object loading scope. It does nothing.
Created to simplify serialization code that doesn't care about shared objects
(e.g. when serializing a single object).
"""
def get(self, unused_object_id):
return None
def set(self, object_id, obj):
pass
SHARED_OBJECT_LOADING = threading.local()
def _shared_object_loading_scope():
"""Get the current shared object saving scope in a threadsafe manner.
Attributes on the threadlocal variable must be set per-thread, thus we
cannot initialize these globally.
Returns:
A SharedObjectLoadingScope or NoopLoadingScope object.
"""
return getattr(SHARED_OBJECT_LOADING, 'scope', NoopLoadingScope())
class SharedObjectLoadingScope(object):
"""A context manager for keeping track of loaded objects.
During the deserialization process, we may come across objects that are
shared across multiple layers. In order to accurately restore the network
structure to its original state, `SharedObjectLoadingScope` allows us to
re-use shared objects rather than cloning them.
"""
def __enter__(self):
global SHARED_OBJECT_LOADING
SHARED_OBJECT_LOADING.scope = self
self._obj_ids_to_obj = {}
return self
def get(self, object_id):
"""Given a shared object ID, returns a previously instantiated object.
Args:
object_id: shared object ID to use when attempting to find already-loaded
object.
Returns:
The object, if we've seen this ID before. Else, `None`.
"""
# Explicitly check for `None` internally to make external calling code a
# bit cleaner.
if object_id is None:
return
return self._obj_ids_to_obj.get(object_id)
def set(self, object_id, obj):
"""Stores an instantiated object for future lookup and sharing."""
if object_id is None:
return
self._obj_ids_to_obj[object_id] = obj
def __exit__(self, *args, **kwargs):
global SHARED_OBJECT_LOADING
SHARED_OBJECT_LOADING.scope = NoopLoadingScope()
SHARED_OBJECT_SAVING = threading.local()
def _shared_object_saving_scope():
"""Get the current shared object saving scope in a threadsafe manner.
Attributes on the threadlocal variable must be set per-thread, thus we
cannot initialize these globally.
Returns:
A SharedObjectSavingScope object or None.
"""
return getattr(SHARED_OBJECT_SAVING, 'scope', None)
class SharedObjectConfig(dict):
"""A configuration container that keeps track of references.
`SharedObjectConfig` will automatically attach a shared object ID to any
configs which are referenced more than once, allowing for proper shared
object reconstruction at load time.
In most cases, it would be more proper to subclass something like
`collections.UserDict` or `collections.Mapping` rather than `dict` directly.
Unfortunately, python's json encoder does not support `Mapping`s. This is
important functionality to retain, since we are dealing with serialization.
We should be safe to subclass `dict` here, since we aren't actually
overriding any core methods, only augmenting with a new one for reference
counting.
"""
def __init__(self, base_config, object_id, **kwargs):
self.ref_count = 1
self.object_id = object_id
super(SharedObjectConfig, self).__init__(base_config, **kwargs)
def increment_ref_count(self):
# As soon as we've seen the object more than once, we want to attach the
# shared object ID. This allows us to only attach the shared object ID when
# it's strictly necessary, making backwards compatibility breakage less
# likely.
if self.ref_count == 1:
self[SHARED_OBJECT_KEY] = self.object_id
self.ref_count += 1
class SharedObjectSavingScope(object):
"""Keeps track of shared object configs when serializing."""
def __enter__(self):
global SHARED_OBJECT_SAVING
# Serialization can happen at a number of layers for a number of reasons.
# We may end up with a case where we're opening a saving scope within
# another saving scope. In that case, we'd like to use the outermost scope
# available and ignore inner scopes, since there is not (yet) a reasonable
# use case for having these nested and distinct.
if _shared_object_saving_scope() is not None:
self._passthrough = True
return _shared_object_saving_scope()
else:
self._passthrough = False
SHARED_OBJECT_SAVING.scope = self
self._shared_objects_config = weakref.WeakKeyDictionary()
self._next_id = 0
return self
def get_config(self, obj):
"""Gets a `SharedObjectConfig` if one has already been seen for `obj`.
Args:
obj: The object for which to retrieve the `SharedObjectConfig`.
Returns:
The SharedObjectConfig for a given object, if already seen. Else,
`None`.
"""
if obj in self._shared_objects_config:
shared_object_config = self._shared_objects_config[obj]
shared_object_config.increment_ref_count()
return shared_object_config
def create_config(self, base_config, obj):
shared_object_config = SharedObjectConfig(base_config, self._next_id)
self._next_id += 1
self._shared_objects_config[obj] = shared_object_config
return shared_object_config
def __exit__(self, *args, **kwargs):
if not self._passthrough:
global SHARED_OBJECT_SAVING
SHARED_OBJECT_SAVING.scope = None
def serialize_keras_class_and_config(
cls_name, cls_config, obj=None, shared_object_id=None):
"""Returns the serialization of the class with the given config."""
return {'class_name': cls_name, 'config': cls_config}
base_config = {'class_name': cls_name, 'config': cls_config}
# We call `serialize_keras_class_and_config` for some branches of the load
# path. In that case, we may already have a shared object ID we'd like to
# retain.
if shared_object_id is not None:
base_config[SHARED_OBJECT_KEY] = shared_object_id
# If we have an active `SharedObjectSavingScope`, check whether we've already
# serialized this config. If so, just use that config. This will store an
# extra ID field in the config, allowing us to re-create the shared object
# relationship at load time.
if _shared_object_saving_scope() is not None and obj is not None:
shared_object_config = _shared_object_saving_scope().get_config(obj)
if shared_object_config is None:
return _shared_object_saving_scope().create_config(base_config, obj)
return shared_object_config
return base_config
@keras_export('keras.utils.register_keras_serializable')
@ -234,7 +432,19 @@ def get_registered_object(name, custom_objects=None, module_objects=None):
@keras_export('keras.utils.serialize_keras_object')
def serialize_keras_object(instance):
"""Serialize a Keras object into a JSON-compatible representation."""
"""Serialize a Keras object into a JSON-compatible representation.
Calls to `serialize_keras_object` while underneath the
`SharedObjectSavingScope` context manager will cause any objects re-used
across multiple layers to be saved with a special shared object ID. This
allows the network to be re-created properly during deserialization.
Args:
instance: The object to serialize.
Returns:
A dict-like, JSON-compatible representation of the object's config.
"""
_, instance = tf_decorator.unwrap(instance)
if instance is None:
return None
@ -265,7 +475,8 @@ def serialize_keras_object(instance):
serialization_config[key] = item
name = get_registered_name(instance.__class__)
return serialize_keras_class_and_config(name, serialization_config)
return serialize_keras_class_and_config(
name, serialization_config, instance)
if hasattr(instance, '__name__'):
return get_registered_name(instance)
raise ValueError('Cannot serialize', instance)
@ -286,8 +497,9 @@ def class_and_config_for_serialized_keras_object(
custom_objects=None,
printable_module_name='object'):
"""Returns the class name and config for a serialized keras object."""
if (not isinstance(config, dict) or 'class_name' not in config or
'config' not in config):
if (not isinstance(config, dict)
or 'class_name' not in config
or 'config' not in config):
raise ValueError('Improper config format: ' + str(config))
class_name = config['class_name']
@ -341,7 +553,24 @@ def deserialize_keras_object(identifier,
module_objects=None,
custom_objects=None,
printable_module_name='object'):
"""Turns the serialized form of a Keras object back into an actual object."""
"""Turns the serialized form of a Keras object back into an actual object.
Calls to `deserialize_keras_object` while underneath the
`SharedObjectLoadingScope` context manager will cause any already-seen shared
objects to be returned as-is rather than creating a new object.
Args:
identifier: the serialized form of the object.
module_objects: A dictionary of custom objects to look the name up in.
Generally, module_objects is provided by midlevel library implementers.
custom_objects: A dictionary of custom objects to look the name up in.
Generally, custom_objects is provided by the user.
printable_module_name: A human-readable string representing the type of the
object. Printed in case of exception.
Returns:
The deserialized object.
"""
if identifier is None:
return None
@ -351,25 +580,39 @@ def deserialize_keras_object(identifier,
(cls, cls_config) = class_and_config_for_serialized_keras_object(
config, module_objects, custom_objects, printable_module_name)
# If this object has already been loaded (i.e. it's shared between multiple
# objects), return the already-loaded object.
shared_object_id = config.get(SHARED_OBJECT_KEY)
shared_object = _shared_object_loading_scope().get(shared_object_id) # pylint: disable=assignment-from-none
if shared_object is not None:
return shared_object
if hasattr(cls, 'from_config'):
arg_spec = tf_inspect.getfullargspec(cls.from_config)
custom_objects = custom_objects or {}
if 'custom_objects' in arg_spec.args:
return cls.from_config(
deserialized_obj = cls.from_config(
cls_config,
custom_objects=dict(
list(_GLOBAL_CUSTOM_OBJECTS.items()) +
list(custom_objects.items())))
with CustomObjectScope(custom_objects):
return cls.from_config(cls_config)
else:
with CustomObjectScope(custom_objects):
deserialized_obj = cls.from_config(cls_config)
else:
# Then `cls` may be a function returning a class.
# in this case by convention `config` holds
# the kwargs of the function.
custom_objects = custom_objects or {}
with CustomObjectScope(custom_objects):
return cls(**cls_config)
deserialized_obj = cls(**cls_config)
# Add object to shared objects, in case we find it referenced again.
_shared_object_loading_scope().set(shared_object_id, deserialized_obj)
return deserialized_obj
elif isinstance(identifier, six.string_types):
object_name = identifier
if custom_objects and object_name in custom_objects:

View File

@ -23,6 +23,7 @@ from functools import partial
import numpy as np
from tensorflow.python import keras
from tensorflow.python.keras.utils import generic_utils
from tensorflow.python.platform import test
@ -384,5 +385,63 @@ class SliceArraysTest(test.TestCase):
[None, None, None])
# object() alone isn't compatible with WeakKeyDictionary, which we use to
# track shared configs.
class MaybeSharedObject(object):
pass
class SharedObjectScopeTest(test.TestCase):
def test_shared_object_saving_scope_single_object_doesnt_export_id(self):
with generic_utils.SharedObjectSavingScope() as scope:
single_object = MaybeSharedObject()
self.assertIsNone(scope.get_config(single_object))
single_object_config = scope.create_config({}, single_object)
self.assertIsNotNone(single_object_config)
self.assertNotIn(generic_utils.SHARED_OBJECT_KEY,
single_object_config)
def test_shared_object_saving_scope_shared_object_exports_id(self):
with generic_utils.SharedObjectSavingScope() as scope:
shared_object = MaybeSharedObject()
self.assertIsNone(scope.get_config(shared_object))
scope.create_config({}, shared_object)
first_object_config = scope.get_config(shared_object)
second_object_config = scope.get_config(shared_object)
self.assertIn(generic_utils.SHARED_OBJECT_KEY,
first_object_config)
self.assertIn(generic_utils.SHARED_OBJECT_KEY,
second_object_config)
self.assertIs(first_object_config, second_object_config)
def test_shared_object_loading_scope_noop(self):
# Test that, without a context manager scope, adding configs will do
# nothing.
obj_id = 1
obj = MaybeSharedObject()
generic_utils._shared_object_loading_scope().set(obj_id, obj)
self.assertIsNone(generic_utils._shared_object_loading_scope().get(obj_id))
def test_shared_object_loading_scope_returns_shared_obj(self):
obj_id = 1
obj = MaybeSharedObject()
with generic_utils.SharedObjectLoadingScope() as scope:
scope.set(obj_id, obj)
self.assertIs(scope.get(obj_id), obj)
def test_nested_shared_object_saving_scopes(self):
my_obj = MaybeSharedObject()
with generic_utils.SharedObjectSavingScope() as scope_1:
scope_1.create_config({}, my_obj)
with generic_utils.SharedObjectSavingScope() as scope_2:
# Nesting saving scopes should return the original scope and should
# not clear any objects we're tracking.
self.assertIs(scope_1, scope_2)
self.assertIsNotNone(scope_2.get_config(my_obj))
self.assertIsNotNone(scope_1.get_config(my_obj))
self.assertIsNone(generic_utils._shared_object_saving_scope())
if __name__ == '__main__':
test.main()

View File

@ -21,7 +21,6 @@ cuda_py_test(
python_version = "PY3",
tags = [
"no_pip",
"no_rocm",
],
deps = [
":mnist_testing_utils",

View File

@ -2118,7 +2118,6 @@ class SingleCycleTests(test.TestCase, parameterized.TestCase):
# allocations at a lower level.
@test_util.assert_no_new_pyobjects_executing_eagerly
def test_functions_cleaned(self):
self.skipTest("TODO(b/175152958): The test is leaking function definitions")
if sys.version_info.major < 3:
self.skipTest("Not working in Python 2")
root = module.Module()

View File

@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/dlpack.h"
#include "tensorflow/c/eager/tfe_context_internal.h"
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_status_helper.h"
@ -670,6 +671,10 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
return output;
});
m.def("TFE_ContextListFunctionNames", [](py::handle& ctx) {
return tensorflow::unwrap(tensorflow::InputTFE_Context(ctx))
->ListFunctionNames();
});
m.def("TFE_ContextEnableRunMetadata", [](py::handle& ctx) {
TFE_ContextEnableRunMetadata(tensorflow::InputTFE_Context(ctx));
});

View File

@ -25,14 +25,18 @@ from __future__ import print_function
import os
import threading
import time
from typing import Any, List, Optional, Text
from tensorflow.core.util import event_pb2
from tensorflow.python.client import session as session_lib
from tensorflow.python.framework import meta_graph
from tensorflow.python.framework import ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import basic_session_run_hooks
from tensorflow.python.training import monitored_session
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.training import session_run_hook
from tensorflow.python.training import training_util
from tensorflow.python.training.session_run_hook import SessionRunArgs
from tensorflow.python.training.summary_io import SummaryWriterCache
@ -40,13 +44,14 @@ class AsyncCheckpointSaverHook(basic_session_run_hooks.CheckpointSaverHook):
"""Saves checkpoints every N steps or seconds."""
def __init__(self,
checkpoint_dir,
save_secs=None,
save_steps=None,
saver=None,
checkpoint_basename="model.ckpt",
scaffold=None,
listeners=None):
checkpoint_dir: Text,
save_secs: Optional[int] = None,
save_steps: Optional[int] = None,
saver: Optional[saver_lib.Saver] = None,
checkpoint_basename: Text = "model.ckpt",
scaffold: Optional[monitored_session.Scaffold] = None,
listeners: Optional[List[
basic_session_run_hooks.CheckpointSaverListener]] = None):
"""Initializes a `CheckpointSaverHook`.
Args:
@ -98,7 +103,7 @@ class AsyncCheckpointSaverHook(basic_session_run_hooks.CheckpointSaverHook):
for l in self._listeners:
l.begin()
def after_create_session(self, session, coord):
def after_create_session(self, session: session_lib.Session, coord: Any):
global_step = session.run(self._global_step_tensor)
# We do write graph and saver_def at the first call of before_run.
@ -122,10 +127,11 @@ class AsyncCheckpointSaverHook(basic_session_run_hooks.CheckpointSaverHook):
self._save(session, global_step)
self._timer.update_last_triggered_step(global_step)
def before_run(self, run_context): # pylint: disable=unused-argument
return SessionRunArgs(self._global_step_tensor)
def before_run(self, run_context: Any): # pylint: disable=unused-argument
return session_run_hook.SessionRunArgs(self._global_step_tensor)
def after_run(self, run_context, run_values):
def after_run(self, run_context: session_run_hook.SessionRunContext,
run_values: Any):
global_step = run_context.session.run(self._global_step_tensor)
if self._timer.should_trigger_for_step(global_step):
self._timer.update_last_triggered_step(global_step)
@ -133,7 +139,7 @@ class AsyncCheckpointSaverHook(basic_session_run_hooks.CheckpointSaverHook):
if self._save(run_context.session, global_step):
run_context.request_stop()
def end(self, session):
def end(self, session: session_lib.Session):
if self._save_thread:
logging.info("Waiting for any pending checkpoints to finish.")
self._save_thread.join()

View File

@ -19,6 +19,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from typing import Generator, Optional, Text
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope
@ -70,10 +72,18 @@ def _get_custom_getter():
@tf_export(v1=['tpu.bfloat16_scope'])
@tf_contextlib.contextmanager
def bfloat16_scope(name=None):
def bfloat16_scope(
name: Optional[Text] = None
) -> Generator[variable_scope.variable_scope, None, None]:
"""Scope class for bfloat16 variables so that the model uses custom getter.
This enables variables to be read as bfloat16 type when using get_variable.
Arguments:
name: Name to use for scope.
Yields:
a variable scope.
"""
if name is None:
name = ''

View File

@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from typing import Callable, Optional, Text, Union
from tensorflow.python.data.experimental.ops import interleave_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
@ -28,13 +30,13 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import functional_ops
def _TextLineDataset(filename):
def _TextLineDataset(filename: Text) -> dataset_ops.Dataset:
buffer_size = 8 * 1024 * 1024 # 8 MiB per file
dataset = readers.TextLineDataset(filename, buffer_size=buffer_size)
return dataset
def _TFRecordDataset(filename):
def _TFRecordDataset(filename: Text) -> dataset_ops.Dataset:
buffer_size = 8 * 1024 * 1024 # 8 MiB per file
dataset = readers.TFRecordDataset(filename, buffer_size=buffer_size)
return dataset
@ -47,15 +49,17 @@ _FILETYPE_MAP = {
}
def StreamingFilesDataset(files,
filetype=None,
file_reader_job=None,
worker_job=None,
num_epochs=None,
filename_shuffle_buffer_size=None,
num_parallel_reads=None,
batch_transfer_size=None,
sloppy=None):
def StreamingFilesDataset(
files: Union[Text, dataset_ops.Dataset],
filetype: Optional[Union[Text, Callable[[Text],
dataset_ops.Dataset]]] = None,
file_reader_job: Optional[Text] = None,
worker_job: Optional[Text] = None,
num_epochs: Optional[int] = None,
filename_shuffle_buffer_size: Optional[Union[int, bool]] = None,
num_parallel_reads: Optional[int] = None,
batch_transfer_size: Optional[Union[int, bool]] = None,
sloppy: bool = True) -> dataset_ops.Dataset:
"""StreamingFilesDataset constructs a dataset to stream from workers (GCE VM).
Because Cloud TPUs are allocated over the network, a Cloud TPU cannot read
@ -126,9 +130,6 @@ def StreamingFilesDataset(files,
if batch_transfer_size is None:
batch_transfer_size = 256
if sloppy is None:
sloppy = True
if file_reader_job == 'coordinator':
file_reader_device = '/job:coordinator/task:0'
else:

View File

@ -20,6 +20,7 @@ from __future__ import print_function
import enum
import math
from typing import List, Optional, Text, Tuple
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
@ -66,7 +67,7 @@ class DeviceAssignment(object):
`DeviceAssignment` directly.
"""
def __init__(self, topology, core_assignment):
def __init__(self, topology: Topology, core_assignment: np.ndarray):
"""Constructs a `DeviceAssignment` object.
Args:
@ -104,22 +105,22 @@ class DeviceAssignment(object):
self._core_assignment, topology)
@property
def topology(self):
def topology(self) -> Topology:
"""A `Topology` that describes the TPU topology."""
return self._topology
@property
def num_cores_per_replica(self):
def num_cores_per_replica(self) -> int:
"""The number of cores per replica."""
return self._num_cores_per_replica
@property
def num_replicas(self):
def num_replicas(self) -> int:
"""The number of replicas of the computation."""
return self._num_replicas
@property
def core_assignment(self):
def core_assignment(self) -> np.ndarray:
"""The logical to physical core mapping.
Returns:
@ -129,11 +130,11 @@ class DeviceAssignment(object):
"""
return self._core_assignment
def coordinates(self, replica, logical_core):
def coordinates(self, replica: int, logical_core: int) -> Tuple: # pylint:disable=g-bare-generic
"""Returns the physical topology coordinates of a logical core."""
return tuple(self.core_assignment[replica, logical_core, :])
def lookup_replicas(self, task_id, logical_core):
def lookup_replicas(self, task_id: int, logical_core: int) -> List[int]:
"""Lookup replica ids by task number and logical core.
Args:
@ -153,31 +154,38 @@ class DeviceAssignment(object):
"Can not find any replica in task: {} contains logical_core: {} ".
format(task_id, logical_core))
def tpu_ordinal(self, replica=0, logical_core=0):
def tpu_ordinal(self, replica: int = 0, logical_core: int = 0) -> int:
"""Returns the ordinal of the TPU device assigned to a logical core."""
coordinates = self.coordinates(replica, logical_core)
return self._topology.tpu_device_ordinal_at_coordinates(coordinates)
def host_device(self, replica=0, logical_core=0, job=None):
def host_device(self,
replica: int = 0,
logical_core: int = 0,
job: Optional[Text] = None) -> Text:
"""Returns the CPU device attached to a logical core."""
coordinates = self.coordinates(replica, logical_core)
return self._topology.cpu_device_name_at_coordinates(coordinates, job=job)
def tpu_device(self, replica=0, logical_core=0, job=None):
def tpu_device(self,
replica: int = 0,
logical_core: int = 0,
job: Optional[Text] = None) -> Text:
"""Returns the name of the TPU device assigned to a logical core."""
coordinates = self.coordinates(replica, logical_core)
return self._topology.tpu_device_name_at_coordinates(coordinates, job=job)
@staticmethod
def build(topology,
computation_shape=None,
computation_stride=None,
num_replicas=1):
def build(topology: Topology,
computation_shape: Optional[np.ndarray] = None,
computation_stride: Optional[np.ndarray] = None,
num_replicas: int = 1) -> "DeviceAssignment":
return device_assignment(topology, computation_shape, computation_stride,
num_replicas)
def _open_ring_2d(x_size, y_size, z_coord):
def _open_ring_2d(x_size: int, y_size: int,
z_coord: int) -> List[Tuple[int, int, int]]:
"""Ring-order of a X by Y mesh, with a fixed Z coordinate.
For example, in a 4x4 mesh, this returns the following order.
@ -213,7 +221,8 @@ def _open_ring_2d(x_size, y_size, z_coord):
return ret
def _ring_3d(x_size, y_size, z_size):
def _ring_3d(x_size: int, y_size: int,
z_size: int) -> List[Tuple[int, int, int]]:
"""Ring-order of a X by Y by Z mesh.
Constructs the 3d ring from 2d rings that are stacked in the Z dimension and
@ -325,11 +334,13 @@ class DeviceOrderMode(enum.IntEnum):
MESH = 2
def device_assignment(topology,
computation_shape=None,
computation_stride=None,
num_replicas=1,
device_order_mode=DeviceOrderMode.AUTO):
def device_assignment(
topology: Topology,
computation_shape: Optional[np.ndarray] = None,
computation_stride: Optional[np.ndarray] = None,
num_replicas: int = 1,
device_order_mode: DeviceOrderMode = DeviceOrderMode.AUTO
) -> DeviceAssignment:
"""Computes a device_assignment of a computation across a TPU topology.
Attempts to choose a compact grid of cores for locality.
@ -341,11 +352,12 @@ def device_assignment(topology,
optimal packing.
Args:
topology: A `Topology` object that describes the TPU cluster topology.
To obtain a TPU topology, evaluate the `Tensor` returned by
topology: A `Topology` object that describes the TPU cluster topology. To
obtain a TPU topology, evaluate the `Tensor` returned by
`initialize_system` using `Session.run`. Either a serialized
`TopologyProto` or a `Topology` object may be passed. Note: you must
evaluate the `Tensor` first; you cannot pass an unevaluated `Tensor` here.
evaluate the `Tensor` first; you cannot pass an unevaluated `Tensor`
here.
computation_shape: A rank 1 int32 numpy array with size equal to the
topology rank, describing the shape of the computation's block of cores.
If None, the `computation_shape` is `[1] * topology_rank`.

View File

@ -20,7 +20,7 @@ from __future__ import print_function
from __future__ import unicode_literals
import functools
from typing import Any, Dict, Callable, List, Optional, Text, Tuple
from typing import Any, Dict, Callable, Iterable, List, Optional, Text, Tuple, Union
from absl import logging
@ -229,7 +229,6 @@ class TPUEmbedding(tracking.AutoTrackable):
model = model_fn(...)
embedding = tf.tpu.experimental.embedding.TPUEmbedding(
feature_config=feature_config,
batch_size=1024,
optimizer=tf.tpu.experimental.embedding.SGD(0.1))
checkpoint = tf.train.Checkpoint(model=model, embedding=embedding)
checkpoint.restore(...)
@ -244,7 +243,7 @@ class TPUEmbedding(tracking.AutoTrackable):
def __init__(
self,
feature_config: Any,
feature_config: Union[tpu_embedding_v2_utils.FeatureConfig, Iterable], # pylint:disable=g-bare-generic
optimizer: Optional[tpu_embedding_v2_utils._Optimizer], # pylint:disable=protected-access
pipeline_execution_with_tensor_core: bool = False):
"""Creates the TPUEmbedding mid level API object.

View File

@ -19,15 +19,23 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from typing import Any, Callable, Iterable, List, Optional, Union
from tensorflow.python.compiler.xla import xla
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.tpu import tensor_tracer
from tensorflow.python.tpu import tpu_feed
from tensorflow.python.tpu import tpu_function
from tensorflow.python.types import core as core_types
def while_loop(condition, body, inputs=None, infeed_queue=None, name=None):
def while_loop(condition: Callable[..., Any],
body: Callable[..., Any],
inputs: Optional[List[Any]] = None,
infeed_queue: Optional[tpu_feed.InfeedQueue] = None,
name: Any = None) -> Any:
"""Builds a training loop for TPUs.
The set of loop-carried tensors corresponds to `inputs`. Both
@ -41,10 +49,10 @@ def while_loop(condition, body, inputs=None, infeed_queue=None, name=None):
Args:
condition: a Python function that builds the loop condition.
body: a Python function that builds the loop body.
inputs: a list of initial values passed into the training loop, or
None (equivalent to an empty list).
infeed_queue: if not None, the infeed queue from which to append a tuple
of arguments as inputs to condition.
inputs: a list of initial values passed into the training loop, or None
(equivalent to an empty list).
infeed_queue: if not None, the infeed queue from which to append a tuple of
arguments as inputs to condition.
name: (Deprecated) Does nothing.
Returns:
@ -178,7 +186,12 @@ def while_loop(condition, body, inputs=None, infeed_queue=None, name=None):
condition_wrapper, body_wrapper, inputs, name="", parallel_iterations=1)
def repeat(n, body, inputs=None, infeed_queue=None, name=None):
def repeat(
n: int,
body: Callable[..., Union[core_types.TensorLike, Iterable]], # pylint:disable=g-bare-generic
inputs: Optional[List[core_types.TensorLike]] = None,
infeed_queue: Optional[tpu_feed.InfeedQueue] = None,
name: Any = None) -> List[core_types.TensorLike]:
"""Builds a training loop that executes a fixed number of iterations.
The set of loop-carried tensors correspond to `inputs`.
@ -188,11 +201,12 @@ def repeat(n, body, inputs=None, infeed_queue=None, name=None):
Args:
n: the number of loop iterations
body: a Python function that builds the loop body.
inputs: a list of initial values passed into the training loop or
None (equivalent to an empty list).
infeed_queue: if not None, the infeed queue from which to append a tuple
of arguments as inputs to condition.
inputs: a list of initial values passed into the training loop or None
(equivalent to an empty list).
infeed_queue: if not None, the infeed queue from which to append a tuple of
arguments as inputs to condition.
name: (Deprecated) Does nothing.
Returns:
The final values of the loop-carried tensors.
Raises:

View File

@ -138,7 +138,7 @@ tf_class {
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'depth\', \'name\', \'separator\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\'], "
argspec: "args=[\'self\', \'depth\', \'name\', \'separator\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'_X_\'], "
}
member_method {
name: "adapt"

View File

@ -138,7 +138,7 @@ tf_class {
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'depth\', \'name\', \'separator\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\'], "
argspec: "args=[\'self\', \'depth\', \'name\', \'separator\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'_X_\'], "
}
member_method {
name: "adapt"

View File

@ -261,6 +261,7 @@ function install_macos_pip_deps {
${PIP_CMD} install $USER_FLAG 'grpcio ~= 1.34.0'
${PIP_CMD} install $USER_FLAG 'portpicker ~= 1.3.1'
${PIP_CMD} install $USER_FLAG 'scipy ~= 1.5.2'
${PIP_CMD} install $USER_FLAG --upgrade certifi
# LINT.ThenChange(:linux_pip_installations_orig)
# LINT.ThenChange(:linux_pip_installations)

View File

@ -46,7 +46,7 @@ py_test(
tags = [
"no_oss_py2",
"no_pip",
"no_rocm",
"no_rocm", # No need to rerun this test for ROCm config.
"no_windows", # numpy prints differently on windows.
"noasan",
"nomsan",