diff --git a/.bazelrc b/.bazelrc index d06e0836184..a1f323c142d 100644 --- a/.bazelrc +++ b/.bazelrc @@ -365,13 +365,14 @@ build:rbe_cpu_linux --host_platform="@org_tensorflow//third_party/toolchains:rbe build:rbe_cpu_linux --platforms="@org_tensorflow//third_party/toolchains:rbe_ubuntu16.04-manylinux2010" build:rbe_linux_cuda_nvcc --config=rbe_linux -build:rbe_linux_cuda_nvcc --crosstool_top="//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.1:toolchain" -build:rbe_linux_cuda_nvcc --extra_toolchains="//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.1:toolchain-linux-x86_64" -build:rbe_linux_cuda_nvcc --extra_execution_platforms="@org_tensorflow//third_party/toolchains:rbe_cuda10.1-cudnn7-ubuntu16.04-manylinux2010,@org_tensorflow//third_party/toolchains:rbe_cuda10.1-cudnn7-ubuntu16.04-manylinux2010" -build:rbe_linux_cuda_nvcc --host_platform="@org_tensorflow//third_party/toolchains:rbe_cuda10.1-cudnn7-ubuntu16.04-manylinux2010" -build:rbe_linux_cuda_nvcc --platforms="@org_tensorflow//third_party/toolchains:rbe_cuda10.1-cudnn7-ubuntu16.04-manylinux2010" -build:rbe_linux_cuda_nvcc --repo_env=TF_CUDA_CONFIG_REPO="@org_tensorflow//third_party/toolchains/preconfig/ubuntu16.04/cuda10.1-cudnn7" -build:rbe_linux_cuda_nvcc --repo_env=TF_TENSORRT_CONFIG_REPO="@org_tensorflow//third_party/toolchains/preconfig/ubuntu16.04/tensorrt6.0" +build:rbe_linux_cuda_nvcc --crosstool_top="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain" +build:rbe_linux_cuda_nvcc --extra_toolchains="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64" +build:rbe_linux_cuda_nvcc --extra_execution_platforms="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform" +build:rbe_linux_cuda_nvcc --host_platform="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform" +build:rbe_linux_cuda_nvcc --platforms="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform" +build:rbe_linux_cuda_nvcc --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda" +build:rbe_linux_cuda_nvcc --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_tensorrt" +build:rbe_linux_cuda_nvcc --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_nccl" build:rbe_linux_cuda_nvcc --repo_env=TF_NEED_TENSORRT=1 build:rbe_linux_cuda_nvcc --repo_env=TF_CUDA_VERSION=10 build:rbe_linux_cuda_nvcc --repo_env=TF_CUDNN_VERSION=7 diff --git a/README.md b/README.md index 56baa0740c3..e95fea22c56 100644 --- a/README.md +++ b/README.md @@ -70,7 +70,7 @@ $ python 3 >>> hello = tf.constant('Hello, TensorFlow!') >>> hello.numpy() -'Hello, TensorFlow!' +b'Hello, TensorFlow!' ``` For more examples, see the diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index 67a2dde6c27..29dba253fee 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -418,21 +418,11 @@ void TensorHandleSilentCopy(bool async, auto op = tensorflow::down_cast<tensorflow::OperationInterface*>( matmul->operation.get()); - if (!async) { - // The input handles should never change since they have been mirrored. - ASSERT_EQ(op->GetInput(0), arg0); - ASSERT_EQ(op->GetInput(1), arg1); - } else { - if (cpu_op) { - ASSERT_EQ(op->GetInput(0), arg0); - // The GPU handle should be replaced with a CPU copy - ASSERT_NE(op->GetInput(1), arg1); - } else { - // The CPU handle should be replaced with a GPU copy - ASSERT_NE(op->GetInput(0), arg0); - ASSERT_EQ(op->GetInput(1), arg1); - } - } + + // The input handles should never change since they have been mirrored. + EXPECT_EQ(op->GetInput(0), arg0); + EXPECT_EQ(op->GetInput(1), arg1); + TFE_DeleteOp(matmul); TFE_DeleteTensorHandle(retvals[0]); TFE_DeleteTensorHandle(hgpu); diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index c283328403b..acbd2d27a45 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -14,6 +14,10 @@ package_group( includes = [ "//tensorflow/compiler/tf2xla:internal", ], + packages = [ + "//tensorflow/compiler/tests/...", + "//tensorflow/python/...", + ], ) package_group( diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc index 6023507cac7..6753ab9e728 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc @@ -46,7 +46,7 @@ limitations under the License. #include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project #include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project -#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project #include "mlir/IR/Attributes.h" // TF:llvm-project #include "mlir/IR/Builders.h" // TF:llvm-project #include "mlir/IR/Diagnostics.h" // TF:llvm-project diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc b/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc index 084bd26d28e..ac20ab68eaa 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc @@ -42,7 +42,7 @@ limitations under the License. #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/ToolOutputFile.h" #include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project -#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project #include "mlir/IR/Builders.h" // TF:llvm-project #include "mlir/IR/Function.h" // TF:llvm-project #include "mlir/IR/Location.h" // TF:llvm-project @@ -165,7 +165,7 @@ constexpr size_t kInitialBufferSize = 10240; // `isSigned` is set to false for other types. static StatusOr<tflite::TensorType> GetTFLiteType(Type type, bool is_signed = true) { - if (!is_signed && type.isInteger(8)) { + if (!is_signed && type.isSignlessInteger(8)) { return tflite::TensorType_UINT8; } if (!is_signed) { diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc index e73f6b732eb..83e372e5732 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc @@ -26,7 +26,7 @@ limitations under the License. #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/FormatVariadic.h" -#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project #include "mlir/IR/Attributes.h" // TF:llvm-project #include "mlir/IR/Builders.h" // TF:llvm-project #include "mlir/IR/Matchers.h" // TF:llvm-project @@ -275,7 +275,7 @@ Attribute ConstFoldBinaryOp( return ConstFoldBinaryOp<FloatAttr>(result_type, operands[0], operands[1], float_calculate, is_commutative); - if (elemType.isa<IntegerType>()) + if (elemType.isSignlessInteger()) return ConstFoldBinaryOp<IntegerAttr>(result_type, operands[0], operands[1], int_calculate, is_commutative); @@ -1560,7 +1560,7 @@ OpFoldResult RangeOp::fold(ArrayRef<Attribute> operands) { limit_tensor.getType().getRank() == 0 && delta_tensor.getType().getRank() == 0); Type elem_type = getType().cast<ShapedType>().getElementType(); - if (elem_type.isa<IntegerType>()) { + if (elem_type.isSignlessInteger()) { auto start_attr = start_tensor.getValue<IntegerAttr>({}); auto limit_attr = limit_tensor.getValue<IntegerAttr>({}); auto delta_attr = delta_tensor.getValue<IntegerAttr>({}); @@ -1662,7 +1662,7 @@ OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) { // Do not try to fold elements attr of a quant type because // DenseElementsAttr does not support it. - if (!getType().cast<ShapedType>().getElementType().isIntOrFloat()) + if (!getType().cast<ShapedType>().getElementType().isSignlessIntOrFloat()) return nullptr; assert(perm_tensor.getType().getRank() == 1); diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td index 36a1e93dc26..c4d5c6ce98f 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td @@ -3100,7 +3100,9 @@ def LstmMandatoryInputsConstraint : PredOpTrait< "mandatory operands element types should match", // TODO(ashwinm): Replace the indices with input tensor names when that // support is available. - TCopVTEtAreSameAt<[0, 2, 3, 4, 6, 7, 8, 13, 14, 15, 18, 19]>>; + Or<[ + TCopVTEtAreSameAt<[0, 2, 3, 4, 6, 7, 8, 13, 14, 15, 18, 19]>, + Neg<TypeIsPred<"input", F32>>]>>; def LstmOptionalPeepholeWeightConstraint : PredOpTrait< "the optional peephole weights should all be specified or none", @@ -3126,7 +3128,7 @@ def LstmProjectionWeightBiasConstraint : PredOpTrait< def LstmResultConstraint : PredOpTrait< "the input and result tensor elemental types must be same", - TCresVTEtIsSameAsOp<0, 0>>; + TFL_TCresVTEtIsSameAsOp<0, 0>>; // This is the basic kernel type LSTM op. // TODO(b/142417845): Refactor this part to return its tflite node name as @@ -3195,47 +3197,47 @@ Ba et al. “Layer Normalization” }]; let arguments = ( - ins TFL_TensorOf<[F32]>:$input, + ins TFL_TensorOf<[F32, QI8]>:$input, // Weights - TFL_TensorOfOrNone<[F32, I8]>:$input_to_input_weights, - TFL_TensorOf<[F32, I8]>:$input_to_forget_weights, - TFL_TensorOf<[F32, I8]>:$input_to_cell_weights, - TFL_TensorOf<[F32, I8]>:$input_to_output_weights, + TFL_TensorOfOrNone<[F32, I8, QI8]>:$input_to_input_weights, + TFL_TensorOf<[F32, I8, QI8]>:$input_to_forget_weights, + TFL_TensorOf<[F32, I8, QI8]>:$input_to_cell_weights, + TFL_TensorOf<[F32, I8, QI8]>:$input_to_output_weights, // Recurrent weights - TFL_TensorOfOrNone<[F32, I8]>:$recurrent_to_input_weights, - TFL_TensorOf<[F32, I8]>:$recurrent_to_forget_weights, - TFL_TensorOf<[F32, I8]>:$recurrent_to_cell_weights, - TFL_TensorOf<[F32, I8]>:$recurrent_to_output_weights, + TFL_TensorOfOrNone<[F32, I8, QI8]>:$recurrent_to_input_weights, + TFL_TensorOf<[F32, I8, QI8]>:$recurrent_to_forget_weights, + TFL_TensorOf<[F32, I8, QI8]>:$recurrent_to_cell_weights, + TFL_TensorOf<[F32, I8, QI8]>:$recurrent_to_output_weights, // Cell weights - TFL_TensorOfOrNone<[F32, I8]>:$cell_to_input_weights, + TFL_TensorOfOrNone<[F32, I8, QI16]>:$cell_to_input_weights, // Optional input - TFL_TensorOfOrNone<[F32, I8]>:$cell_to_forget_weights, + TFL_TensorOfOrNone<[F32, I8, QI16]>:$cell_to_forget_weights, // Optional input - TFL_TensorOfOrNone<[F32, I8]>:$cell_to_output_weights, + TFL_TensorOfOrNone<[F32, I8, QI16]>:$cell_to_output_weights, // Bias - TFL_TensorOfOrNone<[F32]>:$input_gate_bias, - TFL_TensorOf<[F32]>:$forget_gate_bias, - TFL_TensorOf<[F32]>:$cell_bias, - TFL_TensorOf<[F32]>:$output_gate_bias, + TFL_TensorOfOrNone<[F32, QI32]>:$input_gate_bias, + TFL_TensorOf<[F32, QI32]>:$forget_gate_bias, + TFL_TensorOf<[F32, QI32]>:$cell_bias, + TFL_TensorOf<[F32, QI32]>:$output_gate_bias, // Projection weight and bias - TFL_TensorOfOrNone<[F32, I8]>:$projection_weights, + TFL_TensorOfOrNone<[F32, I8, QI8]>:$projection_weights, // Optional input - TFL_TensorOfOrNone<[F32]>:$projection_bias, + TFL_TensorOfOrNone<[F32, QI32]>:$projection_bias, // Stateful activation and cell states. TFL_StatefulTensor:$input_activation_state, TFL_StatefulTensor:$input_cell_state, // Layer norm coefficients - TFL_TensorOfOrNone<[F32]>:$input_layer_norm_coefficients, - TFL_TensorOfOrNone<[F32]>:$forget_layer_norm_coefficients, - TFL_TensorOfOrNone<[F32]>:$cell_layer_norm_coefficients, - TFL_TensorOfOrNone<[F32]>:$output_layer_norm_coefficients, + TFL_TensorOfOrNone<[F32, QI16]>:$input_layer_norm_coefficients, + TFL_TensorOfOrNone<[F32, QI16]>:$forget_layer_norm_coefficients, + TFL_TensorOfOrNone<[F32, QI16]>:$cell_layer_norm_coefficients, + TFL_TensorOfOrNone<[F32, QI16]>:$output_layer_norm_coefficients, // Attributes TFL_AFAttr:$fused_activation_function, diff --git a/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc b/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc index 45e87e63475..617f968b958 100644 --- a/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc +++ b/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc @@ -25,7 +25,7 @@ limitations under the License. #include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:llvm-project #include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project -#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project #include "mlir/IR/AffineExpr.h" // TF:llvm-project #include "mlir/IR/AffineMap.h" // TF:llvm-project #include "mlir/IR/Attributes.h" // TF:llvm-project diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc b/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc index b2355b2ae6e..5f52c892421 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc @@ -26,7 +26,7 @@ limitations under the License. #include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project #include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project -#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project #include "mlir/IR/Attributes.h" // TF:llvm-project #include "mlir/IR/Builders.h" // TF:llvm-project #include "mlir/IR/Function.h" // TF:llvm-project diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h index ed998510328..9bb1d677df2 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h @@ -26,7 +26,7 @@ limitations under the License. #include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:llvm-project #include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project #include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project -#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project #include "mlir/IR/Attributes.h" // TF:llvm-project #include "mlir/IR/BlockAndValueMapping.h" // TF:llvm-project #include "mlir/IR/Function.h" // TF:llvm-project @@ -191,7 +191,7 @@ struct QuantizationPattern : public RewritePattern { auto ele_type = operand.getType().cast<TensorType>().getElementType(); if (auto op_inst = dyn_cast_or_null<DQ>(operand.getDefiningOp())) { inputs.push_back(op_inst.input()); - } else if (ele_type.isa<IntegerType>()) { + } else if (ele_type.isSignlessInteger()) { // If the operand is an integer tensor, then it doesn't require the // DQ op in the pattern. inputs.push_back(operand); @@ -225,7 +225,7 @@ struct QuantizationPattern : public RewritePattern { auto user = llvm::cast<Q>(*result.user_begin()); outputs_replaced.insert({user.output(), enumerated_result.index()}); output_types.push_back(user.getType()); - } else if (result_ele_type.template isa<IntegerType>()) { + } else if (result_ele_type.isSignlessInteger()) { // If the result is an integer tensor, then it doesn't require the // D op in the pattern. outputs_replaced.insert({result, enumerated_result.index()}); diff --git a/tensorflow/compiler/mlir/lite/quantization/xla/materialize.cc b/tensorflow/compiler/mlir/lite/quantization/xla/materialize.cc index 7c2846231c9..0c746d0c943 100644 --- a/tensorflow/compiler/mlir/lite/quantization/xla/materialize.cc +++ b/tensorflow/compiler/mlir/lite/quantization/xla/materialize.cc @@ -26,7 +26,7 @@ limitations under the License. #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project -#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project #include "mlir/IR/Attributes.h" // TF:llvm-project #include "mlir/IR/MLIRContext.h" // TF:llvm-project #include "mlir/IR/PatternMatch.h" // TF:llvm-project diff --git a/tensorflow/compiler/mlir/lite/tests/dilated-conv.mlir b/tensorflow/compiler/mlir/lite/tests/dilated-conv.mlir index a6d6ec52234..5fe5fbfb3ee 100644 --- a/tensorflow/compiler/mlir/lite/tests/dilated-conv.mlir +++ b/tensorflow/compiler/mlir/lite/tests/dilated-conv.mlir @@ -27,6 +27,20 @@ func @testDilatedConvWithNonZeroSTBPadding(%arg0: tensor<1x128x128x3xf32>, %arg1 // CHECK-NEXT: return [[RESULT]] : tensor<1x128x128x8xf32> } +func @testDilatedConvWithNonTrivialDilations(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> { + %cst = constant dense<[2, 2]> : tensor<2xi32> + %0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32> + %1 = "tf.Conv2D"(%0, %arg2) {padding = "VALID", dilations = [1, 2, 2, 1], strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32> + %2 = "tf.BatchToSpaceND"(%1, %cst, %arg1) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32> + return %2 : tensor<1x128x128x8xf32> + + // CHECK-LABEL: testDilatedConvWithNonTrivialDilations + // CHECK: [[STB:%.*]] = "tf.SpaceToBatchND" + // CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D" + // CHECK-NEXT: [[RESULT:%.*]] = "tf.BatchToSpaceND" + // CHECK-NEXT: return [[RESULT]] +} + func @testDilatedDepthWiseConv(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> { %cst = constant dense<[2, 2]> : tensor<2xi32> %0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32> @@ -104,7 +118,7 @@ func @testDilatedDepthWiseConvWithBiasAdd(%arg0: tensor<1x128x128x3xf32>, %arg1: func @testDilatedConvWithExpandSqueeze1(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<128xf32>) -> tensor<1x128x128xf32> { %cst = constant dense<[2, 2]> : tensor<2xi32> - %cst_0 = constant dense<3> : tensor<i32> + %cst_0 = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32> %0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32> %1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor<i32>) -> tensor<4x68x68x1xf32> %2 = "tf.Conv2D"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32> @@ -115,7 +129,7 @@ func @testDilatedConvWithExpandSqueeze1(%arg0: tensor<1x128x128xf32>, %arg1: ten // CHECK-LABEL: testDilatedConvWithExpandSqueeze1 // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<128xf32>) - // CHECK-NEXT: [[AXIS:%.*]] = constant dense<3> : tensor<i32> + // CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32> // CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32> // CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32> // CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [3]} : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32> @@ -125,7 +139,7 @@ func @testDilatedConvWithExpandSqueeze1(%arg0: tensor<1x128x128xf32>, %arg1: ten func @testDilatedDepthWiseConvWithExpandSqueeze1(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<128xf32>) -> tensor<1x128x128xf32> { %cst = constant dense<[2, 2]> : tensor<2xi32> - %cst_0 = constant dense<3> : tensor<i32> + %cst_0 = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32> %0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32> %1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor<i32>) -> tensor<4x68x68x1xf32> %2 = "tf.DepthwiseConv2dNative"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32> @@ -136,7 +150,7 @@ func @testDilatedDepthWiseConvWithExpandSqueeze1(%arg0: tensor<1x128x128xf32>, % // CHECK-LABEL: testDilatedDepthWiseConvWithExpandSqueeze1 // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<128xf32>) - // CHECK-NEXT: [[AXIS:%.*]] = constant dense<3> : tensor<i32> + // CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32> // CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32> // CHECK-NEXT: [[CONV:%.*]] = "tf.DepthwiseConv2dNative"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32> // CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [3]} : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32> @@ -146,7 +160,7 @@ func @testDilatedDepthWiseConvWithExpandSqueeze1(%arg0: tensor<1x128x128xf32>, % func @testDilatedConvWithExpandSqueeze2(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<?xf32>) -> tensor<1x128x128xf32> { %cst = constant dense<[2, 2]> : tensor<2xi32> - %cst_0 = constant dense<3> : tensor<i32> + %cst_0 = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32> %0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x?x?xf32> %1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x?x?xf32>, tensor<i32>) -> tensor<4x?x?x1xf32> %2 = "tf.Conv2D"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x?x?x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x?x?x1xf32> @@ -157,7 +171,7 @@ func @testDilatedConvWithExpandSqueeze2(%arg0: tensor<1x128x128xf32>, %arg1: ten // CHECK-LABEL: testDilatedConvWithExpandSqueeze2 // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<?xf32>) - // CHECK-NEXT: [[AXIS:%.*]] = constant dense<3> : tensor<i32> + // CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32> // CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32> // CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32> // CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [3]} : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32> @@ -167,7 +181,7 @@ func @testDilatedConvWithExpandSqueeze2(%arg0: tensor<1x128x128xf32>, %arg1: ten func @testDilatedDepthWiseConvWithExpandSqueeze2(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<?xf32>) -> tensor<1x128x128xf32> { %cst = constant dense<[2, 2]> : tensor<2xi32> - %cst_0 = constant dense<3> : tensor<i32> + %cst_0 = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32> %0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x?x?xf32> %1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x?x?xf32>, tensor<i32>) -> tensor<4x?x?x1xf32> %2 = "tf.DepthwiseConv2dNative"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x?x?x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x?x?x1xf32> @@ -178,7 +192,7 @@ func @testDilatedDepthWiseConvWithExpandSqueeze2(%arg0: tensor<1x128x128xf32>, % // CHECK-LABEL: testDilatedDepthWiseConvWithExpandSqueeze2 // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<?xf32>) - // CHECK-NEXT: [[AXIS:%.*]] = constant dense<3> : tensor<i32> + // CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32> // CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32> // CHECK-NEXT: [[CONV:%.*]] = "tf.DepthwiseConv2dNative"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32> // CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [3]} : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32> @@ -188,7 +202,7 @@ func @testDilatedDepthWiseConvWithExpandSqueeze2(%arg0: tensor<1x128x128xf32>, % func @testDilatedConvWithExpandSqueeze3(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<128xf32>) -> tensor<1x128x128xf32> { %cst = constant dense<[2, 2]> : tensor<2xi32> - %cst_0 = constant dense<3> : tensor<i32> + %cst_0 = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32> %0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32> %1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor<i32>) -> tensor<4x68x68x1xf32> %2 = "tf.Conv2D"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32> @@ -200,7 +214,7 @@ func @testDilatedConvWithExpandSqueeze3(%arg0: tensor<1x128x128xf32>, %arg1: ten // CHECK-LABEL: testDilatedConvWithExpandSqueeze3 // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<128xf32>) - // CHECK-NEXT: [[AXIS:%.*]] = constant dense<3> : tensor<i32> + // CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32> // CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32> // CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32> // CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [3]} : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32> @@ -210,7 +224,7 @@ func @testDilatedConvWithExpandSqueeze3(%arg0: tensor<1x128x128xf32>, %arg1: ten func @testDilatedDepthWiseConvWithExpandSqueeze3(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<128xf32>) -> tensor<1x128x128xf32> { %cst = constant dense<[2, 2]> : tensor<2xi32> - %cst_0 = constant dense<3> : tensor<i32> + %cst_0 = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32> %0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32> %1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor<i32>) -> tensor<4x68x68x1xf32> %2 = "tf.DepthwiseConv2dNative"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32> @@ -222,10 +236,29 @@ func @testDilatedDepthWiseConvWithExpandSqueeze3(%arg0: tensor<1x128x128xf32>, % // CHECK-LABEL: testDilatedDepthWiseConvWithExpandSqueeze3 // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<128xf32>) - // CHECK-NEXT: [[AXIS:%.*]] = constant dense<3> : tensor<i32> + // CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32> // CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32> // CHECK-NEXT: [[CONV:%.*]] = "tf.DepthwiseConv2dNative"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32> // CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [3]} : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32> // CHECK-NEXT: [[RESULT:%.*]] = "tf.BiasAdd"([[SQUEEZE]], [[BIAS]]) : (tensor<1x128x128xf32>, tensor<128xf32>) -> tensor<1x128x128xf32> // CHECK-NEXT: return [[RESULT]] : tensor<1x128x128xf32> } + +func @testDilatedConvWithDifferentExpandSqueezeAxis(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<128xf32>) -> tensor<1x128x128x1xf32> { + %cst = constant dense<[2, 2]> : tensor<2xi32> + %cst_0 = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32> + %0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32> + %1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor<i32>) -> tensor<4x68x68x1xf32> + %2 = "tf.Conv2D"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32> + %3 = "tf.Squeeze"(%2) {squeeze_dims = [2]} : (tensor<4x64x64x1xf32>) -> tensor<4x64x64x1xf32> + %4 = "tf.BatchToSpaceND"(%3, %cst, %arg1) : (tensor<4x64x64x1xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x1xf32> + return %4 : tensor<1x128x128x1xf32> + + // CHECK-LABEL: testDilatedConvWithDifferentExpandSqueezeAxis + // CHECK: [[STB:%.*]] = "tf.SpaceToBatchND" + // CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims" + // CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D" + // CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze" + // CHECK-NEXT: [[RESULT:%.*]] = "tf.BatchToSpaceND" + // CHECK-NEXT: return [[RESULT]] +} diff --git a/tensorflow/compiler/mlir/lite/tests/ops.mlir b/tensorflow/compiler/mlir/lite/tests/ops.mlir index da58b3704d0..57e2340dd37 100644 --- a/tensorflow/compiler/mlir/lite/tests/ops.mlir +++ b/tensorflow/compiler/mlir/lite/tests/ops.mlir @@ -593,6 +593,21 @@ func @testUnidirectionalSequenceLstmWithInvalidNoneType(%arg0: tensor<? x f32>, return %0 : tensor<?xf32> } +// ----- + +// CHECK-LABEL: testLstmQuantizedType +func @testLstmQuantizedType(%arg0: tensor<1x528x!quant.uniform<i8:f32, 0.037248000502586365:-19>>, %arg1: tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, %arg2: tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, %arg3: tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, %arg4: tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, %arg5: tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, %arg6: tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, %arg7: tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, %arg8: tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, %arg9: tensor<2048x!quant.uniform<i32:f32, 0.01>>, %arg10: tensor<2048x!quant.uniform<i32:f32, 0.01>>, %arg11: tensor<2048x!quant.uniform<i32:f32, 0.01>>, %arg12: tensor<2048x!quant.uniform<i32:f32, 0.01>>, %arg13: tensor<640x2048x!quant.uniform<i8<-127:127>:f32, 0.021174000576138496>>, %arg14: tensor<640x!quant.uniform<i32:f32, 9.9999999747524271E-7>>, %arg15: tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, %arg16: tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, %arg17: tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, %arg18: tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, %arg19: tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>, %arg20: tensor<1x2048x!quant.uniform<i16:f32, 4.8799999058246613E-4>>) -> tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>> { + %cst = constant unit + %0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %cst, %cst, %cst, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg19, %arg20, %arg15, %arg16, %arg17, %arg18) ( { + }) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", kernel_type = "FULL", proj_clip = 0.01 : f32} : (tensor<1x528x!quant.uniform<i8:f32, 0.037248000502586365:-19>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, none, none, none, tensor<2048x!quant.uniform<i32:f32, 1.000000e-02>>, tensor<2048x!quant.uniform<i32:f32, 1.000000e-02>>, tensor<2048x!quant.uniform<i32:f32, 1.000000e-02>>, tensor<2048x!quant.uniform<i32:f32, 1.000000e-02>>, tensor<640x2048x!quant.uniform<i8<-127:127>:f32, 0.021174000576138496>>, tensor<640x!quant.uniform<i32:f32, 9.9999999747524271E-7>>, tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>, tensor<1x2048x!quant.uniform<i16:f32, 4.8799999058246613E-4>>, tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>) -> tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>> + return %0 : tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>> + // CHECK: %[[RES0:.*]] = constant unit + // CHECK: %[[RES1:.*]] = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %[[RES0]], %[[RES0]], %[[RES0]], %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg19, %arg20, %arg15, %arg16, %arg17, %arg18) ( { + // CHECK-NEXT: }) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", kernel_type = "FULL", proj_clip = 0.00999999977 : f32} : (tensor<1x528x!quant.uniform<i8:f32, 0.037248000502586365:-19>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, none, none, none, tensor<2048x!quant.uniform<i32:f32, 1.000000e-02>>, tensor<2048x!quant.uniform<i32:f32, 1.000000e-02>>, tensor<2048x!quant.uniform<i32:f32, 1.000000e-02>>, tensor<2048x!quant.uniform<i32:f32, 1.000000e-02>>, tensor<640x2048x!quant.uniform<i8<-127:127>:f32, 0.021174000576138496>>, tensor<640x!quant.uniform<i32:f32, 9.9999999747524271E-7>>, tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>, tensor<1x2048x!quant.uniform<i16:f32, 4.8799999058246613E-4>>, tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>) -> tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>> + // CHECK: return %[[RES1]] +} + + // ----- // CHECK-LABEL: testLstm diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-composite-functions-tf.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-composite-functions-tf.mlir index 448a4f9eb5f..0dcce6bf4e8 100644 --- a/tensorflow/compiler/mlir/lite/tests/prepare-composite-functions-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/prepare-composite-functions-tf.mlir @@ -154,7 +154,7 @@ func @layernormalizedlstmcellsimple(%arg0: tensor<1x?xf32>, %arg1: tensor<3x4xf3 // ----- module { -func @inference_standard_lstm_time_major(%arg0: tensor<?x8x8xf32>, %arg1: tensor<?x10xf32>, %arg2: tensor<?x10xf32>, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) -> (tensor<?x10xf32>, tensor<?x?x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = true, tf.time_major = true} { +func @inference_standard_lstm_time_major(%arg0: tensor<?x8x8xf32>, %arg1: tensor<?x10xf32>, %arg2: tensor<?x10xf32>, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) -> (tensor<?x10xf32>, tensor<?x?x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = true} { %0 = "tf.BatchMatMulV2"(%arg0, %arg3) {adj_x = false, adj_y = false} : (tensor<?x8x8xf32>, tensor<8x40xf32>) -> tensor<?x8x40xf32> %1 = "tf.Add"(%0, %arg5) : (tensor<?x8x40xf32>, tensor<40xf32>) -> tensor<?x8x40xf32> %2 = "tf.BatchMatMulV2"(%1, %arg4) {adj_x = false, adj_y = true} : (tensor<?x8x40xf32>, tensor<10x40xf32>) -> tensor<?x8x10xf32> @@ -165,7 +165,7 @@ func @inference_standard_lstm_time_major(%arg0: tensor<?x8x8xf32>, %arg1: tensor return %5, %4, %5, %5, %6 : tensor<?x10xf32>, tensor<?x?x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32> } -// CHECK: func @inference_standard_lstm_time_major([[VAL_0:%.*]]: tensor<?x8x8xf32>, [[VAL_1:%.*]]: tensor<?x10xf32>, [[VAL_2:%.*]]: tensor<?x10xf32>, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> tensor<8x?x10xf32> attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = true, tf.time_major = true} { +// CHECK: func @inference_standard_lstm_time_major([[VAL_0:%.*]]: tensor<?x8x8xf32>, [[VAL_1:%.*]]: tensor<?x10xf32>, [[VAL_2:%.*]]: tensor<?x10xf32>, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> tensor<8x?x10xf32> attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = true} { // CHECK: [[VAL_6:%.*]] = constant dense<[1, 0]> : tensor<2xi64> // CHECK: [[VAL_7:%.*]] = "tf.Transpose"([[VAL_3]], [[VAL_6]]) : (tensor<8x40xf32>, tensor<2xi64>) -> tensor<40x8xf32> // CHECK: [[VAL_8:%.*]] = constant dense<[1, 0]> : tensor<2xi64> @@ -189,7 +189,7 @@ func @inference_standard_lstm_time_major(%arg0: tensor<?x8x8xf32>, %arg1: tensor // ----- module { -func @inference_standard_lstm_non_time_major(%arg0: tensor<8x8x8xf32>, %arg1: tensor<?x10xf32>, %arg2: tensor<?x10xf32>, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) -> (tensor<?x10xf32>, tensor<8x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = true, tf.time_major = false} { +func @inference_standard_lstm_non_time_major(%arg0: tensor<8x8x8xf32>, %arg1: tensor<?x10xf32>, %arg2: tensor<?x10xf32>, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) -> (tensor<?x10xf32>, tensor<8x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = false} { %0 = "tf.BatchMatMulV2"(%arg0, %arg3) {adj_x = false, adj_y = false} : (tensor<8x8x8xf32>, tensor<8x40xf32>) -> tensor<8x8x40xf32> %1 = "tf.Add"(%0, %arg5) : (tensor<8x8x40xf32>, tensor<40xf32>) -> tensor<8x8x40xf32> %2 = "tf.BatchMatMulV2"(%1, %arg4) {adj_x = false, adj_y = true} : (tensor<8x8x40xf32>, tensor<10x40xf32>) -> tensor<8x8x10xf32> @@ -200,7 +200,7 @@ func @inference_standard_lstm_non_time_major(%arg0: tensor<8x8x8xf32>, %arg1: te return %5, %4, %5, %5, %6 : tensor<?x10xf32>, tensor<8x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32> } -// CHECK: func @inference_standard_lstm_non_time_major([[VAL_0:%.*]]: tensor<8x8x8xf32>, [[VAL_1:%.*]]: tensor<?x10xf32>, [[VAL_2:%.*]]: tensor<?x10xf32>, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> tensor<8x8x10xf32> attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = true, tf.time_major = false} { +// CHECK: func @inference_standard_lstm_non_time_major([[VAL_0:%.*]]: tensor<8x8x8xf32>, [[VAL_1:%.*]]: tensor<?x10xf32>, [[VAL_2:%.*]]: tensor<?x10xf32>, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> tensor<8x8x10xf32> attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = false} { // CHECK: [[VAL_6:%.*]] = constant dense<[1, 0, 2]> : tensor<3xi64> // CHECK: [[VAL_7:%.*]] = "tf.Transpose"([[VAL_0]], [[VAL_6]]) : (tensor<8x8x8xf32>, tensor<3xi64>) -> tensor<8x8x8xf32> // CHECK: [[VAL_8:%.*]] = constant dense<[1, 0]> : tensor<2xi64> @@ -224,3 +224,84 @@ func @inference_standard_lstm_non_time_major(%arg0: tensor<8x8x8xf32>, %arg1: te // CHECK: return [[VAL_24]] : tensor<8x8x10xf32> // CHECK: } } + +// ----- + +module { +func @inference_standard_lstm_time_major_go_backwards(%arg0: tensor<?x8x8xf32>, %arg1: tensor<?x10xf32>, %arg2: tensor<?x10xf32>, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) -> (tensor<?x10xf32>, tensor<?x?x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = true, tf.time_major = true} { + %0 = "tf.BatchMatMulV2"(%arg0, %arg3) {adj_x = false, adj_y = false} : (tensor<?x8x8xf32>, tensor<8x40xf32>) -> tensor<?x8x40xf32> + %1 = "tf.Add"(%0, %arg5) : (tensor<?x8x40xf32>, tensor<40xf32>) -> tensor<?x8x40xf32> + %2 = "tf.BatchMatMulV2"(%1, %arg4) {adj_x = false, adj_y = true} : (tensor<?x8x40xf32>, tensor<10x40xf32>) -> tensor<?x8x10xf32> + %3 = "tf.Add"(%2, %arg1) : (tensor<?x8x10xf32>, tensor<?x10xf32>) -> tensor<?x8x10xf32> + %4 = "tf.Add"(%2, %arg2) : (tensor<?x8x10xf32>, tensor<?x10xf32>) -> tensor<?x?x10xf32> + %5 = "tf.Add"(%arg1, %arg2) : (tensor<?x10xf32>, tensor<?x10xf32>) -> tensor<?x10xf32> + %6 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "/device:CPU:0", dtype = f32, value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32> + return %5, %4, %5, %5, %6 : tensor<?x10xf32>, tensor<?x?x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32> +} + +// CHECK: func @inference_standard_lstm_time_major_go_backwards([[VAL_0:%.*]]: tensor<?x8x8xf32>, [[VAL_1:%.*]]: tensor<?x10xf32>, [[VAL_2:%.*]]: tensor<?x10xf32>, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> tensor<8x?x10xf32> attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = true, tf.time_major = true} { +// CHECK: [[VAL_6:%.*]] = constant dense<0> : tensor<1xi32> +// CHECK: [[VAL_7:%.*]] = "tf.ReverseV2"([[VAL_0]], [[VAL_6]]) : (tensor<?x8x8xf32>, tensor<1xi32>) -> tensor<?x8x8xf32> +// CHECK: [[VAL_8:%.*]] = constant dense<[1, 0]> : tensor<2xi64> +// CHECK: [[VAL_9:%.*]] = "tf.Transpose"([[VAL_3]], [[VAL_8]]) : (tensor<8x40xf32>, tensor<2xi64>) -> tensor<40x8xf32> +// CHECK: [[VAL_10:%.*]] = constant dense<[1, 0]> : tensor<2xi64> +// CHECK: [[VAL_11:%.*]] = "tf.Transpose"([[VAL_4]], [[VAL_10]]) : (tensor<10x40xf32>, tensor<2xi64>) -> tensor<40x10xf32> +// CHECK: [[VAL_12:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32> +// CHECK: [[VAL_13:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32> +// CHECK: [[VAL_14:%.*]]:4 = "tf.SplitV"([[VAL_9]], [[VAL_12]], [[VAL_13]]) : (tensor<40x8xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>) +// CHECK: [[VAL_15:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32> +// CHECK: [[VAL_16:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32> +// CHECK: [[VAL_17:%.*]]:4 = "tf.SplitV"([[VAL_11]], [[VAL_15]], [[VAL_16]]) : (tensor<40x10xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>) +// CHECK: [[VAL_18:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32> +// CHECK: [[VAL_19:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32> +// CHECK: [[VAL_20:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_18]], [[VAL_19]]) : (tensor<40xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>) +// CHECK: [[VAL_21:%.*]] = constant unit +// CHECK: [[VAL_22:%.*]] = "tfl.lstm"([[VAL_0]], [[VAL_14]]#0, [[VAL_14]]#1, [[VAL_14]]#2, [[VAL_14]]#3, [[VAL_17]]#0, [[VAL_17]]#1, [[VAL_17]]#2, [[VAL_17]]#3, [[VAL_21]], [[VAL_21]], [[VAL_21]], [[VAL_20]]#0, [[VAL_20]]#1, [[VAL_20]]#2, [[VAL_20]]#3, [[VAL_21]], [[VAL_21]], [[VAL_1]], [[VAL_2]], [[VAL_21]], [[VAL_21]], [[VAL_21]], [[VAL_21]]) ( { +// CHECK: }) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<?x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<?x10xf32>, tensor<?x10xf32>, none, none, none, none) -> tensor<8x?x10xf32> +// CHECK: return [[VAL_23:%.*]] : tensor<8x?x10xf32> +// CHECK: } + +} + +// ----- + +module { +func @inference_standard_lstm_non_time_major_go_backwards(%arg0: tensor<8x8x8xf32>, %arg1: tensor<?x10xf32>, %arg2: tensor<?x10xf32>, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) -> (tensor<?x10xf32>, tensor<8x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = true, tf.time_major = false} { + %0 = "tf.BatchMatMulV2"(%arg0, %arg3) {adj_x = false, adj_y = false} : (tensor<8x8x8xf32>, tensor<8x40xf32>) -> tensor<8x8x40xf32> + %1 = "tf.Add"(%0, %arg5) : (tensor<8x8x40xf32>, tensor<40xf32>) -> tensor<8x8x40xf32> + %2 = "tf.BatchMatMulV2"(%1, %arg4) {adj_x = false, adj_y = true} : (tensor<8x8x40xf32>, tensor<10x40xf32>) -> tensor<8x8x10xf32> + %3 = "tf.Add"(%2, %arg1) : (tensor<8x8x10xf32>, tensor<?x10xf32>) -> tensor<8x8x10xf32> + %4 = "tf.Add"(%2, %arg2) : (tensor<8x8x10xf32>, tensor<?x10xf32>) -> tensor<8x8x10xf32> + %5 = "tf.Add"(%arg1, %arg2) : (tensor<?x10xf32>, tensor<?x10xf32>) -> tensor<?x10xf32> + %6 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "/device:CPU:0", dtype = f32, value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32> + return %5, %4, %5, %5, %6 : tensor<?x10xf32>, tensor<8x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32> +} + +// CHECK: func @inference_standard_lstm_non_time_major_go_backwards([[VAL_0:%.*]]: tensor<8x8x8xf32>, [[VAL_1:%.*]]: tensor<?x10xf32>, [[VAL_2:%.*]]: tensor<?x10xf32>, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> tensor<8x8x10xf32> attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = true, tf.time_major = false} { +// CHECK: [[VAL_6:%.*]] = constant dense<[1, 0, 2]> : tensor<3xi64> +// CHECK: [[VAL_7:%.*]] = "tf.Transpose"([[VAL_0]], [[VAL_6]]) : (tensor<8x8x8xf32>, tensor<3xi64>) -> tensor<8x8x8xf32> +// CHECK: [[VAL_8:%.*]] = constant dense<0> : tensor<1xi32> +// CHECK: [[VAL_9:%.*]] = "tf.ReverseV2"([[VAL_7]], [[VAL_8]]) : (tensor<8x8x8xf32>, tensor<1xi32>) -> tensor<8x8x8xf32> +// CHECK: [[VAL_10:%.*]] = constant dense<[1, 0]> : tensor<2xi64> +// CHECK: [[VAL_11:%.*]] = "tf.Transpose"([[VAL_3]], [[VAL_10]]) : (tensor<8x40xf32>, tensor<2xi64>) -> tensor<40x8xf32> +// CHECK: [[VAL_12:%.*]] = constant dense<[1, 0]> : tensor<2xi64> +// CHECK: [[VAL_13:%.*]] = "tf.Transpose"([[VAL_4]], [[VAL_12]]) : (tensor<10x40xf32>, tensor<2xi64>) -> tensor<40x10xf32> +// CHECK: [[VAL_14:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32> +// CHECK: [[VAL_15:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32> +// CHECK: [[VAL_16:%.*]]:4 = "tf.SplitV"([[VAL_11]], [[VAL_14]], [[VAL_15]]) : (tensor<40x8xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>) +// CHECK: [[VAL_17:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32> +// CHECK: [[VAL_18:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32> +// CHECK: [[VAL_19:%.*]]:4 = "tf.SplitV"([[VAL_13]], [[VAL_17]], [[VAL_18]]) : (tensor<40x10xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>) +// CHECK: [[VAL_20:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32> +// CHECK: [[VAL_21:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32> +// CHECK: [[VAL_22:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_20]], [[VAL_21]]) : (tensor<40xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>) +// CHECK: [[VAL_23:%.*]] = constant unit +// CHECK: [[VAL_24:%.*]] = "tfl.lstm"([[VAL_0]], [[VAL_16]]#0, [[VAL_16]]#1, [[VAL_16]]#2, [[VAL_16]]#3, [[VAL_19]]#0, [[VAL_19]]#1, [[VAL_19]]#2, [[VAL_19]]#3, [[VAL_23]], [[VAL_23]], [[VAL_23]], [[VAL_22]]#0, [[VAL_22]]#1, [[VAL_22]]#2, [[VAL_22]]#3, [[VAL_23]], [[VAL_23]], [[VAL_1]], [[VAL_2]], [[VAL_23]], [[VAL_23]], [[VAL_23]], [[VAL_23]]) ( { +// CHECK: }) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<8x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<?x10xf32>, tensor<?x10xf32>, none, none, none, none) -> tensor<8x8x10xf32> +// CHECK: [[VAL_25:%.*]] = constant dense<[1, 0, 2]> : tensor<3xi64> +// CHECK: [[VAL_26:%.*]] = "tf.Transpose"([[VAL_27:%.*]], [[VAL_25]]) : (tensor<8x8x10xf32>, tensor<3xi64>) -> tensor<8x8x10xf32> +// CHECK: return [[VAL_26]] : tensor<8x8x10xf32> +// CHECK: } + +} + diff --git a/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc b/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc index 0472bd6abcf..30fe391762f 100644 --- a/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc +++ b/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Attributes.h" diff --git a/tensorflow/compiler/mlir/lite/transforms/dilated_conv.h b/tensorflow/compiler/mlir/lite/transforms/dilated_conv.h index c3d3df14e0b..65bed845bae 100644 --- a/tensorflow/compiler/mlir/lite/transforms/dilated_conv.h +++ b/tensorflow/compiler/mlir/lite/transforms/dilated_conv.h @@ -27,6 +27,7 @@ limitations under the License. #include "mlir/IR/StandardTypes.h" // TF:llvm-project #include "mlir/IR/TypeUtilities.h" // TF:llvm-project #include "mlir/Pass/Pass.h" // TF:llvm-project +#include "tensorflow/compiler/mlir/lite/utils/validators.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" namespace mlir { @@ -80,6 +81,17 @@ class ConvertTFDilatedConvOp : public OpRewritePattern<Conv2dOpTy> { template <typename Conv2dOpTy> PatternMatchResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite( Conv2dOpTy op, PatternRewriter& rewriter) const { + // Make sure Conv2D has 'VALID' padding. + if (op.template getAttrOfType<StringAttr>("padding").getValue() != "VALID") { + return Pattern::matchFailure(); + } + // Make sure dilations are all ones if set. + const ArrayAttr& dilations = + op.template getAttrOfType<ArrayAttr>("dilations"); + if (dilations && !TFIntListIsAllOnes(dilations)) { + return Pattern::matchFailure(); + } + // Check if the ConvOp is preceded by a `Expand` op and succeeded by a // `Squeeze` op. Operation* prev_op = op.getOperation()->getPrevNode(); @@ -90,6 +102,7 @@ PatternMatchResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite( TF::ExpandDimsOp expand_op; TF::SqueezeOp squeeze_op; + int64_t expand_axis; // Expand + Squeeze op. if (llvm::isa<TF::ExpandDimsOp>(prev_op)) { if (!llvm::isa<TF::SqueezeOp>(next_op)) { @@ -99,6 +112,22 @@ PatternMatchResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite( expand_op = llvm::cast<TF::ExpandDimsOp>(prev_op); squeeze_op = llvm::cast<TF::SqueezeOp>(next_op); + // Make sure that the axis in `expand_op` is constant. + if (auto const_op = + llvm::dyn_cast<TF::ConstOp>(expand_op.dim().getDefiningOp())) { + expand_axis = + (*const_op.value().cast<DenseElementsAttr>().getIntValues().begin()) + .getSExtValue(); + } else { + return Pattern::matchFailure(); + } + // Make sure that the `squeeze_dims` is equal to `expand_axis`. + auto squeeze_dims = squeeze_op.squeeze_dims(); + if (squeeze_dims.size() != 1 || + squeeze_dims[0].cast<IntegerAttr>().getInt() != expand_axis) { + return Pattern::matchFailure(); + } + // Update previous/next op pointer. prev_op = prev_op->getPrevNode(); if (!prev_op) return Pattern::matchFailure(); @@ -108,10 +137,14 @@ PatternMatchResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite( // SpaceToBatchND op. if (!llvm::isa<TF::SpaceToBatchNDOp>(prev_op)) return Pattern::matchFailure(); + // TODO(b/149936532): Check `padding` input, currently ignored. TF::SpaceToBatchNDOp stb_op = llvm::cast<TF::SpaceToBatchNDOp>(prev_op); // Pad op. TF::PadOp pad_op; + // TODO(b/149936532): Currently we just ignore the PadOp. However note that + // in real scenarios this may not always be correct: user can put a PadOp here + // with non-trivial consequences. if (llvm::isa<TF::PadOp>(next_op)) { pad_op = llvm::cast<TF::PadOp>(next_op); next_op = next_op->getNextNode(); @@ -119,6 +152,7 @@ PatternMatchResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite( } // BatchToSpaceND + BiasAdd. + // TODO(b/149936532): Check the `crops` input, currently ignored. TF::BatchToSpaceNDOp bts_op; TF::BiasAddOp biasadd_op; bool final_op_is_bts = true; @@ -146,14 +180,10 @@ PatternMatchResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite( if (!dilations_attr.hasValue()) return Pattern::matchFailure(); op.setAttr("dilations", dilations_attr.getValue()); - // Here we need to set the correct padding for Conv op. In TF, the conv op - // inserted after 'SpaceToBatch' always has 'VALID' padding. This might - // become a problem here if the original Conv op has 'SAME' padding. When - // the original conv has 'SAME' padding, TF will set a non-zero padding for - // the 'SpaceToBatch' op, so we rely on this information to check if we need - // to change the padding from 'VALID' to 'SAME' (a.k.a when we see non-zero - // values in `stb_op.paddings`, we change the current Conv's padding to - // 'SAME'). + // Padding is set to 'SAME' when `stb_op` has non-zero paddings. + // TODO(b/149936532): This assumption only holds when the input width & height + // is multiple of dilation width & height. We should fix it in order to + // support other use cases. auto stb_paddings = stb_op.paddings(); ElementsAttr stb_paddings_attr; if (matchPattern(stb_paddings, m_Constant(&stb_paddings_attr))) { @@ -175,7 +205,8 @@ PatternMatchResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite( auto input_shape = stb_op.input().getType().cast<ShapedType>().getShape(); SmallVector<int64_t, 4> expand_shape(input_shape.begin(), input_shape.end()); - expand_shape.push_back(1); + expand_shape.insert(expand_shape.begin() + expand_axis, 1); + auto expand_result_type = RankedTensorType::get( expand_shape, getElementTypeOrSelf(stb_op.input())); expand_op.getResult().setType(expand_result_type); @@ -208,7 +239,7 @@ ConvertTFDilatedConvOp<Conv2dOpTy>::ExtractDilationsAttrFromBlockShape( ElementsAttr stb_bs_attr, bts_bs_attr; if (!matchPattern(stb_block_shape, m_Constant(&stb_bs_attr)) || !matchPattern(bts_block_shape, m_Constant(&bts_bs_attr))) { - // Returns failure status if block shape is not a constant. + // Returns failure status if block_shape is not a constant. return {}; } // Check that the block_shape of `stb_op` and `bts_op` are equal. @@ -217,9 +248,8 @@ ConvertTFDilatedConvOp<Conv2dOpTy>::ExtractDilationsAttrFromBlockShape( if (stb_bs_attr.getValue({i}) != bts_bs_attr.getValue({i})) return {}; } - // TODO(haoliang): support 1-D dilated conv. + // Set dilation factor. if (stb_bs_attr.getNumElements() < 2) return {}; - int dilation_h_factor = stb_bs_attr.getValue({0}).cast<IntegerAttr>().getInt(); int dilation_w_factor = diff --git a/tensorflow/compiler/mlir/lite/transforms/extract_ophint.cc b/tensorflow/compiler/mlir/lite/transforms/extract_ophint.cc index e07cea8535e..3582046f13f 100644 --- a/tensorflow/compiler/mlir/lite/transforms/extract_ophint.cc +++ b/tensorflow/compiler/mlir/lite/transforms/extract_ophint.cc @@ -22,7 +22,7 @@ limitations under the License. #include "llvm/ADT/StringSwitch.h" #include "llvm/Support/Casting.h" #include "mlir/Analysis/LoopAnalysis.h" // TF:llvm-project -#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project #include "mlir/IR/Attributes.h" // TF:llvm-project #include "mlir/IR/Block.h" // TF:llvm-project #include "mlir/IR/Builders.h" // TF:llvm-project diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_ophint_func_op.cc b/tensorflow/compiler/mlir/lite/transforms/legalize_ophint_func_op.cc index e31b143ab43..f3a15b7ebd3 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_ophint_func_op.cc +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_ophint_func_op.cc @@ -15,7 +15,7 @@ limitations under the License. #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/StringMap.h" -#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project #include "mlir/IR/Attributes.h" // TF:llvm-project #include "mlir/IR/Block.h" // TF:llvm-project #include "mlir/IR/Builders.h" // TF:llvm-project diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td index fd09d2e5c24..683905d06c7 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td @@ -16,7 +16,7 @@ limitations under the License. // TFLite legalization patterns include "mlir/IR/OpBase.td" -include "mlir/Dialect/StandardOps/Ops.td" +include "mlir/Dialect/StandardOps/IR/Ops.td" include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td" include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" @@ -341,7 +341,7 @@ def : Pat<(TF_MatrixDiagOp $diagonal), (TFL_MatrixDiagOp $diagonal)>; class I32VectorElementsAttr<int len> : ElementsAttrBase< CPred<"$_self.isa<DenseIntElementsAttr>() &&" "$_self.cast<DenseIntElementsAttr>().getType()." - "getElementType().isInteger(32)">, + "getElementType().isSignlessInteger(32)">, "32-bit int elements attribute of shape [" # len # "]"> { let storageType = [{ DenseIntElementsAttr }]; diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc index 77ea344b3a4..cf24ed7e0f4 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc @@ -255,7 +255,7 @@ PatternMatchResult ConvertTFReshapeOp::matchAndRewrite( ShapedType shape_type = shape.getType().cast<ShapedType>(); // The tfl reshape's #2 operand needs to i32 tensor type, so we have to cast. - if (!shape_type.getElementType().isInteger(32)) { + if (!shape_type.getElementType().isSignlessInteger(32)) { auto new_shape = shape_type.getShape(); IntegerType new_ele_type = rewriter.getIntegerType(32); ShapedType new_type = RankedTensorType::get(new_shape, new_ele_type); diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_tf_while.cc b/tensorflow/compiler/mlir/lite/transforms/legalize_tf_while.cc index 7d1dbbb9fcc..ea44a34eb2b 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_tf_while.cc +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_tf_while.cc @@ -15,7 +15,7 @@ limitations under the License. // Converts TF While to TFL While with single call in body and cond. -#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project #include "mlir/IR/Attributes.h" // TF:llvm-project #include "mlir/IR/Builders.h" // TF:llvm-project #include "mlir/IR/Operation.h" // TF:llvm-project diff --git a/tensorflow/compiler/mlir/lite/transforms/load_quantization_recipe.cc b/tensorflow/compiler/mlir/lite/transforms/load_quantization_recipe.cc index 3349261af02..4fde08bc1cf 100644 --- a/tensorflow/compiler/mlir/lite/transforms/load_quantization_recipe.cc +++ b/tensorflow/compiler/mlir/lite/transforms/load_quantization_recipe.cc @@ -20,7 +20,7 @@ limitations under the License. #include "llvm/ADT/None.h" #include "llvm/ADT/Optional.h" #include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project -#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project #include "mlir/IR/Builders.h" // TF:llvm-project #include "mlir/IR/MLIRContext.h" // TF:llvm-project #include "mlir/Pass/Pass.h" // TF:llvm-project diff --git a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc index 1b240e2e674..00159644185 100644 --- a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc +++ b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc @@ -32,7 +32,7 @@ limitations under the License. #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" #include "mlir/Analysis/LoopAnalysis.h" // TF:llvm-project -#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project #include "mlir/IR/Attributes.h" // TF:llvm-project #include "mlir/IR/Block.h" // TF:llvm-project #include "mlir/IR/Function.h" // TF:llvm-project @@ -335,8 +335,9 @@ struct ConvertTensorListInitOp : public OpConversionPattern<OpT> { ConversionPatternRewriter &rewriter) const override { Type dtype = op.element_dtype(); if (!(dtype.isF16() || dtype.isF32() || dtype.isF64() || - dtype.isInteger(1) || dtype.isInteger(8) || dtype.isInteger(16) || - dtype.isInteger(32) || dtype.isInteger(64))) { + dtype.isInteger(1) || dtype.isSignlessInteger(8) || + dtype.isSignlessInteger(16) || dtype.isSignlessInteger(32) || + dtype.isSignlessInteger(64))) { op.emitError( "requires element_dtype to be 1-bit/8-bit/16-bit/32-bit/64-bit " "integer or 16-bit/32-bit/64-bit float type during TF Lite " diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize.cc b/tensorflow/compiler/mlir/lite/transforms/optimize.cc index 112e7f788ce..dbc12a85b67 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/optimize.cc @@ -31,7 +31,7 @@ limitations under the License. #include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringSwitch.h" #include "llvm/Support/Casting.h" -#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project #include "mlir/IR/Attributes.h" // TF:llvm-project #include "mlir/IR/Matchers.h" // TF:llvm-project #include "mlir/IR/PatternMatch.h" // TF:llvm-project diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td index 71017fe2801..0ad5be055dc 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td @@ -16,7 +16,7 @@ limitations under the License. // This is the optimization pattern definition file for TensorFlow Lite. include "mlir/IR/OpBase.td" -include "mlir/Dialect/StandardOps/Ops.td" +include "mlir/Dialect/StandardOps/IR/Ops.td" include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td" include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" diff --git a/tensorflow/compiler/mlir/lite/transforms/post_quantize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/post_quantize_patterns.td index 283b29ea005..ecceba5316e 100644 --- a/tensorflow/compiler/mlir/lite/transforms/post_quantize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/post_quantize_patterns.td @@ -16,7 +16,7 @@ limitations under the License. // This is the quantization pattern definition file for TensorFlow Lite. include "mlir/IR/OpBase.td" -include "mlir/Dialect/StandardOps/Ops.td" +include "mlir/Dialect/StandardOps/IR/Ops.td" include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td" // Both Quantize and Dequantize ops have side effects, so we have to define diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc index 7181877085d..98f9c73f791 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc @@ -23,7 +23,7 @@ limitations under the License. #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/raw_ostream.h" -#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project #include "mlir/IR/Attributes.h" // TF:llvm-project #include "mlir/IR/Builders.h" // TF:llvm-project #include "mlir/IR/Function.h" // TF:llvm-project diff --git a/tensorflow/compiler/mlir/lite/transforms/quantize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/quantize_patterns.td index 07dd8ab4455..22bcc563f7b 100644 --- a/tensorflow/compiler/mlir/lite/transforms/quantize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/quantize_patterns.td @@ -16,7 +16,7 @@ limitations under the License. // This is the quantization pattern definition file for TensorFlow Lite. include "mlir/IR/OpBase.td" -include "mlir/Dialect/StandardOps/Ops.td" +include "mlir/Dialect/StandardOps/IR/Ops.td" include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td" // Quantize attribute $0 by using quantization parameter from %1. diff --git a/tensorflow/compiler/mlir/lite/transforms/split_merged_operands.cc b/tensorflow/compiler/mlir/lite/transforms/split_merged_operands.cc index 17125bffd85..c8aa67084ce 100644 --- a/tensorflow/compiler/mlir/lite/transforms/split_merged_operands.cc +++ b/tensorflow/compiler/mlir/lite/transforms/split_merged_operands.cc @@ -18,7 +18,7 @@ limitations under the License. #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/StringMap.h" #include "llvm/Support/Casting.h" -#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project #include "mlir/IR/Attributes.h" // TF:llvm-project #include "mlir/IR/Block.h" // TF:llvm-project #include "mlir/IR/Builders.h" // TF:llvm-project diff --git a/tensorflow/compiler/mlir/lite/transforms/tensorlist_patterns.td b/tensorflow/compiler/mlir/lite/transforms/tensorlist_patterns.td index ff024ad0463..b0435b7cf4c 100644 --- a/tensorflow/compiler/mlir/lite/transforms/tensorlist_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/tensorlist_patterns.td @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ include "mlir/IR/OpBase.td" -include "mlir/Dialect/StandardOps/Ops.td" +include "mlir/Dialect/StandardOps/IR/Ops.td" include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/lite/transforms/trim_functions_tf.cc b/tensorflow/compiler/mlir/lite/transforms/trim_functions_tf.cc index 5a7397ed9c9..13afa1bf9b8 100644 --- a/tensorflow/compiler/mlir/lite/transforms/trim_functions_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/trim_functions_tf.cc @@ -20,7 +20,7 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" #include "llvm/Support/CommandLine.h" -#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project #include "mlir/IR/Builders.h" // TF:llvm-project #include "mlir/IR/Identifier.h" // TF:llvm-project #include "mlir/IR/Location.h" // TF:llvm-project diff --git a/tensorflow/compiler/mlir/lite/transforms/while_loop_outline.cc b/tensorflow/compiler/mlir/lite/transforms/while_loop_outline.cc index 5589fc04c2d..8ed5b0e0341 100644 --- a/tensorflow/compiler/mlir/lite/transforms/while_loop_outline.cc +++ b/tensorflow/compiler/mlir/lite/transforms/while_loop_outline.cc @@ -17,7 +17,7 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" #include "llvm/Support/CommandLine.h" -#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project #include "mlir/IR/Builders.h" // TF:llvm-project #include "mlir/IR/Identifier.h" // TF:llvm-project #include "mlir/IR/Location.h" // TF:llvm-project diff --git a/tensorflow/compiler/mlir/lite/utils/attribute_utils.cc b/tensorflow/compiler/mlir/lite/utils/attribute_utils.cc index a9cc483df76..3d4bbdfa13c 100644 --- a/tensorflow/compiler/mlir/lite/utils/attribute_utils.cc +++ b/tensorflow/compiler/mlir/lite/utils/attribute_utils.cc @@ -38,7 +38,7 @@ FloatAttr GetSingleElementAsFloatOrSelf(Attribute attr) { IntegerAttr ExtractSingleElementAsInteger(ElementsAttr attr) { if (attr.getType().getNumElements() != 1 || - !attr.getType().getElementType().isa<IntegerType>()) { + !attr.getType().getElementType().isSignlessInteger()) { return {}; } SmallVector<uint64_t, 8> index(attr.getType().getRank(), 0); diff --git a/tensorflow/compiler/mlir/lite/utils/attribute_utils.h b/tensorflow/compiler/mlir/lite/utils/attribute_utils.h index 5a11690d15f..7c0ff910db1 100644 --- a/tensorflow/compiler/mlir/lite/utils/attribute_utils.h +++ b/tensorflow/compiler/mlir/lite/utils/attribute_utils.h @@ -19,7 +19,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_ATTRIBUTE_UTILS_H_ #define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_ATTRIBUTE_UTILS_H_ -#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project namespace mlir { namespace TFL { diff --git a/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc b/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc index 6d8bfab0e6c..400d504b8f1 100644 --- a/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc +++ b/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc @@ -21,7 +21,7 @@ limitations under the License. #include "llvm/ADT/StringRef.h" #include "llvm/Support/Casting.h" #include "llvm/Support/raw_ostream.h" -#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project #include "mlir/IR/Attributes.h" // TF:llvm-project #include "mlir/IR/Builders.h" // TF:llvm-project #include "mlir/IR/Function.h" // TF:llvm-project @@ -95,6 +95,14 @@ Value Transpose2D(OpBuilder* builder, Value value_to_transpose, return Transpose(builder, value_to_transpose, perm, type, location); } +Value Reverse(OpBuilder* builder, Value value_to_reverse, int axis, + RankedTensorType type, mlir::Location location) { + auto axis_op = CreateI32SplatConst(builder, {1}, axis, location); + // The result type will be the same as the input. + return builder->create<TF::ReverseV2Op>(location, type, value_to_reverse, + axis_op); +} + ArrayRef<int64_t> GetRankedTensorShape(Value value) { return value.getType().cast<RankedTensorType>().getShape(); } @@ -615,6 +623,16 @@ LogicalResult ConvertKerasLSTMLayer(mlir::FuncOp func_op, OpBuilder* builder) { final_input_type = final_inputs.getType().dyn_cast<RankedTensorType>(); } + // Handle go_backwards: + // LSTM in Keras semantic will reverse the input sequence if it's go_backwards + auto go_backwards_attr = func_op.getAttrOfType<BoolAttr>("tf.go_backwards"); + + if (go_backwards_attr != nullptr && go_backwards_attr.getValue()) { + // We assume input is already in {time, batch, size} layout. + final_inputs = + Reverse(builder, final_inputs, 0, final_input_type, func_op.getLoc()); + } + int batch = final_input_type.getDimSize(1); int time = final_input_type.getDimSize(0); diff --git a/tensorflow/compiler/mlir/lite/utils/lstm_utils_test.cc b/tensorflow/compiler/mlir/lite/utils/lstm_utils_test.cc index b229206a4e4..0593bd150c7 100644 --- a/tensorflow/compiler/mlir/lite/utils/lstm_utils_test.cc +++ b/tensorflow/compiler/mlir/lite/utils/lstm_utils_test.cc @@ -24,7 +24,7 @@ limitations under the License. #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Support/Casting.h" -#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project #include "mlir/IR/Attributes.h" // TF:llvm-project #include "mlir/IR/Builders.h" // TF:llvm-project #include "mlir/IR/Function.h" // TF:llvm-project diff --git a/tensorflow/compiler/mlir/lite/utils/stateful_ops_utils.cc b/tensorflow/compiler/mlir/lite/utils/stateful_ops_utils.cc index a12cad15256..4067cfb04b9 100644 --- a/tensorflow/compiler/mlir/lite/utils/stateful_ops_utils.cc +++ b/tensorflow/compiler/mlir/lite/utils/stateful_ops_utils.cc @@ -17,7 +17,7 @@ limitations under the License. #include <vector> -#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" namespace mlir { diff --git a/tensorflow/compiler/mlir/lite/utils/stateful_ops_utils.h b/tensorflow/compiler/mlir/lite/utils/stateful_ops_utils.h index 917ae93f6a8..635922d5cbb 100644 --- a/tensorflow/compiler/mlir/lite/utils/stateful_ops_utils.h +++ b/tensorflow/compiler/mlir/lite/utils/stateful_ops_utils.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_STATEFUL_OPS_UTILS_H_ #define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_STATEFUL_OPS_UTILS_H_ -#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project namespace mlir { namespace TFL { diff --git a/tensorflow/compiler/mlir/lite/utils/validators.h b/tensorflow/compiler/mlir/lite/utils/validators.h index e1ae4392881..fa1304c68e0 100644 --- a/tensorflow/compiler/mlir/lite/utils/validators.h +++ b/tensorflow/compiler/mlir/lite/utils/validators.h @@ -19,7 +19,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_VALIDATORS_H_ #define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_VALIDATORS_H_ -#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project #include "mlir/IR/StandardTypes.h" // TF:llvm-project namespace mlir { diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index 58e6cb005c1..466f882133f 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -90,7 +90,7 @@ gentbl( td_file = "ir/tf_saved_model_ops.td", td_srcs = [ "@llvm-project//mlir:include/mlir/IR/OpBase.td", - "@llvm-project//mlir:include/mlir/Dialect/StandardOps/Ops.td", + "@llvm-project//mlir:include/mlir/Dialect/StandardOps/IR/Ops.td", ], ) @@ -114,7 +114,7 @@ gentbl( td_file = "ir/tf_executor_ops.td", td_srcs = [ "@llvm-project//mlir:include/mlir/IR/OpBase.td", - "@llvm-project//mlir:include/mlir/Dialect/StandardOps/Ops.td", + "@llvm-project//mlir:include/mlir/Dialect/StandardOps/IR/Ops.td", ], ) @@ -138,7 +138,7 @@ gentbl( td_file = "ir/tf_device_ops.td", td_srcs = [ "@llvm-project//mlir:include/mlir/IR/OpBase.td", - "@llvm-project//mlir:include/mlir/Dialect/StandardOps/Ops.td", + "@llvm-project//mlir:include/mlir/Dialect/StandardOps/IR/Ops.td", ], ) @@ -281,6 +281,7 @@ cc_library( "transforms/generated_canonicalize.inc", "transforms/generated_optimize.inc", "transforms/graph_pruning.cc", + "transforms/launch_to_device_attribute.cc", "transforms/layout_optimization.cc", "transforms/mark_function_visibility.cc", "transforms/materialize_mlir_passthrough_op.cc", diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc index c6144ec21e3..85d87a56f01 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc @@ -26,7 +26,7 @@ limitations under the License. #include "llvm/ADT/StringSwitch.h" #include "llvm/Support/Casting.h" #include "llvm/Support/FormatVariadic.h" -#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project #include "mlir/Dialect/Traits.h" // TF:llvm-project #include "mlir/IR/Attributes.h" // TF:llvm-project #include "mlir/IR/Builders.h" // TF:llvm-project diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index 25ba5cc2cb1..77e098c37e5 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -3111,6 +3111,70 @@ cublas. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_MatrixBandPartOp : TF_Op<"MatrixBandPart", [NoSideEffect, AllTypesMatch<["input", "band"]>]> { + let summary = [{ +Copy a tensor setting everything outside a central band in each innermost matrix +to zero. + }]; + + let description = [{ +The `band` part is computed as follows: +Assume `input` has `k` dimensions `[I, J, K, ..., M, N]`, then the output is a +tensor with the same shape where + +`band[i, j, k, ..., m, n] = in_band(m, n) * input[i, j, k, ..., m, n]`. + +The indicator function + +`in_band(m, n) = (num_lower < 0 || (m-n) <= num_lower)) && + (num_upper < 0 || (n-m) <= num_upper)`. + +For example: + +``` +# if 'input' is [[ 0, 1, 2, 3] + [-1, 0, 1, 2] + [-2, -1, 0, 1] + [-3, -2, -1, 0]], + +tf.matrix_band_part(input, 1, -1) ==> [[ 0, 1, 2, 3] + [-1, 0, 1, 2] + [ 0, -1, 0, 1] + [ 0, 0, -1, 0]], + +tf.matrix_band_part(input, 2, 1) ==> [[ 0, 1, 0, 0] + [-1, 0, 1, 0] + [-2, -1, 0, 1] + [ 0, -2, -1, 0]] +``` + +Useful special cases: + +``` + tf.matrix_band_part(input, 0, -1) ==> Upper triangular part. + tf.matrix_band_part(input, -1, 0) ==> Lower triangular part. + tf.matrix_band_part(input, 0, 0) ==> Diagonal. +``` + }]; + + let arguments = (ins + TF_Tensor:$input, + TF_I32OrI64Tensor:$num_lower, + TF_I32OrI64Tensor:$num_upper + ); + + let results = (outs + TF_Tensor:$band + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + TF_DerivedOperandTypeAttr Tindex = TF_DerivedOperandTypeAttr<1>; + + let verifier = [{ + return Verify(*this); + }]; +} + def TF_MatrixDiagOp : TF_Op<"MatrixDiag", [NoSideEffect]> { let summary = [{ Returns a batched diagonal tensor with a given batched diagonal values. diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td index f3fdab674e4..92e6d522125 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td @@ -85,7 +85,7 @@ class TF_TensorFlowType <string name, string description> : // Any tensor element type allowed in TensorFlow ops def TF_ElementType : Type<Or<[AnyFloat.predicate, - AnyInteger.predicate, + AnySignlessInteger.predicate, AnyComplex.predicate, TF_TFDialectType.predicate]>, "tf.dtype">; diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc index 86de5578553..8d4c284bcf8 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc @@ -35,7 +35,7 @@ limitations under the License. #include "llvm/ADT/StringSwitch.h" #include "llvm/ADT/iterator_range.h" #include "llvm/Support/FormatVariadic.h" -#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project #include "mlir/Dialect/Traits.h" // TF:llvm-project #include "mlir/IR/Attributes.h" // TF:llvm-project #include "mlir/IR/Builders.h" // TF:llvm-project @@ -1506,6 +1506,29 @@ void LogicalNotOp::getCanonicalizationPatterns( LogicalNotOfLess, LogicalNotOfLessEqual>(context); } +//===----------------------------------------------------------------------===// +// MatrixBandPartOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(MatrixBandPartOp op) { + if (!HasRankAtLeast(op.input(), 2)) { + return op.emitOpError() + << "requires `input` to have rank of at least 2, but found " + << op.input().getType(); + } + if (!IsOfRankOrUnranked(op.num_lower(), 0)) { + return op.emitOpError() + << "requires `num_lower` to have 0 dimensions, but found " + << op.num_lower().getType(); + } + if (!IsOfRankOrUnranked(op.num_upper(), 0)) { + return op.emitOpError() + << "requires `num_upper` to have 0 dimensions, but found " + << op.num_upper().getType(); + } + return success(); +} + //===----------------------------------------------------------------------===// // MaxOp //===----------------------------------------------------------------------===// @@ -2104,7 +2127,8 @@ LogicalResult VerifyShapeOperandAndResult(Operation *op, Type operand_type, } Type element_type = result_ranked_type.getElementType(); - if (!element_type.isInteger(32) && !element_type.isInteger(64)) + if (!element_type.isSignlessInteger(32) && + !element_type.isSignlessInteger(64)) return op->emitOpError("requires int32 or int64 return type for result") << variadic_idx_str; diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_types.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_types.h index 2898338f8eb..4059aba209f 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_types.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_types.h @@ -91,7 +91,7 @@ class TensorFlowType : public Type { // Returns true if the specified type is a valid TensorFlow element type. static inline bool IsValidTFElementType(Type type) { return type.isa<ComplexType>() || type.isa<FloatType>() || - type.isa<IntegerType>() || type.isa<TensorFlowType>(); + type.isSignlessInteger() || type.isa<TensorFlowType>(); } // Returns true if this is a valid TensorFlow tensor type. diff --git a/tensorflow/compiler/mlir/tensorflow/tests/launch_to_device_attribute.mlir b/tensorflow/compiler/mlir/tensorflow/tests/launch_to_device_attribute.mlir new file mode 100644 index 00000000000..6e57c3aa6ca --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/launch_to_device_attribute.mlir @@ -0,0 +1,123 @@ +// RUN: tf-opt %s -split-input-file -verify-diagnostics -tf-launch-to-device-attribute | FileCheck %s --dump-input=fail + + +// Tests single TensorFlow op is hoisted out and has the correct device assigned +// by parent `tf_device.launch`. +// CHECK-LABEL: func @single_op_launch +func @single_op_launch() { + tf_executor.graph { + %0:5 = tf_executor.island { + %a = "tf.opA"() : () -> tensor<i1> + %launch:2 = "tf_device.launch"() ( { + %b:2 = "tf.opB"(%a) : (tensor<i1>) -> (tensor<i32>, tensor<f32>) + tf_device.return %b#1, %b#0 : tensor<f32>, tensor<i32> + }) {device = "CPU:0"} : () -> (tensor<f32>, tensor<i32>) + %c = "tf.opC"() : () -> tensor<i1> + tf_executor.yield %a, %launch#0, %launch#1, %c : tensor<i1>, tensor<f32>, tensor<i32>, tensor<i1> + } + tf_executor.fetch + } + return +} + +// CHECK: %[[A:.*]] = "tf.opA" +// CHECK: %[[B:.*]]:2 = "tf.opB"(%[[A]]) +// CHECK-SAME: device = "CPU:0" +// CHECK: %[[C:.*]] = "tf.opC" +// CHECK-NOT: "tf_device.launch" +// CHECK: tf_executor.yield %[[A]], %[[B]]#1, %[[B]]#0, %[[C]] + + +// Tests multiple TensorFlow ops are hoisted out and all have the correct device +// assigned by parent `tf_device.launch`. +// CHECK-LABEL: func @multi_op_launch +func @multi_op_launch() { + tf_executor.graph { + %0:5 = tf_executor.island { + %a = "tf.opA"() : () -> tensor<i1> + %launch:2 = "tf_device.launch"() ( { + %b = "tf.opB"(%a) : (tensor<i1>) -> tensor<i32> + %c = "tf.opC"(%b) : (tensor<i32>) -> tensor<f32> + tf_device.return %c, %b : tensor<f32>, tensor<i32> + }) {device = "CPU:0"} : () -> (tensor<f32>, tensor<i32>) + %d = "tf.opD"() : () -> tensor<i1> + tf_executor.yield %a, %launch#0, %launch#1, %d : tensor<i1>, tensor<f32>, tensor<i32>, tensor<i1> + } + tf_executor.fetch + } + return +} + +// CHECK: %[[A:.*]] = "tf.opA" +// CHECK: %[[B:.*]] = "tf.opB"(%[[A]]) +// CHECK-SAME: device = "CPU:0" +// CHECK: %[[C:.*]] = "tf.opC"(%[[B]]) +// CHECK-SAME: device = "CPU:0" +// CHECK: %[[D:.*]] = "tf.opD" +// CHECK-NOT: "tf_device.launch" +// CHECK: tf_executor.yield %[[A]], %[[C]], %[[B]], %[[D]] + + +// ----- + + +// Tests ops are hoisted out and devices are set only if the `tf_device.launch` +// contains TensorFlow ops. +func @non_tf_dialect_op_launch() { + tf_executor.graph { + %0:5 = tf_executor.island { + %a = "tf.opA"() : () -> tensor<i1> + // expected-error@+1 {{'tf_device.launch' op must contain only 'tf' dialect ops}} + %launch:2 = "tf_device.launch"() ( { + %b = "tf.opB"(%a) : (tensor<i1>) -> tensor<i32> + %c = "unknown.opC"(%b) : (tensor<i32>) -> tensor<f32> + tf_device.return %c, %b : tensor<f32>, tensor<i32> + }) {device = "CPU:0"} : () -> (tensor<f32>, tensor<i32>) + %d = "tf.opD"() : () -> tensor<i1> + tf_executor.yield %a, %launch#0, %launch#1, %d : tensor<i1>, tensor<f32>, tensor<i32>, tensor<i1> + } + tf_executor.fetch + } + return +} + + +// ----- + + +// Tests TensorFlow op with conflicting `device` attribute compared to parent +// `tf_device.launch`. +func @conflicting_device() { + tf_executor.graph { + %0 = tf_executor.island { + // expected-error@+1 {{'tf_device.launch' op inner 'tf' dialect op has conflicting 'device' attribute, got 'GPU:0' but expected 'CPU:0'}} + "tf_device.launch"() ( { + "tf.opA"() {device = "GPU:0"} : () -> () + tf_device.return + }) {device = "CPU:0"} : () -> () + tf_executor.yield + } + tf_executor.fetch + } + return +} + + +// ----- + + +// Tests TensorFlow op with bad `device` attribute already set. +func @bad_tf_device_attr() { + tf_executor.graph { + %0 = tf_executor.island { + // expected-error@+1 {{'tf_device.launch' op inner 'tf' dialect op has bad 'device' attribute}} + "tf_device.launch"() ( { + "tf.opA"() {device = 0 : i32} : () -> () + tf_device.return + }) {device = "CPU:0"} : () -> () + tf_executor.yield + } + tf_executor.fetch + } + return +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir index 9b29c5c1d92..319660ae4bb 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir @@ -854,6 +854,78 @@ func @testInvalidIfOp(tensor<i1>, tensor<*xf32>) -> tensor<2xf32> { // ----- +// Test valid tf.MatrixBandPart +// CHECK-LABEL: func @testValidMatrixBandPartOp +func @testValidMatrixBandPartOp(%arg0: tensor<64x64xbf16>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<64x64xbf16> { + %0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<64x64xbf16>, tensor<i64>, tensor<i64>) -> tensor<64x64xbf16> + return %0 : tensor<64x64xbf16> +} + +// ----- + +// Test valid tf.MatrixBandPart +// CHECK-LABEL: func @testValidMatrixBandPartOp3D +func @testValidMatrixBandPartOp3D(%arg0: tensor<64x64x64xbf16>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<64x64x64xbf16> { + %0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<64x64x64xbf16>, tensor<i64>, tensor<i64>) -> tensor<64x64x64xbf16> + return %0 : tensor<64x64x64xbf16> +} + +// ----- + +// Test valid tf.MatrixBandPart +// CHECK-LABEL: func @testValidMatrixBandPartOpUnranked +func @testValidMatrixBandPartOpUnranked(%arg0: tensor<*xbf16>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<*xbf16> { + %0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<*xbf16>, tensor<i64>, tensor<i64>) -> tensor<*xbf16> + return %0 : tensor<*xbf16> +} + +// ----- + +// Test invalid tf.MatrixBandPart +func @testInvalidMatrixBandPartOp(%arg0: tensor<64x64x64xbf16>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<64x64xbf16> { + // expected-error @+1 {{op failed to verify that all of {input, band} have same type}} + %0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<64x64x64xbf16>, tensor<i64>, tensor<i64>) -> tensor<64x64xbf16> + return %0 : tensor<64x64xbf16> +} + +// ----- + +// Test invalid tf.MatrixBandPart +func @testInvalidMatrixBandPartOp(%arg0: tensor<64x64x64xbf16>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<*xbf16> { + // expected-error @+1 {{op failed to verify that all of {input, band} have same type}} + %0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<64x64x64xbf16>, tensor<i64>, tensor<i64>) -> tensor<*xbf16> + return %0 : tensor<*xbf16> +} + +// ----- + +// Test invalid tf.MatrixBandPart +func @testInvalidMatrixBandPartOp(%arg0: tensor<i64>, %arg1: tensor<64x64xi64>, %arg2: tensor<i64>) -> tensor<i64> { + // expected-error @+1 {{op requires `input` to have rank of at least 2, but found 'tensor<i64>'}} + %0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<i64>, tensor<64x64xi64>, tensor<i64>) -> tensor<i64> + return %0 : tensor<i64> +} + +// ----- + +// Test invalid tf.MatrixBandPart +func @testInvalidMatrixBandPartOp(%arg0: tensor<64x64xi64>, %arg1: tensor<32xi64>, %arg2: tensor<i64>) -> tensor<64x64xi64> { + // expected-error @+1 {{op requires `num_lower` to have 0 dimensions, but found 'tensor<32xi64>'}} + %0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<64x64xi64>, tensor<32xi64>, tensor<i64>) -> tensor<64x64xi64> + return %0 : tensor<64x64xi64> +} + +// ----- + +// Test invalid tf.MatrixBandPart +func @testInvalidMatrixBandPartOp(%arg0: tensor<64x64xi64>, %arg1: tensor<i64>, %arg2: tensor<32xi64>) -> tensor<64x64xi64> { + // expected-error @+1 {{op requires `num_upper` to have 0 dimensions, but found 'tensor<32xi64>'}} + %0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<64x64xi64>, tensor<i64>, tensor<32xi64>) -> tensor<64x64xi64> + return %0 : tensor<64x64xi64> +} + +// ----- + //===--------------------------------------------------------------------===// // tf.{|Stateful}PartitionedCall //===--------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_optimize_global_tensors_interprocedural.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_optimize_global_tensors_interprocedural.mlir new file mode 100644 index 00000000000..afc95865236 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_optimize_global_tensors_interprocedural.mlir @@ -0,0 +1,220 @@ +// RUN: tf-opt -tf-saved-model-optimize-global-tensors -split-input-file %s | FileCheck %s --dump-input=fail +//===----------------------------------------------------------------------===// +// Immutability. +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: module attributes {tf_saved_model.semantics} +module attributes {tf_saved_model.semantics} { + + // Test case: This test exercises marking a global tensor as immutable after it propagates + // via set of chained calls -> f -> f_callee -> f_callee_callee + + // CHECK: "tf_saved_model.global_tensor"() { + // CHECK-NOT: is_mutable + // CHECK-SAME: } : () -> () + "tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> () + + func @f(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v}) -> (tensor<f32> {tf_saved_model.index_path = []}) + attributes {tf_saved_model.exported_names = ["f"]} { + %val = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @f_callee} : (tensor<!tf.resource<tensor<f32>>>) -> (tensor<f32>) + return %val : tensor<f32> + } + + func @f_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32> { + %val = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @f_callee_callee} : (tensor<*x!tf.resource>) -> (tensor<f32>) + return %val : tensor<f32> + } + + func @f_callee_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32> { + %val = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource>) -> tensor<f32> + return %val : tensor<f32> + } +} + +// ----- + +// CHECK-LABEL: module attributes {tf_saved_model.semantics} +module attributes {tf_saved_model.semantics} { + + // Test case: + // This test exercises trying to mark immutable when same func is called by multiple callers + // with different global tensors. + // CHECK: "tf_saved_model.global_tensor"() { + // CHECK-NOT: is_mutable + "tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> () + + // CHECK: "tf_saved_model.global_tensor"() { + // CHECK-NOT: is_mutable + "tf_saved_model.global_tensor"() { is_mutable, sym_name = "v2", type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> () + + func @f(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v}) -> (tensor<f32> {tf_saved_model.index_path = []}) + attributes {tf_saved_model.exported_names = ["f"]} { + %val = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @f_common} : (tensor<!tf.resource<tensor<f32>>>) -> (tensor<f32>) + return %val : tensor<f32> + } + + func @f2(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v2}) -> (tensor<f32> {tf_saved_model.index_path = []}) + attributes {tf_saved_model.exported_names = ["f2"]} { + %val = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @f_common} : (tensor<!tf.resource<tensor<f32>>>) -> (tensor<f32>) + return %val : tensor<f32> + } + + func @f_common(%arg0: tensor<*x!tf.resource>) -> tensor<f32> { + %val = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource>) -> tensor<f32> + return %val : tensor<f32> + } +} + +// ----- + +// CHECK-LABEL: module attributes {tf_saved_model.semantics} +module attributes {tf_saved_model.semantics} { + + // Test case: This test exercises immutability without explicit use + // via ReadVariableOp + + // CHECK: "tf_saved_model.global_tensor"() { + // CHECK-NOT: is_mutable + // CHECK-SAME: } : () -> () + "tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> () + + func @f(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v}) -> (tensor<f32> {tf_saved_model.index_path = []}) + attributes {tf_saved_model.exported_names = ["f"]} { + %val = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @f_callee} : (tensor<!tf.resource<tensor<f32>>>) -> (tensor<f32>) + %val_2 = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @f_callee} : (tensor<!tf.resource<tensor<f32>>>) -> (tensor<f32>) + return %val_2 : tensor<f32> + } + + func @f_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32> { + %cst_1 = constant dense<2.0> : tensor<f32> + return %cst_1 : tensor<f32> + } +} + +// ----- + +//===----------------------------------------------------------------------===// +// Test case: Test mutation detection propagates across function calls +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: module attributes {tf_saved_model.semantics} +module attributes {tf_saved_model.semantics} { + // CHECK: "tf_saved_model.global_tensor"() { + // CHECK-SAME: is_mutable + // CHECK-SAME: } : () -> () + "tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> () + + // CHECK: func @f(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v}) + func @f(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v}) -> (tensor<f32> {tf_saved_model.index_path = []}) + attributes {tf_saved_model.exported_names = ["f"]} { + %val = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @f_callee} : (tensor<!tf.resource<tensor<f32>>>) -> (tensor<f32>) + return %val : tensor<f32> + } + + // CHECK: func @f_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32> + func @f_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32> { + %val = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @f_callee_callee} : (tensor<*x!tf.resource>) -> (tensor<f32>) + return %val : tensor<f32> + } + + // CHECK: func @f_callee_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32> + func @f_callee_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32> { + %c0 = "tf.Const"() { value = dense<1.0> : tensor<f32> } : () -> tensor<f32> + "tf.AssignVariableOp"(%arg0, %c0) : (tensor<*x!tf.resource>, tensor<f32>) -> () + return %c0 : tensor<f32> + } +} + +// ----- + +// CHECK-LABEL: module attributes {tf_saved_model.semantics} +module attributes {tf_saved_model.semantics} { + + // Test case: The inter-procedural analysis with different types of + // TF call ops + + // CHECK: "tf_saved_model.global_tensor"() { + // CHECK-SAME: is_mutable + // CHECK-SAME: } : () -> () + "tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> () + + // CHECK: func @f(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v}) + func @f(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v}) -> (tensor<f32> {tf_saved_model.index_path = []}) + attributes {tf_saved_model.exported_names = ["f"]} { + %val = "tf.StatefulPartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @f_callee} : (tensor<!tf.resource<tensor<f32>>>) -> (tensor<f32>) + return %val : tensor<f32> + } + + // CHECK: func @f_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32> + func @f_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32> { + %val = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @f_callee_callee} : (tensor<*x!tf.resource>) -> (tensor<f32>) + return %val : tensor<f32> + } + + // CHECK: func @f_callee_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32> + func @f_callee_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32> { + %c0 = "tf.Const"() { value = dense<1.0> : tensor<f32> } : () -> tensor<f32> + "tf.AssignVariableOp"(%arg0, %c0) : (tensor<*x!tf.resource>, tensor<f32>) -> () + return %c0 : tensor<f32> + } +} + +// ----- + +// CHECK-LABEL: module attributes {tf_saved_model.semantics} +module attributes {tf_saved_model.semantics} { + + // Test case: The inter-procedural analysis does not recurse infinitely + + // CHECK: "tf_saved_model.global_tensor"() { + // CHECK-NOT: is_mutable + // CHECK-SAME: } : () -> () + "tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> () + + func @exported_f(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v}) -> (tensor<f32> {tf_saved_model.index_path = []}) + attributes {tf_saved_model.exported_names = ["exported_f"]} { + %val = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @f} : (tensor<!tf.resource<tensor<f32>>>) -> (tensor<f32>) + return %val : tensor<f32> + } + + + // CHECK: func @f(%arg0: tensor<*x!tf.resource>) -> tensor<f32> + func @f(%arg0: tensor<*x!tf.resource>) -> tensor<f32> { + %val = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @g} : (tensor<*x!tf.resource>) -> (tensor<f32>) + return %val : tensor<f32> + } + + // CHECK: func @g(%arg0: tensor<*x!tf.resource>) -> tensor<f32> + func @g(%arg0: tensor<*x!tf.resource>) -> tensor<f32> { + %val = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @f} : (tensor<*x!tf.resource>) -> (tensor<f32>) + return %val : tensor<f32> + } +} + +// ----- + +// CHECK-LABEL: module attributes {tf_saved_model.semantics} +module attributes {tf_saved_model.semantics} { + + // Test case: Inter-procedural analysis with resource usage in an + // unknown op, we assume mutating behavior and propagate that. + + // CHECK: "tf_saved_model.global_tensor"() { + // CHECK-SAME: is_mutable + // CHECK-SAME: } : () -> () + "tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> () + + func @exported_f(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v}) -> (tensor<f32> {tf_saved_model.index_path = []}) + attributes {tf_saved_model.exported_names = ["exported_f"]} { + %val = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @f} : (tensor<!tf.resource<tensor<f32>>>) -> (tensor<f32>) + return %val : tensor<f32> + } + + + // CHECK: func @f(%arg0: tensor<*x!tf.resource>) -> tensor<f32> + func @f(%arg0: tensor<*x!tf.resource>) -> tensor<f32> { + %c0 = "tf.Const"() { value = dense<1.0> : tensor<f32> } : () -> tensor<f32> + "tf.AssignAddVariableOp"(%arg0, %c0) : (tensor<*x!tf.resource>, tensor<f32>) -> () + return %c0 : tensor<f32> + } +} diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/cluster_outlining.cc b/tensorflow/compiler/mlir/tensorflow/transforms/cluster_outlining.cc index f181924d0a6..0fef58ebb8a 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/cluster_outlining.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/cluster_outlining.cc @@ -17,7 +17,7 @@ limitations under the License. // `tf_device.launch` with equivalent `tf_device.launch_func` operations. #include "llvm/ADT/SmallVector.h" -#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project #include "mlir/IR/Attributes.h" // TF:llvm-project #include "mlir/IR/Block.h" // TF:llvm-project #include "mlir/IR/Builders.h" // TF:llvm-project diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.td b/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.td index a95a319d0a4..bac7b9ba01c 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.td @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ include "mlir/IR/OpBase.td" -include "mlir/Dialect/StandardOps/Ops.td" +include "mlir/Dialect/StandardOps/IR/Ops.td" include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" // Here, the element type can be any integer or float type. But, note that only diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_inline_tpu_island.cc b/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_inline_tpu_island.cc index 9660367cb68..ad844883453 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_inline_tpu_island.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_inline_tpu_island.cc @@ -17,7 +17,7 @@ limitations under the License. #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Twine.h" -#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project #include "mlir/IR/Attributes.h" // TF:llvm-project #include "mlir/IR/Builders.h" // TF:llvm-project #include "mlir/IR/Visitors.h" // TF:llvm-project diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_outline_tpu_island.cc b/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_outline_tpu_island.cc index 57ea1822b5b..01901d8b5a4 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_outline_tpu_island.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_outline_tpu_island.cc @@ -16,7 +16,7 @@ limitations under the License. #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Twine.h" -#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project #include "mlir/IR/Attributes.h" // TF:llvm-project #include "mlir/IR/Builders.h" // TF:llvm-project #include "mlir/IR/SymbolTable.h" // TF:llvm-project diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/fold_switch.cc b/tensorflow/compiler/mlir/tensorflow/transforms/fold_switch.cc index 44309a5e019..4d5ad5ad423 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/fold_switch.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/fold_switch.cc @@ -31,7 +31,7 @@ limitations under the License. #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" #include "mlir/Analysis/LoopAnalysis.h" // TF:llvm-project -#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project #include "mlir/IR/Attributes.h" // TF:llvm-project #include "mlir/IR/Block.h" // TF:llvm-project #include "mlir/IR/MLIRContext.h" // TF:llvm-project diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_cfg.cc b/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_cfg.cc index 31898ec1048..8cfa69c396e 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_cfg.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_cfg.cc @@ -16,7 +16,7 @@ limitations under the License. // This transformation pass transforms functional control flow operations in the // standard TensorFlow dialect to MLIR Control Flow Graph (CFG) form. -#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project #include "mlir/IR/Attributes.h" // TF:llvm-project #include "mlir/IR/Builders.h" // TF:llvm-project #include "mlir/IR/Operation.h" // TF:llvm-project diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/launch_to_device_attribute.cc b/tensorflow/compiler/mlir/tensorflow/transforms/launch_to_device_attribute.cc new file mode 100644 index 00000000000..9a196aef54b --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/launch_to_device_attribute.cc @@ -0,0 +1,135 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This pass hoists a `tf_device.launch` body and assigns a `device` attribute +// to each TensorFlow dialect op in the body based on the `device` attribute on +// the `tf_device.launch`. If a TensorFlow dialect op already has a device +// attribute, that attribute will be overwritten with the `tf_device.launch` +// device. +// +// For example: +// %island:5 = tf_executor.island { +// %a = "tf.opA"() : () -> tensor<i1> +// %launch:2 = "tf_device.launch"() ( { +// %b = "tf.opB"() : () -> tensor<i32> +// %c = "tf.opC"() : () -> tensor<f32> +// tf_device.return %c, %b : tensor<f32>, tensor<i32> +// }) {device = "CPU:0"} : () -> (tensor<f32>, tensor<i32>) +// %d = "tf.opD"() : () -> tensor<i1> +// tf_executor.yield %a, %launch#0, %launch#1, %d : +// tensor<i1>, tensor<f32>, tensor<i32>, tensor<i1> +// } +// +// Will be transformed into: +// %island:5 = tf_executor.island { +// %a = "tf.opA"() : () -> tensor<i1> +// %b = "tf.opB"() {device = "CPU:0"} : () -> tensor<i32> +// %c = "tf.opC"() {device = "CPU:0"} : () -> tensor<f32> +// %d = "tf.opD"() : () -> tensor<i1> +// tf_executor.yield %a, %c, %b, %d : +// tensor<i1>, tensor<f32>, tensor<i32>, tensor<i1> +// } + +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/Block.h" // TF:llvm-project +#include "mlir/IR/Dialect.h" // TF:llvm-project +#include "mlir/IR/Operation.h" // TF:llvm-project +#include "mlir/IR/Visitors.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/Support/LogicalResult.h" // TF:llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" + +namespace mlir { +namespace TFDevice { +namespace { +constexpr char kDeviceAttr[] = "device"; + +struct LaunchToDeviceAttributePass + : public FunctionPass<LaunchToDeviceAttributePass> { + void runOnFunction() override; +}; + +LogicalResult HoistOpsAndAnnotateWithDevice(const Dialect* tf_dialect, + tf_device::LaunchOp launch) { + // Forward launch inner op results to launch op results. + launch.replaceAllUsesWith(launch.GetBody().getTerminator()->getOperands()); + + // For all inner ops of the TensorFlow dialect, assign the launch device as a + // `device` attribute. + auto body = launch.GetBody().without_terminator(); + for (Operation& op : body) { + if (op.getDialect() != tf_dialect) + return launch.emitOpError() << "must contain only 'tf' dialect ops"; + + auto device_attr = op.getAttr(kDeviceAttr); + if (!device_attr) { + op.setAttr(kDeviceAttr, launch.deviceAttr()); + continue; + } + + if (auto device_str_attr = device_attr.dyn_cast<StringAttr>()) { + if (launch.device() != device_str_attr.getValue()) + return launch.emitOpError() + << "inner 'tf' dialect op has conflicting 'device' attribute, " + "got '" + << device_str_attr.getValue() << "' but expected '" + << launch.device() << "'"; + } else { + return launch.emitOpError() + << "inner 'tf' dialect op has bad 'device' attribute"; + } + } + + // Move all inner ops of the launch to the block containing the launch. + Operation* launch_op = launch.getOperation(); + launch_op->getBlock()->getOperations().splice( + launch_op->getIterator(), launch.GetBody().getOperations(), body.begin(), + body.end()); + + launch.erase(); + + return success(); +} + +void LaunchToDeviceAttributePass::runOnFunction() { + const Dialect* tf_dialect = getContext().getRegisteredDialect("tf"); + if (!tf_dialect) { + signalPassFailure(); + getFunction().emitError() << "'tf' dialect is not registered"; + } + + auto result = getFunction().walk([&](tf_device::LaunchOp launch) { + if (failed(HoistOpsAndAnnotateWithDevice(tf_dialect, launch))) + return WalkResult::interrupt(); + + return WalkResult::advance(); + }); + + if (result.wasInterrupted()) return signalPassFailure(); +} + +} // anonymous namespace + +std::unique_ptr<OpPassBase<FuncOp>> CreateLaunchToDeviceAttributePass() { + return std::make_unique<LaunchToDeviceAttributePass>(); +} + +static PassRegistration<LaunchToDeviceAttributePass> pass( + "tf-launch-to-device-attribute", + "Hoists and annotates device launch inner ops with associated device " + "attribute"); + +} // namespace TFDevice +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td index ec0ac5e3c1e..1074f9e1926 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ include "mlir/IR/OpBase.td" -include "mlir/Dialect/StandardOps/Ops.td" +include "mlir/Dialect/StandardOps/IR/Ops.td" include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" // Here, the element type can be any integer or float type. But, note that only @@ -122,7 +122,7 @@ def LowerSparseSoftmaxCrossEntropyWithLogitsOp : Pattern< //===----------------------------------------------------------------------===// def ComplexTensor : TensorOf<[AnyComplex]>; -def RealTensor : TensorOf<[AnyInteger, AnyFloat]>; +def RealTensor : TensorOf<[AnySignlessInteger, AnyFloat]>; def : Pat<(TF_SquareOp $val), (TF_MulOp $val, $val)>; @@ -179,7 +179,7 @@ def LowerL2LossOp : // Pad op patterns. //===----------------------------------------------------------------------===// -def : Pat<(TF_PadOp TensorOf<[AnyInteger, AnyFloat]>:$input, $paddings), +def : Pat<(TF_PadOp TensorOf<[AnySignlessInteger, AnyFloat]>:$input, $paddings), (TF_PadV2Op $input, $paddings, (TF_ConstOp (GetScalarOfType<0> $input)))>; @@ -224,6 +224,6 @@ def CreateTFShapeOp : NativeCodeCall< // TODO(hinsu): Support inputs of TensorList types. def LowerZerosLikeOp : - Pat<(TF_ZerosLikeOp:$src_op TensorOf<[AnyInteger, AnyFloat]>:$input), + Pat<(TF_ZerosLikeOp:$src_op TensorOf<[AnySignlessInteger, AnyFloat]>:$input), (TF_BroadcastToOp (TF_ConstOp (GetScalarOfType<0> $input)), (CreateTFShapeOp $src_op, $input, /*use 32bit*/ConstBoolAttrFalse))>; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/optimize.cc b/tensorflow/compiler/mlir/tensorflow/transforms/optimize.cc index 36d7712eb2c..acaf7974280 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/optimize.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/optimize.cc @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ #include <iostream> -#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project #include "mlir/IR/Attributes.h" // TF:llvm-project #include "mlir/IR/Builders.h" // TF:llvm-project #include "mlir/IR/Operation.h" // TF:llvm-project diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/optimize.td b/tensorflow/compiler/mlir/tensorflow/transforms/optimize.td index 87467238e57..0fb62cb064d 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/optimize.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/optimize.td @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ include "mlir/IR/OpBase.td" -include "mlir/Dialect/StandardOps/Ops.td" +include "mlir/Dialect/StandardOps/IR/Ops.td" include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" def IsDataFormatNHWC : ConstantAttr<TF_ConvnetDataFormatAttr, "NHWC">; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/optimize_global_tensors.cc b/tensorflow/compiler/mlir/tensorflow/transforms/optimize_global_tensors.cc index 2a61b0cf0d1..1aaceb8ecc7 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/optimize_global_tensors.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/optimize_global_tensors.cc @@ -15,16 +15,27 @@ limitations under the License. // This pass optimizes tf_saved_model.global_tensor ops. +#include <cstddef> #include <map> #include <set> #include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/STLExtras.h" +#include "mlir/Analysis/CallInterfaces.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project #include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/Function.h" // TF:llvm-project #include "mlir/IR/Module.h" // TF:llvm-project +#include "mlir/IR/Operation.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/IR/SymbolTable.h" // TF:llvm-project +#include "mlir/IR/Types.h" // TF:llvm-project #include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/Support/LLVM.h" // TF:llvm-project +#include "mlir/Support/LogicalResult.h" // TF:llvm-project +#include "mlir/Transforms/RegionUtils.h" // TF:llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" namespace mlir { namespace tf_saved_model { @@ -45,8 +56,142 @@ struct GlobalTensorUse { using GlobalTensorUsesMap = std::map<GlobalTensorOp, std::vector<GlobalTensorUse>>; +static bool IsResourceType(Type type) { + if (auto tensor_type = type.dyn_cast<TensorType>()) { + return tensor_type.getElementType().isa<TF::ResourceType>(); + } + return false; +} + +static bool IsResource(Value value) { return IsResourceType(value.getType()); } + +class ResourceAnalyzer { + public: + explicit ResourceAnalyzer(ModuleOp module) { + SymbolTable symbol_table(module); + for (auto func : module.getOps<FuncOp>()) { + AnalyzeFunc(func, symbol_table); + } + } + + bool IsPotentiallyWritten(Value resource) const { + assert(IsResource(resource)); + auto it = resource_infos_.find(resource); + if (it == resource_infos_.end()) { + return false; + } + return it->second.potentially_written; + } + + private: + // Analyze the specified func for resource mutating operations, namely + // TF::AssignVariableOp, if so, set the resource associated as "potentially + // written". Do this recursively across the chain of funcs via call or control + // flow ops. + // TODO(ashwinm): Move to iterative traversal. + LogicalResult AnalyzeFunc(FuncOp func, const SymbolTable& symbol_table) { + // Avoid infinite recursion. + if (!discovered_.insert(func).second) { + return success(); + } + + func.walk([&](Operation* op) { + if (isa<TF::ReadVariableOp>(op) || isa<ReturnOp>(op)) { + return; + } + if (auto assign_variable = dyn_cast<TF::AssignVariableOp>(op)) { + SetPotentiallyWritten(assign_variable.resource()); + return; + } + if (auto call = dyn_cast<CallOpInterface>(op)) { + if (auto sym = op->getAttrOfType<SymbolRefAttr>("f")) { + PropagatePotentiallyWrittenUpFromCallee( + sym.cast<FlatSymbolRefAttr>().getValue(), call.getArgOperands(), + symbol_table); + } + return; + } + if (auto if_op = dyn_cast<TF::IfOp>(op)) { + for (auto callee : {if_op.then_branch(), if_op.else_branch()}) { + PropagatePotentiallyWrittenUpFromCallee(callee, if_op.input(), + symbol_table); + } + return; + } + if (auto while_op = dyn_cast<TF::WhileOp>(op)) { + for (auto callee : {while_op.cond(), while_op.body()}) { + PropagatePotentiallyWrittenUpFromCallee(callee, while_op.input(), + symbol_table); + } + return; + } + // For all other ops, we assume it mutates all resources it uses, so + // this errs on the side of being conservative. We should improve + // this by using either a property or a trait that clearly + // identifies ops with resource mutating behavior. + if (PropagatePotentiallyWrittenWithinUnhandledOp(op)) { + return; + } + }); + return success(); + } + + // If an op is not one of the handled ones, we assume all resource usages + // within its purview are mutating in nature. + bool PropagatePotentiallyWrittenWithinUnhandledOp(Operation* op) { + for (auto operand : op->getOperands()) { + if (IsResource(operand)) { + SetPotentiallyWritten(operand); + return true; + } + } + bool uses_resources = false; + visitUsedValuesDefinedAbove(op->getRegions(), [&](OpOperand* operand) { + if (IsResource(operand->get())) { + SetPotentiallyWritten(operand->get()); + uses_resources = true; + } + }); + return uses_resources; + } + + // Given a funcOp associated with the callee and operands from the + // corresponding callOp, propagate the potentially written decision to the + // callOp's operands, if the corresponding func's arguments are potentially + // written resources. + void PropagatePotentiallyWrittenUpFromCallee( + StringRef callee, Operation::operand_range propagate_to, + const SymbolTable& symbol_table) { + auto func = symbol_table.lookup<FuncOp>(callee); + AnalyzeFunc(func, symbol_table); + for (auto t : llvm::zip(func.getArguments(), propagate_to)) { + if (!IsResource(std::get<0>(t))) { + continue; + } + if (IsPotentiallyWritten(std::get<0>(t))) { + SetPotentiallyWritten(std::get<1>(t)); + } + } + } + + void SetPotentiallyWritten(Value resource) { + assert(IsResource(resource)); + resource_infos_[resource].potentially_written = true; + } + struct ResourceInfo { + bool potentially_written = false; + }; + // Key: Resource Value's + // Value: Information we know about that Value. + // Note that these Value's are in general in different functions. + DenseMap<Value, ResourceInfo> resource_infos_; + // The set of func's we already discovered. + DenseSet<FuncOp> discovered_; +}; + bool IsImmutable(GlobalTensorOp global_tensor, - ArrayRef<GlobalTensorUse> global_tensor_uses) { + ArrayRef<GlobalTensorUse> global_tensor_uses, + const ResourceAnalyzer& resource_analyzer) { // Global tensor is already known to be immutable. if (!global_tensor.is_mutable()) { return false; @@ -57,17 +202,11 @@ bool IsImmutable(GlobalTensorOp global_tensor, return false; } - // Check the uses to see if this global tensor is only used in a way that - // is compatible with being immutable. - // Right now, this uses a very simple algorithm that only checks the top-level - // func for tf.ReadVariableOp. If the resource is passed into other functions - // or control flow, we fail to prove it is freezable even though we could. + // A global tensor is immutable if the resource analyzer deems it so. for (auto& global_tensor_use : global_tensor_uses) { auto arg = global_tensor_use.func.getArgument(global_tensor_use.arg_index); - for (auto user : arg.getUsers()) { - if (!isa<TF::ReadVariableOp>(user)) { - return false; - } + if (resource_analyzer.IsPotentiallyWritten(arg)) { + return false; } } return true; @@ -96,11 +235,12 @@ static GlobalTensorUsesMap CreateGlobalTensorUsesMap(ModuleOp module) { // Removes `is_mutable` attribute from tf_saved_model.global_tensor ops where we // can prove it is safe to do so. void MarkGlobalTensorsImmutable( - ModuleOp module, const GlobalTensorUsesMap& global_tensor_uses_map) { + ModuleOp module, const GlobalTensorUsesMap& global_tensor_uses_map, + const ResourceAnalyzer& resource_analyzer) { for (const auto& kv : global_tensor_uses_map) { auto global_tensor = kv.first; const auto& global_tensor_uses = kv.second; - if (IsImmutable(global_tensor, global_tensor_uses)) { + if (IsImmutable(global_tensor, global_tensor_uses, resource_analyzer)) { global_tensor.removeAttr("is_mutable"); } } @@ -137,17 +277,15 @@ void EraseUnusedBoundInputs(ModuleOp module) { } void OptimizeGlobalTensorsPass::runOnModule() { - // This analysis could be much more elaborate, including tracking global - // tensors interprocedurally and uses in a wide variety of ops. But I don't - // know if we need that complexity. auto module = getModule(); - EraseUnusedBoundInputs(module); - // Figure out which func's use each tf_saved_model.global_tensor. + ResourceAnalyzer resource_analyzer(module); + GlobalTensorUsesMap global_tensor_uses = CreateGlobalTensorUsesMap(module); - MarkGlobalTensorsImmutable(module, global_tensor_uses); + MarkGlobalTensorsImmutable(module, global_tensor_uses, resource_analyzer); + EraseUnusedGlobalTensors(module, global_tensor_uses); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h index 16d8ddfb900..17af8c3cfbb 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h @@ -186,6 +186,10 @@ std::unique_ptr<OpPassBase<FuncOp>> CreateParallelExecuteToIslandsPass(); // same data across replicas. std::unique_ptr<OpPassBase<ModuleOp>> CreateAnnotateParameterReplicationPass(); +// Creates a pass that hoists a `tf_device.launch` body and assigns a `device` +// attribute to each TensorFlow dialect op in the body based on the `device` +// attribute on the `tf_device.launch`. +std::unique_ptr<OpPassBase<FuncOp>> CreateLaunchToDeviceAttributePass(); } // namespace TFDevice namespace TFTPU { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc b/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc index 13dda3ed0d0..d3cc508a490 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc @@ -27,7 +27,7 @@ limitations under the License. #include "llvm/ADT/StringRef.h" #include "llvm/Support/FormatVariadic.h" -#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project #include "mlir/IR/Function.h" // TF:llvm-project #include "mlir/IR/StandardTypes.h" // TF:llvm-project #include "mlir/Pass/Pass.h" // TF:llvm-project diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc index 8dc21feca90..dee5e8b079f 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc @@ -22,7 +22,7 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Casting.h" -#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project #include "mlir/IR/Attributes.h" // TF:llvm-project #include "mlir/IR/Block.h" // TF:llvm-project #include "mlir/IR/BlockAndValueMapping.h" // TF:llvm-project diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc index c44f0f97fd6..b3474e2faf1 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc @@ -25,7 +25,7 @@ limitations under the License. #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" #include "llvm/Support/FormatVariadic.h" -#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project #include "mlir/IR/Block.h" // TF:llvm-project #include "mlir/IR/Builders.h" // TF:llvm-project #include "mlir/IR/Diagnostics.h" // TF:llvm-project diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc index 384b66bc737..eb57e8ff742 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc @@ -26,7 +26,7 @@ limitations under the License. #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Casting.h" -#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project #include "mlir/IR/Attributes.h" // TF:llvm-project #include "mlir/IR/Builders.h" // TF:llvm-project #include "mlir/IR/Function.h" // TF:llvm-project diff --git a/tensorflow/compiler/mlir/tensorflow/translate/breakup-islands.cc b/tensorflow/compiler/mlir/tensorflow/translate/breakup-islands.cc index 8136db7d164..32cb2e02930 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/breakup-islands.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/breakup-islands.cc @@ -19,7 +19,7 @@ limitations under the License. #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" -#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project #include "mlir/IR/Attributes.h" // TF:llvm-project #include "mlir/IR/Builders.h" // TF:llvm-project #include "mlir/IR/Operation.h" // TF:llvm-project diff --git a/tensorflow/compiler/mlir/tensorflow/translate/control_to_executor_dialect.cc b/tensorflow/compiler/mlir/tensorflow/translate/control_to_executor_dialect.cc index 672ba418489..b89b3d8e6b2 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/control_to_executor_dialect.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/control_to_executor_dialect.cc @@ -22,7 +22,7 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Sequence.h" #include "llvm/Support/Debug.h" -#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project #include "mlir/IR/Builders.h" // TF:llvm-project #include "mlir/IR/Operation.h" // TF:llvm-project #include "mlir/IR/Value.h" // TF:llvm-project diff --git a/tensorflow/compiler/mlir/tensorflow/translate/executor_to_control_dialect.cc b/tensorflow/compiler/mlir/tensorflow/translate/executor_to_control_dialect.cc index 96a7fcbb5ba..7755f5f2259 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/executor_to_control_dialect.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/executor_to_control_dialect.cc @@ -21,7 +21,7 @@ limitations under the License. #include "llvm/ADT/SmallString.h" #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" -#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project #include "mlir/IR/Builders.h" // TF:llvm-project #include "mlir/IR/Operation.h" // TF:llvm-project #include "mlir/IR/Value.h" // TF:llvm-project diff --git a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc index 529c2517508..0ae02ed63b6 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc @@ -28,7 +28,7 @@ limitations under the License. #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Casting.h" -#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project #include "mlir/IR/Attributes.h" // TF:llvm-project #include "mlir/IR/Builders.h" // TF:llvm-project #include "mlir/IR/Function.h" // TF:llvm-project diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc index 3a624490da1..1f0f8e2b9de 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc @@ -42,7 +42,7 @@ limitations under the License. #include "llvm/ADT/Twine.h" #include "llvm/Support/raw_ostream.h" #include "mlir/Analysis/Verifier.h" // TF:llvm-project -#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project #include "mlir/IR/Attributes.h" // TF:llvm-project #include "mlir/IR/Builders.h" // TF:llvm-project #include "mlir/IR/Function.h" // TF:llvm-project diff --git a/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_pass.h b/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_pass.h index 79a302b066b..4a67b7fae76 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_pass.h +++ b/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_pass.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_MLIR_ROUNDTRIP_PASS_H_ #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_MLIR_ROUNDTRIP_PASS_H_ -#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project #include "tensorflow/core/common_runtime/optimization_registry.h" #include "tensorflow/core/lib/core/status.h" diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tf_functional_to_executor.cc b/tensorflow/compiler/mlir/tensorflow/translate/tf_functional_to_executor.cc index a97bca9fc3d..2ee3893eac9 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/tf_functional_to_executor.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/tf_functional_to_executor.cc @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ #include "llvm/Support/Debug.h" -#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project #include "mlir/IR/Builders.h" // TF:llvm-project #include "mlir/IR/Operation.h" // TF:llvm-project #include "mlir/Pass/Pass.h" // TF:llvm-project diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc index cce5fde2883..ef1e52ee5c0 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc @@ -17,7 +17,7 @@ limitations under the License. #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/StringRef.h" -#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project #include "mlir/IR/Function.h" // TF:llvm-project #include "mlir/IR/MLIRContext.h" // TF:llvm-project #include "mlir/IR/OpDefinition.h" // TF:llvm-project diff --git a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc index a64b7ecfdb3..f8c118ac9d9 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc @@ -23,7 +23,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Casting.h" -#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project #include "mlir/IR/Attributes.h" // TF:llvm-project #include "mlir/IR/Function.h" // TF:llvm-project #include "mlir/IR/Identifier.h" // TF:llvm-project diff --git a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc index bc9bdf49a39..f00e880f36b 100644 --- a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc +++ b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc @@ -19,7 +19,7 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/raw_ostream.h" -#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project #include "mlir/IR/Attributes.h" // TF:llvm-project #include "mlir/IR/BlockAndValueMapping.h" // TF:llvm-project #include "mlir/IR/Builders.h" // TF:llvm-project diff --git a/tensorflow/compiler/mlir/xla/hlo_module_importer.cc b/tensorflow/compiler/mlir/xla/hlo_module_importer.cc index f8eabeb046d..82304f95e33 100644 --- a/tensorflow/compiler/mlir/xla/hlo_module_importer.cc +++ b/tensorflow/compiler/mlir/xla/hlo_module_importer.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/xla/hlo_module_importer.h" -#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project #include "mlir/IR/Attributes.h" // TF:llvm-project #include "mlir/IR/Location.h" // TF:llvm-project #include "mlir/IR/OperationSupport.h" // TF:llvm-project diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc index 41ef8690735..b011b6069c7 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc @@ -1130,7 +1130,7 @@ Type SliceOp::InferOutputTypes(Builder* builder, Value operand, // Illegal attributes. ShapedType attr_ty = start_indices.getType(); if (attr_ty.getRank() != 1 || attr_ty.getNumElements() != rank || - !attr_ty.getElementType().isInteger(64) || + !attr_ty.getElementType().isSignlessInteger(64) || limit_indices.getType() != attr_ty || strides.getType() != attr_ty) return ty; diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td index 269e1cc8897..42b42d99380 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td @@ -52,7 +52,7 @@ def HLO_FpTensor : TensorOf<[AnyFloat]>; def HLO_PredTensor : TensorOf<[HLO_Pred]>; -def HLO_Tensor : TensorOf<[AnyFloat, AnyInteger, AnyComplex]>; +def HLO_Tensor : TensorOf<[AnyFloat, AnySignlessInteger, AnyComplex]>; def HLO_ComplexTensor : TensorOf<[AnyComplex]>; @@ -64,13 +64,13 @@ def HLO_TensorOrTuple : AnyTypeOf<[HLO_Tensor, HLO_Tuple]>; // an index type (as it stores indices) but that is currently disallowed in // MLIR. def HLO_DimensionTensor : ShapedContainerType< - [AnyInteger], And<[IsTensorTypePred, HasAnyRankOfPred<[1]>]>, + [AnySignlessInteger], And<[IsTensorTypePred, HasAnyRankOfPred<[1]>]>, "a 1D tensor of dimensions">; // In general, static shaped tensor constraints should be avoided unless // it is for a legacy op which is only correct with static shapes. def HLO_StaticShapeTensor : StaticShapeTensorOf<[ - AnyFloat, AnyInteger, AnyComplex]>; + AnyFloat, AnySignlessInteger, AnyComplex]>; //===----------------------------------------------------------------------===// // XLA combined type definitions. @@ -784,7 +784,7 @@ def HLO_ScalarsToDimensionTensorOp : HLO_Op<"scalars_to_dimension_tensor", compute shape arguments to dynamic operations. }]; - let arguments = (ins Variadic<AnyInteger>); + let arguments = (ins Variadic<AnySignlessInteger>); let results = (outs HLO_DimensionTensor); // Cannot be exported to legacy formats. diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_utils.h b/tensorflow/compiler/mlir/xla/ir/hlo_utils.h index 120b035e5d0..3e3570f5b54 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_utils.h +++ b/tensorflow/compiler/mlir/xla/ir/hlo_utils.h @@ -40,7 +40,7 @@ static ElementsAttr getSplat(Builder* b, Value val, T constant) { // Handle integer elements. Attribute elementAttr; - if (valElementType.isa<IntegerType>()) + if (valElementType.isSignlessInteger()) elementAttr = b->getIntegerAttr(valElementType, constant); else if (valElementType.isa<FloatType>()) elementAttr = b->getFloatAttr(valElementType, constant); diff --git a/tensorflow/compiler/mlir/xla/ir/lhlo_ops.td b/tensorflow/compiler/mlir/xla/ir/lhlo_ops.td index 794fee181a6..3a675f20d92 100644 --- a/tensorflow/compiler/mlir/xla/ir/lhlo_ops.td +++ b/tensorflow/compiler/mlir/xla/ir/lhlo_ops.td @@ -42,7 +42,7 @@ def LHLO_PredBuffer : MemRefOf<[HLO_Pred]>; // Any integer or floating-point tensor types def LHLO_IntOrFpBuffer : MemRefOf<[HLO_Int, AnyFloat]>; -def LHLO_Buffer : MemRefOf<[AnyFloat, AnyInteger]>; +def LHLO_Buffer : MemRefOf<[AnyFloat, AnySignlessInteger]>; def LHLO_TupleBuffer : NestedTupleOf<[LHLO_Buffer]>; diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc index 8fa7d809024..92614755ec3 100644 --- a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc +++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc @@ -27,7 +27,7 @@ limitations under the License. #include "llvm/Support/SMLoc.h" #include "llvm/Support/SourceMgr.h" #include "llvm/Support/raw_ostream.h" -#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project #include "mlir/IR/Attributes.h" // TF:llvm-project #include "mlir/IR/Function.h" // TF:llvm-project #include "mlir/IR/Location.h" // TF:llvm-project @@ -1221,7 +1221,7 @@ LogicalResult AddDynamicParameterBindings(mlir::ModuleOp module, << "requires arg " << padding_arg_index << " to be a scalar for use as a dynamic parameter"; - if (!mlir::getElementTypeOrSelf(padding_arg_type).isa<IntegerType>()) + if (!mlir::getElementTypeOrSelf(padding_arg_type).isSignlessInteger()) return entry_func.emitError() << "requires arg " << padding_arg_index << " to be of an int type for use as a dynamic parameter"; diff --git a/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-linalg.mlir b/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-linalg.mlir index 33d5884a882..d43ca3b6bb2 100644 --- a/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-linalg.mlir +++ b/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-linalg.mlir @@ -210,7 +210,6 @@ func @broadcast(%operand: memref<5x7x1xf32>, %result: memref<7x10x6x4x5xf32>) { // ----- -// CHECK-DAG: #[[OPERAND_MPA:.*]] = affine_map<(d0, d1, d2) -> (0)> // CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> // CHECK-LABEL: func @broadcast_scalar func @broadcast_scalar(%operand: memref<f32>, %result: memref<7x10x6xf32>) { @@ -219,9 +218,10 @@ func @broadcast_scalar(%operand: memref<f32>, %result: memref<7x10x6xf32>) { : (memref<f32>, memref<7x10x6xf32>) -> () return } -// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] -// CHECK-NEXT: ^bb0(%[[OPERAND:[a-zA-Z0-9_]*]]: f32, %[[RESULT:.*]]: f32): -// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32 +// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[RESULT_MAP]]] +// CHECK-NEXT: ^bb0(%[[RESULT:.*]]: f32): +// CHECK-NEXT: %[[CONST:.*]] = load %{{.*}} : memref<f32> +// CHECK-NEXT: linalg.yield %[[CONST]] : f32 // ----- diff --git a/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc b/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc index 1384abed91c..29d399c68fa 100644 --- a/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc +++ b/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc @@ -16,7 +16,7 @@ limitations under the License. // This file implements logic for lowering HLO dialect to LHLO dialect. #include "absl/memory/memory.h" -#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project #include "mlir/IR/Attributes.h" // TF:llvm-project #include "mlir/IR/BlockAndValueMapping.h" // TF:llvm-project #include "mlir/IR/Builders.h" // TF:llvm-project diff --git a/tensorflow/compiler/mlir/xla/transforms/hlo_shape_derivation.h b/tensorflow/compiler/mlir/xla/transforms/hlo_shape_derivation.h index 7c6d162632f..d2a1f47e540 100644 --- a/tensorflow/compiler/mlir/xla/transforms/hlo_shape_derivation.h +++ b/tensorflow/compiler/mlir/xla/transforms/hlo_shape_derivation.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_HLO_SHAPE_DERIVATION_H_ #define TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_HLO_SHAPE_DERIVATION_H_ -#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project #include "mlir/IR/Attributes.h" // TF:llvm-project #include "mlir/IR/Builders.h" // TF:llvm-project #include "mlir/IR/Location.h" // TF:llvm-project diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_control_flow.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_control_flow.cc index 8351f94d172..72ea2e18ec0 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_control_flow.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_control_flow.cc @@ -18,7 +18,7 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringSwitch.h" #include "llvm/Support/Casting.h" -#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project #include "mlir/IR/Block.h" // TF:llvm-project #include "mlir/IR/BlockAndValueMapping.h" // TF:llvm-project #include "mlir/IR/Builders.h" // TF:llvm-project diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc index da135ea1860..8f955d6944a 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc @@ -24,7 +24,7 @@ limitations under the License. #include "llvm/ADT/Optional.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" -#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project #include "mlir/Dialect/Traits.h" // TF:llvm-project #include "mlir/IR/Attributes.h" // TF:llvm-project #include "mlir/IR/Diagnostics.h" // TF:llvm-project @@ -119,7 +119,7 @@ static DenseIntElementsAttr GetI32ElementsAttr(ArrayRef<int32_t> values, Type GetSumAccumulationType(Type input_type) { MLIRContext *ctx = input_type.getContext(); if (input_type.isBF16() || input_type.isF16()) return FloatType::getF32(ctx); - if (input_type.isInteger(8) || input_type.isInteger(16)) + if (input_type.isSignlessInteger(8) || input_type.isSignlessInteger(16)) return IntegerType::get(32, ctx); return input_type; } @@ -1274,7 +1274,7 @@ class ConvertMaxPoolOp : public OpRewritePattern<TF::MaxPoolOp> { PatternRewriter &rewriter) const override { Type element_type = op.input().getType().cast<TensorType>().getElementType(); - if (!element_type.isIntOrFloat()) return matchFailure(); + if (!element_type.isSignlessIntOrFloat()) return matchFailure(); Location loc = op.getLoc(); ConstOp init = GetMinValueForType(element_type, loc, &rewriter); @@ -2248,7 +2248,7 @@ class ConvertArgMinMaxOp : public OpRewritePattern<OpTy> { Type input_element_type = input_type.getElementType(); // TODO(bixia): Clarify whether tf.ArgMax supports complex data types. If // tf.ArgMax doesn't support complex data types, this check can be removed. - if (!input_element_type.isIntOrFloat()) return this->matchFailure(); + if (!input_element_type.isSignlessIntOrFloat()) return this->matchFailure(); Location loc = op.getLoc(); Value init_value = diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_control_flow.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_control_flow.cc index 58e98a881e9..265466ef3a4 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_control_flow.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_control_flow.cc @@ -28,7 +28,7 @@ limitations under the License. #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/StringSet.h" #include "llvm/ADT/iterator_range.h" -#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project #include "mlir/IR/Attributes.h" // TF:llvm-project #include "mlir/IR/BlockAndValueMapping.h" // TF:llvm-project #include "mlir/IR/Function.h" // TF:llvm-project diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td index 872a288c259..519ba9235f1 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td @@ -16,7 +16,7 @@ limitations under the License. // This is the legalization pattern definition file for TF to XLA. include "mlir/IR/OpBase.td" -include "mlir/Dialect/StandardOps/Ops.td" +include "mlir/Dialect/StandardOps/IR/Ops.td" include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" include "tensorflow/compiler/mlir/xla/ir/hlo_ops.td" @@ -504,8 +504,8 @@ def : Pat<(TF_SignOp $x), )>; def BothElementTypesSameWidthIntOrFloat : Constraint<CPred< - "getElementTypeOrSelf($0.getType()).isIntOrFloat() && " - "getElementTypeOrSelf($1.getType()).isIntOrFloat() && " + "getElementTypeOrSelf($0.getType()).isSignlessIntOrFloat() && " + "getElementTypeOrSelf($1.getType()).isSignlessIntOrFloat() && " "getElementTypeOrSelf($0.getType()).getIntOrFloatBitWidth() == " "getElementTypeOrSelf($1.getType()).getIntOrFloatBitWidth()">, "element types must be integers or floats of same width">; diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard.cc index 5ee6010c3a8..3c15f0be7e8 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard.cc @@ -16,7 +16,7 @@ limitations under the License. // This file implements logic for lowering XLA dialect to Standard dialect. #include "llvm/ADT/StringSwitch.h" -#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project #include "mlir/IR/Function.h" // TF:llvm-project #include "mlir/IR/PatternMatch.h" // TF:llvm-project #include "mlir/Pass/Pass.h" // TF:llvm-project @@ -45,8 +45,8 @@ class CompareIConvert : public OpRewritePattern<xla_hlo::CompareOp> { // Broadcasting not supported by this rewrite. if (lhs_type.getShape() != rhs_type.getShape()) return matchFailure(); - if (!lhs_type.getElementType().isa<IntegerType>() || - !rhs_type.getElementType().isa<IntegerType>()) + if (!lhs_type.getElementType().isSignlessInteger() || + !rhs_type.getElementType().isSignlessInteger()) return matchFailure(); auto comparison_direction = op.comparison_direction(); @@ -113,7 +113,8 @@ class ConvertIotaOp : public OpRewritePattern<xla_hlo::IotaOp> { PatternRewriter &rewriter) const override { auto output_type = op.getType().cast<ShapedType>(); // TODO(prakalps): Handle FP and ComplexType iota ops. - if (!output_type.getElementType().isa<IntegerType>()) return matchFailure(); + if (!output_type.getElementType().isSignlessInteger()) + return matchFailure(); auto output_size = output_type.getNumElements(); auto dimension = op.iota_dimension().getSExtValue(); auto max_dim_size = output_type.getDimSize(dimension); diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard_patterns.td b/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard_patterns.td index a15b28193cd..c0f6c2c3541 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard_patterns.td +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard_patterns.td @@ -16,7 +16,7 @@ limitations under the License. // This is the legalization pattern definition file for XLA to StandardOps. include "mlir/IR/OpBase.td" -include "mlir/Dialect/StandardOps/Ops.td" +include "mlir/Dialect/StandardOps/IR/Ops.td" include "tensorflow/compiler/mlir/xla/ir/hlo_ops.td" //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_affine.cc b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_affine.cc index b0f6b83038a..2c550465302 100644 --- a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_affine.cc +++ b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_affine.cc @@ -17,7 +17,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "mlir/Dialect/AffineOps/AffineOps.h" // TF:llvm-project -#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project #include "mlir/IR/Attributes.h" // TF:llvm-project #include "mlir/IR/Location.h" // TF:llvm-project #include "mlir/IR/MLIRContext.h" // TF:llvm-project diff --git a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_gpu.cc b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_gpu.cc index e991b186d72..c9245d93e56 100644 --- a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_gpu.cc +++ b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_gpu.cc @@ -22,7 +22,7 @@ limitations under the License. #include "mlir/Dialect/GPU/GPUDialect.h" // TF:llvm-project #include "mlir/Dialect/Linalg/IR/LinalgOps.h" // TF:llvm-project #include "mlir/Dialect/LoopOps/LoopOps.h" // TF:llvm-project -#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project #include "mlir/IR/Attributes.h" // TF:llvm-project #include "mlir/IR/BlockAndValueMapping.h" // TF:llvm-project #include "mlir/IR/Builders.h" // TF:llvm-project diff --git a/tensorflow/compiler/mlir/xla/transforms/lower_complex_patterns.td b/tensorflow/compiler/mlir/xla/transforms/lower_complex_patterns.td index d8a5ae6c6de..dcb0ab20e9e 100644 --- a/tensorflow/compiler/mlir/xla/transforms/lower_complex_patterns.td +++ b/tensorflow/compiler/mlir/xla/transforms/lower_complex_patterns.td @@ -17,7 +17,7 @@ limitations under the License. // equivalent real value operations. include "mlir/IR/OpBase.td" -include "mlir/Dialect/StandardOps/Ops.td" +include "mlir/Dialect/StandardOps/IR/Ops.td" include "tensorflow/compiler/mlir/xla/ir/hlo_ops.td" //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/xla/transforms/lower_general_dot.cc b/tensorflow/compiler/mlir/xla/transforms/lower_general_dot.cc index c956cd6b277..f18607dfffb 100644 --- a/tensorflow/compiler/mlir/xla/transforms/lower_general_dot.cc +++ b/tensorflow/compiler/mlir/xla/transforms/lower_general_dot.cc @@ -17,7 +17,7 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringSwitch.h" -#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project #include "mlir/IR/Attributes.h" // TF:llvm-project #include "mlir/IR/Function.h" // TF:llvm-project #include "mlir/IR/Location.h" // TF:llvm-project diff --git a/tensorflow/compiler/mlir/xla/transforms/map_xla_to_scalar_op.h b/tensorflow/compiler/mlir/xla/transforms/map_xla_to_scalar_op.h index 0e33a8646ad..6554942954e 100644 --- a/tensorflow/compiler/mlir/xla/transforms/map_xla_to_scalar_op.h +++ b/tensorflow/compiler/mlir/xla/transforms/map_xla_to_scalar_op.h @@ -18,7 +18,7 @@ limitations under the License. #include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringSwitch.h" -#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" #include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h" @@ -206,7 +206,7 @@ inline Value MapXlaCompareOpToStdScalarOp(XLACompareOpTy xla_op, const auto& lhs = args[0]; const auto& rhs = args[1]; Type element_type = lhs.getType(); - if (element_type.isa<IntegerType>()) { + if (element_type.isSignlessInteger()) { Optional<CmpIPredicate> predicate = getCmpPredicate<CmpIPredicate>(xla_op.comparison_direction()); assert(predicate.hasValue() && "expected valid comparison direction"); @@ -288,8 +288,8 @@ template <> inline Value MapXlaOpToStdScalarOp<xla_lhlo::ConvertOp>( xla_lhlo::ConvertOp xla_op, ArrayRef<Type> result_types, ArrayRef<Value> args, OpBuilder* b) { - const Type& sourceType = args.front().getType(); - const Type& targetType = result_types.front(); + Type sourceType = args.front().getType(); + Type targetType = result_types.front(); if (mlir::SIToFPOp::areCastCompatible(sourceType, targetType)) { return b->create<mlir::SIToFPOp>(xla_op.getLoc(), result_types, args, @@ -307,7 +307,7 @@ inline Value MapXlaOpToStdScalarOp<xla_lhlo::ConvertOp>( // No conversion is needed for the same width floats return args.front(); } - if (sourceType.isa<IntegerType>() && targetType.isa<IntegerType>()) { + if (sourceType.isSignlessInteger() && targetType.isSignlessInteger()) { IntegerType src = sourceType.cast<IntegerType>(); IntegerType res = targetType.cast<IntegerType>(); if (src.getWidth() > res.getWidth()) { diff --git a/tensorflow/compiler/mlir/xla/transforms/materialize_broadcasts.cc b/tensorflow/compiler/mlir/xla/transforms/materialize_broadcasts.cc index 13467be41d9..4c20a589ce0 100644 --- a/tensorflow/compiler/mlir/xla/transforms/materialize_broadcasts.cc +++ b/tensorflow/compiler/mlir/xla/transforms/materialize_broadcasts.cc @@ -15,7 +15,7 @@ limitations under the License. #include <numeric> -#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project #include "mlir/IR/MLIRContext.h" // TF:llvm-project #include "mlir/IR/Operation.h" // TF:llvm-project #include "mlir/IR/PatternMatch.h" // TF:llvm-project diff --git a/tensorflow/compiler/mlir/xla/transforms/materialize_broadcasts_pass.cc b/tensorflow/compiler/mlir/xla/transforms/materialize_broadcasts_pass.cc index 596b67f0eed..644fffcc7ea 100644 --- a/tensorflow/compiler/mlir/xla/transforms/materialize_broadcasts_pass.cc +++ b/tensorflow/compiler/mlir/xla/transforms/materialize_broadcasts_pass.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project #include "mlir/IR/MLIRContext.h" // TF:llvm-project #include "mlir/IR/Operation.h" // TF:llvm-project #include "mlir/IR/PatternMatch.h" // TF:llvm-project diff --git a/tensorflow/compiler/mlir/xla/transforms/xla_legalize_to_linalg.cc b/tensorflow/compiler/mlir/xla/transforms/xla_legalize_to_linalg.cc index 46df664dc42..7f7060fef64 100644 --- a/tensorflow/compiler/mlir/xla/transforms/xla_legalize_to_linalg.cc +++ b/tensorflow/compiler/mlir/xla/transforms/xla_legalize_to_linalg.cc @@ -19,7 +19,7 @@ limitations under the License. #include "llvm/ADT/APInt.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" // TF:llvm-project #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" // TF:llvm-project -#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project #include "mlir/IR/AffineExpr.h" // TF:llvm-project #include "mlir/IR/Attributes.h" // TF:llvm-project #include "mlir/IR/Builders.h" // TF:llvm-project @@ -83,7 +83,7 @@ class PointwiseToLinalgConverter : public OpConversionPattern<OpTy> { emitError(loc, "lhlo to linalg conversion expects ranked args"); return ConversionPattern::matchFailure(); } - if (!argType.getElementType().isIntOrFloat()) { + if (!argType.getElementType().isSignlessIntOrFloat()) { return ConversionPattern::matchFailure(); } @@ -171,7 +171,7 @@ class ScalarPointwiseToStandardConverter : public OpConversionPattern<LhloOp> { auto loc = lhlo_op.getLoc(); auto argType = lhlo_op.getOperand(0).getType().template dyn_cast<ShapedType>(); - if (!argType || !argType.getElementType().isIntOrFloat() || + if (!argType || !argType.getElementType().isSignlessIntOrFloat() || (argType.getRank() != 0)) { return ConversionPattern::matchFailure(); } @@ -208,6 +208,9 @@ class DataMovementOpConverter : public OpConversionPattern<OpTy> { auto resultType = getXLAOpResultType<isLHLO>(op); if (!verifyXLAOpBufferOrTensorSemantics<isLHLO>(op)) return ConversionPattern::matchFailure(); + // TODO(b/150203558) Enable once tiling/fusion works in this case. + if (isLHLO && (operandType.getRank() == 0)) + return ConversionPattern::matchFailure(); ArrayAttr indexingMapsAttr = static_cast<const Derived&>(*this).getIndexingMapsAttr(op, &rewriter); if (!indexingMapsAttr) return ConversionPattern::matchFailure(); @@ -278,6 +281,52 @@ class BroadcastInDimConverter } }; +// Special case for scalar broadcast in lhlo. +// TODO(b/150203558) Remove once the bug is fixed. +class ScalarBroadcastInDimConverter + : public OpConversionPattern<xla_lhlo::BroadcastInDimOp> { + public: + using OpConversionPattern<xla_lhlo::BroadcastInDimOp>::OpConversionPattern; + + PatternMatchResult matchAndRewrite( + xla_lhlo::BroadcastInDimOp broadcastOp, ArrayRef<Value> args, + ConversionPatternRewriter& rewriter) const final { + auto operandMemrefType = + broadcastOp.operand().getType().dyn_cast<MemRefType>(); + // Only support scalar operands. + if (operandMemrefType.getRank() != 0) return matchFailure(); + auto resultMemrefType = + broadcastOp.output().getType().dyn_cast<MemRefType>(); + if (!operandMemrefType || !resultMemrefType) return matchFailure(); + auto broadcastDims = broadcastOp.broadcast_dimensions(); + if (!broadcastDims.hasValue()) return matchFailure(); + + unsigned nloops = resultMemrefType.getRank(); + SmallVector<Attribute, 1> indexingMaps{ + AffineMapAttr::get(rewriter.getMultiDimIdentityMap(nloops))}; + auto loc = broadcastOp.getLoc(); + auto linalgOp = rewriter.create<linalg::GenericOp>( + loc, ArrayRef<Type>{}, broadcastOp.output(), + rewriter.getI64IntegerAttr(0), // args_in + rewriter.getI64IntegerAttr(1), // args_out + rewriter.getArrayAttr(indexingMaps), + GetNParallelLoopsAttrs(nloops, &rewriter), + /*doc=*/nullptr, /*fun=*/nullptr, /*library_call=*/nullptr); + + // Add a block to the region. + auto* region = &linalgOp.region(); + auto* block = rewriter.createBlock(region, region->end()); + block->addArguments(resultMemrefType.getElementType()); + + rewriter.setInsertionPointToEnd(block); + auto scalar = + rewriter.create<LoadOp>(loc, broadcastOp.operand(), llvm::None); + rewriter.create<linalg::YieldOp>(loc, scalar.getResult()); + rewriter.eraseOp(broadcastOp); + return matchSuccess(); + } +}; + template <typename OpTy, bool isLHLO = true> class TransposeConverter : public DataMovementOpConverter<TransposeConverter<OpTy, isLHLO>, OpTy, @@ -385,7 +434,7 @@ class IotaConverter : public OpConversionPattern<xla_lhlo::IotaOp> { if (!resultMemrefType) return matchFailure(); auto resultElementType = resultMemrefType.getElementType(); - if (!resultElementType.isIntOrFloat()) return matchFailure(); + if (!resultElementType.isSignlessIntOrFloat()) return matchFailure(); // Construct the indexing maps needed for linalg.generic ops. unsigned nloops = resultMemrefType.getRank(); @@ -502,6 +551,7 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context, PointwiseToLinalgConverter<xla_lhlo::SubOp>, PointwiseToLinalgConverter<xla_lhlo::TanhOp>, ReshapeAddRemoveDimConverter<xla_lhlo::ReshapeOp>, + ScalarBroadcastInDimConverter, ScalarPointwiseToStandardConverter<xla_lhlo::AddOp>, SliceConverter >(context); diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index d517b5d0bdd..f3ee4e38f31 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -1514,6 +1514,7 @@ cuda_py_test( "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip ], xla_enable_strict_auto_jit = False, + xla_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client", @@ -1534,6 +1535,7 @@ cuda_py_test( "no_rocm", ], xla_enable_strict_auto_jit = False, + xla_enabled = True, deps = [ ":test_utils", "//tensorflow/core:protos_all_py", @@ -1558,6 +1560,7 @@ cuda_py_test( "no_rocm", ], xla_enable_strict_auto_jit = False, + xla_enabled = True, deps = [ ":test_utils", "//tensorflow/core:protos_all_py", @@ -1650,6 +1653,7 @@ cuda_py_test( "no_rocm", ], xla_enable_strict_auto_jit = False, + xla_enabled = True, deps = [ ":lstm", ":xla_test", diff --git a/tensorflow/compiler/tests/eager_test.py b/tensorflow/compiler/tests/eager_test.py index a03980f20ba..0ed81b7e9e5 100644 --- a/tensorflow/compiler/tests/eager_test.py +++ b/tensorflow/compiler/tests/eager_test.py @@ -695,7 +695,7 @@ class EagerFunctionTest(xla_test.XLATestCase): wholly_compiled_f = def_function.function(f) op_by_op_f = def_function.function(f, experimental_compile=False) - x = constant_op.constant([0.0, 2.0], name='data') + x = array_ops.identity([0.0, 2.0], name='data') # When function is wholly compiled, all outputs will be on the # device on which it is run. diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 86da500b1dd..ba5d7e9d788 100755 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -620,6 +620,7 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/stream_executor:device_description", "@com_google_absl//absl/algorithm:container", "@llvm-project//llvm:core", ], @@ -1124,6 +1125,7 @@ cc_library( ":cudnn_batchnorm_rewriter", ":cudnn_pad_for_convolutions", ":fusion_merger", + ":gemm_rewriter", ":gpu_constants", ":gpu_conv_algorithm_picker", ":gpu_conv_padding_legalization", @@ -1140,9 +1142,13 @@ cc_library( ":ir_emitter", ":multi_output_fusion", ":partition_assignment", + ":reduction_degenerate_dim_remover", + ":reduction_dimension_grouper", + ":reduction_layout_normalizer", ":stream_assignment", ":stream_executor_util", ":target_constants", + ":tree_reduction_rewriter", ":variadic_op_splitter", "//tensorflow/compiler/xla:protobuf_util", "//tensorflow/compiler/xla:status_macros", @@ -1232,19 +1238,13 @@ cc_library( ":cudnn_pad_for_convolutions", ":cusolver_rewriter", ":gemm_algorithm_picker", - ":gemm_rewriter", ":gpu_compiler", - ":gpu_conv_algorithm_picker", ":gpu_conv_padding_legalization", ":gpu_conv_rewriter", ":gpu_layout_assignment", ":ir_emission_utils", - ":reduction_degenerate_dim_remover", - ":reduction_dimension_grouper", - ":reduction_layout_normalizer", ":stream_executor_util", ":target_constants", - ":tree_reduction_rewriter", "@com_google_absl//absl/base", "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/types:optional", diff --git a/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.cc index 97013804271..b78748edb7e 100644 --- a/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.cc @@ -87,39 +87,6 @@ Status AMDGPUCompiler::OptimizeHloConvolutionCanonicalization( return Status::OK(); } -Status AMDGPUCompiler::OptimizeHloPostLayoutAssignment( - HloModule* hlo_module, se::StreamExecutor* stream_exec, - se::DeviceMemoryAllocator* device_allocator) { - HloPassPipeline pipeline("post-layout_assignment"); - pipeline.AddInvariantChecker<HloVerifier>( - /*layout_sensitive=*/true, - /*allow_mixed_precision=*/false, - LayoutAssignment::InstructionCanChangeLayout); - - pipeline.AddPass<ReductionDegenerateDimRemover>(); - pipeline.AddPass<ReductionLayoutNormalizer>(); - pipeline.AddPass<ReductionDimensionGrouper>(); - - // The LayoutAssignment pass may leave behind kCopy instructions which are - // duplicate or NOPs, so remove them with algebraic simplification and CSE. - AlgebraicSimplifierOptions options; - options.set_is_layout_sensitive(true); - pipeline.AddPass<HloPassFix<AlgebraicSimplifier>>(options); - - // Rewrite GEMMs into custom calls. - pipeline.AddPass<GemmRewriter>(); - - pipeline.AddPass<GpuConvAlgorithmPicker>(stream_exec, device_allocator); - - // Clean up new_tuple described above. - pipeline.AddPass<TupleSimplifier>(); - - pipeline.AddPass<HloCSE>(/*is_layout_sensitive=*/true); - TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); - - return Status::OK(); -} - AMDGPUCompiler::AMDGPUCompiler() : GpuCompiler(stream_executor::rocm::kROCmPlatformId, amdgpu::kTargetTriple, amdgpu::kDataLayout) {} diff --git a/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.h b/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.h index d1a74a7822e..acc5e021e3d 100644 --- a/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.h +++ b/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.h @@ -37,10 +37,6 @@ class AMDGPUCompiler : public GpuCompiler { HloModule* hlo_module, se::StreamExecutor* stream_exec, se::DeviceMemoryAllocator* device_allocator) override; - Status OptimizeHloPostLayoutAssignment( - HloModule* hlo_module, se::StreamExecutor* stream_exec, - se::DeviceMemoryAllocator* device_allocator) override; - GpuVersion GetGpuVersion(se::StreamExecutor* stream_exec) override; StatusOr<std::pair<std::string, std::vector<uint8>>> CompileTargetBinary( diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index e4c57203543..51b30e238e9 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -46,7 +46,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/alias_passthrough_params.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h" #include "tensorflow/compiler/xla/service/gpu/fusion_merger.h" +#include "tensorflow/compiler/xla/service/gpu/gemm_rewriter.h" #include "tensorflow/compiler/xla/service/gpu/gpu_constants.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.h" #include "tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h" #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" #include "tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h" @@ -61,10 +63,14 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h" #include "tensorflow/compiler/xla/service/gpu/multi_output_fusion.h" #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" +#include "tensorflow/compiler/xla/service/gpu/reduction_degenerate_dim_remover.h" +#include "tensorflow/compiler/xla/service/gpu/reduction_dimension_grouper.h" +#include "tensorflow/compiler/xla/service/gpu/reduction_layout_normalizer.h" #include "tensorflow/compiler/xla/service/gpu/stream_assignment.h" #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h" #include "tensorflow/compiler/xla/service/gpu/target_constants.h" #include "tensorflow/compiler/xla/service/gpu/thunk_schedule.h" +#include "tensorflow/compiler/xla/service/gpu/tree_reduction_rewriter.h" #include "tensorflow/compiler/xla/service/gpu/variadic_op_splitter.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_constant_folding.h" @@ -106,6 +112,7 @@ limitations under the License. #include "tensorflow/core/platform/subprocess.h" #include "tensorflow/core/platform/tracing.h" #include "tensorflow/core/profiler/lib/traceme.h" +#include "tensorflow/core/util/env_var.h" namespace xla { namespace gpu { @@ -333,6 +340,81 @@ Status GpuCompiler::PrepareHloModuleForIrEmitting(HloModule* hlo_module) { return pipeline.Run(hlo_module).status(); } +// TODO(cheshire): Duplication with gpu_conv_algorithm picker, figure out a +// right way to share this. +static bool RequireDeterminism() { + bool deterministic_ops = false; + TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar("TF_DETERMINISTIC_OPS", + /*default_val=*/false, + &deterministic_ops)); + return deterministic_ops; +} + +Status GpuCompiler::OptimizeHloPostLayoutAssignment( + HloModule* hlo_module, se::StreamExecutor* stream_exec, + se::DeviceMemoryAllocator* device_allocator) { + HloPassPipeline pipeline("post-layout_assignment"); + /* TODO(b/117531509): Use LayoutAssignment::InstructionCanChangeLayout after + * fixing the ticket. */ + pipeline.AddInvariantChecker<HloVerifier>( + /*layout_sensitive=*/true, + /*allow_mixed_precision=*/false, + LayoutAssignment::InstructionCanChangeLayout); + + pipeline.AddPass<ReductionDegenerateDimRemover>(); + pipeline.AddPass<ReductionLayoutNormalizer>(); + pipeline.AddPass<ReductionDimensionGrouper>(); + + // The LayoutAssignment pass may leave behind kCopy instructions which are + // duplicate or NOPs, so remove them with algebraic simplification and CSE. + AlgebraicSimplifierOptions options; + options.set_is_layout_sensitive(true); + pipeline.AddPass<HloPassFix<AlgebraicSimplifier>>(options); + + if (RequireDeterminism() || + hlo_module->config().debug_options().xla_gpu_deterministic_reductions()) { + pipeline.AddPass<HloPassFix<GpuTreeReductionRewriter>>(); + } + + // Rewrite GEMMs into custom calls. + pipeline.AddPass<GemmRewriter>(); + + // Choose the fastest algorithm for each conv. + // + // We pick the algorithm before fusion so we can generate better HLO. After + // GpuConvRewriter, our convolutions are CustomCalls which return a + // tuple (conv_result, scratch_memory), and the each conv uses 0 bytes of + // scratch: + // + // customcall = (f32[...], f32[0]) + // return gte(customcall, 0) + // + // The algorithm picker then chooses the best algorithm, and potentially + // increases the scratch space. It replaces customcall with new_tuple, + // giving us the following: + // + // new_customcall = (f32[...], f32[N]) + // new_tuple = tuple(gte(new_customcall, 0), constant f32[0]) + // return gte(new_tuple, 0) + // + // The new tuple and gte instructions then be simplified away, because + // nobody is expected to use the scratch value. + // + // However, if we were to run GpuConvAlgorithmPicker after fusion + // the gte(customcall, 0) would probably already be into a fusion node. We + // can't simplify across HloComputation boundaries, so in this case we + // wouldn't be able to simplify away the new_tuple bits. + pipeline.AddPass<GpuConvAlgorithmPicker>(stream_exec, device_allocator); + + // Clean up new_tuple described above. + pipeline.AddPass<TupleSimplifier>(); + + pipeline.AddPass<HloCSE>(/*is_layout_sensitive=*/true); + TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); + + return Status::OK(); +} + StatusOr<std::unique_ptr<HloModule>> GpuCompiler::RunHloPasses( std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec, se::DeviceMemoryAllocator* device_allocator) { diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.h b/tensorflow/compiler/xla/service/gpu/gpu_compiler.h index 901d994d4ad..b52af5392d1 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.h @@ -64,7 +64,7 @@ class GpuCompiler : public LLVMCompiler { virtual Status OptimizeHloPostLayoutAssignment( HloModule* hlo_module, se::StreamExecutor* stream_exec, - se::DeviceMemoryAllocator* device_allocator) = 0; + se::DeviceMemoryAllocator* device_allocator); virtual HloDataflowAnalysis::CanShareBuffer GetCanShareBuffer() { return diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc index 8646b4b5016..566d4f0e463 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -34,6 +34,7 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/stream_executor/device_description.h" namespace xla { namespace gpu { @@ -126,7 +127,9 @@ bool IsCublasGemm(const HloInstruction& hlo) { } std::array<int64, 3> GetReductionTiling( - const ReductionDimensions& reduction_dimensions) { + const ReductionDimensions& reduction_dimensions, + int smallest_input_dtype_bits, + const stream_executor::DeviceDescription* device_description) { if (reduction_dimensions.is_row_reduction) { int64 tile_z = std::min(reduction_dimensions.dimensions[0], int64{8}); if (reduction_dimensions.dimensions[1] == 1) { @@ -137,7 +140,17 @@ std::array<int64, 3> GetReductionTiling( 0) { return {tile_z, 1, 64}; } - return {tile_z, 1, 8}; + int cc_major = 0, cc_minor = 0; + if (device_description != nullptr) { + device_description->cuda_compute_capability(&cc_major, &cc_minor); + } + int unroll_x = 8; + if (cc_major >= 6 && smallest_input_dtype_bits == 16) { + unroll_x = 16; + } else if (cc_major >= 6 && smallest_input_dtype_bits == 8) { + unroll_x = 64; + } + return {tile_z, 1, unroll_x}; } // Column reduction. diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h index 82b10a50c39..8a2385a242b 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/stream_executor/device_description.h" // TODO(jlebar): Move functions related to cublas/cudnn to a separate file; they // don't belong in "ir_emission_utils". @@ -193,8 +194,12 @@ ReductionDimensions GetReductionKindAndContiguousComponents( // Get tiling per thread for the given reduction in dimensions [D, H, W] per // thread. +// If the device isn't known pass null for device_description and you will get +// non-optimized value. std::array<int64, 3> GetReductionTiling( - const ReductionDimensions& reduction_dimensions); + const ReductionDimensions& reduction_dimensions, + int smallest_input_dtype_bits, + const stream_executor::DeviceDescription* device_description); // Emits call to "vprintf" with given format and arguments. llvm::Value* EmitPrintf(absl::string_view fmt, diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index f99e43cc06d..99bfc9185a7 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -3097,9 +3097,20 @@ ReductionCodegenInfo IrEmitterUnnested::ComputeReductionCodegenInfo( << " " << reduction_dimensions.dimensions[0] << " " << reduction_dimensions.dimensions[1] << " " << reduction_dimensions.dimensions[2]; + auto get_dtype_bits = [](const HloInstruction* i) { + return primitive_util::BitWidth(i->shape().element_type()); + }; + // For fusion with multiple inputs, use the smallest input dtype to + // select the reduction_tiling. + int smallest_input_dtype_bits = get_dtype_bits(first_reduce->operand(0)); + for (xla::HloInstruction* input : unnested_hlo->operands()) { + smallest_input_dtype_bits = + std::min(get_dtype_bits(input), smallest_input_dtype_bits); + } std::array<int64, 3> reduction_tiling = - GetReductionTiling(reduction_dimensions); + GetReductionTiling(reduction_dimensions, smallest_input_dtype_bits, + &ir_emitter_context_->device_description()); bool dilated_x = reduction_dimensions.is_row_reduction || !IsUnrollingColumnReductionBeneficial(unnested_hlo, input_shape, diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc index a1a901f0b94..6d036094a69 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc @@ -27,19 +27,13 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/cudnn_pad_for_convolutions.h" #include "tensorflow/compiler/xla/service/gpu/cusolver_rewriter.h" #include "tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.h" -#include "tensorflow/compiler/xla/service/gpu/gemm_rewriter.h" -#include "tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.h" #include "tensorflow/compiler/xla/service/gpu/gpu_conv_padding_legalization.h" #include "tensorflow/compiler/xla/service/gpu/gpu_conv_rewriter.h" #include "tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h" -#include "tensorflow/compiler/xla/service/gpu/reduction_degenerate_dim_remover.h" -#include "tensorflow/compiler/xla/service/gpu/reduction_dimension_grouper.h" -#include "tensorflow/compiler/xla/service/gpu/reduction_layout_normalizer.h" #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h" #include "tensorflow/compiler/xla/service/gpu/target_constants.h" -#include "tensorflow/compiler/xla/service/gpu/tree_reduction_rewriter.h" #include "tensorflow/compiler/xla/service/hlo_constant_folding.h" #include "tensorflow/compiler/xla/service/hlo_cse.h" #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" @@ -55,7 +49,6 @@ limitations under the License. #include "tensorflow/core/platform/cuda_libdevice_path.h" #include "tensorflow/core/platform/tracing.h" #include "tensorflow/core/profiler/lib/traceme.h" -#include "tensorflow/core/util/env_var.h" #include "tensorflow/stream_executor/cuda/cuda_diagnostics.h" #include "tensorflow/stream_executor/gpu/asm_compiler.h" @@ -152,83 +145,20 @@ Status NVPTXCompiler::OptimizeHloConvolutionCanonicalization( return Status::OK(); } -// TODO(cheshire): Duplication with gpu_conv_algorithm picker, figure out a -// right way to share this. -static bool RequireDeterminism() { - bool deterministic_ops = false; - TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar("TF_DETERMINISTIC_OPS", - /*default_val=*/false, - &deterministic_ops)); - return deterministic_ops; -} - Status NVPTXCompiler::OptimizeHloPostLayoutAssignment( HloModule* hlo_module, se::StreamExecutor* stream_exec, se::DeviceMemoryAllocator* device_allocator) { - HloPassPipeline pipeline("post-layout_assignment"); - /* TODO(b/117531509): Use LayoutAssignment::InstructionCanChangeLayout after - * fixing the ticket. */ - pipeline.AddInvariantChecker<HloVerifier>( - /*layout_sensitive=*/true, - /*allow_mixed_precision=*/false, - LayoutAssignment::InstructionCanChangeLayout); - - pipeline.AddPass<ReductionDegenerateDimRemover>(); - pipeline.AddPass<ReductionLayoutNormalizer>(); - pipeline.AddPass<ReductionDimensionGrouper>(); - - // The LayoutAssignment pass may leave behind kCopy instructions which are - // duplicate or NOPs, so remove them with algebraic simplification and CSE. - AlgebraicSimplifierOptions options; - options.set_is_layout_sensitive(true); - pipeline.AddPass<HloPassFix<AlgebraicSimplifier>>(options); - - if (RequireDeterminism() || - hlo_module->config().debug_options().xla_gpu_deterministic_reductions()) { - pipeline.AddPass<HloPassFix<GpuTreeReductionRewriter>>(); - } + TF_RETURN_IF_ERROR(GpuCompiler::OptimizeHloPostLayoutAssignment( + hlo_module, stream_exec, device_allocator)); + HloPassPipeline pipeline("nvptx post-layout_assignment"); // Pad the dimensions of matrices in dot operations to multiples of 8. if (IsVoltaOrLater(*stream_exec)) { pipeline.AddPass<CublasGemmPadForTensorCores>(); } - // Rewrite GEMMs into custom calls. - pipeline.AddPass<GemmRewriter>(); - - // Choose the fastest algorithm for each conv. - // - // We pick the algorithm before fusion so we can generate better HLO. After - // GpuConvRewriter, our convolutions are CustomCalls which return a - // tuple (conv_result, scratch_memory), and the each conv uses 0 bytes of - // scratch: - // - // customcall = (f32[...], f32[0]) - // return gte(customcall, 0) - // - // The algorithm picker then chooses the best algorithm, and potentially - // increases the scratch space. It replaces customcall with new_tuple, - // giving us the following: - // - // new_customcall = (f32[...], f32[N]) - // new_tuple = tuple(gte(new_customcall, 0), constant f32[0]) - // return gte(new_tuple, 0) - // - // The new tuple and gte instructions then be simplified away, because - // nobody is expected to use the scratch value. - // - // However, if we were to run GpuConvAlgorithmPicker after fusion - // the gte(customcall, 0) would probably already be into a fusion node. We - // can't simplify across HloComputation boundaries, so in this case we - // wouldn't be able to simplify away the new_tuple bits. - pipeline.AddPass<GpuConvAlgorithmPicker>(stream_exec, device_allocator); // Find the fastest algorithm for GEMMs. pipeline.AddPass<GemmAlgorithmPicker>(stream_exec, device_allocator); - - // Clean up new_tuple described above. - pipeline.AddPass<TupleSimplifier>(); - - pipeline.AddPass<HloCSE>(/*is_layout_sensitive=*/true); TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); return Status::OK(); diff --git a/tensorflow/compiler/xla/service/interpreter/platform.cc b/tensorflow/compiler/xla/service/interpreter/platform.cc index e1ab7bb9646..79f5c0fd901 100644 --- a/tensorflow/compiler/xla/service/interpreter/platform.cc +++ b/tensorflow/compiler/xla/service/interpreter/platform.cc @@ -30,7 +30,7 @@ limitations under the License. namespace stream_executor { namespace interpreter { -XlaInterpreterPlatform::XlaInterpreterPlatform(const string& name, +XlaInterpreterPlatform::XlaInterpreterPlatform(const std::string& name, const Platform::Id& id) : name_(name), id_(id) {} @@ -40,7 +40,7 @@ Platform::Id XlaInterpreterPlatform::id() const { return id_; } int XlaInterpreterPlatform::VisibleDeviceCount() const { return 1; } -const string& XlaInterpreterPlatform::Name() const { return name_; } +const std::string& XlaInterpreterPlatform::Name() const { return name_; } port::StatusOr<std::unique_ptr<DeviceDescription>> XlaInterpreterPlatform::DescriptionForDevice(int ordinal) const { diff --git a/tensorflow/compiler/xla/service/interpreter/platform.h b/tensorflow/compiler/xla/service/interpreter/platform.h index ff9c5d07f8d..da037bf17bc 100644 --- a/tensorflow/compiler/xla/service/interpreter/platform.h +++ b/tensorflow/compiler/xla/service/interpreter/platform.h @@ -31,14 +31,14 @@ class XlaInterpreterPlatform : public Platform { public: XlaInterpreterPlatform() : XlaInterpreterPlatform("Interpreter", kXlaInterpreterPlatformId) {} - XlaInterpreterPlatform(const string& name, const Platform::Id& id); + XlaInterpreterPlatform(const std::string& name, const Platform::Id& id); ~XlaInterpreterPlatform() override; Platform::Id id() const override; int VisibleDeviceCount() const override; - const string& Name() const override; + const std::string& Name() const override; port::StatusOr<std::unique_ptr<DeviceDescription>> DescriptionForDevice( int ordinal) const override; @@ -60,7 +60,7 @@ class XlaInterpreterPlatform : public Platform { private: // This platform's name. - string name_; + std::string name_; // This platform's id. Platform::Id id_; diff --git a/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter.cc b/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter.cc index aa28a36c945..b3e4002a898 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter.cc @@ -31,7 +31,7 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/AffineOps/AffineOps.h" // TF:llvm-project -#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project #include "mlir/IR/AffineExpr.h" // TF:llvm-project #include "mlir/IR/AffineMap.h" // TF:llvm-project #include "mlir/IR/StandardTypes.h" // TF:llvm-project diff --git a/tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.cc b/tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.cc index bb67305c344..184d8d202c3 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.h" #include "llvm/ADT/STLExtras.h" -#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project #include "mlir/IR/Attributes.h" // TF:llvm-project #include "mlir/IR/StandardTypes.h" // TF:llvm-project #include "mlir/IR/Types.h" // TF:llvm-project diff --git a/tensorflow/compiler/xla/service/mlir_gpu/inject_errors_pass.cc b/tensorflow/compiler/xla/service/mlir_gpu/inject_errors_pass.cc index 5d67d7dcf7f..bd64c18680c 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/inject_errors_pass.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/inject_errors_pass.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/mlir_gpu/inject_errors_pass.h" -#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" namespace mlir { namespace { diff --git a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc index 176e7af5d8c..ca26ae4e756 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc @@ -31,7 +31,7 @@ limitations under the License. #include "mlir/Dialect/Linalg/IR/LinalgOps.h" // TF:llvm-project #include "mlir/Dialect/Linalg/Passes.h" // TF:llvm-project #include "mlir/Dialect/LoopOps/LoopOps.h" // TF:llvm-project -#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project #include "mlir/IR/Attributes.h" // TF:llvm-project #include "mlir/IR/BlockAndValueMapping.h" // TF:llvm-project #include "mlir/IR/Builders.h" // TF:llvm-project diff --git a/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc b/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc index 13009992ab5..75c7c284881 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" // TF:llvm-project -#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project #include "mlir/IR/Attributes.h" // TF:llvm-project #include "mlir/IR/Builders.h" // TF:llvm-project #include "mlir/IR/Function.h" // TF:llvm-project diff --git a/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.cc b/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.cc index e383169399c..e471ba192e1 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.cc @@ -21,7 +21,7 @@ limitations under the License. #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" // TF:llvm-project #include "mlir/Dialect/GPU/GPUDialect.h" // TF:llvm-project #include "mlir/Dialect/LLVMIR/LLVMDialect.h" // TF:llvm-project -#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project #include "mlir/IR/Attributes.h" // TF:llvm-project #include "mlir/IR/Function.h" // TF:llvm-project #include "mlir/IR/Location.h" // TF:llvm-project diff --git a/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc b/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc index 8f89c4655a3..c35f05ebf45 100644 --- a/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc +++ b/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc @@ -60,8 +60,8 @@ int main(int argc, char** argv) { LOG(FATAL) << "local_client_aot_test_helper TARGET_CPU"; } - string triple_string; - string target_cpu = argv[1]; + std::string triple_string; + std::string target_cpu = argv[1]; if (target_cpu == "k8") { triple_string = "x86_64-none-linux-gnu"; } else if (target_cpu == "darwin") { diff --git a/tensorflow/compiler/xla/tests/reduce_test.cc b/tensorflow/compiler/xla/tests/reduce_test.cc index 3913d604881..5c71f539efa 100644 --- a/tensorflow/compiler/xla/tests/reduce_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_test.cc @@ -112,10 +112,7 @@ class ReduceTest : public ClientLibraryTestBase { std::unique_ptr<GlobalData> input_global_data = client_->TransferToServer(input_literal).ConsumeValueOrDie(); - float expected = 0.0; - for (float item : input_data) { - expected += item; - } + float expected = absl::c_accumulate(input_data, 0.0f); ComputeAndCompareR0<float>(&builder, expected, {input_global_data.get()}, ErrorSpec(0.001)); } diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 4bae89cda01..b02eb89ebfc 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -106,6 +106,9 @@ load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test") # buildifier: disable=same-origin-load # Placeholder: load("//tensorflow:tensorflow.bzl", "tf_portable_proto_lib") +# buildifier: disable=same-origin-load +load("//tensorflow:tensorflow.bzl", "tf_monitoring_deps") + # For platform specific build config load( "//tensorflow/core/platform:build_config.bzl", @@ -1988,7 +1991,6 @@ cc_library( "//tensorflow/core/platform:hash", "//tensorflow/core/platform:load_library", "//tensorflow/core/platform:logger", - "//tensorflow/core/platform:monitoring", "//tensorflow/core/platform:mutex", "//tensorflow/core/platform:notification", "//tensorflow/core/platform:net", @@ -2026,7 +2028,7 @@ cc_library( "@zlib", "@double_conversion//:double-conversion", "@com_google_protobuf//:protobuf", - ] + tf_protos_all_impl() + tf_protos_grappler_impl() + tf_protos_profiler_impl(), + ] + tf_protos_all_impl() + tf_protos_grappler_impl() + tf_protos_profiler_impl() + tf_monitoring_deps(), # Alwayslink causes a cc_binary to "always link" in the # srcs for a given cc_library, even if they are unreferenced, see: # https://docs.bazel.build/versions/master/be/c-cpp.html#cc_library.alwayslink diff --git a/tensorflow/core/common_runtime/device_mgr.cc b/tensorflow/core/common_runtime/device_mgr.cc index b17278fb365..c7583c374f2 100644 --- a/tensorflow/core/common_runtime/device_mgr.cc +++ b/tensorflow/core/common_runtime/device_mgr.cc @@ -45,7 +45,7 @@ StaticDeviceMgr::StaticDeviceMgr(std::vector<std::unique_ptr<Device>> devices) } const auto& t = d->device_type(); device_type_counts_[t]++; - if (cpu_device_ == nullptr && t == "CPU") { + if (cpu_device_ == nullptr && t == "CPU" && d->parsed_name().id == 0) { cpu_device_ = d.get(); } } diff --git a/tensorflow/core/common_runtime/direct_session_test.cc b/tensorflow/core/common_runtime/direct_session_test.cc index 66b6bdde1c8..968713f8acd 100644 --- a/tensorflow/core/common_runtime/direct_session_test.cc +++ b/tensorflow/core/common_runtime/direct_session_test.cc @@ -2813,4 +2813,66 @@ TEST_F(DirectSessionCollectiveTest, ASSERT_EQ(key1, key2); } +// Accesses the cancellation manager for the step after the step has been +// cancelled. +class StatefulOutputRequiredOp : public OpKernel { + public: + explicit StatefulOutputRequiredOp(OpKernelConstruction* ctx) + : OpKernel(ctx) {} + void Compute(OpKernelContext* ctx) override { + // The op counts the number of outputs required in the current subgraph, + // and emits that number on each of its required outputs. + Tensor count_outputs_required_t(0LL); + int64& count_outputs_required = count_outputs_required_t.scalar<int64>()(); + for (int i = 0; i < num_outputs(); ++i) { + if (ctx->output_required(i)) ++count_outputs_required; + } + for (int i = 0; i < num_outputs(); ++i) { + if (ctx->output_required(i)) ctx->set_output(i, count_outputs_required_t); + } + } +}; + +REGISTER_KERNEL_BUILDER(Name("StatefulOutputRequired").Device(DEVICE_CPU), + StatefulOutputRequiredOp); +REGISTER_OP("StatefulOutputRequired") + .Output("results : num_outs * int64") + .Attr("num_outs : int = 5") + .SetIsStateful(); + +TEST(DirectSessionTest, TestStatefulOutputRequiredOp) { + GraphDef graph; + // Creates a graph with a StatefulOutputRequired op with 5 outputs. + protobuf::TextFormat::ParseFromString( + R"proto( + node { name: 'n' op: 'StatefulOutputRequired' device: '/device:CPU:0' } + versions { producer: 9 } + )proto", + &graph); + + std::unique_ptr<Session> session(NewSession(SessionOptions())); + ASSERT_TRUE(session != nullptr); + TF_ASSERT_OK(session->Create(std::move(graph))); + + // As a stateful op, a single StatefulOutputRequired kernel will be created + // and shared across multiple subgraphs. We create 5 different subgraphs, + // fetching different prefixes of the output of the op. + for (int num_outputs_required = 1; num_outputs_required <= 5; + ++num_outputs_required) { + std::vector<string> fetch_tensor_names; + fetch_tensor_names.reserve(num_outputs_required); + for (int output_idx = 0; output_idx < num_outputs_required; ++output_idx) { + fetch_tensor_names.push_back(strings::StrCat("n:", output_idx)); + } + std::vector<Tensor> fetch_tensors; + TF_ASSERT_OK(session->Run({}, fetch_tensor_names, {}, &fetch_tensors)); + ASSERT_EQ(num_outputs_required, fetch_tensors.size()); + for (const Tensor& t : fetch_tensors) { + ASSERT_EQ(num_outputs_required, t.scalar<int64>()()); + } + } + + TF_ASSERT_OK(session->Close()); +} + } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/dynamic_device_mgr.cc b/tensorflow/core/common_runtime/dynamic_device_mgr.cc index f7e2e27e4ab..4bea08bb021 100644 --- a/tensorflow/core/common_runtime/dynamic_device_mgr.cc +++ b/tensorflow/core/common_runtime/dynamic_device_mgr.cc @@ -194,7 +194,8 @@ Device* DynamicDeviceMgr::HostCPU() const { } cpu_device_ = nullptr; for (const auto& pair : dynamic_devices_) { - if (pair.first->device_type() == DEVICE_CPU) { + if (pair.first->device_type() == DEVICE_CPU && + pair.first->parsed_name().id == 0) { cpu_device_ = pair.first; break; } diff --git a/tensorflow/core/common_runtime/eager/BUILD b/tensorflow/core/common_runtime/eager/BUILD index 76e34173459..9211a022b24 100644 --- a/tensorflow/core/common_runtime/eager/BUILD +++ b/tensorflow/core/common_runtime/eager/BUILD @@ -135,6 +135,7 @@ tf_cuda_library( "//conditions:default": [ "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core/profiler/lib:traceme", ], }), ) diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc index fb051a9a583..5a053c2b51a 100644 --- a/tensorflow/core/common_runtime/eager/context.cc +++ b/tensorflow/core/common_runtime/eager/context.cc @@ -49,7 +49,6 @@ limitations under the License. #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/lib/core/blocking_counter.h" #include "tensorflow/core/lib/monitoring/gauge.h" -#include "tensorflow/core/platform/monitoring.h" #include "tensorflow/core/util/env_var.h" namespace tensorflow { @@ -102,7 +101,6 @@ EagerContext::EagerContext( // provided). For builds using "tensorflow/core/platform/default", this is // currently a no-op. eager_context_created->GetCell()->Set(true); - monitoring::StartExporter(); InitPrioritizedDeviceTypeList(); runner_ = [this](std::function<void()> closure) { this->thread_pool_->Schedule(std::move(closure)); @@ -167,6 +165,16 @@ std::vector<string> DevicesToString(const PrioritizedDeviceVector& devices) { return v; } +std::vector<string> DeviceTypesToString( + const PrioritizedDeviceTypeVector& types) { + std::vector<string> v; + v.reserve(types.size()); + for (const auto& p : types) { + v.push_back(p.first.type_string()); + } + return v; +} + // Selects the "best" device that both exists and is supported. // // The `existing` argument specifies the available devices in the system, in @@ -232,13 +240,17 @@ Status EagerContext::SelectDevice(DeviceNameUtils::ParsedName preferred, return errors::InvalidArgument( "Could not satisfy device specification '", preferred, "'. enable_soft_placement=", AllowSoftPlacement(), - ". All available devices [", + ". Supported device types [", + absl::StrJoin(DeviceTypesToString(supported), ", "), + "]. All available devices [", absl::StrJoin(DevicesToString(existing), ", "), "]."); } return errors::InvalidArgument( "No supported device found in available devices [", absl::StrJoin(DevicesToString(existing), ", "), - "]. enable_soft_placement=", AllowSoftPlacement(), "."); + "]. enable_soft_placement=", AllowSoftPlacement(), + ". Supported devices types [", + absl::StrJoin(DeviceTypesToString(supported), ", "), "]."); } void EagerContext::ResetClusterFLR( diff --git a/tensorflow/core/common_runtime/eager/copy_to_device_node.h b/tensorflow/core/common_runtime/eager/copy_to_device_node.h index 4c4a1e13266..1a1459a9f1c 100644 --- a/tensorflow/core/common_runtime/eager/copy_to_device_node.h +++ b/tensorflow/core/common_runtime/eager/copy_to_device_node.h @@ -27,15 +27,25 @@ namespace tensorflow { class CopyToDeviceNode : public EagerNode { public: CopyToDeviceNode(TensorHandle* src, TensorHandle* dst, Device* dstd, - const EagerContext& ctx) - : EagerNode(), src_(src), dst_(dst), dstd_(dstd), ctx_(ctx) { - src_->Ref(); - dst_->Ref(); + const EagerContext& ctx, bool async, bool mirror) + : EagerNode(), + src_(src), + dst_(dst), + dstd_(dstd), + ctx_(ctx), + async_(async), + mirror_(mirror) { + if (async_) { + src_->Ref(); + dst_->Ref(); + } } ~CopyToDeviceNode() override { - src_->Unref(); - dst_->Unref(); + if (async_) { + src_->Unref(); + dst_->Unref(); + } } Status Run() override { @@ -43,16 +53,20 @@ class CopyToDeviceNode : public EagerNode { MEMDEBUG_CACHE_OP(MEMDEBUG_CACHE_VAL ? MEMDEBUG_CACHE_VAL : "eager::CopyToDeviceNode"); TF_RETURN_IF_ERROR(src_->CopyToDevice(ctx_, dstd_, &tensor)); - return dst_->SetTensor(std::move(tensor), ctx_.CanonicalDevice(dstd_)); + if (!async_ && mirror_) { + return dst_->AddLocalMirror(std::move(tensor), dstd_); + } else { + return dst_->SetTensor(std::move(tensor), dstd_); + } } - void Abort(Status status) override { dst_->Poison(status); } + void Abort(Status status) override { dst_->Poison(status, dstd_); } string DebugString() const override { string out = "[CopyToDeviceNode]"; strings::StrAppend(&out, " src_tensor: ", src_->DebugString()); strings::StrAppend(&out, ", dst_tensor: ", dst_->DebugString()); - strings::StrAppend(&out, ", dst_device: ", dstd_->name()); + strings::StrAppend(&out, ", dst_device: ", dstd_ ? dstd_->name() : "[]"); return out; } @@ -63,6 +77,8 @@ class CopyToDeviceNode : public EagerNode { TensorHandle* dst_; Device* dstd_; const EagerContext& ctx_; + bool async_; + bool mirror_; }; } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/eager/eager_executor.cc b/tensorflow/core/common_runtime/eager/eager_executor.cc index 2e50bd9de49..d49c9a5064b 100644 --- a/tensorflow/core/common_runtime/eager/eager_executor.cc +++ b/tensorflow/core/common_runtime/eager/eager_executor.cc @@ -19,8 +19,17 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/cleanup.h" +#include "tensorflow/core/util/env_var.h" namespace tensorflow { +namespace { +bool IsAsyncWaitForRemoteFunctionEnabled() { + bool enabled = true; + TF_CHECK_OK(ReadBoolFromEnvVar("TF_ENABLE_ASYNC_WAIT_FOR_REMOTE_FUNCTION", + true, &enabled)); + return enabled; +} +} // namespace EagerExecutor::EagerExecutor(bool async) : next_node_id_(0), @@ -28,7 +37,10 @@ EagerExecutor::EagerExecutor(bool async) thread_(async ? tensorflow::Env::Default()->StartThread( tensorflow::ThreadOptions(), "eager_async_executor", std::bind(&EagerExecutor::Run, this)) - : nullptr) {} + : nullptr), + last_eager_client_(nullptr), + enable_async_wait_for_remote_function_( + IsAsyncWaitForRemoteFunctionEnabled()) {} EagerExecutor::~EagerExecutor() { tensorflow::mutex_lock l(node_queue_mutex_); @@ -194,6 +206,7 @@ void EagerExecutor::ClearError() { DCHECK(node_queue_.empty()); status_ = tensorflow::Status::OK(); ok_ = true; + last_eager_client_ = nullptr; nodes_pending_.notify_all(); } @@ -327,6 +340,33 @@ Status EagerExecutor::RunItem(core::RefCountPtr<NodeItem> item, bool from_queue) { DVLOG(3) << "Running Node: [id " << item->id << "] " << item->node->DebugString(); + AsyncRemoteExecuteNode* async_remote_node = + item->node->AsAsyncRemoteExecuteNode(); + if (enable_async_wait_for_remote_function_) { + if (async_remote_node != nullptr) { + if (last_eager_client_ != nullptr && + async_remote_node->eager_client() != nullptr && + last_eager_client_ != async_remote_node->eager_client()) { + // Running a remote function, need to sync if the function is going to + // different device than last time we run remote distributed function. + DVLOG(3) << "Executing Sync Executor for node" << item->id; + tensorflow::Status status = async_remote_node->SyncExecutors(); + if (!status.ok()) { + NodeDone(item, status, from_queue); + return status; + } + last_eager_client_ = nullptr; + } + if (async_remote_node->eager_client() != nullptr && + async_remote_node->needs_remote_inputs() && + async_remote_node->allow_multiple_pending_requests()) { + // We are running remote distributed function, update + // last_remote_device_name_. + last_eager_client_ = async_remote_node->eager_client(); + } + } + } + AsyncEagerNode* async_node = item->node->AsAsync(); if (async_node == nullptr) { tensorflow::Status status = item->node->Run(); diff --git a/tensorflow/core/common_runtime/eager/eager_executor.h b/tensorflow/core/common_runtime/eager/eager_executor.h index e414ab4b6c5..375e9a6e6a7 100644 --- a/tensorflow/core/common_runtime/eager/eager_executor.h +++ b/tensorflow/core/common_runtime/eager/eager_executor.h @@ -39,6 +39,10 @@ limitations under the License. namespace tensorflow { class AsyncEagerNode; +class AsyncRemoteExecuteNode; +namespace eager { +class EagerClient; +} // A unit of execution for the EagerExecutor class below. Example subclasses // encapsulate execution of a TFE_Op, or copying a TFE_TensorHandle from one @@ -65,6 +69,7 @@ class EagerNode { // Returns nullptr iff this Eager node is synchronous. virtual AsyncEagerNode* AsAsync() { return nullptr; } + virtual AsyncRemoteExecuteNode* AsAsyncRemoteExecuteNode() { return nullptr; } virtual string DebugString() const = 0; @@ -86,6 +91,16 @@ class AsyncEagerNode : public EagerNode { } }; +class AsyncRemoteExecuteNode : public AsyncEagerNode { + public: + AsyncRemoteExecuteNode* AsAsyncRemoteExecuteNode() final { return this; } + + virtual const eager::EagerClient* eager_client() const = 0; + virtual bool needs_remote_inputs() const = 0; + virtual bool allow_multiple_pending_requests() const = 0; + virtual Status SyncExecutors() = 0; +}; + // A class for handling async execution (see TFE_ContextSetAsync). // Note that this class is thread-safe. // TODO(agarwal): TFE_OpAddInput may currently block if it tries to access the @@ -228,6 +243,11 @@ class EagerExecutor { // Thread object that calls the `Run` method in async mode.This thread runs // until state_ is set to kShuttingDown. It is `nullptr` in sync mode. const std::unique_ptr<Thread> thread_; + + // Last device where remote function with remote inputs was executed. + const eager::EagerClient* last_eager_client_; + + const bool enable_async_wait_for_remote_function_; }; inline bool EagerExecutor::Async() const { return thread_ != nullptr; } diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc index 3496a39714d..4f1a9cfc2cc 100644 --- a/tensorflow/core/common_runtime/eager/execute.cc +++ b/tensorflow/core/common_runtime/eager/execute.cc @@ -769,9 +769,11 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals, id, i, remote_task, output_dtypes[i], op_device, &ctx, &retvals[i]); if (!status.ok()) { for (int j = 0; j < i; ++j) { - retvals[j]->Poison(errors::Internal( - "Failed to construct unshaped remote tensor handle at index ", i, - " for op ", op->Name())); + retvals[j]->PoisonRemote( + errors::Internal( + "Failed to construct unshaped remote tensor handle at index ", + i, " for op ", op->Name()), + op_device, ctx.GetContextViewId()); } return status; } @@ -796,7 +798,8 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals, << " (is async?: " << executor.Async() << ")."; std::unique_ptr<EagerNode> node(new eager::RemoteExecuteNode( - std::move(request), op_device, ctx.GetContextViewId(), eager_client.get(), + &op->EagerContext(), std::move(request), op_device, + ctx.GetContextViewId(), eager_client.get(), op->MutableAttrs()->BuildNodeDef(), op->EagerContext().FuncLibDef(), op->Inputs(), {retvals, num_outputs})); Status s = executor.AddOrExecute(std::move(node)); @@ -1057,11 +1060,15 @@ Status LocalEagerCopyToDevice(TensorHandle* h, EagerContext* ctx, return Status::OK(); } - // TODO(gjn): Need to add support for async execution. Note if receiver - // is local, we need to first add support in TensorHandle to wait on local - // mirrors. - if (mirror && !executor->Async()) { - TF_RETURN_IF_ERROR(h->AddEmptyLocalMirror(d)); + bool async = executor->Async(); + if (mirror) { + // We don't bother adding an empty local mirror in sync mode since we'll be + // executing the operation directly and be calling AddLocalMirror. A + // reference count is still needed which will be removed if the operation + // fails. + if (async) { + TF_RETURN_IF_ERROR(h->AddEmptyLocalMirror(d)); + } h->Ref(); *result = h; } else { @@ -1069,10 +1076,18 @@ Status LocalEagerCopyToDevice(TensorHandle* h, EagerContext* ctx, true, d, dstd, h->resource_device(), h->dtype, ctx, result)); } - // Note that `h` may not be currently ready. However execution order will - // make sure that `h` is ready before the copy is actually done. - std::unique_ptr<EagerNode> node(new CopyToDeviceNode(h, *result, dstd, *ctx)); - Status s = executor->AddOrExecute(std::move(node)); + Status s; + if (async) { + // Note that `h` may not be currently ready. However execution order will + // make sure that `h` is ready before the copy is actually done. + std::unique_ptr<EagerNode> node( + new CopyToDeviceNode(h, *result, d, *ctx, async, mirror)); + s = executor->AddOrExecute(std::move(node)); + } else { + CopyToDeviceNode node(h, *result, d, *ctx, async, mirror); + s = executor->SyncExecute(&node); + } + // Since the operation failed, we need to Unref any outputs that were // allocated. if (!s.ok()) { @@ -1121,7 +1136,7 @@ Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx, // TODO(gjn): Need to add support for async execution. Note if receiver // is local, we need to first add support in TensorHandle to wait on local // mirrors. - if (mirror && !executor->Async()) { + if (mirror) { TF_RETURN_IF_ERROR(h->AddEmptyLocalMirror(d)); h->Ref(); *result = h; @@ -1159,7 +1174,7 @@ Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx, } } - auto node = absl::make_unique<eager::RemoteCopyNode>( + auto node = std::make_unique<eager::RemoteCopyNode>( ctx, executor, h, result[0], device, recv_op_id); Status s = executor->AddOrExecute(std::move(node)); if (!s.ok()) { diff --git a/tensorflow/core/common_runtime/eager/execute_node.h b/tensorflow/core/common_runtime/eager/execute_node.h index 7e5340575c9..ed1bd956179 100644 --- a/tensorflow/core/common_runtime/eager/execute_node.h +++ b/tensorflow/core/common_runtime/eager/execute_node.h @@ -189,8 +189,10 @@ class AsyncExecuteNode : public EagerNode { } void Abort(Status status) override { + int i = 0; for (auto handle : retvals_) { - handle->Poison(status); + handle->Poison(status, ctx_->CanonicalDevice(kernel_->OutputDevice(i))); + ++i; } } diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.cc b/tensorflow/core/common_runtime/eager/tensor_handle.cc index 4dbf5d6313d..47a7125ced8 100644 --- a/tensorflow/core/common_runtime/eager/tensor_handle.cc +++ b/tensorflow/core/common_runtime/eager/tensor_handle.cc @@ -319,21 +319,23 @@ Status TensorHandle::WaitReady(const char* caller) const { if (!IsReady()) { profiler::TraceMe activity(absl::StrCat(caller, " WaitReady"), profiler::TraceMeLevel::kInfo); - DVLOG(3) << "Waiting on TensorHandle " << this; tf_shared_lock l(mu_); mu_.Await(Condition(&is_ready_)); - DVLOG(3) << "TensorHandle ready: " << this; } return is_poisoned_; } Status TensorHandle::Tensor(const tensorflow::Tensor** t) const { + DVLOG(3) << "Tensor on TensorHandle: " << this; + TF_RETURN_IF_ERROR(WaitReady("TensorHandle::Tensor")); return tensor_handle_data_->Tensor(t); } Status TensorHandle::TensorFromDevice(const Device* d, const tensorflow::Tensor** t) const { + DVLOG(3) << "TensorFromDevice on TensorHandle: " << this << " device: " << d; + if (d == absl::get<Device*>(device_)) { if (is_remote_) { return errors::Internal("Invalid Tensor call on remote handle: ", this); @@ -344,19 +346,20 @@ Status TensorHandle::TensorFromDevice(const Device* d, } tf_shared_lock l(mu_); - auto mirror = local_mirrors_.find(d); - if (mirror != local_mirrors_.end()) { - return mirror->second->Tensor(t); + auto elem = local_mirrors_.find(d); + if (elem == local_mirrors_.end()) { + return errors::Internal("Invalid device: ", d, + " in Tensor call to handle: ", this); } - auto empty_mirror = empty_local_mirrors_.find(d); - if (empty_mirror != empty_local_mirrors_.end()) { - // TODO(gjn): Add support for waiting on local mirrors - return errors::Internal("Attempted to get Tensor for empty mirror"); + // Check if the handle is non-empty, else wait. + auto& mirror = elem->second; + if (mirror.second == nullptr) { + TF_RETURN_IF_ERROR( + mirror.first->WaitReady("TensorHandle::TensorFromDevice")); } - return errors::Internal("Invalid device: ", d, - " in Tensor call to handle: ", this); + return mirror.second->Tensor(t); } Status TensorHandle::TensorValue(tensorflow::TensorValue* t, const Device* d) { @@ -371,19 +374,19 @@ Status TensorHandle::TensorValue(tensorflow::TensorValue* t, const Device* d) { } tf_shared_lock l(mu_); - auto mirror = local_mirrors_.find(d); - if (mirror != local_mirrors_.end()) { - return mirror->second->TensorValue(t); + auto elem = local_mirrors_.find(d); + if (elem == local_mirrors_.end()) { + return errors::Internal("Invalid device: ", d, + " in TensorValue call to handle: ", this); } - auto empty_mirror = empty_local_mirrors_.find(d); - if (empty_mirror != empty_local_mirrors_.end()) { - // TODO(gjn): Add support for waiting on local mirrors - return errors::Internal("Attempted to get TensorValue for empty mirror"); + // Check if the handle is non-empty, else wait. + auto& mirror = elem->second; + if (mirror.second == nullptr) { + TF_RETURN_IF_ERROR(mirror.first->WaitReady("TensorHandle::TensorValue")); } - return errors::Internal("Invalid device: ", d, - " in TensorValue call to handle: ", this); + return mirror.second->TensorValue(t); } TensorHandle::VariantDevice TensorHandle::DeviceOrHostCPU( @@ -506,53 +509,50 @@ Status TensorHandle::NumElements(int64* num_elements) const { } Status TensorHandle::Unprotect(const Device* d) { + DVLOG(3) << "Unprotect on TensorHandle: " << this << " device: " << d; + if (d == absl::get<Device*>(device_)) { return tensor_handle_data_->Unprotect(); } tf_shared_lock l(mu_); - auto mirror = local_mirrors_.find(d); - if (mirror != local_mirrors_.end()) { - return mirror->second->Unprotect(); + auto elem = local_mirrors_.find(d); + if (elem == local_mirrors_.end()) { + return errors::Internal("Invalid device: ", d, + " in Unprotect call to handle: ", this); } - auto empty_mirror = empty_local_mirrors_.find(d); - if (empty_mirror != empty_local_mirrors_.end()) { + // Check if the handle is non-empty + auto& mirror = elem->second; + if (mirror.second == nullptr) { return errors::Internal("Attempted to unprotect an empty mirror"); } - return errors::Internal("Invalid device: ", d, - " in Unprotect call to handle: ", this); + return mirror.second->Unprotect(); } bool TensorHandle::HasLocalMirror(const Device* d) const { - mutex_lock l(mu_); - auto mirror = local_mirrors_.find(d); - if (mirror != local_mirrors_.end()) { - return true; - } + DVLOG(3) << "HasLocalMirror on TensorHandle: " << this << " device: " << d; - auto empty_mirror = empty_local_mirrors_.find(d); - if (empty_mirror != empty_local_mirrors_.end()) { - return true; - } - - return false; + tf_shared_lock l(mu_); + return local_mirrors_.find(d) != local_mirrors_.end(); } Status TensorHandle::AddEmptyLocalMirror(const Device* d) { DVLOG(3) << "AddEmptyLocalMirror on TensorHandle: " << this << " device: " << d; + if (!VariantDeviceIsCustom(device_) && d == absl::get<Device*>(device_)) { + return errors::Internal("Cannot add mirror for primary device."); + } + mutex_lock l(mu_); if (local_mirrors_.find(d) != local_mirrors_.end()) { return errors::Internal("Attempted to duplicate a local mirror."); } - auto ret = empty_local_mirrors_.insert(d); - if (!ret.second) { - return errors::Internal("Attempted to duplicate an empty local mirror."); - } + local_mirrors_[d] = + std::make_pair(std::make_unique<EmptyLocalTensorHandleData>(), nullptr); return Status::OK(); } @@ -560,6 +560,9 @@ Status TensorHandle::AddEmptyLocalMirror(const Device* d) { #if !defined(IS_MOBILE_PLATFORM) Status TensorHandle::RemoteAddress(const Device* d, int64* op_id, int32* output_num) const { + DVLOG(3) << "RemoteAddress on TensorHandle: " << this << " device: " << d + << " " << d->name(); + if (VariantDeviceIsCustom(device_) || d != absl::get<Device*>(device_)) { tf_shared_lock l(mu_); auto mirror = remote_mirrors_.find(d->name()); @@ -593,7 +596,8 @@ Status TensorHandle::RemoteAddress(const Device* d, int64* op_id, bool TensorHandle::HasRemoteMirror(const Device* d, uint64 context_view_id) const { - DVLOG(3) << "HasRemoteMirror on TensorHandle: " << this; + DVLOG(3) << "HasRemoteMirror on TensorHandle: " << this << " device: " << d + << " " << d->name(); tf_shared_lock l(mu_); auto mirror = remote_mirrors_.find(d->name()); @@ -619,7 +623,8 @@ bool TensorHandle::HasRemoteMirror(const Device* d, bool TensorHandle::HasResourceShapeMirror(const Device* d, uint64 context_view_id) const { - DVLOG(3) << "HasResourceShapeMirror on TensorHandle: " << this; + DVLOG(3) << "HasResourceShapeMirror on TensorHandle: " << this + << " device: " << d << " " << d->name(); tf_shared_lock l(mu_); auto mirror = resource_shape_mirrors_.find(d->name()); @@ -635,7 +640,8 @@ bool TensorHandle::HasResourceShapeMirror(const Device* d, Status TensorHandle::AddUnshapedRemoteMirror( std::unique_ptr<UnshapedRemoteTensorHandleData> t, const Device* d) { - DVLOG(3) << "AddUnshapedRemoteMirror on TensorHandle: " << this; + DVLOG(3) << "AddUnshapedRemoteMirror on TensorHandle: " << this + << " device: " << d << " " << d->name(); mutex_lock l(mu_); auto remote_mirror = remote_mirrors_.find(d->name()); @@ -704,7 +710,8 @@ Status TensorHandle::AddRemoteMirror(std::unique_ptr<RemoteTensorHandleData> t, Status TensorHandle::SetRemoteShape(const TensorShape& shape, const Device* d, uint64 context_view_id) { - DVLOG(3) << "SetRemoteShape on TensorHandle: " << this << " device: " << d; + DVLOG(3) << "SetRemoteShape on TensorHandle: " << this << " device: " << d + << " " << d->name(); if (VariantDeviceIsCustom(device_) || d != absl::get<Device*>(device_)) { mutex_lock l(mu_); @@ -758,16 +765,57 @@ Status TensorHandle::SetRemoteShape(const TensorShape& shape, const Device* d, return Status::OK(); } + +void TensorHandle::PoisonRemote(Status status, const Device* d, + uint64 context_view_id) { + DVLOG(3) << "PoisonRemote on TensorHandle: " << this << " device: " << d + << " " << d->name(); + + if (!VariantDeviceIsCustom(device_) && d == absl::get<Device*>(device_)) { + DCHECK(!is_async_ || !IsReady()) + << "PoisonRemote can only be called on non-ready handle: " << this; + + is_poisoned_ = status; + mutex_lock l(mu_); + is_ready_ = true; + } else { + tf_shared_lock l(mu_); + auto mirror = unshaped_remote_mirrors_.find(d->name()); + if (mirror != unshaped_remote_mirrors_.end()) { + if (mirror->second->context_view_id() == context_view_id) { + mirror->second->Poison(status); + } + } + } +} #endif +Status TensorHandle::AddLocalMirror(tensorflow::Tensor&& tensor, + const Device* d) { + if (d == absl::get<Device*>(device_)) { + return errors::Internal( + "Local mirror assign conflicts with primary device."); + } + + mutex_lock l(mu_); + auto elem = local_mirrors_.insert(std::make_pair( + d, std::make_pair(nullptr, + std::make_unique<LocalTensorHandleData>(tensor)))); + if (!elem.second) { + return errors::Internal("Attempted to set tensor for existing mirror."); + } + + return Status::OK(); +} + Status TensorHandle::SetTensor(tensorflow::Tensor&& tensor, const Device* d) { + DVLOG(3) << "SetTensor on TensorHandle: " << this << " device: " << d; + if (d == absl::get<Device*>(device_)) { DCHECK(!is_remote_) << "SetTensor is not called on remote handles."; DCHECK(!is_async_ || !IsReady()) << "SetTensor is only called on non-ready handles."; - DVLOG(3) << "SetTensor on TensorHandle: " << this; - if (tensor.dtype() == DT_RESOURCE && tensor.NumElements() > 0) { auto& resource_handle = tensor.flat<class ResourceHandle>()(0); handle_dtypes_and_shapes_ = resource_handle.dtypes_and_shapes(); @@ -779,37 +827,53 @@ Status TensorHandle::SetTensor(tensorflow::Tensor&& tensor, const Device* d) { is_ready_ = true; } } else { - mutex_lock l(mu_); - if (local_mirrors_.find(d) != local_mirrors_.end()) { - return errors::Internal("Attempted to set tensor for existing mirror."); - } - - auto elem = empty_local_mirrors_.find(d); - if (elem == empty_local_mirrors_.end()) { + tf_shared_lock l(mu_); + auto elem = local_mirrors_.find(d); + if (elem == local_mirrors_.end()) { return errors::Internal( "Attempted to set tensor for non-existent local mirror."); } - local_mirrors_[d] = absl::make_unique<LocalTensorHandleData>(tensor); - empty_local_mirrors_.erase(elem); + + auto& mirror = elem->second; + if (mirror.second != nullptr) { + return errors::Internal("Attempted to set tensor for existing mirror."); + } + + mirror.second = absl::make_unique<LocalTensorHandleData>(tensor); + mirror.first->SetReady(); } return Status::OK(); } -void TensorHandle::Poison(Status status) { - DCHECK(!is_async_ || !IsReady()) - << "Poison(status) can only be called on non-ready handle: " << this; +void TensorHandle::Poison(Status status, const Device* d) { + DVLOG(3) << "Poison on TensorHandle: " << this << " device: " << d; - DVLOG(3) << "Poison on TensorHandle: " << this; + if (!VariantDeviceIsCustom(device_) && d == absl::get<Device*>(device_)) { + DCHECK(!is_async_ || !IsReady()) + << "Poison can only be called on non-ready handle: " << this; - is_poisoned_ = status; - mutex_lock l(mu_); - is_ready_ = true; + is_poisoned_ = status; + mutex_lock l(mu_); + is_ready_ = true; + } else { + tf_shared_lock l(mu_); + auto elem = local_mirrors_.find(d); + DCHECK(elem != local_mirrors_.end()) + << "Attempted to poison non-existent local mirror, handle: " << this + << " device: " << d; + + auto& mirror = elem->second; + DCHECK(mirror.second == nullptr) << "Attempted to poison existing mirror."; + + mirror.first->Poison(status); + } } Status TensorHandle::CopyToDevice(const EagerContext& ctx, - tensorflow::Device* dstd, + tensorflow::Device* d, tensorflow::Tensor* output) { + tensorflow::Device* dstd = (d == nullptr) ? ctx.HostCPU() : d; tensorflow::Device* srcd = absl::get<Device*>(DeviceOrHostCPU(ctx)); const bool dst_cpu = dstd->tensorflow_gpu_device_info() == nullptr; const bool src_cpu = srcd->tensorflow_gpu_device_info() == nullptr; diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.h b/tensorflow/core/common_runtime/eager/tensor_handle.h index 817347adc92..1e5b741cbb7 100644 --- a/tensorflow/core/common_runtime/eager/tensor_handle.h +++ b/tensorflow/core/common_runtime/eager/tensor_handle.h @@ -153,6 +153,9 @@ class TensorHandle : public core::RefCounted { // Add an empty mirror placeholder for the specified device. The expectation // is this will be populated by a call to SetTensor. Status AddEmptyLocalMirror(const Device* d); + // Add a local mirror. This will fail if an empty local mirror was previously + // added. For that case, SetTensor should be used instead. + Status AddLocalMirror(tensorflow::Tensor&& tensor, const Device* d); #if !defined(IS_MOBILE_PLATFORM) bool HasRemoteMirror(const Device* d, uint64 context_view_id) const; @@ -176,6 +179,13 @@ class TensorHandle : public core::RefCounted { // were created without a known shape. Status SetRemoteShape(const TensorShape& shape, const Device* d, uint64 context_view_id); + + // Poisons either this handle or a remote mirror with error `status`. + // Poisoning means that the handle will become ready and methods trying + // to access the remote shape will return this error `status`. + // Exactly one of SetRemoteShape or PoisonRemote methods must be called on a + // unshaped handle on a remote device. + void PoisonRemote(Status status, const Device* d, uint64 context_view_id); #endif // Sets the `tensor` for this async non-ready handle making it ready. @@ -183,14 +193,14 @@ class TensorHandle : public core::RefCounted { // handles to make them ready. Status SetTensor(tensorflow::Tensor&& tensor, const Device* d); - // Poisons this non-ready handle with an error `status`. + // Poisons either this handle or a local mirror with error `status`. // Poisoning means that the handle will become ready and methods trying // to access the actual tensor or shape will return this error `status`. - // Exactly one of SetTensor, SetRemoteShape, or Poison methods must be called - // on a non-ready tensor. - void Poison(Status status); + // Exactly one of SetTensor or Poison methods must be called on a non-ready + // tensor for a specific device. + void Poison(Status status, const Device* d); - Status CopyToDevice(const EagerContext& ctx, tensorflow::Device* dstd, + Status CopyToDevice(const EagerContext& ctx, tensorflow::Device* d, tensorflow::Tensor* output); Status InferenceShape( @@ -265,9 +275,14 @@ class TensorHandle : public core::RefCounted { mutable mutex mu_; - std::map<const tensorflow::Device*, std::unique_ptr<LocalTensorHandleData>> + // Map of local mirrors. In sync mode the EmptyLocalTensorHandleData is + // nullptr. In async mode, we use the EmptyLocalTensorHandleData to manage + // waiting clients. Once the EmptyLocalTensorHandleData is "ready" only the + // LocalTensorHandleData should be used. + std::map<const tensorflow::Device*, + std::pair<std::unique_ptr<EmptyLocalTensorHandleData>, + std::unique_ptr<LocalTensorHandleData>>> local_mirrors_ GUARDED_BY(mu_); - std::set<const tensorflow::Device*> empty_local_mirrors_ GUARDED_BY(mu_); #if !defined(IS_MOBILE_PLATFORM) // TODO(yujingzhang): Remove resource_shape_mirrors_ once scalable per-replica // variable is ready, since we could get the shape locally without remote copy diff --git a/tensorflow/core/common_runtime/eager/tensor_handle_data.cc b/tensorflow/core/common_runtime/eager/tensor_handle_data.cc index b6d17e1ee1a..690c14e2ffd 100644 --- a/tensorflow/core/common_runtime/eager/tensor_handle_data.cc +++ b/tensorflow/core/common_runtime/eager/tensor_handle_data.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/eager/eager_executor.h" #include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/profiler/lib/traceme.h" namespace tensorflow { @@ -104,6 +105,32 @@ Status EmptyLocalTensorHandleData::Unprotect() { return errors::Unavailable("Unable to unprotect an empty handle."); } +bool EmptyLocalTensorHandleData::IsReady() const { + tf_shared_lock l(mu_); + return is_ready_; +} + +void EmptyLocalTensorHandleData::SetReady() { + mutex_lock l(mu_); + is_ready_ = true; +} + +Status EmptyLocalTensorHandleData::WaitReady(const char* caller) const { + if (!IsReady()) { + profiler::TraceMe activity(absl::StrCat(caller, " WaitReady"), + profiler::TraceMeLevel::kInfo); + tf_shared_lock l(mu_); + mu_.Await(Condition(&is_ready_)); + } + return is_poisoned_; +} + +void EmptyLocalTensorHandleData::Poison(Status status) { + is_poisoned_ = status; + mutex_lock l(mu_); + is_ready_ = true; +} + string EmptyLocalTensorHandleData::DebugString() const { return "EmptyLocalTensorHandleData"; } diff --git a/tensorflow/core/common_runtime/eager/tensor_handle_data.h b/tensorflow/core/common_runtime/eager/tensor_handle_data.h index 5e600cc8818..3a791d94315 100644 --- a/tensorflow/core/common_runtime/eager/tensor_handle_data.h +++ b/tensorflow/core/common_runtime/eager/tensor_handle_data.h @@ -57,7 +57,9 @@ class LocalTensorHandleData : public TensorHandleData { Status NumElements(int64* num_elements) const override; Status Unprotect() override; - string DebugString() const override { return tensor_.DebugString(); } + string DebugString() const override { + return tensor_.DeviceSafeDebugString(); + } private: tensorflow::Tensor tensor_; @@ -87,7 +89,18 @@ class EmptyLocalTensorHandleData : public TensorHandleData { Status NumElements(int64* num_elements) const override; Status Unprotect() override; + bool IsReady() const; + void SetReady(); + Status WaitReady(const char* caller) const; + void Poison(Status status); + Status IsPoisoned() const { return is_poisoned_; } + string DebugString() const override; + + private: + mutable mutex mu_; + bool is_ready_ GUARDED_BY(mu_); + Status is_poisoned_; }; } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc index 0be1d5df616..4a8d38c8e53 100644 --- a/tensorflow/core/common_runtime/executor.cc +++ b/tensorflow/core/common_runtime/executor.cc @@ -197,6 +197,10 @@ struct NodeItem { // Number of output edges. size_t num_output_edges; + // If non-null, contains an array of num_outputs bools, where the ith bool + // is true if and only if the ith output is consumed by another node. + std::unique_ptr<bool[]> outputs_required; + const EdgeInfo* output_edge_list() const { return output_edge_base(); } // ith output edge. @@ -708,16 +712,24 @@ Status ExecutorImpl::Initialize(const Graph& graph) { } // Record information about whether each output of the op is used. - std::vector<bool> used_outputs(n->num_outputs(), false); + std::unique_ptr<bool[]> outputs_required(new bool[n->num_outputs()]); + std::fill(&outputs_required[0], &outputs_required[n->num_outputs()], false); + size_t unused_outputs = n->num_outputs(); for (const Edge* e : n->out_edges()) { if (e->src_output() >= 0) { - used_outputs[e->src_output()] = true; + if (!outputs_required[e->src_output()]) { + --unused_outputs; + outputs_required[e->src_output()] = true; + } } } - for (bool used_output : used_outputs) { - if (!used_output) { - metrics::RecordUnusedOutput(n->type_string()); + if (unused_outputs > 0) { + for (int i = 0; i < n->num_outputs(); ++i) { + if (!outputs_required[i]) { + metrics::RecordUnusedOutput(n->type_string()); + } } + item->outputs_required = std::move(outputs_required); } } @@ -1823,6 +1835,7 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) { params.is_input_dead = is_input_dead; params.output_attr_array = item.output_attrs(); params.forward_from_array = item.forward_from(); + params.outputs_required_array = item.outputs_required.get(); if (item.kernel_is_async) { // Asynchronous computes. @@ -2093,9 +2106,10 @@ Status ExecutorState::ProcessOutputs(const NodeItem& item, OpKernelContext* ctx, for (int i = 0; i < item.num_outputs; ++i) { const TensorValue val = ctx->release_output(i); if (val.tensor == nullptr) { - // Unless it's a Switch or a Recv, the node must produce a - // tensor value at i-th output. - if (!item.is_recv_or_switch) { + // Unless it's a Switch or a Recv, or the executor has marked the output + // as not required, the node must produce a tensor value at i-th output. + if (!(item.is_recv_or_switch || + (item.outputs_required && !item.outputs_required[i]))) { s.Update(errors::Internal("Missing ", i, "-th output from ", FormatNodeDefForError(item.kernel->def()))); } diff --git a/tensorflow/core/common_runtime/session.cc b/tensorflow/core/common_runtime/session.cc index 575fafdbcde..3817736020e 100644 --- a/tensorflow/core/common_runtime/session.cc +++ b/tensorflow/core/common_runtime/session.cc @@ -21,7 +21,6 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/monitoring/gauge.h" #include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/monitoring.h" namespace tensorflow { namespace { @@ -71,7 +70,6 @@ Session* NewSession(const SessionOptions& options) { // provided). For builds using "tensorflow/core/platform/default", this is // currently a no-op. session_created->GetCell()->Set(true); - monitoring::StartExporter(); Session* out_session; s = NewSession(options, &out_session); if (!s.ok()) { @@ -93,7 +91,6 @@ Status NewSession(const SessionOptions& options, Session** out_session) { // provided). For builds using "tensorflow/core/platform/default", this is // currently a no-op. session_created->GetCell()->Set(true); - monitoring::StartExporter(); s = factory->NewSession(options, out_session); if (!s.ok()) { *out_session = nullptr; diff --git a/tensorflow/core/distributed_runtime/eager/eager_client.h b/tensorflow/core/distributed_runtime/eager/eager_client.h index 3b083f3cae6..5f260e477d6 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_client.h +++ b/tensorflow/core/distributed_runtime/eager/eager_client.h @@ -57,6 +57,8 @@ class EagerClient : public core::RefCounted { virtual void StreamingEnqueueAsync(const EnqueueRequest* request, EnqueueResponse* response, StatusCallback done) = 0; + + virtual bool allow_multiple_pending_requests() const = 0; }; // Simple wrapper class that can be used to retrieve EagerClients. diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc index 686f471ca5e..41db645507b 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc +++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc @@ -96,6 +96,8 @@ class FakeEagerClient : public EagerClient { done(impl_->Enqueue(request, response)); } + bool allow_multiple_pending_requests() const override { return false; } + private: TestEagerServiceImpl* impl_; }; diff --git a/tensorflow/core/distributed_runtime/eager/remote_copy_node.cc b/tensorflow/core/distributed_runtime/eager/remote_copy_node.cc index f5fc68a8e38..7bb001ef853 100644 --- a/tensorflow/core/distributed_runtime/eager/remote_copy_node.cc +++ b/tensorflow/core/distributed_runtime/eager/remote_copy_node.cc @@ -200,11 +200,12 @@ void RemoteCopyNode::RunRemoteRecv(EagerOperation* op, StatusCallback done) { auto* remote_op = request.add_queue()->mutable_operation(); PrepareRemoteOp(remote_op, op); remote_op->set_id(recv_op_id_); + uint64 context_view_id = ctx_->GetContextViewId(); core::RefCountPtr<eager::EagerClient> eager_client; Status status = ctx_->GetClient(recv_device_, &eager_client); if (!status.ok()) { - captured_state_->dst()->Poison(status); + captured_state_->dst()->PoisonRemote(status, recv_device_, context_view_id); done(status); return; } @@ -216,7 +217,7 @@ void RemoteCopyNode::RunRemoteRecv(EagerOperation* op, StatusCallback done) { // Blocks until send has completed. Status send_status = captured_state_->GetSendStatus(); if (!send_status.ok()) { - captured_state_->dst()->Poison(send_status); + captured_state_->dst()->PoisonRemote(status, recv_device_, context_view_id); done(send_status); return; } @@ -224,7 +225,6 @@ void RemoteCopyNode::RunRemoteRecv(EagerOperation* op, StatusCallback done) { EnqueueResponse* response = new EnqueueResponse; const std::shared_ptr<CapturedSharedState>& captured_state = captured_state_; Device* recv_device = recv_device_; - uint64 context_view_id = ctx_->GetContextViewId(); eager_client->StreamingEnqueueAsync( &request, response, [captured_state, response, recv_device, context_view_id, @@ -241,7 +241,7 @@ void RemoteCopyNode::RunRemoteRecv(EagerOperation* op, StatusCallback done) { "Please file an issue with the TensorFlow Team."; } } else { - captured_state->dst()->Poison(s); + captured_state->dst()->PoisonRemote(s, recv_device, context_view_id); } done(s); delete response; @@ -254,8 +254,9 @@ void RemoteCopyNode::StartRecv(StatusCallback done) { EagerOperation op(ctx_); Status status = op.Reset("_Recv", /*raw_device_name=*/nullptr, /*remote=*/false, /*executor=*/nullptr); + Device* recv_device = ctx_->CanonicalDevice(recv_device_); if (!status.ok()) { - captured_state_->dst()->Poison(status); + captured_state_->dst()->Poison(status, recv_device); done(status); return; } @@ -276,12 +277,12 @@ void RemoteCopyNode::StartRecv(StatusCallback done) { std::vector<Tensor> outputs(1); status = RunLocalRecv(&op, &outputs); if (!status.ok()) { - captured_state_->dst()->Poison(status); + captured_state_->dst()->Poison(status, recv_device); done(status); return; } - status = captured_state_->dst()->SetTensor( - std::move(outputs[0]), ctx_->CanonicalDevice(recv_device_)); + status = + captured_state_->dst()->SetTensor(std::move(outputs[0]), recv_device); done(status); } else { // Handles captured_state_->dst_ internally. @@ -297,6 +298,7 @@ void RemoteCopyNode::StartRemoteSendTensor(StatusCallback done) { auto* send_tensor = request.add_queue()->mutable_send_tensor(); send_tensor->set_op_id(recv_op_id_); send_tensor->set_device_name(recv_device_->name()); + uint64 context_view_id = ctx_->GetContextViewId(); // AsProtoTensorContent doesn't work when the tensor is on the GPU, hence // copy it to the CPU before copying it out. @@ -314,7 +316,7 @@ void RemoteCopyNode::StartRemoteSendTensor(StatusCallback done) { core::RefCountPtr<eager::EagerClient> eager_client; s = ctx_->GetClient(recv_device_, &eager_client); if (!s.ok()) { - captured_state_->dst()->Poison(s); + captured_state_->dst()->PoisonRemote(s, recv_device_, context_view_id); done(s); return; } @@ -322,7 +324,6 @@ void RemoteCopyNode::StartRemoteSendTensor(StatusCallback done) { const std::shared_ptr<CapturedSharedState>& captured_state = captured_state_; captured_state->SetSrcShape(tensor.shape()); Device* recv_device = recv_device_; - uint64 context_view_id = ctx_->GetContextViewId(); eager_client->StreamingEnqueueAsync( &request, response, [captured_state, response, recv_device, context_view_id, @@ -336,7 +337,7 @@ void RemoteCopyNode::StartRemoteSendTensor(StatusCallback done) { << status.ToString(); } } else { - captured_state->dst()->Poison(s); + captured_state->dst()->PoisonRemote(s, recv_device, context_view_id); } done(s); delete response; @@ -378,7 +379,8 @@ void RemoteCopyNode::RunAsync(StatusCallback done) { void RemoteCopyNode::Abort(Status status) { if (!started_) { - captured_state_->dst()->Poison(status); + uint64 context_view_id = ctx_->GetContextViewId(); + captured_state_->dst()->PoisonRemote(status, recv_device_, context_view_id); } } diff --git a/tensorflow/core/distributed_runtime/eager/remote_execute_node.cc b/tensorflow/core/distributed_runtime/eager/remote_execute_node.cc index f84e0ebb5ee..81547615706 100644 --- a/tensorflow/core/distributed_runtime/eager/remote_execute_node.cc +++ b/tensorflow/core/distributed_runtime/eager/remote_execute_node.cc @@ -81,7 +81,7 @@ void RemoteExecuteNode::RunAsync(StatusCallback done) { "Please file an issue with the TensorFlow Team."; } } else { - retvals[i]->Poison(status); + retvals[i]->PoisonRemote(status, device, context_view_id); } retvals[i]->Unref(); } diff --git a/tensorflow/core/distributed_runtime/eager/remote_execute_node.h b/tensorflow/core/distributed_runtime/eager/remote_execute_node.h index de5d66a1efb..ed9f9c0ee0f 100644 --- a/tensorflow/core/distributed_runtime/eager/remote_execute_node.h +++ b/tensorflow/core/distributed_runtime/eager/remote_execute_node.h @@ -34,14 +34,16 @@ namespace eager { // RemoteExecuteNode is an implementation of EagerNode which enqueues // an operation via RPC in a remote EagerService. -class RemoteExecuteNode : public AsyncEagerNode { +class RemoteExecuteNode : public AsyncRemoteExecuteNode { public: - RemoteExecuteNode(std::unique_ptr<EnqueueRequest> request, Device* device, + RemoteExecuteNode(EagerContext* eager_context, + std::unique_ptr<EnqueueRequest> request, Device* device, uint64 context_view_id, EagerClient* eager_client, const NodeDef& ndef, FunctionLibraryDefinition* lib_def, const gtl::InlinedVector<TensorHandle*, 4>& inputs, absl::Span<TensorHandle*> retvals) - : AsyncEagerNode(), + : AsyncRemoteExecuteNode(), + eager_context_(eager_context), request_(std::move(request)), device_(device), context_view_id_(context_view_id), @@ -62,6 +64,16 @@ class RemoteExecuteNode : public AsyncEagerNode { handle->Ref(); } eager_client_->Ref(); + + needs_remote_inputs_ = false; + for (const TensorHandle* input : inputs_) { + // TODO(bramandia): Should this be op_device() instead? + if (input->resource_device() != nullptr && + input->resource_device() != device_) { + needs_remote_inputs_ = true; + break; + } + } } ~RemoteExecuteNode() override { @@ -81,12 +93,24 @@ class RemoteExecuteNode : public AsyncEagerNode { void RunAsync(StatusCallback done) override; + Status SyncExecutors() override { return eager_context_->SyncExecutors(); } + void Abort(Status status) override { + int i = 0; for (auto handle : retvals_) { - handle->Poison(status); + handle->PoisonRemote(status, device_, context_view_id_); + ++i; } } + const EagerClient* eager_client() const override { return eager_client_; } + + bool needs_remote_inputs() const override { return needs_remote_inputs_; } + + bool allow_multiple_pending_requests() const override { + return eager_client_->allow_multiple_pending_requests(); + } + string DebugString() const override { string out = "[RemoteExecuteNode]"; strings::StrAppend(&out, " request: ", request_->DebugString()); @@ -95,9 +119,11 @@ class RemoteExecuteNode : public AsyncEagerNode { } private: + EagerContext* eager_context_; // Not owned, and must outlive this node. std::unique_ptr<EnqueueRequest> request_; Device* device_; // Not owned uint64 context_view_id_; + bool needs_remote_inputs_; EagerClient* eager_client_; // Not owned, and must outlive this node. const NodeDef ndef_; const FunctionLibraryDefinition* lib_def_; diff --git a/tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.h b/tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.h index 9f7db52b447..34e3ec5f83d 100644 --- a/tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.h +++ b/tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.h @@ -75,6 +75,9 @@ class UnshapedRemoteTensorHandleData : public TensorHandleData { Status NumElements(int64* num_elements) const override; Status Unprotect() override; + void Poison(Status status) { is_poisoned_ = status; } + Status IsPoisoned() const { return is_poisoned_; } + string DebugString() const override; int64 op_id() const { return op_id_; } @@ -94,6 +97,7 @@ class UnshapedRemoteTensorHandleData : public TensorHandleData { void ReleaseRemoteTensorHandle() { delete_remote_tensor_ = false; } private: + Status is_poisoned_; // IDs required when this class is representing a remote tensor handle. const int64 op_id_; const int32 output_num_; diff --git a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.cc b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.cc index 5ad48118ae9..c1811303bc9 100644 --- a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.cc +++ b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.cc @@ -115,6 +115,10 @@ class GrpcEagerClient : public EagerClient { } ~GrpcEagerClient() override { thread_->Unref(); } + bool allow_multiple_pending_requests() const override { + return EnableStreaming(); + } + #define CLIENT_METHOD(method) \ void method##Async(const method##Request* request, \ method##Response* response, StatusCallback done) \ diff --git a/tensorflow/core/distributed_runtime/worker_session.cc b/tensorflow/core/distributed_runtime/worker_session.cc index 69559f27c29..fac28514115 100644 --- a/tensorflow/core/distributed_runtime/worker_session.cc +++ b/tensorflow/core/distributed_runtime/worker_session.cc @@ -14,8 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/distributed_runtime/worker_session.h" +#include "tensorflow/core/lib/monitoring/collection_registry.h" #include "tensorflow/core/lib/monitoring/gauge.h" -#include "tensorflow/core/platform/monitoring.h" namespace tensorflow { @@ -123,7 +123,6 @@ WorkerSession::WorkerSession( // provided). For builds using "tensorflow/core/platform/default", this is // currently a no-op. worker_session_created->GetCell()->Set(true); - monitoring::StartExporter(); } Status WorkerSession::UpdateWorkerCacheAndDevices( @@ -170,7 +169,6 @@ WorkerSession::WorkerSession( // provided). For builds using "tensorflow/core/platform/default", this is // currently a no-op. worker_session_created->GetCell()->Set(true); - monitoring::StartExporter(); } WorkerSession::~WorkerSession() { diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h index 9e22321b42c..876c7551263 100644 --- a/tensorflow/core/framework/op_kernel.h +++ b/tensorflow/core/framework/op_kernel.h @@ -156,6 +156,18 @@ class OpKernel { // Returns a pointer to the tensor stored inside constant ops. virtual const Tensor* const_tensor() const { return nullptr; } + // Returns true if this kernel must produce its ith output. + // REQUIRES: 0 <= i < num_inputs(). + bool output_required(int i) const { return outputs_required_[i]; } + + // Hints whether or not the ith output must be produced when running the + // kernel. By default, all outputs are required. The kernel implementation + // may ignore the hint. + // REQUIRES: 0 <= i < num_inputs(). + void set_output_required(int i, bool is_required) { + outputs_required_[i] = is_required; + } + // Updates the dynamic cost estimate, which is used to determine whether this // op is expensive. The new cost estimate is a weighted average of the old // cost estimate and the latest cost. @@ -223,6 +235,7 @@ class OpKernel { const bool is_deferred_; bool expensive_; std::atomic_uint_fast64_t cost_estimate_; + std::vector<bool> outputs_required_; TF_DISALLOW_COPY_AND_ASSIGN(OpKernel); }; @@ -732,6 +745,10 @@ class OpKernelContext { // For tracking actively running deferred ops. std::function<void()> inc_num_deferred_ops_function; std::function<void()> dec_num_deferred_ops_function; + + // For implementing `OpKernelContext::output_required()`. If null, all + // outputs are required. + bool* outputs_required_array = nullptr; }; // params must outlive the OpKernelContext. @@ -941,10 +958,9 @@ class OpKernelContext { // should call allocate_output(index, ...), set_output(index, ...), // set_output_ref(index, ...), or set the status to a non-ok value. // If it returns false, it may output, but is not required to do so. - // TODO(mrry): Convert this to return Status, and implement a string - // name version. bool output_required(int index) const { - return true; // TODO(josh11b): implement + return !params_->outputs_required_array || + params_->outputs_required_array[index]; } // Allocation of tensors during kernel execution inside the Compute diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc index cdb841e06d7..79e0cd2bcf1 100644 --- a/tensorflow/core/graph/mkl_layout_pass.cc +++ b/tensorflow/core/graph/mkl_layout_pass.cc @@ -17,6 +17,8 @@ limitations under the License. // all over the place, we should log an error and execute the original graph. #ifdef INTEL_MKL +#include "tensorflow/core/graph/mkl_layout_pass.h" + #include <algorithm> #include <functional> #include <memory> @@ -34,6 +36,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/mkl_graph_util.h" #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -43,9 +46,6 @@ limitations under the License. #include "tensorflow/core/util/tensor_format.h" #include "tensorflow/core/util/util.h" -#include "tensorflow/core/graph/mkl_graph_util.h" -#include "tensorflow/core/graph/mkl_layout_pass.h" - namespace tensorflow { // This pass implements rewriting of graph to support following scenarios: diff --git a/tensorflow/core/kernels/concat_op.cc b/tensorflow/core/kernels/concat_op.cc index 9d7f37d2be6..19ed267b441 100644 --- a/tensorflow/core/kernels/concat_op.cc +++ b/tensorflow/core/kernels/concat_op.cc @@ -48,53 +48,60 @@ class ConcatBaseOp : public OpKernel { typedef std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>> ConstMatrixVector; - explicit ConcatBaseOp(OpKernelConstruction* c) : OpKernel(c) {} + explicit ConcatBaseOp(OpKernelConstruction* c) + : OpKernel(c), + axis_attribute_name_(AxisArgName == NAME_IS_AXIS + ? "axis" + : AxisArgName == NAME_IS_CONCAT_DIM + ? "concat_dim" + : "<invalid>") { + int unused; + OP_REQUIRES_OK( + c, InputRange(axis_attribute_name_, &axis_input_index_, &unused)); + OP_REQUIRES_OK(c, InputRange("values", &values_input_start_index_, + &values_input_end_index_)); + } void Compute(OpKernelContext* c) override { - const Tensor* concat_dim_tensor; - const char* axis_attribute_name = - AxisArgName == NAME_IS_AXIS - ? "axis" - : AxisArgName == NAME_IS_CONCAT_DIM ? "concat_dim" : "<invalid>"; - OP_REQUIRES_OK(c, c->input(axis_attribute_name, &concat_dim_tensor)); + const Tensor& concat_dim_tensor = c->input(axis_input_index_); + // TODO(rmlarsen): Disallow legacy use of length-1 vectors as scalars. OP_REQUIRES(c, - (TensorShapeUtils::IsScalar(concat_dim_tensor->shape()) || - (TensorShapeUtils::IsVector(concat_dim_tensor->shape()) && - concat_dim_tensor->shape().dim_size(0) == 1)), + (TensorShapeUtils::IsScalar(concat_dim_tensor.shape()) || + (TensorShapeUtils::IsVector(concat_dim_tensor.shape()) && + concat_dim_tensor.shape().dim_size(0) == 1)), errors::InvalidArgument( - axis_attribute_name, + axis_attribute_name_, " tensor should be a scalar integer, but got shape ", - concat_dim_tensor->shape().DebugString())); + concat_dim_tensor.shape().DebugString())); int64 concat_dim; // In case of ConcatV2, "axis" could be int32 or int64 if (AxisArgName == NAME_IS_AXIS) { OP_REQUIRES( c, - (concat_dim_tensor->dtype() == DT_INT32 || - concat_dim_tensor->dtype() == DT_INT64), - errors::InvalidArgument(axis_attribute_name, + (concat_dim_tensor.dtype() == DT_INT32 || + concat_dim_tensor.dtype() == DT_INT64), + errors::InvalidArgument(axis_attribute_name_, " tensor should be int32 or int64, but got ", - DataTypeString(concat_dim_tensor->dtype()))); + DataTypeString(concat_dim_tensor.dtype()))); } else { - OP_REQUIRES(c, (concat_dim_tensor->dtype() == DT_INT32), + OP_REQUIRES(c, (concat_dim_tensor.dtype() == DT_INT32), errors::InvalidArgument( - axis_attribute_name, " tensor should be int32, but got ", - DataTypeString(concat_dim_tensor->dtype()))); + axis_attribute_name_, " tensor should be int32, but got ", + DataTypeString(concat_dim_tensor.dtype()))); } - if (concat_dim_tensor->dtype() == DT_INT32) { + if (concat_dim_tensor.dtype() == DT_INT32) { concat_dim = - internal::SubtleMustCopy(concat_dim_tensor->scalar<int32>()()); + internal::SubtleMustCopy(concat_dim_tensor.scalar<int32>()()); } else { concat_dim = - internal::SubtleMustCopy(concat_dim_tensor->scalar<int64>()()); + internal::SubtleMustCopy(concat_dim_tensor.scalar<int64>()()); } - OpInputList values; - OP_REQUIRES_OK(c, c->input_list("values", &values)); - const int N = values.size(); - const int input_dims = values[0].dims(); - const TensorShape& input_shape = values[0].shape(); + const int N = values_input_end_index_ - values_input_start_index_; + const Tensor& first_input = c->input(values_input_start_index_); + const int input_dims = first_input.dims(); + const TensorShape& input_shape = first_input.shape(); int32 axis = concat_dim < 0 ? concat_dim + input_dims : concat_dim; // concat_dim==0 allows concatenating a list of scalars into a vector. @@ -116,7 +123,7 @@ class ConcatBaseOp : public OpKernel { } int64 output_concat_dim = 0; for (int i = 0; i < N; ++i) { - const auto& in = values[i]; + const auto& in = c->input(values_input_start_index_ + i); OP_REQUIRES( c, in.dims() == input_dims, errors::InvalidArgument( @@ -137,7 +144,7 @@ class ConcatBaseOp : public OpKernel { if (in.NumElements() > 0) { int64 inputs_flat_dim1 = in.NumElements() / inputs_flat_dim0; inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix( - in.shaped<T, 2>({inputs_flat_dim0, inputs_flat_dim1}))); + in.template shaped<T, 2>({inputs_flat_dim0, inputs_flat_dim1}))); } // TODO(rmlarsen): Remove check once !allow_legacy_scalars()? output_concat_dim += in.dims() > 0 ? in.dim_size(axis) : 1; @@ -170,6 +177,12 @@ class ConcatBaseOp : public OpKernel { ConcatCPU<T>(c->device(), inputs_flat, &output_flat); } } + + private: + const char* const axis_attribute_name_; + int axis_input_index_; + int values_input_start_index_; + int values_input_end_index_; }; template <typename Device, typename T> diff --git a/tensorflow/core/kernels/constant_op.cc b/tensorflow/core/kernels/constant_op.cc index ccdafdf91c9..4f26aed641e 100644 --- a/tensorflow/core/kernels/constant_op.cc +++ b/tensorflow/core/kernels/constant_op.cc @@ -48,7 +48,7 @@ namespace tensorflow { namespace { NodeDef StripTensorDataFromNodeDef(OpKernelConstruction* ctx) { -#ifndef __ANDROID__ +#ifndef TENSORFLOW_LITE_PROTOS DCHECK_EQ(NodeDef::descriptor()->field_count(), 6) << "The NodeDef format has changed, and the attr-stripping code may need " << "to be updated."; diff --git a/tensorflow/core/kernels/cwise_op_select.cc b/tensorflow/core/kernels/cwise_op_select.cc index 911462c8eff..54bde28ad62 100644 --- a/tensorflow/core/kernels/cwise_op_select.cc +++ b/tensorflow/core/kernels/cwise_op_select.cc @@ -44,12 +44,9 @@ class SelectOp : public OpKernel { explicit SelectOp(OpKernelConstruction* context) : OpKernel(context) {} void Compute(OpKernelContext* ctx) override { - const Tensor* cond; - const Tensor* then; - const Tensor* else_; - OP_REQUIRES_OK(ctx, ctx->input("condition", &cond)); - OP_REQUIRES_OK(ctx, ctx->input("t", &then)); - OP_REQUIRES_OK(ctx, ctx->input("e", &else_)); + const Tensor* cond = &ctx->input(0); + const Tensor* then = &ctx->input(1); + const Tensor* else_ = &ctx->input(2); if (TensorShapeUtils::IsScalar(cond->shape())) { ComputeScalar(ctx, cond, then, else_); @@ -149,12 +146,9 @@ class SelectV2Op : public OpKernel { explicit SelectV2Op(OpKernelConstruction* context) : OpKernel(context) {} void Compute(OpKernelContext* ctx) override { - const Tensor* cond; - const Tensor* then; - const Tensor* else_; - OP_REQUIRES_OK(ctx, ctx->input("condition", &cond)); - OP_REQUIRES_OK(ctx, ctx->input("t", &then)); - OP_REQUIRES_OK(ctx, ctx->input("e", &else_)); + const Tensor* cond = &ctx->input(0); + const Tensor* then = &ctx->input(1); + const Tensor* else_ = &ctx->input(2); // The `cond`, `then`, and `else` are broadcastable (bcast.IsValid()), // This matches the behavior of numpy. @@ -260,7 +254,6 @@ class SelectV2Op : public OpKernel { ctx->input(1).shape().DebugString(), " is not supported yet.")); break; } - return; } private: diff --git a/tensorflow/core/kernels/data/experimental/BUILD b/tensorflow/core/kernels/data/experimental/BUILD index d041ab5ac6a..a68b3faeb37 100644 --- a/tensorflow/core/kernels/data/experimental/BUILD +++ b/tensorflow/core/kernels/data/experimental/BUILD @@ -434,6 +434,7 @@ tf_kernel_library( name = "snapshot_dataset_op", srcs = ["snapshot_dataset_op.cc"], deps = [ + "//tensorflow/core:core_cpu_internal", "//tensorflow/core:dataset_ops_op_lib", "//tensorflow/core:framework", "//tensorflow/core:lib", @@ -441,6 +442,7 @@ tf_kernel_library( "//tensorflow/core:protos_all_cc", "//tensorflow/core/grappler:graph_view", "//tensorflow/core/kernels/data:dataset_utils", + "//tensorflow/core/platform:platform_port", "//tensorflow/core/profiler/lib:traceme", "@com_google_absl//absl/time", ], diff --git a/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc b/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc index ae3015bc833..68ee3c4c134 100644 --- a/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc @@ -15,6 +15,7 @@ limitations under the License. #include <random> #include "absl/time/clock.h" +#include "tensorflow/core/common_runtime/dma_helper.h" #include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/partial_tensor_shape.h" @@ -32,7 +33,9 @@ limitations under the License. #include "tensorflow/core/lib/io/compression.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/io/random_inputstream.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/file_system.h" +#include "tensorflow/core/platform/snappy.h" #if !defined(IS_SLIM_BUILD) #include "tensorflow/core/lib/io/snappy/snappy_inputbuffer.h" #include "tensorflow/core/lib/io/snappy/snappy_outputbuffer.h" @@ -63,9 +66,6 @@ enum SnapshotMode { READER = 0, WRITER = 1, PASSTHROUGH = 2 }; // Defaults to 10 GiB per shard. const int64 kDefaultShardSizeBytes = 10LL * 1024 * 1024 * 1024; -const int64 kSnappyWriterInputBufferSizeBytes = 16 << 20; // 16 MiB -const int64 kSnappyWriterOutputBufferSizeBytes = 16 << 20; // 16 MiB - // The reader input buffer size is deliberately large because the input reader // will throw an error if the compressed block length cannot fit in the input // buffer. @@ -75,6 +75,8 @@ const int64 kSnappyReaderOutputBufferSizeBytes = 32 << 20; // 32 MiB const size_t kHeaderSize = sizeof(uint64); +const int64 kCurrentVersion = 1; + constexpr char kModeAuto[] = "auto"; constexpr char kModeWrite[] = "write"; constexpr char kModeRead[] = "read"; @@ -95,6 +97,7 @@ constexpr char kState[] = "state"; constexpr char kHashDir[] = "hash_dir"; constexpr char kRunId[] = "run_id"; constexpr char kRunDir[] = "run_dir"; +constexpr char kVersionStr[] = "version"; constexpr char kFilenames[] = "filenames"; constexpr char kCurrentFilenames[] = "current_filenames"; constexpr char kElementsProduced[] = "elements_produced"; @@ -115,9 +118,9 @@ class SnapshotWriter { static constexpr const char* const kWriteStringPiece = "WriteStringPiece"; static constexpr const char* const kWriteCord = "WriteCord"; - explicit SnapshotWriter(WritableFile* dest, const string& compression_type = - io::compression::kNone) - : dest_(dest), compression_type_(compression_type) { + explicit SnapshotWriter(WritableFile* dest, const string& compression_type, + int version, const DataTypeVector& dtypes) + : dest_(dest), compression_type_(compression_type), version_(version) { #if defined(IS_SLIM_BUILD) if (compression_type != io::compression::kNone) { LOG(ERROR) << "Compression is unsupported on mobile platforms. Turning " @@ -134,41 +137,100 @@ class SnapshotWriter { TF_CHECK_OK(zlib_output_buffer->Init()); dest_ = zlib_output_buffer; dest_is_owned_ = true; - } else if (compression_type == io::compression::kSnappy) { - io::SnappyOutputBuffer* snappy_output_buffer = new io::SnappyOutputBuffer( - dest, /*input_buffer_bytes=*/kSnappyWriterInputBufferSizeBytes, - /*output_buffer_bytes=*/kSnappyWriterOutputBufferSizeBytes); - dest_ = snappy_output_buffer; - dest_is_owned_ = true; } #endif // IS_SLIM_BUILD + simple_tensor_mask_.reserve(dtypes.size()); + for (const auto& dtype : dtypes) { + if (DataTypeCanUseMemcpy(dtype)) { + simple_tensor_mask_.push_back(true); + num_simple_++; + } else { + simple_tensor_mask_.push_back(false); + num_complex_++; + } + } } - Status WriteRecord(const StringPiece& data) { - profiler::TraceMe activity( - [&]() { - return absl::StrCat(kClassName, kSeparator, kWriteStringPiece); - }, - profiler::TraceMeLevel::kInfo); - char header[kHeaderSize]; - core::EncodeFixed64(header, data.size()); - TF_RETURN_IF_ERROR(dest_->Append(StringPiece(header, sizeof(header)))); - return dest_->Append(data); - } - + Status WriteTensors(const std::vector<Tensor>& tensors) { + if (compression_type_ != io::compression::kSnappy) { + experimental::SnapshotRecord record; + for (const auto& tensor : tensors) { + TensorProto* t = record.add_tensor(); + tensor.AsProtoTensorContent(t); + } #if defined(PLATFORM_GOOGLE) - Status WriteRecord(const absl::Cord& data) { - profiler::TraceMe activity( - [&]() { return absl::StrCat(kClassName, kSeparator, kWriteCord); }, - profiler::TraceMeLevel::kInfo); - char header[kHeaderSize]; - core::EncodeFixed64(header, data.size()); - - TF_RETURN_IF_ERROR(dest_->Append(StringPiece(header, sizeof(header)))); - - return dest_->Append(data); - } + return WriteRecord(record.SerializeAsCord()); +#else // PLATFORM_GOOGLE + return WriteRecord(record.SerializeAsString()); #endif // PLATFORM_GOOGLE + } + + if (version_ != 1) { + return errors::InvalidArgument("Version: ", version_, + " is not supported."); + } + if (compression_type_ != io::compression::kSnappy) { + return errors::InvalidArgument( + "Version 1 is only compatible with snappy compression"); + } + + std::vector<const TensorBuffer*> tensor_buffers; + tensor_buffers.reserve(num_simple_); + std::vector<TensorProto> tensor_protos; + tensor_protos.reserve(num_complex_); + SnapshotTensorMetadata metadata; + int64 total_size = 0; + for (int i = 0; i < tensors.size(); ++i) { + const Tensor& tensor = tensors[i]; + TensorMetadata* tensor_metadata = metadata.add_tensor_metadata(); + tensor.shape().AsProto(tensor_metadata->mutable_tensor_shape()); + int64 size = 0; + if (simple_tensor_mask_[i]) { + auto tensor_buffer = DMAHelper::buffer(&tensor); + tensor_buffers.push_back(tensor_buffer); + size = tensor_buffer->size(); + } else { + TensorProto proto; + tensor.AsProtoTensorContent(&proto); + size = proto.ByteSizeLong(); + tensor_protos.push_back(std::move(proto)); + } + tensor_metadata->set_tensor_size_bytes(size); + total_size += size; + } + + std::vector<char> uncompressed(total_size); + char* position = uncompressed.data(); + int buffer_index = 0; + int proto_index = 0; + for (int i = 0; i < tensors.size(); ++i) { + const auto& tensor_metadata = metadata.tensor_metadata(i); + if (simple_tensor_mask_[i]) { + memcpy(position, tensor_buffers[buffer_index]->data(), + tensor_metadata.tensor_size_bytes()); + buffer_index++; + } else { + tensor_protos[proto_index].SerializeToArray( + position, tensor_metadata.tensor_size_bytes()); + proto_index++; + } + position += tensor_metadata.tensor_size_bytes(); + } + DCHECK_EQ(position, uncompressed.data() + total_size); + + string output; + if (!port::Snappy_Compress(uncompressed.data(), total_size, &output)) { + return errors::Internal("Failed to compress using snappy."); + } +#if defined(PLATFORM_GOOGLE) + absl::Cord metadata_serialized = metadata.SerializeAsCord(); +#else // PLATFORM_GOOGLE + std::string metadata_serialized = metadata.SerializeAsString(); +#endif // PLATFORM_GOOGLE + TF_RETURN_IF_ERROR(WriteRecord(metadata_serialized)); + TF_RETURN_IF_ERROR(WriteRecord(output)); + return Status::OK(); + } Status Sync() { return dest_->Sync(); } @@ -192,9 +254,29 @@ class SnapshotWriter { } private: + Status WriteRecord(const StringPiece& data) { + char header[kHeaderSize]; + core::EncodeFixed64(header, data.size()); + TF_RETURN_IF_ERROR(dest_->Append(StringPiece(header, sizeof(header)))); + return dest_->Append(data); + } + +#if defined(PLATFORM_GOOGLE) + Status WriteRecord(const absl::Cord& data) { + char header[kHeaderSize]; + core::EncodeFixed64(header, data.size()); + TF_RETURN_IF_ERROR(dest_->Append(StringPiece(header, sizeof(header)))); + return dest_->Append(data); + } +#endif // PLATFORM_GOOGLE + WritableFile* dest_; bool dest_is_owned_ = false; const string compression_type_; + const int version_; + std::vector<bool> simple_tensor_mask_; // true for simple, false for complex. + int num_simple_ = 0; + int num_complex_ = 0; }; class SnapshotReader { @@ -203,12 +285,14 @@ class SnapshotReader { static constexpr const char* const kReadString = "ReadString"; static constexpr const char* const kReadCord = "ReadCord"; - explicit SnapshotReader( - RandomAccessFile* file, - const string& compression_type = io::compression::kNone) + explicit SnapshotReader(RandomAccessFile* file, + const string& compression_type, int version, + const DataTypeVector& dtypes) : file_(file), input_stream_(new io::RandomAccessInputStream(file)), - compression_type_(compression_type) { + compression_type_(compression_type), + version_(version), + dtypes_(dtypes) { #if defined(IS_SLIM_BUILD) if (compression_type_ != io::compression::kNone) { LOG(ERROR) << "Compression is unsupported on mobile platforms. Turning " @@ -223,17 +307,167 @@ class SnapshotReader { input_stream_.release(), zlib_options.input_buffer_size, zlib_options.output_buffer_size, zlib_options, true); } else if (compression_type_ == io::compression::kSnappy) { - input_stream_ = absl::make_unique<io::SnappyInputBuffer>( - file_, /*input_buffer_bytes=*/kSnappyReaderInputBufferSizeBytes, - /*output_buffer_bytes=*/kSnappyReaderOutputBufferSizeBytes); + if (version_ == 0) { + input_stream_ = absl::make_unique<io::SnappyInputBuffer>( + file_, /*input_buffer_bytes=*/kSnappyReaderInputBufferSizeBytes, + /*output_buffer_bytes=*/kSnappyReaderOutputBufferSizeBytes); + } else { + input_stream_ = + absl::make_unique<io::BufferedInputStream>(file_, 64 << 20); + } } #endif // IS_SLIM_BUILD + simple_tensor_mask_.reserve(dtypes.size()); + for (const auto& dtype : dtypes) { + if (DataTypeCanUseMemcpy(dtype)) { + simple_tensor_mask_.push_back(true); + num_simple_++; + } else { + simple_tensor_mask_.push_back(false); + num_complex_++; + } + } + } + + Status ReadTensors(std::vector<Tensor>* read_tensors) { + profiler::TraceMe activity( + [&]() { return absl::StrCat(kClassName, kSeparator, "ReadTensors"); }, + profiler::TraceMeLevel::kInfo); + if (version_ == 0 || compression_type_ != io::compression::kSnappy) { + return ReadTensorsV0(read_tensors); + } + if (version_ != 1) { + return errors::InvalidArgument("Version: ", version_, + " is not supported."); + } + if (compression_type_ != io::compression::kSnappy) { + return errors::InvalidArgument("Version 1 only supports snappy."); + } + + SnapshotTensorMetadata metadata; + tstring metadata_str; + TF_RETURN_IF_ERROR(ReadRecord(&metadata_str)); + if (!metadata.ParseFromArray(metadata_str.data(), metadata_str.size())) { + return errors::DataLoss("Could not parse SnapshotTensorMetadata"); + } + read_tensors->reserve(metadata.tensor_metadata_size()); + + std::vector<Tensor> simple_tensors; + simple_tensors.reserve(num_simple_); + std::vector<std::pair<std::unique_ptr<char[]>, size_t>> tensor_proto_strs; + tensor_proto_strs.reserve(num_complex_); + TF_RETURN_IF_ERROR( + SnappyUncompress(metadata, &simple_tensors, &tensor_proto_strs)); + + int simple_index = 0; + int complex_index = 0; + for (int i = 0; i < simple_tensor_mask_.size(); ++i) { + if (simple_tensor_mask_[i]) { + read_tensors->push_back(std::move(simple_tensors[simple_index])); + simple_index++; + } else { + auto tensor_proto_str = + std::move(tensor_proto_strs[complex_index].first); + size_t tensor_proto_size = tensor_proto_strs[complex_index].second; + TensorProto tp; +#if defined(PLATFORM_GOOGLE) + auto tensor_proto_ptr = tensor_proto_str.release(); + absl::Cord c; + c.AppendExternalMemory( + absl::string_view(tensor_proto_ptr, tensor_proto_size), + tensor_proto_ptr, + [](void* arg) { delete[] static_cast<char*>(arg); }); + if (!tp.ParseFromCord(c)) { + return errors::Internal("Could not parse TensorProto"); + } +#else // PLATFORM_GOOGLE + if (!tp.ParseFromArray(tensor_proto_str.get(), tensor_proto_size)) { + return errors::Internal("Could not parse TensorProto"); + } +#endif // PLATFORM_GOOGLE + Tensor t; + if (!t.FromProto(tp)) { + return errors::Internal("Could not parse Tensor"); + } + read_tensors->push_back(std::move(t)); + complex_index++; + } + } + return Status::OK(); + } + + private: + Status ReadTensorsV0(std::vector<Tensor>* read_tensors) { + experimental::SnapshotRecord record; +#if defined(PLATFORM_GOOGLE) + absl::Cord c; + TF_RETURN_IF_ERROR(ReadRecord(&c)); + record.ParseFromCord(c); +#else // PLATFORM_GOOGLE + tstring record_bytes; + TF_RETURN_IF_ERROR(ReadRecord(&record_bytes)); + record.ParseFromArray(record_bytes.data(), record_bytes.size()); +#endif // PLATFORM_GOOGLE + read_tensors->reserve(record.tensor_size()); + for (int i = 0; i < record.tensor_size(); ++i) { + read_tensors->emplace_back(); + if (!read_tensors->back().FromProto(record.tensor(i))) { + return errors::DataLoss("Unable to parse tensor from proto."); + } + } + return Status::OK(); + } + + Status SnappyUncompress( + const SnapshotTensorMetadata& metadata, + std::vector<Tensor>* simple_tensors, + std::vector<std::pair<std::unique_ptr<char[]>, size_t>>* + tensor_proto_strs) { + tstring compressed; + TF_RETURN_IF_ERROR(ReadRecord(&compressed)); + size_t size; + if (!port::Snappy_GetUncompressedLength(compressed.data(), + compressed.size(), &size)) { + return errors::Internal("Could not get snappy uncompressed length"); + } + + int num_tensors = metadata.tensor_metadata_size(); + std::vector<struct iovec> iov(num_tensors); + int index = 0; + int64 total_size = 0; + for (int i = 0; i < simple_tensor_mask_.size(); ++i) { + const auto& tensor_metadata = metadata.tensor_metadata(i); + if (simple_tensor_mask_[i]) { + TensorShape shape(tensor_metadata.tensor_shape()); + Tensor simple_tensor(dtypes_[i], shape); + TensorBuffer* buffer = DMAHelper::buffer(&simple_tensor); + iov[index].iov_base = buffer->data(); + iov[index].iov_len = buffer->size(); + simple_tensors->push_back(std::move(simple_tensor)); + } else { + auto tensor_proto_str = + absl::make_unique<char[]>(tensor_metadata.tensor_size_bytes()); + iov[index].iov_base = tensor_proto_str.get(); + iov[index].iov_len = tensor_metadata.tensor_size_bytes(); + tensor_proto_strs->push_back(std::make_pair( + std::move(tensor_proto_str), tensor_metadata.tensor_size_bytes())); + } + total_size += iov[index].iov_len; + index++; + } + if (size != total_size) { + return errors::Internal("Uncompressed size mismatch. Snappy expects ", + size, " whereas the tensor metadata suggests ", + total_size); + } + if (!port::Snappy_UncompressToIOVec(compressed.data(), compressed.size(), + iov.data(), num_tensors)) { + return errors::Internal("Failed to perform snappy decompression."); + } + return Status::OK(); } Status ReadRecord(tstring* record) { - profiler::TraceMe activity( - [&]() { return absl::StrCat(kClassName, kSeparator, kReadString); }, - profiler::TraceMeLevel::kInfo); tstring header; TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(kHeaderSize, &header)); uint64 length = core::DecodeFixed64(header.data()); @@ -245,13 +479,6 @@ class SnapshotReader { tstring header; TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(kHeaderSize, &header)); uint64 length = core::DecodeFixed64(header.data()); - profiler::TraceMe activity( - [&]() { - return absl::StrCat(kClassName, kSeparator, kReadCord, - "#length=", length, "#"); - }, - profiler::TraceMeLevel::kInfo); - if (compression_type_ == io::compression::kNone) { return input_stream_->ReadNBytes(length, record); } else { @@ -268,50 +495,31 @@ class SnapshotReader { } #endif - private: RandomAccessFile* file_; std::unique_ptr<io::InputStreamInterface> input_stream_; const string compression_type_; + const int version_; + const DataTypeVector dtypes_; + int num_simple_ = 0; + int num_complex_ = 0; + std::vector<bool> simple_tensor_mask_; // true for simple, false for complex. }; Status WriteMetadataFile(const string& hash_dir, const experimental::SnapshotMetadataRecord& metadata) { string metadata_filename = io::JoinPath(hash_dir, kSnapshotFilename); TF_RETURN_IF_ERROR(Env::Default()->RecursivelyCreateDir(hash_dir)); - std::string tmp_filename = absl::StrCat(metadata_filename, "-tmp-", random::New64()); - - std::unique_ptr<WritableFile> file; - TF_RETURN_IF_ERROR(Env::Default()->NewWritableFile(tmp_filename, &file)); - - auto writer = absl::make_unique<SnapshotWriter>(file.get()); - TF_RETURN_IF_ERROR(writer->WriteRecord(metadata.SerializeAsString())); - TF_RETURN_IF_ERROR(writer->Close()); - TF_RETURN_IF_ERROR(file->Sync()); - TF_RETURN_IF_ERROR(file->Close()); - - TF_RETURN_IF_ERROR( - Env::Default()->RenameFile(tmp_filename, metadata_filename)); - - return Status::OK(); + TF_RETURN_IF_ERROR(WriteBinaryProto(Env::Default(), tmp_filename, metadata)); + return Env::Default()->RenameFile(tmp_filename, metadata_filename); } Status ReadMetadataFile(const string& hash_dir, experimental::SnapshotMetadataRecord* metadata) { string metadata_filename = io::JoinPath(hash_dir, kSnapshotFilename); TF_RETURN_IF_ERROR(Env::Default()->FileExists(metadata_filename)); - - std::unique_ptr<RandomAccessFile> file; - TF_RETURN_IF_ERROR( - Env::Default()->NewRandomAccessFile(metadata_filename, &file)); - - tstring record_bytes; - SnapshotReader reader(file.get()); - TF_RETURN_IF_ERROR(reader.ReadRecord(&record_bytes)); - - metadata->ParseFromArray(record_bytes.data(), record_bytes.size()); - return Status::OK(); + return ReadBinaryProto(Env::Default(), metadata_filename, metadata); } Status DumpDatasetGraph(const std::string& path, uint64 hash, @@ -332,6 +540,10 @@ Status DetermineOpState(const std::string& mode_string, const uint64 pending_snapshot_expiry_seconds, SnapshotMode* mode) { if (mode_string == kModeRead) { + // In read mode, we should expect a metadata file is written. + if (errors::IsNotFound(file_status)) { + return file_status; + } LOG(INFO) << "Overriding mode to reader."; *mode = READER; return Status::OK(); @@ -727,10 +939,25 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { if (run_id.empty()) { run_id = metadata.run_id(); } + // dtypes in metadata should be the same as dataset()->output_dtypes + if (metadata.dtype_size() != dataset()->output_dtypes().size()) { + return errors::Internal( + "Expected number of dtypes: ", + dataset()->output_dtypes().size(), + " but number in snapshot: ", metadata.dtype_size()); + } + for (int i = 0; i < metadata.dtype_size(); ++i) { + if (metadata.dtype(i) != dataset()->output_dtypes()[i]) { + return errors::Internal( + "Type: ", i, + " doesn't match. Snapshot: ", metadata.dtype(i), + "; dataset: ", dataset()->output_dtypes()[i]); + } + } iterator_ = absl::make_unique<SnapshotReaderIterator>( SnapshotReaderIterator::Params{ dataset(), absl::StrCat(prefix(), "ReaderImpl")}, - hash_dir_, run_id); + hash_dir_, run_id, metadata.version()); break; case PASSTHROUGH: iterator_ = absl::make_unique<SnapshotPassthroughIterator>( @@ -748,10 +975,11 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { explicit SnapshotReaderIterator(const Params& params, const string& hash_dir, - const string& run_id) + const string& run_id, int64 version) : DatasetIterator<Dataset>(params), hash_dir_(hash_dir), - run_id_(run_id) {} + run_id_(run_id), + version_(version) {} ~SnapshotReaderIterator() override { mutex_lock l(mu_); @@ -889,6 +1117,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { writer->WriteScalar(full_name(kHashDir), hash_dir_)); TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kRunId), run_id_)); TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kRunDir), run_dir_)); + TF_RETURN_IF_ERROR( + writer->WriteScalar(full_name(kVersionStr), version_)); TF_RETURN_IF_ERROR(writer->WriteScalar( full_name(strings::StrCat(kFilenames, kSizeSuffix)), filenames_.size())); @@ -932,6 +1162,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { } TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kRunId), &run_id_)); TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kRunDir), &run_dir_)); + TF_RETURN_IF_ERROR( + reader->ReadScalar(full_name(kVersionStr), &version_)); curr_filenames_.clear(); curr_filenames_.reserve(dataset()->num_reader_threads_); for (auto i = 0; i < dataset()->num_reader_threads_; ++i) { @@ -986,7 +1218,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { std::unique_ptr<RandomAccessFile> file; TF_RETURN_IF_ERROR( Env::Default()->NewRandomAccessFile(filename, &file)); - SnapshotReader reader(file.get(), dataset()->compression_); + SnapshotReader reader(file.get(), dataset()->compression_, version_, + dataset()->output_dtypes()); while (true) { // Wait for a slot in the buffer. @@ -1003,30 +1236,14 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { "ReadFile"); } } -#if !defined(PLATFORM_GOOGLE) - tstring record_bytes; - Status s = reader.ReadRecord(&record_bytes); -#else - absl::Cord record_cord; - Status s = reader.ReadRecord(&record_cord); -#endif + std::vector<Tensor> read_tensors; + Status s = reader.ReadTensors(&read_tensors); if (s.ok()) { profiler::TraceMe activity( [&]() { return absl::StrCat(prefix(), kSeparator, kParse); }, profiler::TraceMeLevel::kInfo); - experimental::SnapshotRecord record; -#if !defined(PLATFORM_GOOGLE) - record.ParseFromArray(record_bytes.data(), record_bytes.size()); -#else - record.ParseFromCord(record_cord); -#endif BufferElement elem; - for (int i = 0; i < record.tensor_size(); ++i) { - elem.value.emplace_back(); - if (!elem.value.back().FromProto(record.tensor(i))) { - return errors::DataLoss("Unable to parse tensor from proto."); - } - } + elem.value = std::move(read_tensors); elem.status = Status::OK(); mutex_lock l(mu_); buffer_.push_back(std::move(elem)); @@ -1142,9 +1359,9 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { condition_variable cond_var_; const string hash_dir_; - const experimental::SnapshotMetadataRecord metadata_; tstring run_id_ GUARDED_BY(mu_); tstring run_dir_ GUARDED_BY(mu_); + int64 version_; std::vector<tstring> filenames_; uint64 elements_produced_ GUARDED_BY(mu_) = 0; @@ -1220,6 +1437,10 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { metadata.set_creation_timestamp(EnvTime::NowMicros()); metadata.set_graph_hash(dataset()->graph_hash_); metadata.set_run_id(run_id_.data(), run_id_.size()); + metadata.set_version(kCurrentVersion); + for (const auto& output_dtype : dataset()->output_dtypes()) { + metadata.add_dtype(output_dtype); + } metadata.set_finalized(false); TF_RETURN_IF_ERROR(WriteMetadataFile(hash_dir_, metadata)); } @@ -1564,11 +1785,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { } if (produced_elem) { - experimental::SnapshotRecord record; for (const auto& out_tensor : elem.value) { *bytes_written += out_tensor.TotalBytes(); - TensorProto* t = record.add_tensor(); - out_tensor.AsProtoTensorContent(t); } bool should_close; @@ -1584,16 +1802,11 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { TF_RETURN_IF_ERROR(Env::Default()->NewAppendableFile( *snapshot_data_filename, file)); *writer = absl::make_unique<SnapshotWriter>( - file->get(), dataset()->compression_); + file->get(), dataset()->compression_, kCurrentVersion, + dataset()->output_dtypes()); *bytes_written = 0; } -#if defined(PLATFORM_GOOGLE) - TF_RETURN_IF_ERROR( - (*writer)->WriteRecord(record.SerializeAsCord())); -#else // PLATFORM_GOOGLE - TF_RETURN_IF_ERROR( - (*writer)->WriteRecord(record.SerializeAsString())); -#endif // PLATFORM_GOOGLE + TF_RETURN_IF_ERROR((*writer)->WriteTensors(elem.value)); return Status::OK(); } @@ -1641,7 +1854,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { return; } std::unique_ptr<SnapshotWriter> writer( - new SnapshotWriter(file.get(), dataset()->compression_)); + new SnapshotWriter(file.get(), dataset()->compression_, + kCurrentVersion, dataset()->output_dtypes())); bool end_of_processing = false; while (!end_of_processing) { diff --git a/tensorflow/core/kernels/mfcc_mel_filterbank.cc b/tensorflow/core/kernels/mfcc_mel_filterbank.cc index 2c22fec2b11..8eb2d9d8309 100644 --- a/tensorflow/core/kernels/mfcc_mel_filterbank.cc +++ b/tensorflow/core/kernels/mfcc_mel_filterbank.cc @@ -100,8 +100,8 @@ bool MfccMelFilterbank::Initialize(int input_length, double input_sample_rate, if ((i < start_index_) || (i > end_index_)) { band_mapper_[i] = -2; // Indicate an unused Fourier coefficient. } else { - while ((center_frequencies_[channel] < melf) && - (channel < num_channels_)) { + while ((channel < num_channels_) && + (center_frequencies_[channel] < melf)) { ++channel; } band_mapper_[i] = channel - 1; // Can be == -1 diff --git a/tensorflow/core/kernels/mkl_avgpooling_op.cc b/tensorflow/core/kernels/mkl_avgpooling_op.cc index e36e481ebbf..0751fff4c26 100644 --- a/tensorflow/core/kernels/mkl_avgpooling_op.cc +++ b/tensorflow/core/kernels/mkl_avgpooling_op.cc @@ -111,16 +111,18 @@ class MklAvgPoolingOp : public MklPoolingForwardOpBase<T> { else pooling_prop_kind = prop_kind::forward_training; #ifdef ENABLE_MKLDNN_V1 + // TODO(DNNL): Find out what should we use input_md.data.format. MklPoolingParams fwdParams( src_dims, output_dims_mkl_order, filter_dims, strides, padding_left, padding_right, ALGORITHM::pooling_avg_exclude_padding, pooling_prop_kind, - static_cast<MEMORY_FORMAT>(this->data_format_mkldnn_)); + static_cast<MEMORY_FORMAT>(this->data_format_mkldnn_), input_md); #else MklPoolingParams fwdParams( src_dims, output_dims_mkl_order, filter_dims, strides, padding_left, padding_right, ALGORITHM::pooling_avg_exclude_padding, - pooling_prop_kind, static_cast<MEMORY_FORMAT>(input_md.data.format)); + pooling_prop_kind, static_cast<MEMORY_FORMAT>(input_md.data.format), + input_md); #endif pooling_fwd = MklPoolingFwdPrimitiveFactory<T>::Get(fwdParams); @@ -234,17 +236,18 @@ class MklAvgPoolingGradOp : public MklPoolingBackwardOpBase<T> { // Pass prop_kind::forward_training to create a forward primitive // that is used in the backward pass. #ifdef ENABLE_MKLDNN_V1 + // TODO(DNNL): Find out what should we use src_md.data.format. MklPoolingParams bwdParams( orig_input_dims_mkl_order, output_dims_mkl_order, filter_dims, strides, padding_left, padding_right, ALGORITHM::pooling_avg_exclude_padding, prop_kind::forward_training, - static_cast<MEMORY_FORMAT>(this->data_format_mkldnn_)); + static_cast<MEMORY_FORMAT>(this->data_format_mkldnn_), src_md); #else MklPoolingParams bwdParams( orig_input_dims_mkl_order, output_dims_mkl_order, filter_dims, strides, padding_left, padding_right, ALGORITHM::pooling_avg_exclude_padding, prop_kind::forward_training, - static_cast<MEMORY_FORMAT>(src_md.data.format)); + static_cast<MEMORY_FORMAT>(src_md.data.format), src_md); #endif MklPoolingBwdPrimitive<T>* pooling_bwd = MklPoolingBwdPrimitiveFactory<T>::Get(bwdParams); diff --git a/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc b/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc index 30023f360cc..2638d835f2c 100644 --- a/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc +++ b/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc @@ -26,6 +26,9 @@ limitations under the License. #define GET_FLAG(bn_flag) static_cast<int>(BN_FLAGS::bn_flag) #define IS_SET(cflag) (context_.flags & GET_FLAG(cflag)) +#define GET_FLAG(bn_flag) static_cast<int>(BN_FLAGS::bn_flag) +#define IS_SET(cflag) (context_.flags & GET_FLAG(cflag)) + using mkldnn::batch_normalization_backward; using mkldnn::batch_normalization_forward; using mkldnn::prop_kind; @@ -51,18 +54,17 @@ struct MklBatchNormFwdParams { MklBatchNormFwdParams(const memory::dims& src_dims, int depth, float eps, #ifndef ENABLE_MKLDNN_V1 bool training, MEMORY_FORMAT src_format) +#else + bool training, memory::desc src_md) +#endif // !ENABLE_MKLDNN_V1 : src_dims(src_dims), depth(depth), eps(eps), training(training), +#ifndef ENABLE_MKLDNN_V1 src_format(src_format) { } #else - bool training, memory::desc src_md) - : src_dims(src_dims), - depth(depth), - eps(eps), - training(training), src_md(src_md) { } #endif // !ENABLE_MKLDNN_V1 @@ -231,6 +233,7 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive { } // BatchNorm forward primitive. + // TODO(intel-tf): Merge all the #ifdefs and simplify code if (!fwdParams.training && !(IS_SET(use_global_stats))) { #ifdef ENABLE_MKLDNN_V1 if ((IS_SET(use_scale_shift)) && mkldnn_use_scaleshift) { @@ -383,6 +386,7 @@ struct MklBatchNormBwdParams { int depth; float eps; bool training; + #ifndef ENABLE_MKLDNN_V1 MEMORY_FORMAT src_format; #else @@ -466,10 +470,8 @@ class MklFusedBatchNormBwdPrimitive : public MklPrimitive { #ifdef ENABLE_MKLDNN_V1 // Execute backward batch-normalization primitives. DCHECK_EQ(context_.bwd_primitives.size(), context_.net_args.size()); - for (size_t i = 0; i < context_.bwd_primitives.size(); ++i) { - context_.bwd_primitives.at(i).execute(*context_.bwd_stream, - context_.net_args.at(i)); - } + execute_primitives(context_.bwd_primitives, context_.bwd_stream, + context_.net_args); #else context_.bwd_stream->submit(context_.bwd_primitives); #endif // ENABLE_MKLDNN_V1 @@ -841,7 +843,17 @@ class MklFusedBatchNormOp : public OpKernel { MklFusedBatchNormFwdPrimitive<T, U>* bn_fwd = MklFusedBatchNormFwdPrimitiveFactory<T, U>::Get(fwdParams); - const T* src_data = src_tensor.flat<T>().data(); + // Check if reorder is needed for src. + const T* src_data = nullptr; + std::shared_ptr<BatchNormFwdPd> bn_fwd_pd = bn_fwd->GetBatchNormFwdPd(); + if (IS_SRC_REORDER_NEEDED(src_md, bn_fwd_pd, bn_fwd)) { + src.SetUsrMem(src_md, &src_tensor); + src.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA( + GET_SRC_DESC_FROM_OP_PD(bn_fwd_pd), cpu_engine_)); + src_data = static_cast<T*>(src.GetOpMem().get_data_handle()); + } else { + src_data = static_cast<T*>(const_cast<T*>(src_tensor.flat<T>().data())); + } // Allocate output (dst) tensor; always set it as MKL-DNN layout MklDnnShape dnn_shape_dst; diff --git a/tensorflow/core/kernels/mkl_matmul_op_fused.cc b/tensorflow/core/kernels/mkl_matmul_op_fused.cc index 18f6667fd1e..2ff24f5d2fc 100644 --- a/tensorflow/core/kernels/mkl_matmul_op_fused.cc +++ b/tensorflow/core/kernels/mkl_matmul_op_fused.cc @@ -86,11 +86,10 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase<T, T> { const int k = src_tf_shape.dim_size(dim_pair[0]); const int channel = weight_tf_shape.dim_size(1 - dim_pair[1]); - OP_REQUIRES( - ctx, k == weight_tf_shape.dim_size(dim_pair[1]), - errors::InvalidArgument( - "Matrix size-incompatible: In[0]: ", src_tf_shape.DebugString(), - ", In[1]: ", weight_tf_shape.DebugString())); + OP_REQUIRES(ctx, k == weight_tf_shape.dim_size(dim_pair[1]), + errors::InvalidArgument("Matrix size-incompatible: In[0]: ", + src_tf_shape.DebugString(), ", In[1]: ", + weight_tf_shape.DebugString())); OP_REQUIRES(ctx, bias_tensor.shape().dim_size(0) == channel, errors::InvalidArgument( "Must provide as many biases as the channel size: ", @@ -201,9 +200,9 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase<T, T> { // Execute fused matmul op. matmul_prim->Execute(src_data, weight_data, bias_data, dst_data); } catch (mkldnn::error& e) { - string error_msg = "Status: " + std::to_string(e.status) + - ", message: " + string(e.message) + ", in file " + - string(__FILE__) + ":" + std::to_string(__LINE__); + string error_msg = "Status: " + std::to_string(e.status) + ", message: " + + string(e.message) + ", in file " + string(__FILE__) + + ":" + std::to_string(__LINE__); OP_REQUIRES_OK( ctx, errors::Aborted("Operation received an exception:", error_msg)); } diff --git a/tensorflow/core/kernels/mkl_matmul_ops_common.h b/tensorflow/core/kernels/mkl_matmul_ops_common.h index ca90a24a1cd..c5d27b92e00 100644 --- a/tensorflow/core/kernels/mkl_matmul_ops_common.h +++ b/tensorflow/core/kernels/mkl_matmul_ops_common.h @@ -443,16 +443,16 @@ class MklDnnMatMulOpBase : public OpKernel { Tensor* weight_tensor_ptr = nullptr; - size_t size = matmul_fwd_pd.get()->PRIMITIVE_DESC_WEIGHTS.get_size(); + size_t weight_size = matmul_fwd_pd.get()->PRIMITIVE_DESC_WEIGHTS.get_size(); TensorShape weight_tf_shape; - weight_tf_shape.AddDim(size / sizeof(Tweight)); + weight_tf_shape.AddDim(weight_size / sizeof(Tweight)); OP_REQUIRES_OK(context, context->allocate_persistent( DataTypeToEnum<Tweight>::value, weight_tf_shape, &weight_oi_, &weight_tensor_ptr)); void* weight_oi_t_data = weight.GetTensorBuffer(weight_tensor_ptr); - memcpy(weight_oi_t_data, weight_data, size); + memcpy(weight_oi_t_data, weight_data, weight_size); // cache the memory descriptor #ifdef ENABLE_MKLDNN_V1 diff --git a/tensorflow/core/kernels/mkl_maxpooling_op.cc b/tensorflow/core/kernels/mkl_maxpooling_op.cc index de4974f659e..098ea049246 100644 --- a/tensorflow/core/kernels/mkl_maxpooling_op.cc +++ b/tensorflow/core/kernels/mkl_maxpooling_op.cc @@ -139,15 +139,16 @@ class MklMaxPoolingOp : public MklPoolingForwardOpBase<T> { else pooling_prop_kind = prop_kind::forward_training; #ifdef ENABLE_MKLDNN_V1 + // TODO(DNNL): Figure out what should be used for input_md.data.format MklPoolingParams fwdParams( src_dims, output_dims_mkl_order, filter_dims, strides, padding_left, padding_right, ALGORITHM::pooling_max, pooling_prop_kind, - static_cast<MEMORY_FORMAT>(this->data_format_mkldnn_)); + static_cast<MEMORY_FORMAT>(this->data_format_mkldnn_), input_md); #else MklPoolingParams fwdParams( src_dims, output_dims_mkl_order, filter_dims, strides, padding_left, padding_right, ALGORITHM::pooling_max, pooling_prop_kind, - static_cast<MEMORY_FORMAT>(input_md.data.format)); + static_cast<MEMORY_FORMAT>(input_md.data.format), input_md); #endif pooling_fwd = MklPoolingFwdPrimitiveFactory<T>::Get(fwdParams); // Allocate output tensor. @@ -297,17 +298,18 @@ class MklMaxPoolingGradOp : public MklPoolingBackwardOpBase<T> { this->data_format_mkldnn_); #ifdef ENABLE_MKLDNN_V1 + // TODO(DNNL): Find out what should be used for src_md.data.format. MklPoolingParams bwdParams( orig_input_dims_mkl_order, output_dims_mkl_order, filter_dims, strides, padding_left, padding_right, ALGORITHM::pooling_max, prop_kind::forward_training, - static_cast<MEMORY_FORMAT>(this->data_format_mkldnn_)); + static_cast<MEMORY_FORMAT>(this->data_format_mkldnn_), src_md); #else MklPoolingParams bwdParams( orig_input_dims_mkl_order, output_dims_mkl_order, filter_dims, strides, padding_left, padding_right, ALGORITHM::pooling_max, prop_kind::forward_training, - static_cast<MEMORY_FORMAT>(src_md.data.format)); + static_cast<MEMORY_FORMAT>(src_md.data.format), src_md); #endif MklPoolingBwdPrimitive<T>* pooling_bwd = MklPoolingBwdPrimitiveFactory<T>::Get(bwdParams); diff --git a/tensorflow/core/kernels/mkl_pooling_ops_common.cc b/tensorflow/core/kernels/mkl_pooling_ops_common.cc index 29bcaf5e67c..904866f8223 100644 --- a/tensorflow/core/kernels/mkl_pooling_ops_common.cc +++ b/tensorflow/core/kernels/mkl_pooling_ops_common.cc @@ -43,6 +43,7 @@ void MklPoolingFwdPrimitive<T>::Setup(const MklPoolingParams& fwdParams) { // so src format is currently hard-coded. // A utility function is used to do this, // which may be broken with future CPU architectures +#ifndef ENABLE_MKLDNN_V1 bool is_2d = (fwdParams.src_dims.size() == 4); if (std::is_same<T, qint8>::value || std::is_same<T, quint8>::value) context_.src_fmt = is_2d ? MEMORY_FORMAT::nhwc : MEMORY_FORMAT::ndhwc; @@ -51,6 +52,9 @@ void MklPoolingFwdPrimitive<T>::Setup(const MklPoolingParams& fwdParams) { context_.src_md.reset(new memory::desc({fwdParams.src_dims}, MklDnnType<T>(), context_.src_fmt)); +#else + context_.src_md.reset(new memory::desc(fwdParams.src_md.data)); +#endif // !ENABLE_MKLDNN_V1 context_.dst_md.reset(new memory::desc({fwdParams.dst_dims}, MklDnnType<T>(), MEMORY_FORMAT::any)); diff --git a/tensorflow/core/kernels/mkl_pooling_ops_common.h b/tensorflow/core/kernels/mkl_pooling_ops_common.h index 8d18c95a542..ff51282ecc6 100644 --- a/tensorflow/core/kernels/mkl_pooling_ops_common.h +++ b/tensorflow/core/kernels/mkl_pooling_ops_common.h @@ -48,12 +48,13 @@ struct MklPoolingParams { mkldnn::algorithm alg_kind; mkldnn::prop_kind prop_kind; MEMORY_FORMAT src_format; + memory::desc src_md; MklPoolingParams(memory::dims src_dims, memory::dims dst_dims, memory::dims filter_dims, memory::dims strides, memory::dims padding_left, memory::dims padding_right, mkldnn::algorithm alg_kind, mkldnn::prop_kind prop_kind, - MEMORY_FORMAT src_format) + MEMORY_FORMAT src_format, memory::desc src_md) : src_dims(src_dims), dst_dims(dst_dims), filter_dims(filter_dims), @@ -62,7 +63,8 @@ struct MklPoolingParams { padding_right(padding_right), alg_kind(alg_kind), prop_kind(prop_kind), - src_format(src_format) {} + src_format(src_format), + src_md(src_md) {} }; template <typename T> diff --git a/tensorflow/core/kernels/pack_op.cc b/tensorflow/core/kernels/pack_op.cc index 4b4705150a6..cf2b6bb1100 100644 --- a/tensorflow/core/kernels/pack_op.cc +++ b/tensorflow/core/kernels/pack_op.cc @@ -50,20 +50,10 @@ class PackOp : public OpKernel { } void Compute(OpKernelContext* c) override { - OpInputList values; - OP_REQUIRES_OK(c, c->input_list("values", &values)); - const int num = values.size(); + const int num = num_inputs(); + const Tensor& first_input = c->input(0); - // Verify that all input shapes match - for (int i = 1; i < num; i++) { - OP_REQUIRES(c, values[0].shape().IsSameSize(values[i].shape()), - errors::InvalidArgument( - "Shapes of all inputs must match: values[0].shape = ", - values[0].shape().DebugString(), " != values[", i, - "].shape = ", values[i].shape().DebugString())); - } - - int expanded_num_dims = values[0].dims() + 1; + int expanded_num_dims = first_input.dims() + 1; int axis = axis_; if (axis < 0) axis += expanded_num_dims; @@ -72,13 +62,13 @@ class PackOp : public OpKernel { -expanded_num_dims, ", ", expanded_num_dims, ")")); - TensorShape output_shape(values[0].shape()); + TensorShape output_shape(first_input.shape()); output_shape.InsertDim(axis, num); // In the num = 1 case, just reshape the input if (num == 1) { Tensor output; - CHECK(output.CopyFrom(values[0], output_shape)); + CHECK(output.CopyFrom(first_input, output_shape)); c->set_output(0, output); return; } @@ -109,8 +99,15 @@ class PackOp : public OpKernel { ConstMatrixVector inputs_flat; inputs_flat.reserve(num); for (int i = 0; i < num; ++i) { + const Tensor& input = c->input(i); + OP_REQUIRES(c, first_input.shape().IsSameSize(input.shape()), + errors::InvalidArgument( + "Shapes of all inputs must match: values[0].shape = ", + first_input.shape().DebugString(), " != values[", i, + "].shape = ", input.shape().DebugString())); + inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix( - values[i].shaped<T, 2>({before_dim, after_dim}))); + input.shaped<T, 2>({before_dim, after_dim}))); } #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM if (std::is_same<Device, GPUDevice>::value) { diff --git a/tensorflow/core/kernels/sparse_fill_empty_rows_op.cc b/tensorflow/core/kernels/sparse_fill_empty_rows_op.cc index a97c5cb47a2..8de93cf9b30 100644 --- a/tensorflow/core/kernels/sparse_fill_empty_rows_op.cc +++ b/tensorflow/core/kernels/sparse_fill_empty_rows_op.cc @@ -78,16 +78,23 @@ class SparseFillEmptyRowsOp : public OpKernel { const int64 N = indices_t.shape().dim_size(0); const int64 dense_rows = dense_shape(0); - Tensor* empty_row_indicator_t; - OP_REQUIRES_OK(context, context->allocate_output(kEmptyRowIndicatorOutput, - TensorShape({dense_rows}), - &empty_row_indicator_t)); - auto empty_row_indicator = empty_row_indicator_t->vec<bool>(); - Tensor* reverse_index_map_t; - OP_REQUIRES_OK(context, context->allocate_output(kReverseIndexMapOutput, - TensorShape({N}), - &reverse_index_map_t)); - auto reverse_index_map = reverse_index_map_t->vec<int64>(); + bool* empty_row_indicator = nullptr; + if (context->output_required(kEmptyRowIndicatorOutput)) { + Tensor* empty_row_indicator_t = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(kEmptyRowIndicatorOutput, + TensorShape({dense_rows}), + &empty_row_indicator_t)); + empty_row_indicator = empty_row_indicator_t->vec<bool>().data(); + } + int64* reverse_index_map = nullptr; + if (context->output_required(kReverseIndexMapOutput)) { + Tensor* reverse_index_map_t = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(kReverseIndexMapOutput, + TensorShape({N}), + &reverse_index_map_t)); + reverse_index_map = reverse_index_map_t->vec<int64>().data(); + } int rank = indices_t.shape().dim_size(1); @@ -122,8 +129,11 @@ class SparseFillEmptyRowsOp : public OpKernel { bool all_rows_full = true; for (int row = 0; row < dense_rows; ++row) { // csr_offset here describes the number of elements in this dense row - empty_row_indicator(row) = (csr_offset[row] == 0); - all_rows_full = all_rows_full & !empty_row_indicator(row); + bool row_empty = (csr_offset[row] == 0); + if (empty_row_indicator) { + empty_row_indicator[row] = row_empty; + } + all_rows_full = all_rows_full & !row_empty; // In filled version, each row has at least one element. csr_offset[row] = std::max(csr_offset[row], int64{1}); // Update csr_offset to represent the number of elements up to and @@ -140,8 +150,10 @@ class SparseFillEmptyRowsOp : public OpKernel { if (all_rows_full) { context->set_output(kOutputIndicesOutput, indices_t); context->set_output(kOutputValuesOutput, values_t); - for (int64 i = 0; i < N; ++i) { - reverse_index_map(i) = i; + if (reverse_index_map) { + for (int64 i = 0; i < N; ++i) { + reverse_index_map[i] = i; + } } } else { Tensor* output_indices_t; @@ -169,7 +181,9 @@ class SparseFillEmptyRowsOp : public OpKernel { std::copy_n(&indices(i, 0), rank, &output_indices(output_i, 0)); output_values(output_i) = values(i); // We'll need this reverse index map to backprop correctly. - reverse_index_map(i) = output_i; + if (reverse_index_map) { + reverse_index_map[i] = output_i; + } } // Fill in values for rows that are missing diff --git a/tensorflow/core/lib/io/BUILD b/tensorflow/core/lib/io/BUILD index 87b5090a59f..d03a895b429 100644 --- a/tensorflow/core/lib/io/BUILD +++ b/tensorflow/core/lib/io/BUILD @@ -208,6 +208,21 @@ cc_library( alwayslink = True, ) +cc_library( + name = "cache", + srcs = [ + "cache.cc", + ], + hdrs = [ + "cache.h", + ], + deps = [ + "//tensorflow/core/platform:coding", + "//tensorflow/core/platform:mutex", + "//tensorflow/core/platform:stringpiece", + ], +) + cc_library( name = "table", srcs = [ @@ -220,6 +235,7 @@ cc_library( ], deps = [ ":block", + ":cache", ":iterator", ":table_options", "//tensorflow/core/lib/core:coding", @@ -290,6 +306,8 @@ filegroup( "block_builder.h", "buffered_inputstream.cc", "buffered_inputstream.h", + "cache.cc", + "cache.h", "compression.cc", "compression.h", "format.cc", @@ -352,6 +370,7 @@ filegroup( name = "legacy_lib_io_all_tests", srcs = [ "buffered_inputstream_test.cc", + "cache_test.cc", "inputbuffer_test.cc", "inputstream_interface_test.cc", "path_test.cc", @@ -369,6 +388,7 @@ filegroup( name = "legacy_lib_io_headers", srcs = [ "buffered_inputstream.h", + "cache.h", "compression.h", "inputstream_interface.h", "path.h", diff --git a/tensorflow/core/lib/io/cache.cc b/tensorflow/core/lib/io/cache.cc new file mode 100644 index 00000000000..b5521b1752b --- /dev/null +++ b/tensorflow/core/lib/io/cache.cc @@ -0,0 +1,450 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/lib/io/cache.h" + +#include <assert.h> +#include <stdio.h> +#include <stdlib.h> +#include <string.h> + +#include "tensorflow/core/platform/coding.h" +#include "tensorflow/core/platform/mutex.h" + +namespace tensorflow { + +namespace table { + +Cache::~Cache() {} + +namespace { + +// LRU cache implementation +// +// Cache entries have an "in_cache" boolean indicating whether the cache has a +// reference on the entry. The only ways that this can become false without the +// entry being passed to its "deleter" are via Erase(), via Insert() when +// an element with a duplicate key is inserted, or on destruction of the cache. +// +// The cache keeps two linked lists of items in the cache. All items in the +// cache are in one list or the other, and never both. Items still referenced +// by clients but erased from the cache are in neither list. The lists are: +// - in-use: contains the items currently referenced by clients, in no +// particular order. (This list is used for invariant checking. If we +// removed the check, elements that would otherwise be on this list could be +// left as disconnected singleton lists.) +// - LRU: contains the items not currently referenced by clients, in LRU order +// Elements are moved between these lists by the Ref() and Unref() methods, +// when they detect an element in the cache acquiring or losing its only +// external reference. + +// An entry is a variable length heap-allocated structure. Entries +// are kept in a circular doubly linked list ordered by access time. +struct LRUHandle { + void* value; + void (*deleter)(const Slice&, void* value); + LRUHandle* next_hash; + LRUHandle* next; + LRUHandle* prev; + size_t charge; // TODO(opt): Only allow uint32_t? + size_t key_length; + bool in_cache; // Whether entry is in the cache. + uint32_t refs; // References, including cache reference, if present. + uint32_t hash; // Hash of key(); used for fast sharding and comparisons + char key_data[1]; // Beginning of key + + Slice key() const { + // next_ is only equal to this if the LRU handle is the list head of an + // empty list. List heads never have meaningful keys. + assert(next != this); + + return Slice(key_data, key_length); + } +}; + +// We provide our own simple hash table since it removes a whole bunch +// of porting hacks and is also faster than some of the built-in hash +// table implementations in some of the compiler/runtime combinations +// we have tested. E.g., readrandom speeds up by ~5% over the g++ +// 4.4.3's builtin hashtable. +class HandleTable { + public: + HandleTable() : length_(0), elems_(0), list_(nullptr) { Resize(); } + ~HandleTable() { delete[] list_; } + + LRUHandle* Lookup(const Slice& key, uint32_t hash) { + return *FindPointer(key, hash); + } + + LRUHandle* Insert(LRUHandle* h) { + LRUHandle** ptr = FindPointer(h->key(), h->hash); + LRUHandle* old = *ptr; + h->next_hash = (old == nullptr ? nullptr : old->next_hash); + *ptr = h; + if (old == nullptr) { + ++elems_; + if (elems_ > length_) { + // Since each cache entry is fairly large, we aim for a small + // average linked list length (<= 1). + Resize(); + } + } + return old; + } + + LRUHandle* Remove(const Slice& key, uint32_t hash) { + LRUHandle** ptr = FindPointer(key, hash); + LRUHandle* result = *ptr; + if (result != nullptr) { + *ptr = result->next_hash; + --elems_; + } + return result; + } + + private: + // The table consists of an array of buckets where each bucket is + // a linked list of cache entries that hash into the bucket. + uint32_t length_; + uint32_t elems_; + LRUHandle** list_; + + // Return a pointer to slot that points to a cache entry that + // matches key/hash. If there is no such cache entry, return a + // pointer to the trailing slot in the corresponding linked list. + LRUHandle** FindPointer(const Slice& key, uint32_t hash) { + LRUHandle** ptr = &list_[hash & (length_ - 1)]; + while (*ptr != nullptr && ((*ptr)->hash != hash || key != (*ptr)->key())) { + ptr = &(*ptr)->next_hash; + } + return ptr; + } + + void Resize() { + uint32_t new_length = 4; + while (new_length < elems_) { + new_length *= 2; + } + LRUHandle** new_list = new LRUHandle*[new_length]; + memset(new_list, 0, sizeof(new_list[0]) * new_length); + uint32_t count = 0; + for (uint32_t i = 0; i < length_; i++) { + LRUHandle* h = list_[i]; + while (h != nullptr) { + LRUHandle* next = h->next_hash; + uint32_t hash = h->hash; + LRUHandle** ptr = &new_list[hash & (new_length - 1)]; + h->next_hash = *ptr; + *ptr = h; + h = next; + count++; + } + } + assert(elems_ == count); + delete[] list_; + list_ = new_list; + length_ = new_length; + } +}; + +// A single shard of sharded cache. +class LRUCache { + public: + LRUCache(); + ~LRUCache(); + + // Separate from constructor so caller can easily make an array of LRUCache + void SetCapacity(size_t capacity) { capacity_ = capacity; } + + // Like Cache methods, but with an extra "hash" parameter. + Cache::Handle* Insert(const Slice& key, uint32_t hash, void* value, + size_t charge, + void (*deleter)(const Slice& key, void* value)); + Cache::Handle* Lookup(const Slice& key, uint32_t hash); + void Release(Cache::Handle* handle); + void Erase(const Slice& key, uint32_t hash); + void Prune(); + size_t TotalCharge() const { + mutex_lock l(mutex_); + return usage_; + } + + private: + void LRU_Remove(LRUHandle* e); + void LRU_Append(LRUHandle* list, LRUHandle* e); + void Ref(LRUHandle* e); + void Unref(LRUHandle* e); + bool FinishErase(LRUHandle* e) EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + // Initialized before use. + size_t capacity_; + + // mutex_ protects the following state. + mutable mutex mutex_; + size_t usage_ GUARDED_BY(mutex_); + + // Dummy head of LRU list. + // lru.prev is newest entry, lru.next is oldest entry. + // Entries have refs==1 and in_cache==true. + LRUHandle lru_ GUARDED_BY(mutex_); + + // Dummy head of in-use list. + // Entries are in use by clients, and have refs >= 2 and in_cache==true. + LRUHandle in_use_ GUARDED_BY(mutex_); + + HandleTable table_ GUARDED_BY(mutex_); +}; + +LRUCache::LRUCache() : capacity_(0), usage_(0) { + // Make empty circular linked lists. + lru_.next = &lru_; + lru_.prev = &lru_; + in_use_.next = &in_use_; + in_use_.prev = &in_use_; +} + +LRUCache::~LRUCache() { + assert(in_use_.next == &in_use_); // Error if caller has an unreleased handle + for (LRUHandle* e = lru_.next; e != &lru_;) { + LRUHandle* next = e->next; + assert(e->in_cache); + e->in_cache = false; + assert(e->refs == 1); // Invariant of lru_ list. + Unref(e); + e = next; + } +} + +void LRUCache::Ref(LRUHandle* e) { + if (e->refs == 1 && e->in_cache) { // If on lru_ list, move to in_use_ list. + LRU_Remove(e); + LRU_Append(&in_use_, e); + } + e->refs++; +} + +void LRUCache::Unref(LRUHandle* e) { + assert(e->refs > 0); + e->refs--; + if (e->refs == 0) { // Deallocate. + assert(!e->in_cache); + (*e->deleter)(e->key(), e->value); + free(e); + } else if (e->in_cache && e->refs == 1) { + // No longer in use; move to lru_ list. + LRU_Remove(e); + LRU_Append(&lru_, e); + } +} + +void LRUCache::LRU_Remove(LRUHandle* e) { + e->next->prev = e->prev; + e->prev->next = e->next; +} + +void LRUCache::LRU_Append(LRUHandle* list, LRUHandle* e) { + // Make "e" newest entry by inserting just before *list + e->next = list; + e->prev = list->prev; + e->prev->next = e; + e->next->prev = e; +} + +Cache::Handle* LRUCache::Lookup(const Slice& key, uint32_t hash) { + mutex_lock l(mutex_); + LRUHandle* e = table_.Lookup(key, hash); + if (e != nullptr) { + Ref(e); + } + return reinterpret_cast<Cache::Handle*>(e); +} + +void LRUCache::Release(Cache::Handle* handle) { + mutex_lock l(mutex_); + Unref(reinterpret_cast<LRUHandle*>(handle)); +} + +Cache::Handle* LRUCache::Insert(const Slice& key, uint32_t hash, void* value, + size_t charge, + void (*deleter)(const Slice& key, + void* value)) { + mutex_lock l(mutex_); + + LRUHandle* e = + reinterpret_cast<LRUHandle*>(malloc(sizeof(LRUHandle) - 1 + key.size())); + e->value = value; + e->deleter = deleter; + e->charge = charge; + e->key_length = key.size(); + e->hash = hash; + e->in_cache = false; + e->refs = 1; // for the returned handle. + memcpy(e->key_data, key.data(), key.size()); + + if (capacity_ > 0) { + e->refs++; // for the cache's reference. + e->in_cache = true; + LRU_Append(&in_use_, e); + usage_ += charge; + FinishErase(table_.Insert(e)); + } else { // don't cache. (capacity_==0 is supported and turns off caching.) + // next is read by key() in an assert, so it must be initialized + e->next = nullptr; + } + while (usage_ > capacity_ && lru_.next != &lru_) { + LRUHandle* old = lru_.next; + assert(old->refs == 1); + bool erased = FinishErase(table_.Remove(old->key(), old->hash)); + if (!erased) { // to avoid unused variable when compiled NDEBUG + assert(erased); + } + } + + return reinterpret_cast<Cache::Handle*>(e); +} + +// If e != nullptr, finish removing *e from the cache; it has already been +// removed from the hash table. Return whether e != nullptr. +bool LRUCache::FinishErase(LRUHandle* e) { + if (e != nullptr) { + assert(e->in_cache); + LRU_Remove(e); + e->in_cache = false; + usage_ -= e->charge; + Unref(e); + } + return e != nullptr; +} + +void LRUCache::Erase(const Slice& key, uint32_t hash) { + mutex_lock l(mutex_); + FinishErase(table_.Remove(key, hash)); +} + +void LRUCache::Prune() { + mutex_lock l(mutex_); + while (lru_.next != &lru_) { + LRUHandle* e = lru_.next; + assert(e->refs == 1); + bool erased = FinishErase(table_.Remove(e->key(), e->hash)); + if (!erased) { // to avoid unused variable when compiled NDEBUG + assert(erased); + } + } +} + +static const int kNumShardBits = 4; +static const int kNumShards = 1 << kNumShardBits; + +class ShardedLRUCache : public Cache { + private: + LRUCache shard_[kNumShards]; + mutex id_mutex_; + uint64_t last_id_; + + static inline uint32_t HashSlice(const Slice& s) { + return Hash(s.data(), s.size(), 0); + } + + static uint32_t Shard(uint32_t hash) { return hash >> (32 - kNumShardBits); } + + public: + explicit ShardedLRUCache(size_t capacity) : last_id_(0) { + const size_t per_shard = (capacity + (kNumShards - 1)) / kNumShards; + for (int s = 0; s < kNumShards; s++) { + shard_[s].SetCapacity(per_shard); + } + } + ~ShardedLRUCache() override {} + Handle* Insert(const Slice& key, void* value, size_t charge, + void (*deleter)(const Slice& key, void* value)) override { + const uint32_t hash = HashSlice(key); + return shard_[Shard(hash)].Insert(key, hash, value, charge, deleter); + } + Handle* Lookup(const Slice& key) override { + const uint32_t hash = HashSlice(key); + return shard_[Shard(hash)].Lookup(key, hash); + } + void Release(Handle* handle) override { + LRUHandle* h = reinterpret_cast<LRUHandle*>(handle); + shard_[Shard(h->hash)].Release(handle); + } + void Erase(const Slice& key) override { + const uint32_t hash = HashSlice(key); + shard_[Shard(hash)].Erase(key, hash); + } + void* Value(Handle* handle) override { + return reinterpret_cast<LRUHandle*>(handle)->value; + } + uint64_t NewId() override { + mutex_lock l(id_mutex_); + return ++(last_id_); + } + void Prune() override { + for (int s = 0; s < kNumShards; s++) { + shard_[s].Prune(); + } + } + size_t TotalCharge() const override { + size_t total = 0; + for (int s = 0; s < kNumShards; s++) { + total += shard_[s].TotalCharge(); + } + return total; + } + + private: + // TODO(byronyi): Figure out why Hash32 fails EvictionPolicy test. + static uint32_t Hash(const char* data, size_t n, uint32_t seed) { + // Similar to murmur hash + const uint32_t m = 0xc6a4a793; + const uint32_t r = 24; + const char* limit = data + n; + uint32_t h = seed ^ (n * m); + + // Pick up four bytes at a time + while (data + 4 <= limit) { + uint32_t w = core::DecodeFixed32(data); + data += 4; + h += w; + h *= m; + h ^= (h >> 16); + } + + // Pick up remaining bytes + switch (limit - data) { + case 3: + h += static_cast<uint8_t>(data[2]) << 16; + ABSL_FALLTHROUGH_INTENDED; + case 2: + h += static_cast<uint8_t>(data[1]) << 8; + ABSL_FALLTHROUGH_INTENDED; + case 1: + h += static_cast<uint8_t>(data[0]); + h *= m; + h ^= (h >> r); + break; + } + return h; + } +}; + +} // end anonymous namespace + +Cache* NewLRUCache(size_t capacity) { return new ShardedLRUCache(capacity); } + +} // namespace table + +} // namespace tensorflow diff --git a/tensorflow/core/lib/io/cache.h b/tensorflow/core/lib/io/cache.h new file mode 100644 index 00000000000..788a637077a --- /dev/null +++ b/tensorflow/core/lib/io/cache.h @@ -0,0 +1,125 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_LIB_IO_CACHE_H_ +#define TENSORFLOW_CORE_LIB_IO_CACHE_H_ + +#include "tensorflow/core/platform/stringpiece.h" + +// A Cache is an interface that maps keys to values. It has internal +// synchronization and may be safely accessed concurrently from +// multiple threads. It may automatically evict entries to make room +// for new entries. Values have a specified charge against the cache +// capacity. For example, a cache where the values are variable +// length strings, may use the length of the string as the charge for +// the string. +// +// A builtin cache implementation with a least-recently-used eviction +// policy is provided. Clients may use their own implementations if +// they want something more sophisticated (like scan-resistance, a +// custom eviction policy, variable cache sizing, etc.) + +namespace tensorflow { + +using Slice = StringPiece; + +namespace table { + +class Cache; + +// Create a new cache with a fixed size capacity. This implementation +// of Cache uses a least-recently-used eviction policy. +Cache* NewLRUCache(size_t capacity); + +class Cache { + public: + Cache() = default; + + Cache(const Cache&) = delete; + Cache& operator=(const Cache&) = delete; + + // Destroys all existing entries by calling the "deleter" + // function that was passed to the constructor. + virtual ~Cache(); + + // Opaque handle to an entry stored in the cache. + struct Handle {}; + + // Insert a mapping from key->value into the cache and assign it + // the specified charge against the total cache capacity. + // + // Returns a handle that corresponds to the mapping. The caller + // must call this->Release(handle) when the returned mapping is no + // longer needed. + // + // When the inserted entry is no longer needed, the key and + // value will be passed to "deleter". + virtual Handle* Insert(const Slice& key, void* value, size_t charge, + void (*deleter)(const Slice& key, void* value)) = 0; + + // If the cache has no mapping for "key", returns nullptr. + // + // Else return a handle that corresponds to the mapping. The caller + // must call this->Release(handle) when the returned mapping is no + // longer needed. + virtual Handle* Lookup(const Slice& key) = 0; + + // Release a mapping returned by a previous Lookup(). + // REQUIRES: handle must not have been released yet. + // REQUIRES: handle must have been returned by a method on *this. + virtual void Release(Handle* handle) = 0; + + // Return the value encapsulated in a handle returned by a + // successful Lookup(). + // REQUIRES: handle must not have been released yet. + // REQUIRES: handle must have been returned by a method on *this. + virtual void* Value(Handle* handle) = 0; + + // If the cache contains entry for key, erase it. Note that the + // underlying entry will be kept around until all existing handles + // to it have been released. + virtual void Erase(const Slice& key) = 0; + + // Return a new numeric id. May be used by multiple clients who are + // sharing the same cache to partition the key space. Typically the + // client will allocate a new id at startup and prepend the id to + // its cache keys. + virtual uint64_t NewId() = 0; + + // Remove all cache entries that are not actively in use. Memory-constrained + // applications may wish to call this method to reduce memory usage. + // Default implementation of Prune() does nothing. Subclasses are strongly + // encouraged to override the default implementation. A future release of + // leveldb may change Prune() to a pure abstract method. + virtual void Prune() {} + + // Return an estimate of the combined charges of all elements stored in the + // cache. + virtual size_t TotalCharge() const = 0; + + private: + void LRU_Remove(Handle* e); + void LRU_Append(Handle* e); + void Unref(Handle* e); + + struct Rep; + Rep* rep_; +}; + +} // namespace table + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_LIB_IO_CACHE_H_ diff --git a/tensorflow/core/lib/io/cache_test.cc b/tensorflow/core/lib/io/cache_test.cc new file mode 100644 index 00000000000..38552d43b34 --- /dev/null +++ b/tensorflow/core/lib/io/cache_test.cc @@ -0,0 +1,238 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/lib/io/cache.h" + +#include <string> +#include <vector> + +#include "tensorflow/core/lib/core/coding.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { + +namespace table { +// Conversions between numeric keys/values and the types expected by Cache. +static std::string EncodeKey(int k) { + std::string result; + core::PutFixed32(&result, k); + return result; +} +static int DecodeKey(const Slice& k) { + assert(k.size() == 4); + return core::DecodeFixed32(k.data()); +} +static void* EncodeValue(uintptr_t v) { return reinterpret_cast<void*>(v); } +static int DecodeValue(void* v) { return reinterpret_cast<uintptr_t>(v); } + +class CacheTest : public ::testing::Test { + public: + static void Deleter(const Slice& key, void* v) { + current_->deleted_keys_.push_back(DecodeKey(key)); + current_->deleted_values_.push_back(DecodeValue(v)); + } + + static const int kCacheSize = 1000; + std::vector<int> deleted_keys_; + std::vector<int> deleted_values_; + Cache* cache_; + + CacheTest() : cache_(NewLRUCache(kCacheSize)) { current_ = this; } + + ~CacheTest() { delete cache_; } + + int Lookup(int key) { + Cache::Handle* handle = cache_->Lookup(EncodeKey(key)); + const int r = (handle == nullptr) ? -1 : DecodeValue(cache_->Value(handle)); + if (handle != nullptr) { + cache_->Release(handle); + } + return r; + } + + void Insert(int key, int value, int charge = 1) { + cache_->Release(cache_->Insert(EncodeKey(key), EncodeValue(value), charge, + &CacheTest::Deleter)); + } + + Cache::Handle* InsertAndReturnHandle(int key, int value, int charge = 1) { + return cache_->Insert(EncodeKey(key), EncodeValue(value), charge, + &CacheTest::Deleter); + } + + void Erase(int key) { cache_->Erase(EncodeKey(key)); } + static CacheTest* current_; +}; +CacheTest* CacheTest::current_; + +TEST_F(CacheTest, HitAndMiss) { + ASSERT_EQ(-1, Lookup(100)); + + Insert(100, 101); + ASSERT_EQ(101, Lookup(100)); + ASSERT_EQ(-1, Lookup(200)); + ASSERT_EQ(-1, Lookup(300)); + + Insert(200, 201); + ASSERT_EQ(101, Lookup(100)); + ASSERT_EQ(201, Lookup(200)); + ASSERT_EQ(-1, Lookup(300)); + + Insert(100, 102); + ASSERT_EQ(102, Lookup(100)); + ASSERT_EQ(201, Lookup(200)); + ASSERT_EQ(-1, Lookup(300)); + + ASSERT_EQ(1, deleted_keys_.size()); + ASSERT_EQ(100, deleted_keys_[0]); + ASSERT_EQ(101, deleted_values_[0]); +} + +TEST_F(CacheTest, Erase) { + Erase(200); + ASSERT_EQ(0, deleted_keys_.size()); + + Insert(100, 101); + Insert(200, 201); + Erase(100); + ASSERT_EQ(-1, Lookup(100)); + ASSERT_EQ(201, Lookup(200)); + ASSERT_EQ(1, deleted_keys_.size()); + ASSERT_EQ(100, deleted_keys_[0]); + ASSERT_EQ(101, deleted_values_[0]); + + Erase(100); + ASSERT_EQ(-1, Lookup(100)); + ASSERT_EQ(201, Lookup(200)); + ASSERT_EQ(1, deleted_keys_.size()); +} + +TEST_F(CacheTest, EntriesArePinned) { + Insert(100, 101); + Cache::Handle* h1 = cache_->Lookup(EncodeKey(100)); + ASSERT_EQ(101, DecodeValue(cache_->Value(h1))); + + Insert(100, 102); + Cache::Handle* h2 = cache_->Lookup(EncodeKey(100)); + ASSERT_EQ(102, DecodeValue(cache_->Value(h2))); + ASSERT_EQ(0, deleted_keys_.size()); + + cache_->Release(h1); + ASSERT_EQ(1, deleted_keys_.size()); + ASSERT_EQ(100, deleted_keys_[0]); + ASSERT_EQ(101, deleted_values_[0]); + + Erase(100); + ASSERT_EQ(-1, Lookup(100)); + ASSERT_EQ(1, deleted_keys_.size()); + + cache_->Release(h2); + ASSERT_EQ(2, deleted_keys_.size()); + ASSERT_EQ(100, deleted_keys_[1]); + ASSERT_EQ(102, deleted_values_[1]); +} + +TEST_F(CacheTest, EvictionPolicy) { + Insert(100, 101); + Insert(200, 201); + Insert(300, 301); + Cache::Handle* h = cache_->Lookup(EncodeKey(300)); + + // Frequently used entry must be kept around, + // as must things that are still in use. + for (int i = 0; i < kCacheSize + 100; i++) { + Insert(1000 + i, 2000 + i); + ASSERT_EQ(2000 + i, Lookup(1000 + i)); + ASSERT_EQ(101, Lookup(100)); + } + ASSERT_EQ(101, Lookup(100)); + ASSERT_EQ(-1, Lookup(200)); + ASSERT_EQ(301, Lookup(300)); + cache_->Release(h); +} + +TEST_F(CacheTest, UseExceedsCacheSize) { + // Overfill the cache, keeping handles on all inserted entries. + std::vector<Cache::Handle*> h; + for (int i = 0; i < kCacheSize + 100; i++) { + h.push_back(InsertAndReturnHandle(1000 + i, 2000 + i)); + } + + // Check that all the entries can be found in the cache. + for (int i = 0; i < h.size(); i++) { + ASSERT_EQ(2000 + i, Lookup(1000 + i)); + } + + for (int i = 0; i < h.size(); i++) { + cache_->Release(h[i]); + } +} + +TEST_F(CacheTest, HeavyEntries) { + // Add a bunch of light and heavy entries and then count the combined + // size of items still in the cache, which must be approximately the + // same as the total capacity. + const int kLight = 1; + const int kHeavy = 10; + int added = 0; + int index = 0; + while (added < 2 * kCacheSize) { + const int weight = (index & 1) ? kLight : kHeavy; + Insert(index, 1000 + index, weight); + added += weight; + index++; + } + + int cached_weight = 0; + for (int i = 0; i < index; i++) { + const int weight = (i & 1 ? kLight : kHeavy); + int r = Lookup(i); + if (r >= 0) { + cached_weight += weight; + ASSERT_EQ(1000 + i, r); + } + } + ASSERT_LE(cached_weight, kCacheSize + kCacheSize / 10); +} + +TEST_F(CacheTest, NewId) { + uint64_t a = cache_->NewId(); + uint64_t b = cache_->NewId(); + ASSERT_NE(a, b); +} + +TEST_F(CacheTest, Prune) { + Insert(1, 100); + Insert(2, 200); + + Cache::Handle* handle = cache_->Lookup(EncodeKey(1)); + ASSERT_TRUE(handle); + cache_->Prune(); + cache_->Release(handle); + + ASSERT_EQ(100, Lookup(1)); + ASSERT_EQ(-1, Lookup(2)); +} + +TEST_F(CacheTest, ZeroSizeCache) { + delete cache_; + cache_ = NewLRUCache(0); + + Insert(1, 100); + ASSERT_EQ(-1, Lookup(1)); +} + +} // namespace table +} // namespace tensorflow diff --git a/tensorflow/core/lib/io/table.cc b/tensorflow/core/lib/io/table.cc index 1e68493bfe9..6cd1b21c14d 100644 --- a/tensorflow/core/lib/io/table.cc +++ b/tensorflow/core/lib/io/table.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/core/lib/core/coding.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/io/block.h" +#include "tensorflow/core/lib/io/cache.h" #include "tensorflow/core/lib/io/format.h" #include "tensorflow/core/lib/io/table_options.h" #include "tensorflow/core/lib/io/two_level_iterator.h" @@ -32,7 +33,7 @@ struct Table::Rep { Options options; Status status; RandomAccessFile* file; - // XXX uint64 cache_id; + uint64 cache_id; BlockHandle metaindex_handle; // Handle to metaindex_block: saved from footer Block* index_block; @@ -60,21 +61,18 @@ Status Table::Open(const Options& options, RandomAccessFile* file, uint64 size, Block* index_block = nullptr; if (s.ok()) { s = ReadBlock(file, footer.index_handle(), &contents); - if (s.ok()) { - index_block = new Block(contents); - } } if (s.ok()) { // We've successfully read the footer and the index block: we're // ready to serve requests. + index_block = new Block(contents); Rep* rep = new Table::Rep; rep->options = options; rep->file = file; rep->metaindex_handle = footer.metaindex_handle(); rep->index_block = index_block; - // XXX rep->cache_id = (options.block_cache ? - // options.block_cache->NewId() : 0); + rep->cache_id = (options.block_cache ? options.block_cache->NewId() : 0); *table = new Table(rep); } else { if (index_block) delete index_block; @@ -89,13 +87,24 @@ static void DeleteBlock(void* arg, void* ignored) { delete reinterpret_cast<Block*>(arg); } +static void DeleteCachedBlock(const absl::string_view&, void* value) { + Block* block = reinterpret_cast<Block*>(value); + delete block; +} + +static void ReleaseBlock(void* arg, void* h) { + Cache* cache = reinterpret_cast<Cache*>(arg); + Cache::Handle* handle = reinterpret_cast<Cache::Handle*>(h); + cache->Release(handle); +} + // Convert an index iterator value (i.e., an encoded BlockHandle) // into an iterator over the contents of the corresponding block. Iterator* Table::BlockReader(void* arg, const StringPiece& index_value) { Table* table = reinterpret_cast<Table*>(arg); - // Cache* block_cache = table->rep_->options.block_cache; + Cache* block_cache = table->rep_->options.block_cache; Block* block = nullptr; - // Cache::Handle* cache_handle = NULL; + Cache::Handle* cache_handle = NULL; BlockHandle handle; StringPiece input = index_value; @@ -105,16 +114,38 @@ Iterator* Table::BlockReader(void* arg, const StringPiece& index_value) { if (s.ok()) { BlockContents contents; - s = ReadBlock(table->rep_->file, handle, &contents); - if (s.ok()) { - block = new Block(contents); + if (block_cache != nullptr) { + char cache_key_buffer[16]; + core::EncodeFixed64(cache_key_buffer, table->rep_->cache_id); + core::EncodeFixed64(cache_key_buffer + 8, handle.offset()); + absl::string_view key(cache_key_buffer, sizeof(cache_key_buffer)); + cache_handle = block_cache->Lookup(key); + if (cache_handle != nullptr) { + block = reinterpret_cast<Block*>(block_cache->Value(cache_handle)); + } else { + s = ReadBlock(table->rep_->file, handle, &contents); + if (s.ok()) { + block = new Block(contents); + cache_handle = block_cache->Insert(key, block, block->size(), + &DeleteCachedBlock); + } + } + } else { + s = ReadBlock(table->rep_->file, handle, &contents); + if (s.ok()) { + block = new Block(contents); + } } } Iterator* iter; if (block != nullptr) { iter = block->NewIterator(); - iter->RegisterCleanup(&DeleteBlock, block, nullptr); + if (cache_handle == nullptr) { + iter->RegisterCleanup(&DeleteBlock, block, nullptr); + } else { + iter->RegisterCleanup(&ReleaseBlock, block_cache, cache_handle); + } } else { iter = NewErrorIterator(s); } diff --git a/tensorflow/core/lib/io/table_options.h b/tensorflow/core/lib/io/table_options.h index 9a36bf16315..d1b43ae7d43 100644 --- a/tensorflow/core/lib/io/table_options.h +++ b/tensorflow/core/lib/io/table_options.h @@ -21,6 +21,8 @@ limitations under the License. namespace tensorflow { namespace table { +class Cache; + // DB contents are stored in a set of blocks, each of which holds a // sequence of key,value pairs. Each block may be compressed before // being stored in a file. The following enum describes which @@ -60,6 +62,12 @@ struct Options { // incompressible, the kSnappyCompression implementation will // efficiently detect that and will switch to uncompressed mode. CompressionType compression = kSnappyCompression; + + // Control over blocks (user data is stored in a set of blocks, and + // a block is the unit of reading from disk). + + // If non-null, use the specified cache for blocks. + Cache* block_cache = nullptr; }; } // namespace table diff --git a/tensorflow/core/lib/monitoring/BUILD b/tensorflow/core/lib/monitoring/BUILD index fd74298eae0..bbed9a58452 100644 --- a/tensorflow/core/lib/monitoring/BUILD +++ b/tensorflow/core/lib/monitoring/BUILD @@ -70,11 +70,6 @@ cc_library( cc_library( name = "counter", hdrs = ["counter.h"], - visibility = [ - "//tensorflow/c/eager:__pkg__", - "//tensorflow/core:__pkg__", - "//tensorflow/core/platform:__subpackages__", - ], deps = [ ":collection_registry", ":metric_def", @@ -91,11 +86,6 @@ cc_library( cc_library( name = "gauge", hdrs = ["gauge.h"], - visibility = [ - "//tensorflow/c/eager:__pkg__", - "//tensorflow/core:__pkg__", - "//tensorflow/core/platform:__subpackages__", - ], deps = [ ":collection_registry", ":metric_def", @@ -156,11 +146,6 @@ cc_library( name = "sampler", srcs = ["sampler.cc"], hdrs = ["sampler.h"], - visibility = [ - "//tensorflow/c/eager:__pkg__", - "//tensorflow/core:__pkg__", - "//tensorflow/core/platform:__subpackages__", - ], deps = [ ":collection_registry", ":metric_def", @@ -192,11 +177,6 @@ cc_library( name = "percentile_sampler", srcs = ["percentile_sampler.cc"], hdrs = ["percentile_sampler.h"], - visibility = [ - "//tensorflow/c/eager:__pkg__", - "//tensorflow/core:__pkg__", - "//tensorflow/core/platform:__subpackages__", - ], deps = [ ":collection_registry", ":metric_def", diff --git a/tensorflow/core/lib/monitoring/collection_registry.h b/tensorflow/core/lib/monitoring/collection_registry.h index 6b637c21d24..26f3ec02694 100644 --- a/tensorflow/core/lib/monitoring/collection_registry.h +++ b/tensorflow/core/lib/monitoring/collection_registry.h @@ -360,6 +360,37 @@ MetricCollector<metric_kind, Value, NumLabels> MetricCollectorGetter::Get( collector_); } +class Exporter { + public: + virtual ~Exporter() {} + virtual void PeriodicallyExportMetrics() = 0; + virtual void ExportMetrics() = 0; +}; + +namespace exporter_registration { + +class ExporterRegistration { + public: + explicit ExporterRegistration(Exporter* exporter) : exporter_(exporter) { + exporter_->PeriodicallyExportMetrics(); + } + + private: + Exporter* exporter_; +}; + +} // namespace exporter_registration + +#define REGISTER_TF_METRICS_EXPORTER(exporter) \ + REGISTER_TF_METRICS_EXPORTER_UNIQ_HELPER(__COUNTER__, exporter) + +#define REGISTER_TF_METRICS_EXPORTER_UNIQ_HELPER(ctr, exporter) \ + REGISTER_TF_METRICS_EXPORTER_UNIQ(ctr, exporter) + +#define REGISTER_TF_METRICS_EXPORTER_UNIQ(ctr, exporter) \ + static ::tensorflow::monitoring::exporter_registration::ExporterRegistration \ + exporter_registration_##ctr(new exporter()) + } // namespace monitoring } // namespace tensorflow diff --git a/tensorflow/core/platform/BUILD b/tensorflow/core/platform/BUILD index fb40e56829d..eef097b0c75 100644 --- a/tensorflow/core/platform/BUILD +++ b/tensorflow/core/platform/BUILD @@ -16,7 +16,6 @@ load( "//tensorflow/core/platform:build_config.bzl", "tf_additional_env_hdrs", "tf_additional_lib_hdrs", - "tf_additional_monitoring_hdrs", "tf_additional_tensor_coding_deps", "tf_additional_test_srcs", "tf_fingerprint_deps", @@ -24,7 +23,6 @@ load( "tf_google_mobile_srcs_only_runtime", "tf_kernel_tests_linkstatic", "tf_logging_deps", - "tf_monitoring_deps", "tf_platform_alias", "tf_platform_deps", "tf_protobuf_compiler_deps", @@ -83,7 +81,6 @@ exports_files( "load_library.h", "logging.h", "mem.h", - "monitoring.h", "mutex.h", "net.h", "numa.h", @@ -333,12 +330,6 @@ cc_library( deps = tf_logging_deps(), ) -cc_library( - name = "monitoring", - textual_hdrs = ["monitoring.h"], - deps = tf_monitoring_deps(), -) - cc_library( name = "macros", hdrs = ["macros.h"], @@ -760,6 +751,62 @@ cc_library( ] + tf_platform_deps("unbounded_work_queue"), ) +cc_library( + name = "retrying_utils", + srcs = [ + "retrying_utils.cc", + ], + hdrs = [ + "retrying_utils.h", + ], + copts = tf_copts(), + deps = [ + "//tensorflow/core:framework_headers_lib", + "//tensorflow/core:lib_internal", + ], +) + +cc_library( + name = "retrying_file_system", + hdrs = [ + "retrying_file_system.h", + ], + copts = tf_copts(), + deps = [ + ":retrying_utils", + "//tensorflow/core:framework_headers_lib", + "//tensorflow/core:lib_internal", + ], +) + +tf_cc_test( + name = "retrying_file_system_test", + size = "small", + srcs = ["retrying_file_system_test.cc"], + deps = [ + ":retrying_file_system", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/platform:str_util", + ], +) + +tf_cc_test( + name = "retrying_utils_test", + size = "small", + srcs = ["retrying_utils_test.cc"], + deps = [ + ":retrying_utils", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/platform:str_util", + ], +) + # This is a hacky, do-nothing, binary that makes it easy to verify ability to # build, link, and in some cases run all of the libraries under platform. # Realistically, most of this would be covered by tests but at this point @@ -791,6 +838,8 @@ cc_binary( ":png", ":prefetch", ":protobuf", + ":retrying_utils", + ":retrying_file_system", ":scanner", ":setround", ":stacktrace", @@ -1178,7 +1227,6 @@ filegroup( "init_main.h", "logger.h", "mem.h", - "monitoring.h", "mutex.h", "net.h", "notification.h", @@ -1203,7 +1251,7 @@ filegroup( "subprocess.h", "thread_annotations.h", ":base_hdrs", - ] + tf_additional_monitoring_hdrs() + tf_additional_env_hdrs(), + ] + tf_additional_env_hdrs(), visibility = ["//tensorflow/core:__pkg__"], ) @@ -1280,7 +1328,6 @@ filegroup( "demangle.h", "denormal.h", "host_info.h", - "monitoring.h", "platform.h", "protobuf_internal.h", "refcount.h", @@ -1452,7 +1499,6 @@ filegroup( "cpu_feature_guard.cc", "cpu_feature_guard.h", "fingerprint.h", - "monitoring.h", "notification.h", "platform_strings.cc", "platform_strings.h", diff --git a/tensorflow/core/platform/build_config.bzl b/tensorflow/core/platform/build_config.bzl index 4a1ba38fbd8..d1fb5f829ea 100644 --- a/tensorflow/core/platform/build_config.bzl +++ b/tensorflow/core/platform/build_config.bzl @@ -12,7 +12,6 @@ load( _tf_additional_env_hdrs = "tf_additional_env_hdrs", _tf_additional_lib_deps = "tf_additional_lib_deps", _tf_additional_lib_hdrs = "tf_additional_lib_hdrs", - _tf_additional_monitoring_hdrs = "tf_additional_monitoring_hdrs", _tf_additional_rpc_deps = "tf_additional_rpc_deps", _tf_additional_tensor_coding_deps = "tf_additional_tensor_coding_deps", _tf_additional_test_deps = "tf_additional_test_deps", @@ -24,7 +23,6 @@ load( _tf_kernel_tests_linkstatic = "tf_kernel_tests_linkstatic", _tf_lib_proto_parsing_deps = "tf_lib_proto_parsing_deps", _tf_logging_deps = "tf_logging_deps", - _tf_monitoring_deps = "tf_monitoring_deps", _tf_platform_alias = "tf_platform_alias", _tf_platform_deps = "tf_platform_deps", _tf_portable_deps_no_runtime = "tf_portable_deps_no_runtime", @@ -56,7 +54,6 @@ tf_additional_device_tracer_srcs = _tf_additional_device_tracer_srcs tf_additional_env_hdrs = _tf_additional_env_hdrs tf_additional_lib_deps = _tf_additional_lib_deps tf_additional_lib_hdrs = _tf_additional_lib_hdrs -tf_additional_monitoring_hdrs = _tf_additional_monitoring_hdrs tf_additional_rpc_deps = _tf_additional_rpc_deps tf_additional_tensor_coding_deps = _tf_additional_tensor_coding_deps tf_additional_test_deps = _tf_additional_test_deps @@ -68,7 +65,6 @@ tf_jspb_proto_library = _tf_jspb_proto_library tf_kernel_tests_linkstatic = _tf_kernel_tests_linkstatic tf_lib_proto_parsing_deps = _tf_lib_proto_parsing_deps tf_logging_deps = _tf_logging_deps -tf_monitoring_deps = _tf_monitoring_deps tf_platform_alias = _tf_platform_alias tf_platform_deps = _tf_platform_deps tf_portable_deps_no_runtime = _tf_portable_deps_no_runtime diff --git a/tensorflow/core/platform/cloud/BUILD b/tensorflow/core/platform/cloud/BUILD index 53c4f6cda1f..fe08edceae9 100644 --- a/tensorflow/core/platform/cloud/BUILD +++ b/tensorflow/core/platform/cloud/BUILD @@ -92,14 +92,14 @@ cc_library( ":google_auth_provider", ":http_request", ":ram_file_block_cache", - ":retrying_file_system", - ":retrying_utils", ":time_util", "//tensorflow/core:framework_headers_lib", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core/platform:numbers", "//tensorflow/core/platform:path", + "//tensorflow/core/platform:retrying_file_system", + "//tensorflow/core/platform:retrying_utils", "//tensorflow/core/platform:str_util", "//tensorflow/core/platform:stringprintf", "@jsoncpp_git//:jsoncpp", @@ -128,14 +128,14 @@ cc_library( ":google_auth_provider", ":http_request", ":ram_file_block_cache", - ":retrying_file_system", - ":retrying_utils", ":time_util", "//tensorflow/core:framework_headers_lib", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core/platform:numbers", "//tensorflow/core/platform:path", + "//tensorflow/core/platform:retrying_file_system", + "//tensorflow/core/platform:retrying_utils", "//tensorflow/core/platform:str_util", "//tensorflow/core/platform:stringprintf", "@jsoncpp_git//:jsoncpp", @@ -200,12 +200,12 @@ cc_library( deps = [ ":compute_engine_metadata_client", ":oauth_client", - ":retrying_utils", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core/platform:base64", "//tensorflow/core/platform:errors", "//tensorflow/core/platform:path", + "//tensorflow/core/platform:retrying_utils", "//tensorflow/core/platform:status", "@com_google_absl//absl/strings", "@jsoncpp_git//:jsoncpp", @@ -224,9 +224,9 @@ cc_library( deps = [ ":curl_http_request", ":http_request", - ":retrying_utils", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "//tensorflow/core/platform:retrying_utils", ], ) @@ -283,34 +283,6 @@ cc_library( ], ) -cc_library( - name = "retrying_utils", - srcs = [ - "retrying_utils.cc", - ], - hdrs = [ - "retrying_utils.h", - ], - copts = tf_copts(), - deps = [ - "//tensorflow/core:framework_headers_lib", - "//tensorflow/core:lib_internal", - ], -) - -cc_library( - name = "retrying_file_system", - hdrs = [ - "retrying_file_system.h", - ], - copts = tf_copts(), - deps = [ - ":retrying_utils", - "//tensorflow/core:framework_headers_lib", - "//tensorflow/core:lib_internal", - ], -) - cc_library( name = "time_util", srcs = [ @@ -482,20 +454,6 @@ tf_cc_test( ], ) -tf_cc_test( - name = "retrying_file_system_test", - size = "small", - srcs = ["retrying_file_system_test.cc"], - deps = [ - ":retrying_file_system", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core/platform:str_util", - ], -) - tf_cc_test( name = "time_util_test", size = "small", @@ -506,17 +464,3 @@ tf_cc_test( "//tensorflow/core:test_main", ], ) - -tf_cc_test( - name = "retrying_utils_test", - size = "small", - srcs = ["retrying_utils_test.cc"], - deps = [ - ":retrying_utils", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core/platform:str_util", - ], -) diff --git a/tensorflow/core/platform/cloud/compute_engine_metadata_client.h b/tensorflow/core/platform/cloud/compute_engine_metadata_client.h index d7611615606..164380b4141 100644 --- a/tensorflow/core/platform/cloud/compute_engine_metadata_client.h +++ b/tensorflow/core/platform/cloud/compute_engine_metadata_client.h @@ -17,7 +17,7 @@ limitations under the License. #define TENSORFLOW_CORE_PLATFORM_CLOUD_COMPUTE_ENGINE_METADATA_CLIENT_H_ #include "tensorflow/core/platform/cloud/http_request.h" -#include "tensorflow/core/platform/cloud/retrying_utils.h" +#include "tensorflow/core/platform/retrying_utils.h" #include "tensorflow/core/platform/status.h" namespace tensorflow { diff --git a/tensorflow/core/platform/cloud/gcs_file_system.cc b/tensorflow/core/platform/cloud/gcs_file_system.cc index 054ad242692..57847d2ea38 100644 --- a/tensorflow/core/platform/cloud/gcs_file_system.cc +++ b/tensorflow/core/platform/cloud/gcs_file_system.cc @@ -32,7 +32,6 @@ limitations under the License. #include "tensorflow/core/platform/cloud/file_block_cache.h" #include "tensorflow/core/platform/cloud/google_auth_provider.h" #include "tensorflow/core/platform/cloud/ram_file_block_cache.h" -#include "tensorflow/core/platform/cloud/retrying_utils.h" #include "tensorflow/core/platform/cloud/time_util.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/errors.h" @@ -40,6 +39,7 @@ limitations under the License. #include "tensorflow/core/platform/numbers.h" #include "tensorflow/core/platform/path.h" #include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/retrying_utils.h" #include "tensorflow/core/platform/str_util.h" #include "tensorflow/core/platform/stringprintf.h" #include "tensorflow/core/platform/thread_annotations.h" diff --git a/tensorflow/core/platform/cloud/gcs_file_system.h b/tensorflow/core/platform/cloud/gcs_file_system.h index b075cbe9828..98933532b17 100644 --- a/tensorflow/core/platform/cloud/gcs_file_system.h +++ b/tensorflow/core/platform/cloud/gcs_file_system.h @@ -29,8 +29,8 @@ limitations under the License. #include "tensorflow/core/platform/cloud/gcs_dns_cache.h" #include "tensorflow/core/platform/cloud/gcs_throttle.h" #include "tensorflow/core/platform/cloud/http_request.h" -#include "tensorflow/core/platform/cloud/retrying_file_system.h" #include "tensorflow/core/platform/file_system.h" +#include "tensorflow/core/platform/retrying_file_system.h" #include "tensorflow/core/platform/status.h" namespace tensorflow { diff --git a/tensorflow/core/platform/cloud/google_auth_provider.cc b/tensorflow/core/platform/cloud/google_auth_provider.cc index 2c74a3e2330..e8546ca022f 100644 --- a/tensorflow/core/platform/cloud/google_auth_provider.cc +++ b/tensorflow/core/platform/cloud/google_auth_provider.cc @@ -26,10 +26,10 @@ limitations under the License. #include "absl/strings/match.h" #include "include/json/json.h" #include "tensorflow/core/platform/base64.h" -#include "tensorflow/core/platform/cloud/retrying_utils.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/path.h" +#include "tensorflow/core/platform/retrying_utils.h" namespace tensorflow { diff --git a/tensorflow/core/platform/default/BUILD b/tensorflow/core/platform/default/BUILD index 07a057718cb..7a96f134a9d 100644 --- a/tensorflow/core/platform/default/BUILD +++ b/tensorflow/core/platform/default/BUILD @@ -193,17 +193,6 @@ cc_library( ], ) -cc_library( - name = "monitoring", - srcs = ["monitoring.cc"], - hdrs = ["//tensorflow/core/platform:monitoring.h"], - tags = [ - "manual", - "no_oss", - "nobuilder", - ], -) - cc_library( name = "mutex", srcs = [ @@ -529,7 +518,6 @@ filegroup( srcs = [ "casts.h", "cord.h", - "monitoring.cc", "mutex.h", "mutex_data.h", "notification.h", diff --git a/tensorflow/core/platform/default/build_config.bzl b/tensorflow/core/platform/default/build_config.bzl index 4b9b5a896d4..dbb6ffdbc6e 100644 --- a/tensorflow/core/platform/default/build_config.bzl +++ b/tensorflow/core/platform/default/build_config.bzl @@ -569,9 +569,6 @@ def tf_additional_lib_hdrs(): ], }) -def tf_additional_monitoring_hdrs(): - return [] - def tf_additional_env_hdrs(): return [] @@ -756,9 +753,6 @@ def tf_platform_alias(name): def tf_logging_deps(): return ["//tensorflow/core/platform/default:logging"] -def tf_monitoring_deps(): - return ["//tensorflow/core/platform/default:monitoring"] - def tf_resource_deps(): return ["//tensorflow/core/platform/default:resource"] diff --git a/tensorflow/core/platform/default/monitoring.cc b/tensorflow/core/platform/default/monitoring.cc deleted file mode 100644 index 71ece3e3c14..00000000000 --- a/tensorflow/core/platform/default/monitoring.cc +++ /dev/null @@ -1,26 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/platform/monitoring.h" - -namespace tensorflow { -namespace monitoring { - -void StartExporter() {} - -void ExportMetrics() {} - -} // namespace monitoring -} // namespace tensorflow diff --git a/tensorflow/core/platform/default/port.cc b/tensorflow/core/platform/default/port.cc index 47f4abae3bb..756e7e8a93a 100644 --- a/tensorflow/core/platform/default/port.cc +++ b/tensorflow/core/platform/default/port.cc @@ -332,6 +332,16 @@ bool Snappy_Uncompress(const char* input, size_t length, char* output) { #endif } +bool Snappy_UncompressToIOVec(const char* compressed, size_t compressed_length, + const struct iovec* iov, size_t iov_cnt) { +#ifdef TF_USE_SNAPPY + return snappy::RawUncompressToIOVec(compressed, compressed_length, iov, + iov_cnt); +#else + return false; +#endif +} + string Demangle(const char* mangled) { return mangled; } double NominalCPUFrequency() { diff --git a/tensorflow/core/platform/monitoring.h b/tensorflow/core/platform/monitoring.h deleted file mode 100644 index f01233933c3..00000000000 --- a/tensorflow/core/platform/monitoring.h +++ /dev/null @@ -1,38 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_PLATFORM_MONITORING_H_ -#define TENSORFLOW_CORE_PLATFORM_MONITORING_H_ - -namespace tensorflow { -namespace monitoring { - -// Starts exporting metrics through a platform-specific monitoring API (if -// provided). For builds using "tensorflow/core/platform/default", this is -// currently a no-op. This function is idempotent. -// -// The TensorFlow runtime will call this the first time a new session is created -// using the NewSession() method or an Eager Context is created. -void StartExporter(); - -// Manually invokes a one time metrics export through a platform-specific -// monitoring API (if provided). For builds using -// "tensorflow/core/platform/default", this is currently a no-op. -void ExportMetrics(); - -} // namespace monitoring -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_PLATFORM_MONITORING_H_ diff --git a/tensorflow/core/platform/cloud/retrying_file_system.h b/tensorflow/core/platform/retrying_file_system.h similarity index 99% rename from tensorflow/core/platform/cloud/retrying_file_system.h rename to tensorflow/core/platform/retrying_file_system.h index 12bbc7d6abb..df8850ace93 100644 --- a/tensorflow/core/platform/cloud/retrying_file_system.h +++ b/tensorflow/core/platform/retrying_file_system.h @@ -21,10 +21,10 @@ limitations under the License. #include <vector> #include "tensorflow/core/lib/random/random.h" -#include "tensorflow/core/platform/cloud/retrying_utils.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/file_system.h" +#include "tensorflow/core/platform/retrying_utils.h" #include "tensorflow/core/platform/status.h" namespace tensorflow { diff --git a/tensorflow/core/platform/cloud/retrying_file_system_test.cc b/tensorflow/core/platform/retrying_file_system_test.cc similarity index 99% rename from tensorflow/core/platform/cloud/retrying_file_system_test.cc rename to tensorflow/core/platform/retrying_file_system_test.cc index b48831ab238..b43c3375265 100644 --- a/tensorflow/core/platform/cloud/retrying_file_system_test.cc +++ b/tensorflow/core/platform/retrying_file_system_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/platform/cloud/retrying_file_system.h" +#include "tensorflow/core/platform/retrying_file_system.h" #include <fstream> diff --git a/tensorflow/core/platform/cloud/retrying_utils.cc b/tensorflow/core/platform/retrying_utils.cc similarity index 98% rename from tensorflow/core/platform/cloud/retrying_utils.cc rename to tensorflow/core/platform/retrying_utils.cc index 1f0c41824bf..1b6fa68c31c 100644 --- a/tensorflow/core/platform/cloud/retrying_utils.cc +++ b/tensorflow/core/platform/retrying_utils.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/platform/cloud/retrying_utils.h" +#include "tensorflow/core/platform/retrying_utils.h" #include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/platform/env.h" diff --git a/tensorflow/core/platform/cloud/retrying_utils.h b/tensorflow/core/platform/retrying_utils.h similarity index 100% rename from tensorflow/core/platform/cloud/retrying_utils.h rename to tensorflow/core/platform/retrying_utils.h diff --git a/tensorflow/core/platform/cloud/retrying_utils_test.cc b/tensorflow/core/platform/retrying_utils_test.cc similarity index 98% rename from tensorflow/core/platform/cloud/retrying_utils_test.cc rename to tensorflow/core/platform/retrying_utils_test.cc index 7a2dbacacc8..5b162571067 100644 --- a/tensorflow/core/platform/cloud/retrying_utils_test.cc +++ b/tensorflow/core/platform/retrying_utils_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/platform/cloud/retrying_utils.h" +#include "tensorflow/core/platform/retrying_utils.h" #include <fstream> diff --git a/tensorflow/core/platform/s3/BUILD b/tensorflow/core/platform/s3/BUILD index a5494d5c318..d174b108279 100644 --- a/tensorflow/core/platform/s3/BUILD +++ b/tensorflow/core/platform/s3/BUILD @@ -33,6 +33,8 @@ tf_cc_binary( linkshared = 1, deps = [ "//tensorflow/core:framework_headers_lib", + "//tensorflow/core/platform:retrying_file_system", + "//tensorflow/core/platform:retrying_utils", "@aws", "@com_google_protobuf//:protobuf_headers", "@curl", diff --git a/tensorflow/core/platform/s3/s3_file_system.cc b/tensorflow/core/platform/s3/s3_file_system.cc index acdacc306d5..3be563d2c94 100644 --- a/tensorflow/core/platform/s3/s3_file_system.cc +++ b/tensorflow/core/platform/s3/s3_file_system.cc @@ -196,6 +196,23 @@ Status ParseS3Path(const string& fname, bool empty_object_ok, string* bucket, return Status::OK(); } +static Status CheckForbiddenError( + const Aws::Client::AWSError<Aws::S3::S3Errors>& error) { + if (error.GetResponseCode() == Aws::Http::HttpResponseCode::FORBIDDEN) { + return errors::FailedPrecondition( + "AWS Credentials have not been set properly. " + "Unable to access the specified S3 location"); + } else { + return Status::OK(); + } +} + +static Status CreateStatusFromAwsError( + const Aws::Client::AWSError<Aws::S3::S3Errors>& error) { + TF_RETURN_IF_ERROR(CheckForbiddenError(error)); + return errors::Unknown(error.GetExceptionName(), ": ", error.GetMessage()); +} + class S3RandomAccessFile : public RandomAccessFile { public: S3RandomAccessFile(const string& bucket, const string& object, @@ -217,9 +234,14 @@ class S3RandomAccessFile : public RandomAccessFile { }); auto getObjectOutcome = this->s3_client_->GetObject(getObjectRequest); if (!getObjectOutcome.IsSuccess()) { - n = 0; - *result = StringPiece(scratch, n); - return Status(error::OUT_OF_RANGE, "Read less bytes than requested"); + auto error = getObjectOutcome.GetError(); + if (error.GetResponseCode() == + Aws::Http::HttpResponseCode::REQUESTED_RANGE_NOT_SATISFIABLE) { + n = 0; + *result = StringPiece(scratch, n); + return Status(error::OUT_OF_RANGE, "Read less bytes than requested"); + } + return CreateStatusFromAwsError(error); } n = getObjectOutcome.GetResult().GetContentLength(); getObjectOutcome.GetResult().GetBody().read(scratch, n); @@ -305,15 +327,10 @@ class S3WritableFile : public WritableFile { if (handle->GetStatus() != Aws::Transfer::TransferStatus::COMPLETED) { auto error = handle->GetLastError(); - if (error.GetResponseCode() == Aws::Http::HttpResponseCode::FORBIDDEN) { - return errors::FailedPrecondition( - "AWS Credentials have not been set properly. " - "Unable to access the specified S3 location"); - } else { - return errors::Unknown( - error.GetExceptionName(), ": ", handle->GetFailedParts().size(), - " failed parts. ", handle->GetLastError().GetMessage()); - } + TF_RETURN_IF_ERROR(CheckForbiddenError(error)); + return errors::Unknown(error.GetExceptionName(), ": ", + handle->GetFailedParts().size(), " failed parts. ", + handle->GetLastError().GetMessage()); } outfile_->clear(); outfile_->seekp(offset); @@ -514,8 +531,7 @@ Status S3FileSystem::GetChildren(const string& dir, auto listObjectsOutcome = this->GetS3Client()->ListObjects(listObjectsRequest); if (!listObjectsOutcome.IsSuccess()) { - return errors::Unknown(listObjectsOutcome.GetError().GetExceptionName(), - ": ", listObjectsOutcome.GetError().GetMessage()); + return CreateStatusFromAwsError(listObjectsOutcome.GetError()); } listObjectsResult = listObjectsOutcome.GetResult(); @@ -549,8 +565,7 @@ Status S3FileSystem::Stat(const string& fname, FileStatistics* stats) { headBucketRequest.WithBucket(bucket.c_str()); auto headBucketOutcome = this->GetS3Client()->HeadBucket(headBucketRequest); if (!headBucketOutcome.IsSuccess()) { - return errors::Unknown(headBucketOutcome.GetError().GetExceptionName(), - ": ", headBucketOutcome.GetError().GetMessage()); + return CreateStatusFromAwsError(headBucketOutcome.GetError()); } stats->length = 0; stats->is_directory = 1; @@ -570,6 +585,8 @@ Status S3FileSystem::Stat(const string& fname, FileStatistics* stats) { stats->mtime_nsec = headObjectOutcome.GetResult().GetLastModified().Millis() * 1e6; found = true; + } else { + TF_RETURN_IF_ERROR(CheckForbiddenError(headObjectOutcome.GetError())); } string prefix = object; if (prefix.back() != '/') { @@ -584,11 +601,15 @@ Status S3FileSystem::Stat(const string& fname, FileStatistics* stats) { auto listObjectsOutcome = this->GetS3Client()->ListObjects(listObjectsRequest); if (listObjectsOutcome.IsSuccess()) { - if (listObjectsOutcome.GetResult().GetContents().size() > 0) { + auto listObjects = listObjectsOutcome.GetResult().GetContents(); + if (listObjects.size() > 0) { stats->length = 0; stats->is_directory = 1; + stats->mtime_nsec = listObjects[0].GetLastModified().Millis() * 1e6; found = true; } + } else { + TF_RETURN_IF_ERROR(CheckForbiddenError(listObjectsOutcome.GetError())); } if (!found) { return errors::NotFound("Object ", fname, " does not exist"); @@ -611,8 +632,7 @@ Status S3FileSystem::DeleteFile(const string& fname) { auto deleteObjectOutcome = this->GetS3Client()->DeleteObject(deleteObjectRequest); if (!deleteObjectOutcome.IsSuccess()) { - return errors::Unknown(deleteObjectOutcome.GetError().GetExceptionName(), - ": ", deleteObjectOutcome.GetError().GetMessage()); + return CreateStatusFromAwsError(deleteObjectOutcome.GetError()); } return Status::OK(); } @@ -626,6 +646,7 @@ Status S3FileSystem::CreateDir(const string& dirname) { headBucketRequest.WithBucket(bucket.c_str()); auto headBucketOutcome = this->GetS3Client()->HeadBucket(headBucketRequest); if (!headBucketOutcome.IsSuccess()) { + TF_RETURN_IF_ERROR(CheckForbiddenError(headBucketOutcome.GetError())); return errors::NotFound("The bucket ", bucket, " was not found."); } return Status::OK(); @@ -634,9 +655,11 @@ Status S3FileSystem::CreateDir(const string& dirname) { if (filename.back() != '/') { filename.push_back('/'); } - std::unique_ptr<WritableFile> file; - TF_RETURN_IF_ERROR(NewWritableFile(filename, &file)); - TF_RETURN_IF_ERROR(file->Close()); + if (!this->FileExists(filename).ok()) { + std::unique_ptr<WritableFile> file; + TF_RETURN_IF_ERROR(NewWritableFile(filename, &file)); + TF_RETURN_IF_ERROR(file->Close()); + } return Status::OK(); } @@ -660,7 +683,10 @@ Status S3FileSystem::DeleteDir(const string& dirname) { auto contents = listObjectsOutcome.GetResult().GetContents(); if (contents.size() > 1 || (contents.size() == 1 && contents[0].GetKey() != prefix.c_str())) { - return errors::FailedPrecondition("Cannot delete a non-empty directory."); + return errors::Unknown( + "Cannot delete a non-empty directory. " + "This operation will be retried in case this " + "is due to S3's eventual consistency."); } if (contents.size() == 1 && contents[0].GetKey() == prefix.c_str()) { string filename = dirname; @@ -669,6 +695,8 @@ Status S3FileSystem::DeleteDir(const string& dirname) { } return DeleteFile(filename); } + } else { + TF_RETURN_IF_ERROR(CheckForbiddenError(listObjectsOutcome.GetError())); } return Status::OK(); } @@ -762,8 +790,7 @@ Status S3FileSystem::SimpleCopy(const Aws::String& source, copyObjectRequest.SetCopySource(source); auto copyObjectOutcome = this->GetS3Client()->CopyObject(copyObjectRequest); if (!copyObjectOutcome.IsSuccess()) { - return errors::Unknown(copyObjectOutcome.GetError().GetExceptionName(), - ": ", copyObjectOutcome.GetError().GetMessage()); + return CreateStatusFromAwsError(copyObjectOutcome.GetError()); } return Status::OK(); } @@ -782,9 +809,7 @@ Status S3FileSystem::MultiPartCopy(const Aws::String& source, auto multipartUploadOutcome = this->GetS3Client()->CreateMultipartUpload(multipartUploadRequest); if (!multipartUploadOutcome.IsSuccess()) { - return errors::Unknown(multipartUploadOutcome.GetError().GetExceptionName(), - ": ", - multipartUploadOutcome.GetError().GetMessage()); + return CreateStatusFromAwsError(multipartUploadOutcome.GetError()); } Aws::String uploadID = multipartUploadOutcome.GetResult().GetUploadId(); @@ -916,8 +941,7 @@ Status S3FileSystem::AbortMultiPartCopy(Aws::String target_bucket, .WithUploadId(uploadID); auto abortOutcome = this->GetS3Client()->AbortMultipartUpload(abortRequest); if (!abortOutcome.IsSuccess()) { - return errors::Unknown(abortOutcome.GetError().GetExceptionName(), ": ", - abortOutcome.GetError().GetMessage()); + return CreateStatusFromAwsError(abortOutcome.GetError()); } return Status::OK(); } @@ -933,8 +957,7 @@ Status S3FileSystem::CompleteMultiPartCopy( auto completeOutcome = this->GetS3Client()->CompleteMultipartUpload(completeRequest); if (!completeOutcome.IsSuccess()) { - return errors::Unknown(completeOutcome.GetError().GetExceptionName(), ": ", - completeOutcome.GetError().GetMessage()); + return CreateStatusFromAwsError(completeOutcome.GetError()); } return Status::OK(); } @@ -969,8 +992,7 @@ Status S3FileSystem::RenameFile(const string& src, const string& target) { auto listObjectsOutcome = this->GetS3Client()->ListObjects(listObjectsRequest); if (!listObjectsOutcome.IsSuccess()) { - return errors::Unknown(listObjectsOutcome.GetError().GetExceptionName(), - ": ", listObjectsOutcome.GetError().GetMessage()); + return CreateStatusFromAwsError(listObjectsOutcome.GetError()); } listObjectsResult = listObjectsOutcome.GetResult(); @@ -989,9 +1011,7 @@ Status S3FileSystem::RenameFile(const string& src, const string& target) { auto deleteObjectOutcome = this->GetS3Client()->DeleteObject(deleteObjectRequest); if (!deleteObjectOutcome.IsSuccess()) { - return errors::Unknown( - deleteObjectOutcome.GetError().GetExceptionName(), ": ", - deleteObjectOutcome.GetError().GetMessage()); + return CreateStatusFromAwsError(deleteObjectOutcome.GetError()); } } listObjectsRequest.SetMarker(listObjectsResult.GetNextMarker()); @@ -1000,6 +1020,6 @@ Status S3FileSystem::RenameFile(const string& src, const string& target) { return Status::OK(); } -REGISTER_FILE_SYSTEM("s3", S3FileSystem); +REGISTER_FILE_SYSTEM("s3", RetryingS3FileSystem); } // namespace tensorflow diff --git a/tensorflow/core/platform/s3/s3_file_system.h b/tensorflow/core/platform/s3/s3_file_system.h index 8d4dfd6b6bd..7ea01b24cf9 100644 --- a/tensorflow/core/platform/s3/s3_file_system.h +++ b/tensorflow/core/platform/s3/s3_file_system.h @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/retrying_file_system.h" namespace tensorflow { @@ -133,6 +134,17 @@ class S3FileSystem : public FileSystem { uint64 multi_part_copy_part_size_; }; +/// S3 implementation of a file system with retry on failures. +class RetryingS3FileSystem : public RetryingFileSystem<S3FileSystem> { + public: + RetryingS3FileSystem() + : RetryingFileSystem( + std::unique_ptr<S3FileSystem>(new S3FileSystem), + RetryConfig(100000 /* init_delay_time_us */, + 32000000 /* max_delay_time_us */, 10 /* max_retries */ + )) {} +}; + } // namespace tensorflow #endif // TENSORFLOW_CONTRIB_S3_S3_FILE_SYSTEM_H_ diff --git a/tensorflow/core/platform/snappy.h b/tensorflow/core/platform/snappy.h index 5477b097ef0..df06f3dcc1e 100644 --- a/tensorflow/core/platform/snappy.h +++ b/tensorflow/core/platform/snappy.h @@ -18,6 +18,17 @@ limitations under the License. #include "tensorflow/core/platform/types.h" +#if !defined(PLATFORM_WINDOWS) +#include <sys/uio.h> +#else +namespace tensorflow { +struct iovec { + void* iov_base; + size_t iov_len; +}; +} // namespace tensorflow +#endif + namespace tensorflow { namespace port { @@ -28,6 +39,9 @@ bool Snappy_GetUncompressedLength(const char* input, size_t length, size_t* result); bool Snappy_Uncompress(const char* input, size_t length, char* output); +bool Snappy_UncompressToIOVec(const char* compressed, size_t compressed_length, + const struct iovec* iov, size_t iov_cnt); + } // namespace port } // namespace tensorflow diff --git a/tensorflow/core/platform/windows/port.cc b/tensorflow/core/platform/windows/port.cc index 2303b587ce6..547af76bdf6 100644 --- a/tensorflow/core/platform/windows/port.cc +++ b/tensorflow/core/platform/windows/port.cc @@ -157,6 +157,17 @@ bool Snappy_Uncompress(const char* input, size_t length, char* output) { #endif } +bool Snappy_UncompressToIOVec(const char* compressed, size_t compressed_length, + const struct iovec* iov, size_t iov_cnt) { +#ifdef TF_USE_SNAPPY + const snappy::iovec* snappy_iov = reinterpret_cast<const snappy::iovec*>(iov); + return snappy::RawUncompressToIOVec(compressed, compressed_length, snappy_iov, + iov_cnt); +#else + return false; +#endif +} + string Demangle(const char* mangled) { return mangled; } double NominalCPUFrequency() { diff --git a/tensorflow/core/profiler/rpc/client/capture_profile.cc b/tensorflow/core/profiler/rpc/client/capture_profile.cc index 87a71bb9ff2..e344bf6e1ba 100644 --- a/tensorflow/core/profiler/rpc/client/capture_profile.cc +++ b/tensorflow/core/profiler/rpc/client/capture_profile.cc @@ -60,6 +60,7 @@ ProfileRequest PopulateProfileRequest(int duration_ms, } request.add_tools("op_profile"); request.add_tools("input_pipeline"); + request.add_tools("kernel_stats"); request.add_tools("memory_viewer"); request.add_tools("overview_page"); request.add_tools("pod_viewer"); diff --git a/tensorflow/core/protobuf/data/experimental/snapshot.proto b/tensorflow/core/protobuf/data/experimental/snapshot.proto index 422602d3760..e013deb2ee1 100644 --- a/tensorflow/core/protobuf/data/experimental/snapshot.proto +++ b/tensorflow/core/protobuf/data/experimental/snapshot.proto @@ -3,6 +3,8 @@ syntax = "proto3"; package tensorflow.data.experimental; import "tensorflow/core/framework/tensor.proto"; +import "tensorflow/core/framework/tensor_shape.proto"; +import "tensorflow/core/framework/types.proto"; // Each SnapshotRecord represents one batch of pre-processed input data. A batch // consists of a list of tensors that we encode as TensorProtos. This message @@ -13,9 +15,29 @@ message SnapshotRecord { // This stores the metadata information present in each snapshot record. message SnapshotMetadataRecord { + // Stores the fingerprint of the graph that describes the dataset that is + // snapshotted. string graph_hash = 1; + // Run ID that this snapshot corresponds to. string run_id = 2; + // Time when we started creating this snapshot. int64 creation_timestamp = 3; + // Version of the snapshot data file format. + int64 version = 4; + // A list of tensor dtype corresponding to each element of the snapshot. + repeated .tensorflow.DataType dtype = 5; bool finalized = 1000; } + +// Metadata for a single tensor in the Snapshot Record. +message TensorMetadata { + .tensorflow.TensorShapeProto tensor_shape = 2; + // Number of uncompressed bytes used to store the tensor representation. + int64 tensor_size_bytes = 3; +} + +// Metadata for all the tensors in a Snapshot Record. +message SnapshotTensorMetadata { + repeated TensorMetadata tensor_metadata = 1; +} diff --git a/tensorflow/core/util/mkl_types.h b/tensorflow/core/util/mkl_types.h index cdf313d585e..8edbddac8a1 100644 --- a/tensorflow/core/util/mkl_types.h +++ b/tensorflow/core/util/mkl_types.h @@ -48,6 +48,7 @@ namespace tensorflow { #define GET_WORKSPACE_DESC_FROM_OP_PD(op_pd) op_pd->workspace_desc() #define GET_TENSOR_FORMAT(fmt) MklTensorFormatToMklDnnDataFormat(fmt) #define GET_TF_DATA_FORMAT(shape, mem_desc) shape.GetTfDataFormat() +#define GET_USR_MEM_PRIM_DESC(src) src.GetUsrMemDesc() #define GET_WEIGHTS_DESC_FROM_OP_PD(op_pd) op_pd->weights_desc() #define GET_WEIGHTS_FORMAT_FROM_OP_PD(op_pd, op) \ GET_WEIGHTS_DESC_FROM_OP_PD(op_pd) @@ -114,14 +115,13 @@ namespace tensorflow { #define TENSOR_FORMAT MKL_TENSOR_FORMAT #define TENSOR_FORMAT_NHWC MKL_TENSOR_FORMAT_NHWC #define TENSOR_MAX_DIMS MKLDNN_MAX_NDIMS -#define GET_USR_MEM_PRIM_DESC(src) src.GetUsrMemDesc() -#define BN_FLAGS mkldnn::normalization_flags #else #define ADD_MD add_pd #define ALGORITHM mkldnn #define ALGORITHM_UNDEF ALGORITHM::algorithm_undef +#define BN_FLAGS mkldnn #define CPU_STREAM(engine) stream(stream::kind::eager_nostore) #define DATA_WITH_ENGINE(data, engine) data #define DST_MD dst_pd @@ -148,6 +148,7 @@ namespace tensorflow { op_pd.get()->workspace_primitive_desc() #define GET_TENSOR_FORMAT(fmt) fmt #define GET_TF_DATA_FORMAT(shape, mem_desc) mem_desc.data.format +#define GET_USR_MEM_PRIM_DESC(src) src.GetUsrMemPrimDesc() #define GET_WEIGHTS_DESC_FROM_OP_PD(op_pd) op_pd.get()->weights_primitive_desc() #define GET_WEIGHTS_FORMAT_FROM_OP_PD(op_pd, op) op->GetFilterMemoryFormat() #define IS_DIFF_DST_REORDER_NEEDED(diff_dst_md, op_pd, op) \ @@ -215,8 +216,6 @@ namespace tensorflow { #define SUMMAND_MD summand_pd #define TENSOR_FORMAT TensorFormat #define TENSOR_FORMAT_NHWC FORMAT_NHWC -#define GET_USR_MEM_PRIM_DESC(src) src.GetUsrMemPrimDesc() -#define BN_FLAGS mkldnn #endif // ENABLE_MKLDNN_V1 } // namespace tensorflow diff --git a/tensorflow/core/util/tensor_bundle/tensor_bundle.cc b/tensorflow/core/util/tensor_bundle/tensor_bundle.cc index 3ebb2ea5dc8..339c28dfb66 100644 --- a/tensorflow/core/util/tensor_bundle/tensor_bundle.cc +++ b/tensorflow/core/util/tensor_bundle/tensor_bundle.cc @@ -40,6 +40,7 @@ limitations under the License. #include "tensorflow/core/lib/io/table_builder.h" #include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/util/env_var.h" #include "tensorflow/core/util/saved_tensor_slice_util.h" #include "tensorflow/core/util/tensor_bundle/byte_swap.h" #include "tensorflow/core/util/tensor_slice_util.h" @@ -729,6 +730,7 @@ BundleReader::BundleReader(Env* env, StringPiece prefix) prefix_(prefix), metadata_(nullptr), table_(nullptr), + index_cache_(nullptr), iter_(nullptr), need_to_swap_bytes_(false) { const string filename = MetaFilename(prefix_); @@ -741,7 +743,17 @@ BundleReader::BundleReader(Env* env, StringPiece prefix) status_ = env_->NewRandomAccessFile(filename, &wrapper); if (!status_.ok()) return; metadata_ = wrapper.release(); - status_ = table::Table::Open(table::Options(), metadata_, file_size, &table_); + + table::Options o; + int64 cache_size; + Status s = + ReadInt64FromEnvVar("TF_TABLE_INDEX_CACHE_SIZE_IN_MB", 0, &cache_size); + if (s.ok() && cache_size > 0) { + index_cache_ = table::NewLRUCache(cache_size << 20); + o.block_cache = index_cache_; + } + + status_ = table::Table::Open(o, metadata_, file_size, &table_); if (!status_.ok()) return; iter_ = table_->NewIterator(); @@ -772,6 +784,9 @@ BundleReader::~BundleReader() { delete metadata_; delete iter_; delete table_; + if (index_cache_) { + delete index_cache_; + } // InputBuffer does not own the underlying RandomAccessFile. for (auto pair : data_) { if (pair.second != nullptr && pair.second->file() != nullptr) { diff --git a/tensorflow/core/util/tensor_bundle/tensor_bundle.h b/tensorflow/core/util/tensor_bundle/tensor_bundle.h index e1f39eccd17..c362dd41151 100644 --- a/tensorflow/core/util/tensor_bundle/tensor_bundle.h +++ b/tensorflow/core/util/tensor_bundle/tensor_bundle.h @@ -61,8 +61,6 @@ limitations under the License. #ifndef TENSORFLOW_CORE_UTIL_TENSOR_BUNDLE_TENSOR_BUNDLE_H_ #define TENSORFLOW_CORE_UTIL_TENSOR_BUNDLE_TENSOR_BUNDLE_H_ -#include "tensorflow/core/protobuf/tensor_bundle.pb.h" - #include <map> #include <string> #include <unordered_map> @@ -72,12 +70,14 @@ limitations under the License. #include "tensorflow/core/framework/tensor_slice.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/io/cache.h" #include "tensorflow/core/lib/io/inputbuffer.h" #include "tensorflow/core/lib/io/table.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/file_system.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/core/protobuf/tensor_bundle.pb.h" #include "tensorflow/core/util/tensor_bundle/naming.h" #include "tensorflow/core/util/tensor_slice_set.h" @@ -288,6 +288,7 @@ class BundleReader { Status status_; RandomAccessFile* metadata_; // Owned. table::Table* table_; + table::Cache* index_cache_; table::Iterator* iter_; // Owned the InputBuffer objects and their underlying RandomAccessFile's. std::unordered_map<int32, io::InputBuffer*> data_; diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index 449a95765a5..ecdce1e627b 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -11611,7 +11611,7 @@ func DepthwiseConv2dNativeBackpropFilterDataFormat(value string) DepthwiseConv2d // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeBackpropFilterDilations(value []int64) DepthwiseConv2dNativeBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -11868,7 +11868,7 @@ func SampleDistortedBoundingBoxV2Seed2(value int64) SampleDistortedBoundingBoxV2 // // value: The cropped area of the image must have an aspect ratio = // width / height within this range. -// If not specified, defaults to {f:0.75 f:1.33} +// If not specified, defaults to {f:0.75 f:1.33} func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistortedBoundingBoxV2Attr { return func(m optionalAttr) { m["aspect_ratio_range"] = value @@ -11879,7 +11879,7 @@ func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistort // // value: The cropped area of the image must contain a fraction of the // supplied image within this range. -// If not specified, defaults to {f:0.05 f:1} +// If not specified, defaults to {f:0.05 f:1} func SampleDistortedBoundingBoxV2AreaRange(value []float32) SampleDistortedBoundingBoxV2Attr { return func(m optionalAttr) { m["area_range"] = value @@ -12085,7 +12085,7 @@ func SampleDistortedBoundingBoxMinObjectCovered(value float32) SampleDistortedBo // // value: The cropped area of the image must have an aspect ratio = // width / height within this range. -// If not specified, defaults to {f:0.75 f:1.33} +// If not specified, defaults to {f:0.75 f:1.33} func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { m["aspect_ratio_range"] = value @@ -12096,7 +12096,7 @@ func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistorted // // value: The cropped area of the image must contain a fraction of the // supplied image within this range. -// If not specified, defaults to {f:0.05 f:1} +// If not specified, defaults to {f:0.05 f:1} func SampleDistortedBoundingBoxAreaRange(value []float32) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { m["area_range"] = value @@ -18937,7 +18937,7 @@ func ImageSummaryMaxImages(value int64) ImageSummaryAttr { // ImageSummaryBadColor sets the optional bad_color attribute to value. // // value: Color to use for pixels with non-finite values. -// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255} +// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255} func ImageSummaryBadColor(value tf.Tensor) ImageSummaryAttr { return func(m optionalAttr) { m["bad_color"] = value @@ -20077,7 +20077,7 @@ func Conv3DBackpropFilterV2DataFormat(value string) Conv3DBackpropFilterV2Attr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropFilterV2Dilations(value []int64) Conv3DBackpropFilterV2Attr { return func(m optionalAttr) { m["dilations"] = value @@ -21345,7 +21345,7 @@ func Conv2DBackpropInputDataFormat(value string) Conv2DBackpropInputAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DBackpropInputDilations(value []int64) Conv2DBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22053,7 +22053,7 @@ func Conv2DDataFormat(value string) Conv2DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DDilations(value []int64) Conv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22249,7 +22249,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeOutType(value tf.DataTy // QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22318,7 +22318,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluOutType(value tf.DataType) Quantized // QuantizedDepthwiseConv2DWithBiasAndReluDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasAndReluDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22433,7 +22433,7 @@ func QuantizedDepthwiseConv2DWithBiasOutType(value tf.DataType) QuantizedDepthwi // QuantizedDepthwiseConv2DWithBiasDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DWithBiasDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22492,7 +22492,7 @@ func QuantizedDepthwiseConv2DOutType(value tf.DataType) QuantizedDepthwiseConv2D // QuantizedDepthwiseConv2DDilations sets the optional dilations attribute to value. // // value: List of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedDepthwiseConv2DDilations(value []int64) QuantizedDepthwiseConv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22666,7 +22666,7 @@ func QuantizedConv2DPerChannelOutType(value tf.DataType) QuantizedConv2DPerChann // QuantizedConv2DPerChannelDilations sets the optional dilations attribute to value. // // value: list of dilation values. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedConv2DPerChannelDilations(value []int64) QuantizedConv2DPerChannelAttr { return func(m optionalAttr) { m["dilations"] = value @@ -22857,7 +22857,7 @@ func Conv3DBackpropInputV2DataFormat(value string) Conv3DBackpropInputV2Attr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropInputV2Dilations(value []int64) Conv3DBackpropInputV2Attr { return func(m optionalAttr) { m["dilations"] = value @@ -25297,7 +25297,7 @@ func AvgPool3DGrad(scope *Scope, orig_input_shape tf.Output, grad tf.Output, ksi type Conv3DBackpropFilterAttr func(optionalAttr) // Conv3DBackpropFilterDilations sets the optional dilations attribute to value. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropFilterDilations(value []int64) Conv3DBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value @@ -25629,7 +25629,7 @@ func DepthwiseConv2dNativeBackpropInputDataFormat(value string) DepthwiseConv2dN // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeBackpropInputDilations(value []int64) DepthwiseConv2dNativeBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -25679,7 +25679,7 @@ func DepthwiseConv2dNativeBackpropInput(scope *Scope, input_sizes tf.Output, fil type Conv3DBackpropInputAttr func(optionalAttr) // Conv3DBackpropInputDilations sets the optional dilations attribute to value. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DBackpropInputDilations(value []int64) Conv3DBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value @@ -25929,7 +25929,7 @@ func DepthwiseConv2dNativeDataFormat(value string) DepthwiseConv2dNativeAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func DepthwiseConv2dNativeDilations(value []int64) DepthwiseConv2dNativeAttr { return func(m optionalAttr) { m["dilations"] = value @@ -26559,7 +26559,7 @@ func QuantizedConv2DOutType(value tf.DataType) QuantizedConv2DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func QuantizedConv2DDilations(value []int64) QuantizedConv2DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -27624,7 +27624,7 @@ func Conv3DDataFormat(value string) Conv3DAttr { // filter element on that dimension. The dimension order is determined by the // value of `data_format`, see above for details. Dilations in the batch and // depth dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1} func Conv3DDilations(value []int64) Conv3DAttr { return func(m optionalAttr) { m["dilations"] = value @@ -45536,7 +45536,7 @@ func Conv2DBackpropFilterDataFormat(value string) Conv2DBackpropFilterAttr { // element on that dimension. The dimension order is determined by the value of // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. -// If not specified, defaults to {i:1 i:1 i:1 i:1} +// If not specified, defaults to {i:1 i:1 i:1 i:1} func Conv2DBackpropFilterDilations(value []int64) Conv2DBackpropFilterAttr { return func(m optionalAttr) { m["dilations"] = value diff --git a/tensorflow/lite/delegates/gpu/cl/BUILD b/tensorflow/lite/delegates/gpu/cl/BUILD index d09e75e052d..2aeff13b3be 100644 --- a/tensorflow/lite/delegates/gpu/cl/BUILD +++ b/tensorflow/lite/delegates/gpu/cl/BUILD @@ -301,6 +301,7 @@ cc_library( ":model_hints", ":opencl_wrapper", ":precision", + ":storage_type_util", ":tensor_type", "//tensorflow/lite/delegates/gpu/cl/kernels:gpu_operation", "//tensorflow/lite/delegates/gpu/cl/selectors:operation_selector", @@ -387,6 +388,19 @@ cc_library( ], ) +cc_library( + name = "storage_type_util", + srcs = ["storage_type_util.cc"], + hdrs = ["storage_type_util.h"], + deps = [ + ":cl_context", + ":cl_device", + ":tensor_type", + "//tensorflow/lite/delegates/gpu/common:data_type", + "//tensorflow/lite/delegates/gpu/common:shape", + ], +) + cc_library( name = "tensor", srcs = ["tensor.cc"], diff --git a/tensorflow/lite/delegates/gpu/cl/inference_context.cc b/tensorflow/lite/delegates/gpu/cl/inference_context.cc index a2a66cae0c9..6b0511fb267 100644 --- a/tensorflow/lite/delegates/gpu/cl/inference_context.cc +++ b/tensorflow/lite/delegates/gpu/cl/inference_context.cc @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/cl/model_hints.h" #include "tensorflow/lite/delegates/gpu/cl/precision.h" #include "tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.h" +#include "tensorflow/lite/delegates/gpu/cl/storage_type_util.h" #include "tensorflow/lite/delegates/gpu/cl/tensor_type.h" #include "tensorflow/lite/delegates/gpu/common/data_type.h" #include "tensorflow/lite/delegates/gpu/common/memory_management.h" @@ -109,64 +110,6 @@ void AddUsage(ValueId id, int task_index, } } -TensorStorageType SelectBestStorageType(const CLContext& context, - const CLDevice& device, - const BHWC& shape, - const TensorStorageType& desired, - const DataType& data_type, - const Layout& layout) { - if (CanCreateTensorWithShape(context, device, shape, - TensorDescriptor{data_type, desired, layout})) { - return desired; - } - auto GetBestTypeAfterTextureArray = [&]() { - if (device.SupportsImageBuffer() && - CanCreateTensorWithShape( - context, device, shape, - TensorDescriptor{data_type, TensorStorageType::IMAGE_BUFFER, - layout})) { - return TensorStorageType::IMAGE_BUFFER; - } else { - return TensorStorageType::BUFFER; - } - }; - auto GetBestTypeAfterTexture2D = [&]() { - if (device.SupportsTextureArray() && - CanCreateTensorWithShape( - context, device, shape, - TensorDescriptor{data_type, TensorStorageType::TEXTURE_ARRAY, - layout})) { - return TensorStorageType::TEXTURE_ARRAY; - } else { - return GetBestTypeAfterTextureArray(); - } - }; - auto GetBestTypeAfterTexture3D = [&]() { - if (CanCreateTensorWithShape( - context, device, shape, - TensorDescriptor{data_type, TensorStorageType::TEXTURE_2D, - layout})) { - return TensorStorageType::TEXTURE_2D; - } else { - return GetBestTypeAfterTexture2D(); - } - }; - switch (desired) { - case TensorStorageType::TEXTURE_2D: - case TensorStorageType::SINGLE_TEXTURE_2D: - return GetBestTypeAfterTexture2D(); - case TensorStorageType::TEXTURE_ARRAY: - return GetBestTypeAfterTextureArray(); - case TensorStorageType::TEXTURE_3D: - return GetBestTypeAfterTexture3D(); - case TensorStorageType::IMAGE_BUFFER: - case TensorStorageType::BUFFER: - return TensorStorageType::BUFFER; - default: - return TensorStorageType::BUFFER; - } -} - // returns true if actual memory for this storage type will be allocated with // clCreateBuffer. bool IsBufferBased(const TensorStorageType& type) { diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/conv_powervr.cc b/tensorflow/lite/delegates/gpu/cl/kernels/conv_powervr.cc index aa071910658..78e9795bc63 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/conv_powervr.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/conv_powervr.cc @@ -69,6 +69,63 @@ std::string GenerateAsyncUpload(const std::string& local_ptr_name, offset + ", " + std::to_string(elements_to_upload) + ", 0);\n"; return c; } + +std::string GenerateBlockCoords(const int3& block_size, + const int3& work_group_launch_order, + bool linear_hw) { + std::string c; + int3 launch_remap; + launch_remap[work_group_launch_order.x] = 0; + launch_remap[work_group_launch_order.y] = 1; + launch_remap[work_group_launch_order.z] = 2; + if (linear_hw) { + if (work_group_launch_order[0] == 0) { + c += " int linear_hw = get_global_id(0);\n"; + } else { + c += " int linear_hw = get_group_id(" + std::to_string(launch_remap[0]) + + ") * get_local_size(0) + get_local_id(0);\n"; + } + c += " int Y = (linear_hw / task_size_x) * " + + std::to_string(block_size.y) + ";\n"; + c += " int X = (linear_hw % task_size_x) * " + + std::to_string(block_size.x) + ";\n"; + if (work_group_launch_order[1] == 1) { + c += " int Z = get_global_id(1) * " + std::to_string(block_size.z) + + ";\n"; + } else { + c += " int Z = (get_group_id(" + std::to_string(launch_remap[1]) + + ") * get_local_size(1) + get_local_id(1)) * " + + std::to_string(block_size.z) + ";\n"; + } + } else { + if (work_group_launch_order[0] == 0) { + c += " int X = get_global_id(0) * " + std::to_string(block_size.x) + + ";\n"; + } else { + c += " int X = (get_group_id(" + std::to_string(launch_remap[0]) + + ") * get_local_size(0) + get_local_id(0)) * " + + std::to_string(block_size.x) + ";\n"; + } + if (work_group_launch_order[1] == 1) { + c += " int Y = get_global_id(1) * " + std::to_string(block_size.y) + + ";\n"; + } else { + c += " int Y = (get_group_id(" + std::to_string(launch_remap[1]) + + ") * get_local_size(1) + get_local_id(1)) * " + + std::to_string(block_size.y) + ";\n"; + } + if (work_group_launch_order[2] == 2) { + c += " int Z = get_global_id(2) * " + std::to_string(block_size.z) + + ";\n"; + } else { + c += " int Z = (get_group_id(" + std::to_string(launch_remap[2]) + + ") * get_local_size(2) + get_local_id(2)) * " + + std::to_string(block_size.z) + ";\n"; + } + } + + return c; +} } // namespace ConvPowerVR::ConvPowerVR(const OperationDef& definition, @@ -146,6 +203,11 @@ Status ConvPowerVR::BindArguments() { int4(kernel_dilation_.x, kernel_dilation_.y, kernel_dilation_.z * src_[0]->Batch(), kernel_dilation_.w))); } + if (conv_params_.linear_hw) { + const int grid_x = IntegralDivideRoundUp( + dst_[0]->Width() * dst_[0]->Batch(), conv_params_.block_size.x); + RETURN_IF_ERROR(kernel_.SetBytesAuto(grid_x)); + } RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWBatchedHSB())); RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWBatchedHSB())); return OkStatus(); @@ -159,15 +221,27 @@ int3 ConvPowerVR::GetGridSize() const { const int grid_z = IntegralDivideRoundUp(dst_[0]->Slices(), conv_params_.block_size.z); int3 wg; - wg.x = IntegralDivideRoundUp(grid_x, conv_params_.work_group_size.x); - wg.y = IntegralDivideRoundUp(grid_y, conv_params_.work_group_size.y); - wg.z = IntegralDivideRoundUp(grid_z, conv_params_.work_group_size.z); - return int3(wg[conv_params_.work_group_launch_order[0]] * - conv_params_.work_group_size.x, - wg[conv_params_.work_group_launch_order[1]] * - conv_params_.work_group_size.y, - wg[conv_params_.work_group_launch_order[2]] * - conv_params_.work_group_size.z); + + if (conv_params_.linear_hw) { + wg.x = + IntegralDivideRoundUp(grid_x * grid_y, conv_params_.work_group_size.x); + wg.y = IntegralDivideRoundUp(grid_z, conv_params_.work_group_size.y); + return int3(wg[conv_params_.work_group_launch_order[0]] * + conv_params_.work_group_size.x, + wg[conv_params_.work_group_launch_order[1]] * + conv_params_.work_group_size.y, + 1); + } else { + wg.x = IntegralDivideRoundUp(grid_x, conv_params_.work_group_size.x); + wg.y = IntegralDivideRoundUp(grid_y, conv_params_.work_group_size.y); + wg.z = IntegralDivideRoundUp(grid_z, conv_params_.work_group_size.z); + return int3(wg[conv_params_.work_group_launch_order[0]] * + conv_params_.work_group_size.x, + wg[conv_params_.work_group_launch_order[1]] * + conv_params_.work_group_size.y, + wg[conv_params_.work_group_launch_order[2]] * + conv_params_.work_group_size.z); + } } Status ConvPowerVR::Tune(const TuningParameters& params) { @@ -248,33 +322,22 @@ std::string GenerateConvPowerVR1x1( c += " int4 stride_padding, \n"; c += " int4 kernel_dilation, \n"; } + if (conv_params.linear_hw) { + c += " int task_size_x, \n"; + } c += " int4 src_size, \n"; c += " int4 dst_size \n"; c += ") {\n"; - int3 launch_remap; - launch_remap[conv_params.work_group_launch_order.x] = 0; - launch_remap[conv_params.work_group_launch_order.y] = 1; - launch_remap[conv_params.work_group_launch_order.z] = 2; - if (conv_params.work_group_launch_order[0] == 0) { - c += " int X = get_global_id(0) * " + std::to_string(block_size.x) + ";\n"; - } else { - c += " int X = (get_group_id(" + std::to_string(launch_remap[0]) + - ") * get_local_size(0) + get_local_id(0)) * " + - std::to_string(block_size.x) + ";\n"; + c += GenerateBlockCoords(conv_params.block_size, + conv_params.work_group_launch_order, + conv_params.linear_hw); + std::vector<std::string> dst_x(conv_params.block_size.x); + for (int x = 0; x < conv_params.block_size.x; ++x) { + dst_x[x] = "(X + " + std::to_string(x) + ")"; } - if (conv_params.work_group_launch_order[1] == 1) { - c += " int Y = get_global_id(1) * " + std::to_string(block_size.y) + ";\n"; - } else { - c += " int Y = (get_group_id(" + std::to_string(launch_remap[1]) + - ") * get_local_size(1) + get_local_id(1)) * " + - std::to_string(block_size.y) + ";\n"; - } - if (conv_params.work_group_launch_order[2] == 2) { - c += " int Z = get_global_id(2) * " + std::to_string(block_size.z) + ";\n"; - } else { - c += " int Z = (get_group_id(" + std::to_string(launch_remap[2]) + - ") * get_local_size(2) + get_local_id(2)) * " + - std::to_string(block_size.z) + ";\n"; + std::vector<std::string> dst_y(conv_params.block_size.y); + for (int y = 0; y < conv_params.block_size.y; ++y) { + dst_y[y] = "(Y + " + std::to_string(y) + ")"; } if (!need_local_mem) { c += " if (X >= dst_size.x || Y >= dst_size.y || Z >= dst_size.z) {\n"; @@ -283,8 +346,12 @@ std::string GenerateConvPowerVR1x1( } if (conv_params.weights_upload_type == ConvPowerVR::WeightsUploadType::LOCAL_MEM_BY_THREADS) { - c += " int lid = get_local_id(1) * " + std::to_string(work_group_size.x) + - " + get_local_id(0);\n"; + if (conv_params.linear_hw) { + c += " int lid = get_local_id(0);\n"; + } else { + c += " int lid = get_local_id(1) * " + + std::to_string(work_group_size.x) + " + get_local_id(0);\n"; + } } for (int z = 0; z < block_size.z; ++z) { for (int y = 0; y < block_size.y; ++y) { @@ -296,20 +363,18 @@ std::string GenerateConvPowerVR1x1( } if (!is1x1) { for (int x = 0; x < block_size.x; ++x) { - const std::string xc = "(X + " + std::to_string(x) + ")"; if (stride_correction) { c += " int xc" + std::to_string(x) + " = " + - GetXStrideCorrected(xc, "src_size.w", "stride_padding.x", + GetXStrideCorrected(dst_x[x], "src_size.w", "stride_padding.x", "stride_padding.z") + ";\n"; } else { - c += " int xc" + std::to_string(x) + " = " + xc + + c += " int xc" + std::to_string(x) + " = " + dst_x[x] + " * stride_padding.x + stride_padding.z;\n"; } } for (int y = 0; y < block_size.y; ++y) { - const std::string yc = "(Y + " + std::to_string(y) + ")"; - c += " int yc" + std::to_string(y) + " = " + yc + + c += " int yc" + std::to_string(y) + " = " + dst_y[y] + " * stride_padding.y + stride_padding.w;\n"; } } @@ -373,10 +438,8 @@ std::string GenerateConvPowerVR1x1( const std::string yck = "yck" + std::to_string(y); for (int x = 0; x < block_size.x; ++x) { const std::string xck = "xck" + std::to_string(x); - std::string xc = - is1x1 ? "min(X + " + std::to_string(x) + ", src_size.x - 1)" : xck; - std::string yc = - is1x1 ? "min(Y + " + std::to_string(y) + ", src_size.y - 1)" : yck; + std::string xc = is1x1 ? "min(" + dst_x[x] + ", src_size.x - 1)" : xck; + std::string yc = is1x1 ? "min(" + dst_y[y] + ", src_size.y - 1)" : yck; std::string id = std::to_string(y) + std::to_string(x); c += " int src_a_" + id + " = " + yc + " * src_size.x + " + xc + ";\n"; } @@ -408,10 +471,8 @@ std::string GenerateConvPowerVR1x1( c += " src_a_" + id + " += src_layer_offset;\n"; } else { std::string id = std::to_string(y) + std::to_string(x); - const std::string xc = - is1x1 ? "X + " + std::to_string(x) : "xck" + std::to_string(x); - const std::string yc = - is1x1 ? "Y + " + std::to_string(y) : "yck" + std::to_string(y); + const std::string xc = is1x1 ? dst_x[x] : "xck" + std::to_string(x); + const std::string yc = is1x1 ? dst_y[y] : "yck" + std::to_string(y); c += " src" + id + " = " + src_tensor.ReadAsTypeWHS(conv_params.weights_data_type, xc, yc, "s", mode) + @@ -522,8 +583,8 @@ std::string GenerateConvPowerVR1x1( c += " FLT4 bias_val = TO_FLT4(weights_cache[" + sz + "]);\n"; for (int y = 0; y < block_size.y; ++y) { for (int x = 0; x < block_size.x; ++x) { - const std::string xs = "X + " + std::to_string(x); - const std::string ys = "Y + " + std::to_string(y); + const std::string xs = dst_x[x]; + const std::string ys = dst_y[y]; const std::string zs = "Z + " + sz; const std::string r_id = sz + std::to_string(y) + std::to_string(x); bool need_x_check = x != 0; @@ -552,18 +613,27 @@ std::string GenerateConvPowerVR1x1( ConvPowerVR::ConvParams ConvPowerVR::GuessBestParams( const CLDevice& device, const OperationDef& definition, int src_depth, - int dst_depth, bool x_kernel_is_1, bool y_kernel_is_1) const { + int dst_depth, bool x_kernel_is_1, bool y_kernel_is_1, + bool different_weights_for_height) const { ConvParams conv_params; + conv_params.linear_hw = false; conv_params.weights_data_type = DeduceDataTypeFromPrecision(definition.precision); conv_params.x_kernel_is_1 = x_kernel_is_1; conv_params.y_kernel_is_1 = y_kernel_is_1; - conv_params.different_weights_for_height = false; + conv_params.different_weights_for_height = different_weights_for_height; if (device.IsNvidia()) { + if (different_weights_for_height) { + conv_params.work_group_size = int3(32, 1, 1); + conv_params.work_group_launch_order = int3(2, 0, 1); + conv_params.fixed_work_group_size = true; + } else { + conv_params.linear_hw = true; + conv_params.work_group_size = int3(32, 1, 1); + conv_params.work_group_launch_order = int3(1, 0, 2); + conv_params.fixed_work_group_size = true; + } conv_params.block_size = int3(1, 1, 4); - conv_params.work_group_size = int3(8, 4, 1); - conv_params.work_group_launch_order = int3(2, 0, 1); - conv_params.fixed_work_group_size = true; conv_params.src_depth_loop_size = 1; conv_params.weights_upload_type = WeightsUploadType::LOCAL_MEM_BY_THREADS; if (dst_depth % 4 == 0 || dst_depth >= 8) { @@ -580,13 +650,20 @@ ConvPowerVR::ConvParams ConvPowerVR::GuessBestParams( conv_params.src_depth_loop_size = 4; } } else if (device.IsPowerVR()) { + if (different_weights_for_height) { + conv_params.work_group_size = int3(32, 1, 1); + conv_params.work_group_launch_order = int3(2, 0, 1); + conv_params.fixed_work_group_size = true; + } else { + conv_params.linear_hw = true; + conv_params.work_group_size = int3(32, 1, 1); + conv_params.work_group_launch_order = int3(1, 0, 2); + conv_params.fixed_work_group_size = true; + } conv_params.weights_data_type = definition.precision == CalculationsPrecision::F16 ? DataType::FLOAT16 : DataType::FLOAT32; conv_params.block_size = int3(1, 1, 4); - conv_params.work_group_size = int3(8, 4, 1); - conv_params.work_group_launch_order = int3(2, 0, 1); - conv_params.fixed_work_group_size = true; conv_params.src_depth_loop_size = 1; conv_params.weights_upload_type = WeightsUploadType::LOCAL_MEM_ASYNC_SUBGROUP; @@ -619,16 +696,22 @@ ConvPowerVR::ConvParams ConvPowerVR::GuessBestParams( } } conv_params.block_size.x = 2; - conv_params.work_group_size = int3(4, 8, 1); } } else if (device.IsAMD()) { + if (different_weights_for_height) { + conv_params.work_group_size = int3(32, 1, 1); + conv_params.work_group_launch_order = int3(2, 0, 1); + conv_params.fixed_work_group_size = true; + } else { + conv_params.work_group_size = int3(8, 4, 1); + conv_params.work_group_launch_order = int3(2, 0, 1); + conv_params.fixed_work_group_size = true; + } + conv_params.block_size = int3(2, 1, 1); if (x_kernel_is_1 && y_kernel_is_1) { conv_params.block_size.y = 2; } - conv_params.work_group_size = int3(8, 4, 1); - conv_params.work_group_launch_order = int3(2, 0, 1); - conv_params.fixed_work_group_size = true; conv_params.src_depth_loop_size = 1; conv_params.weights_upload_type = WeightsUploadType::CONSTANT_MEM; if (dst_depth % 8 == 0 || dst_depth >= 32) { @@ -694,7 +777,7 @@ ConvPowerVR::ConvParams ConvPowerVR::GuessBestParams( attr.padding.prepended.h == 0 && attr.padding.appended.h == 0; return GuessBestParams(device, definition, src_depth, dst_depth, - x_kernel_is_1, y_kernel_is_1); + x_kernel_is_1, y_kernel_is_1, false); } ConvPowerVR::ConvParams ConvPowerVR::GuessBestParams( @@ -702,8 +785,8 @@ ConvPowerVR::ConvParams ConvPowerVR::GuessBestParams( const FullyConnectedAttributes& attr) const { const int dst_depth = IntegralDivideRoundUp(attr.weights.shape.o, 4); const int src_depth = IntegralDivideRoundUp(attr.weights.shape.i, 4); - ConvPowerVR::ConvParams params = - GuessBestParams(device, definition, src_depth, dst_depth, true, true); + ConvPowerVR::ConvParams params = GuessBestParams( + device, definition, src_depth, dst_depth, true, true, false); params.work_group_size.x *= params.work_group_size.y; params.work_group_size.y = 1; params.block_size.x *= params.block_size.y; @@ -716,13 +799,10 @@ ConvPowerVR::ConvParams ConvPowerVR::GuessBestParamsWinograd( const Convolution2DAttributes& attr) const { const int dst_depth = IntegralDivideRoundUp(attr.weights.shape.o, 4); const int src_depth = IntegralDivideRoundUp(attr.weights.shape.i, 4); - ConvPowerVR::ConvParams params = - GuessBestParams(device, definition, src_depth, dst_depth, true, true); - params.work_group_size.x *= params.work_group_size.y; - params.work_group_size.y = 1; + ConvPowerVR::ConvParams params = GuessBestParams( + device, definition, src_depth, dst_depth, true, true, true); params.block_size.x *= params.block_size.y; params.block_size.y = 1; - params.different_weights_for_height = true; return params; } diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/conv_powervr.h b/tensorflow/lite/delegates/gpu/cl/kernels/conv_powervr.h index 0832fcf91f0..110b983940a 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/conv_powervr.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/conv_powervr.h @@ -70,6 +70,7 @@ class ConvPowerVR : public GPUOperation { int3 work_group_size; int3 work_group_launch_order; bool fixed_work_group_size; + bool linear_hw; bool different_weights_for_height; int src_depth_loop_size; WeightsUploadType weights_upload_type; @@ -127,7 +128,8 @@ class ConvPowerVR : public GPUOperation { ConvParams GuessBestParams(const CLDevice& device, const OperationDef& definition, int src_depth, int dst_depth, bool x_kernel_is_1, - bool y_kernel_is_1) const; + bool y_kernel_is_1, + bool different_weights_for_height) const; Status BindArguments(); int3 GetGridSize() const; diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.cc b/tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.cc index a09236f77fc..7a2e54840b9 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.cc @@ -248,13 +248,8 @@ Status GetBestWorkGroup(const TuningParameters& params, const CLKernel& kernel, const int3& grid, int3* best_work_group) { switch (params.tuning_type) { case TuningType::FAST: - if (params.info->vendor != Vendor::QUALCOMM) { - *best_work_group = int3(8, 4, 1); - return OkStatus(); - } else { - *best_work_group = GetWorkGroup(grid, kernel.GetMaxWorkGroupSize()); - return OkStatus(); - } + *best_work_group = GetWorkGroup(grid, kernel.GetMaxWorkGroupSize()); + return OkStatus(); case TuningType::EXHAUSTIVE: return GetBestWorkGroupAlignedToGrid(params, kernel, grid, best_work_group); @@ -268,16 +263,16 @@ Status GetBestWorkGroupConv(const TuningParameters& params, const CLKernel& kernel, const int3& grid, int3* best_work_group) { switch (params.tuning_type) { - case TuningType::FAST: - if (params.info->vendor != Vendor::QUALCOMM) { - *best_work_group = int3(8, 4, 1); - return OkStatus(); - } else { - int max_z_size = params.info->adreno_info.gpu_version < 400 ? 16 : 64; - *best_work_group = - GetWorkGroupConv(grid, kernel.GetMaxWorkGroupSize(), max_z_size); - return OkStatus(); + case TuningType::FAST: { + int max_z_size = 16; + if (params.info->vendor == Vendor::QUALCOMM) { + max_z_size = params.info->adreno_info.gpu_version < 400 ? 16 : 64; } + max_z_size = std::min(max_z_size, params.info->max_work_group_sizes.z); + *best_work_group = + GetWorkGroupConv(grid, kernel.GetMaxWorkGroupSize(), max_z_size); + return OkStatus(); + } case TuningType::EXHAUSTIVE: return GetBestWorkGroupAlignedToGrid(params, kernel, grid, best_work_group); diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/BUILD b/tensorflow/lite/delegates/gpu/cl/selectors/BUILD index 695077ba8a8..908c1b91583 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/BUILD +++ b/tensorflow/lite/delegates/gpu/cl/selectors/BUILD @@ -83,7 +83,9 @@ cc_library( ":dw_convolution_selector", ":fully_connected_selector", ":simple_selectors", + "//tensorflow/lite/delegates/gpu/cl:cl_device", "//tensorflow/lite/delegates/gpu/cl:model_hints", + "//tensorflow/lite/delegates/gpu/cl:storage_type_util", "//tensorflow/lite/delegates/gpu/cl:tensor_type", "//tensorflow/lite/delegates/gpu/cl/kernels:elementwise", "//tensorflow/lite/delegates/gpu/cl/kernels:gpu_operation", @@ -119,8 +121,10 @@ cc_library( "//tensorflow/lite/delegates/gpu/cl/kernels:resize", "//tensorflow/lite/delegates/gpu/cl/kernels:softmax", "//tensorflow/lite/delegates/gpu/cl/kernels:softmax1x1", + "//tensorflow/lite/delegates/gpu/cl/kernels:space_to_depth", "//tensorflow/lite/delegates/gpu/cl/kernels:strided_slice", "//tensorflow/lite/delegates/gpu/cl/kernels:transpose", + "//tensorflow/lite/delegates/gpu/cl/kernels:winograd", "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:shape", "//tensorflow/lite/delegates/gpu/common:status", diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/convolution_selector.cc b/tensorflow/lite/delegates/gpu/cl/selectors/convolution_selector.cc index 9e7876cee80..0103ca08b90 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/convolution_selector.cc +++ b/tensorflow/lite/delegates/gpu/cl/selectors/convolution_selector.cc @@ -47,6 +47,20 @@ Status SelectConvolutionAdreno(const Convolution2DAttributes& attr, return OkStatus(); } +Status SelectConvolutionWinogradAdreno(const Convolution2DAttributes& attr, + const BHWC& dst_shape, + const CreationContext& creation_context, + const OperationDef& op_def, + ModelHints hints, + std::unique_ptr<GPUOperation>* ptr) { + ConvTexture conv; + RETURN_IF_ERROR( + CreateConvTextureWino4x4To6x6(creation_context, op_def, attr, &conv)); + *ptr = absl::make_unique<ConvTexture>(std::move(conv)); + + return OkStatus(); +} + Status SelectConvolutionNVidia(const Convolution2DAttributes& attr, const CreationContext& creation_context, const OperationDef& op_def, @@ -89,6 +103,25 @@ Status SelectConvolutionMali(const Convolution2DAttributes& attr, } return OkStatus(); } + +Status SelectConvolutionWinogradMali(const Convolution2DAttributes& attr, + const CreationContext& creation_context, + const OperationDef& op_def, + std::unique_ptr<GPUOperation>* ptr) { + if (op_def.src_tensors[0].storage_type == TensorStorageType::BUFFER) { + ConvBuffer1x1 conv; + RETURN_IF_ERROR( + CreateConvBuffer1x1Wino4x4To6x6(creation_context, op_def, attr, &conv)); + *ptr = absl::make_unique<ConvBuffer1x1>(std::move(conv)); + } else { + ConvPowerVR conv; + RETURN_IF_ERROR( + CreateConvPowerVRWino4x4To6x6(creation_context, op_def, attr, &conv)); + *ptr = absl::make_unique<ConvPowerVR>(std::move(conv)); + } + + return OkStatus(); +} } // namespace Status SelectConvolution(const Convolution2DAttributes& attr, @@ -113,6 +146,33 @@ Status SelectConvolution(const Convolution2DAttributes& attr, } } +Status SelectConvolutionForWinograd(const Convolution2DAttributes& attr, + const BHWC& dst_shape, + const CreationContext& creation_context, + const OperationDef& op_def, + ModelHints hints, + std::unique_ptr<GPUOperation>* ptr) { + switch (creation_context.device->vendor()) { + case Vendor::QUALCOMM: + return SelectConvolutionWinogradAdreno(attr, dst_shape, creation_context, + op_def, hints, ptr); + case Vendor::POWERVR: + case Vendor::AMD: + case Vendor::NVIDIA: { + ConvPowerVR conv; + RETURN_IF_ERROR( + CreateConvPowerVRWino4x4To6x6(creation_context, op_def, attr, &conv)); + *ptr = absl::make_unique<ConvPowerVR>(std::move(conv)); + return OkStatus(); + } + case Vendor::MALI: + return SelectConvolutionWinogradMali(attr, creation_context, op_def, ptr); + default: + return SelectConvolutionWinogradAdreno(attr, dst_shape, creation_context, + op_def, hints, ptr); + } +} + } // namespace cl } // namespace gpu } // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/convolution_selector.h b/tensorflow/lite/delegates/gpu/cl/selectors/convolution_selector.h index 7dd6c79eea0..dc0657ec47c 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/convolution_selector.h +++ b/tensorflow/lite/delegates/gpu/cl/selectors/convolution_selector.h @@ -34,6 +34,13 @@ Status SelectConvolution(const Convolution2DAttributes& attr, const OperationDef& op_def, ModelHints hints, std::unique_ptr<GPUOperation>* ptr); +Status SelectConvolutionForWinograd(const Convolution2DAttributes& attr, + const BHWC& dst_shape, + const CreationContext& creation_context, + const OperationDef& op_def, + ModelHints hints, + std::unique_ptr<GPUOperation>* ptr); + } // namespace cl } // namespace gpu } // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.cc b/tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.cc index 527c94b1fe2..29c246a2744 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.cc +++ b/tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.cc @@ -17,12 +17,14 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/types/any.h" +#include "tensorflow/lite/delegates/gpu/cl/cl_device.h" #include "tensorflow/lite/delegates/gpu/cl/kernels/elementwise.h" #include "tensorflow/lite/delegates/gpu/cl/selectors/convolution_selector.h" #include "tensorflow/lite/delegates/gpu/cl/selectors/convolution_transposed_selector.h" #include "tensorflow/lite/delegates/gpu/cl/selectors/dw_convolution_selector.h" #include "tensorflow/lite/delegates/gpu/cl/selectors/fully_connected_selector.h" #include "tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.h" +#include "tensorflow/lite/delegates/gpu/cl/storage_type_util.h" #include "tensorflow/lite/delegates/gpu/common/data_type.h" #include "tensorflow/lite/delegates/gpu/common/operations.h" #include "tensorflow/lite/delegates/gpu/common/shape.h" @@ -51,6 +53,111 @@ bool IsChannelsBroadcastedForSecondInput( inputs[0]->tensor.shape.c != inputs[1]->tensor.shape.c && inputs[1]->tensor.shape.c == 1; } + +bool IsSuitableForWinograd4x4To6x6(const Convolution2DAttributes& attr, + const CLDevice& device, + const BHWC& dst_shape) { + const int tiles_x = IntegralDivideRoundUp(dst_shape.w, 4); + const int tiles_y = IntegralDivideRoundUp(dst_shape.h, 4); + const int src_depth = IntegralDivideRoundUp(attr.weights.shape.i, 4); + const int dst_depth = IntegralDivideRoundUp(attr.weights.shape.o, 4); + const bool suitable_attributes = + attr.weights.shape.w == 3 && attr.weights.shape.h == 3 && + attr.dilations == HW(1, 1) && attr.strides == HW(1, 1); + const int min_depth = 32; + const bool recommended_channels = + dst_depth % 4 == 0 && src_depth >= min_depth && dst_depth >= min_depth; + const bool recommended_hw = tiles_x * tiles_y >= 128; + return suitable_attributes && recommended_channels && recommended_hw; +} + +Status WinogradFromNode(const CreationContext& creation_context, + const OperationDef& op_def, ModelHints hints, + const BHWC& input_shape, const BHWC& output_shape, + const Convolution2DAttributes& attr, + GPUOperationsSubgraph* gpu_subgraph) { + if (!IsSuitableForWinograd4x4To6x6(attr, *creation_context.device, + output_shape)) { + return UnimplementedError("No implementation for this case."); + } + + const int tiles_x = IntegralDivideRoundUp(output_shape.w, 4); + const int tiles_y = IntegralDivideRoundUp(output_shape.h, 4); + const BHWC shape_0{input_shape.b, 36, tiles_x * tiles_y, input_shape.c}; + const BHWC shape_1{input_shape.b, 36, tiles_x * tiles_y, output_shape.c}; + TensorDescriptor td_0; + td_0.storage_type = SelectBestStorageType( + *creation_context.context, *creation_context.device, shape_0, + op_def.src_tensors[0].storage_type, op_def.src_tensors[0].data_type, + op_def.src_tensors[0].layout); + td_0.data_type = op_def.src_tensors[0].data_type; + td_0.layout = op_def.src_tensors[0].layout; + TensorDescriptor td_1; + td_1.storage_type = SelectBestStorageType( + *creation_context.context, *creation_context.device, shape_1, + op_def.src_tensors[0].storage_type, op_def.src_tensors[0].data_type, + op_def.src_tensors[0].layout); + td_1.data_type = op_def.src_tensors[0].data_type; + td_1.layout = op_def.src_tensors[0].layout; + gpu_subgraph->new_tensors = {{shape_0, td_0}, {shape_1, td_1}}; + gpu_subgraph->operations.clear(); + gpu_subgraph->operations.resize(3); + + OperationDef winograd_up_def; + winograd_up_def.precision = op_def.precision; + winograd_up_def.src_tensors.push_back(op_def.src_tensors[0]); + winograd_up_def.dst_tensors.push_back(td_0); + auto& winograd_up = gpu_subgraph->operations[0]; + RETURN_IF_ERROR(SelectWinograd4x4To36( + creation_context, attr.padding, winograd_up_def, &winograd_up.operation)); + winograd_up.input_ids = {0}; + winograd_up.output_ids = {-1}; + + OperationDef conv_def; + conv_def.precision = op_def.precision; + conv_def.src_tensors.push_back(td_0); + conv_def.dst_tensors.push_back(td_1); + auto& conv = gpu_subgraph->operations[1]; + conv.input_ids = {-1}; + conv.output_ids = {-2}; + RETURN_IF_ERROR(SelectConvolutionForWinograd( + attr, input_shape, creation_context, conv_def, hints, &conv.operation)); + + OperationDef winograd_down_def; + winograd_down_def.precision = op_def.precision; + winograd_down_def.src_tensors.push_back(td_1); + winograd_down_def.dst_tensors.push_back(op_def.dst_tensors[0]); + auto& winograd_down = gpu_subgraph->operations[2]; + winograd_down.input_ids = {-2}; + winograd_down.output_ids = {0}; + auto bias_copy = attr.bias; + if (bias_copy.shape.v < attr.weights.shape.o) { + bias_copy.shape = Linear(attr.weights.shape.o); + bias_copy.data.resize(attr.weights.shape.o); + } + RETURN_IF_ERROR(SelectWinograd36To4x4(creation_context, winograd_down_def, + bias_copy, &winograd_down.operation)); + + return OkStatus(); +} + +std::unique_ptr<GPUOperation>* InitSingleOpSubgraph( + const std::vector<Value<TensorRef<BHWC>>*>& inputs, + const std::vector<Value<TensorRef<BHWC>>*>& outputs, + GPUOperationsSubgraph* gpu_subgraph) { + gpu_subgraph->operations.clear(); + gpu_subgraph->new_tensors.clear(); + gpu_subgraph->operations.push_back({}); + for (int i = 0; i < inputs.size(); ++i) { + gpu_subgraph->operations[0].input_ids.push_back(i); + } + for (int i = 0; i < outputs.size(); ++i) { + gpu_subgraph->operations[0].output_ids.push_back(i); + } + + return &gpu_subgraph->operations[0].operation; +} + } // namespace Status GPUOperationFromNode(const CreationContext& creation_context, @@ -59,15 +166,8 @@ Status GPUOperationFromNode(const CreationContext& creation_context, const std::vector<Value<TensorRef<BHWC>>*>& outputs, const Node& node, GPUOperationsSubgraph* gpu_subgraph) { - gpu_subgraph->operations.push_back({}); std::unique_ptr<GPUOperation>* gpu_op = - &gpu_subgraph->operations[0].operation; - for (int i = 0; i < inputs.size(); ++i) { - gpu_subgraph->operations[0].input_ids.push_back(i); - } - for (int i = 0; i < outputs.size(); ++i) { - gpu_subgraph->operations[0].output_ids.push_back(i); - } + InitSingleOpSubgraph(inputs, outputs, gpu_subgraph); auto op_type = OperationTypeFromString(node.operation.type); switch (op_type) { case OperationType::ADD: { @@ -111,9 +211,17 @@ Status GPUOperationFromNode(const CreationContext& creation_context, case OperationType::CONVOLUTION_2D: { auto attr = absl::any_cast<Convolution2DAttributes>(node.operation.attributes); - auto input = inputs[0]; - return SelectConvolution(attr, input->tensor.shape, creation_context, - op_def, hints, gpu_op); + auto input_shape = inputs[0]->tensor.shape; + auto output_shape = outputs[0]->tensor.shape; + if (WinogradFromNode(creation_context, op_def, hints, input_shape, + output_shape, attr, gpu_subgraph) + .ok()) { + return OkStatus(); + } else { + gpu_op = InitSingleOpSubgraph(inputs, outputs, gpu_subgraph); + return SelectConvolution(attr, input_shape, creation_context, op_def, + hints, gpu_op); + } } case OperationType::CONVOLUTION_TRANSPOSED: { auto attr = absl::any_cast<ConvolutionTransposedAttributes>( diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.cc b/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.cc index 92ef85b1779..22244351bd7 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.cc +++ b/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.cc @@ -38,6 +38,7 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/cl/kernels/space_to_depth.h" #include "tensorflow/lite/delegates/gpu/cl/kernels/strided_slice.h" #include "tensorflow/lite/delegates/gpu/cl/kernels/transpose.h" +#include "tensorflow/lite/delegates/gpu/cl/kernels/winograd.h" #include "tensorflow/lite/delegates/gpu/common/status.h" namespace tflite { @@ -195,6 +196,28 @@ void SelectTranspose(const TransposeAttributes& attr, *ptr = absl::make_unique<Transpose>(std::move(operation)); } +Status SelectWinograd4x4To36(const CreationContext& creation_context, + const Padding2D& padding, + const OperationDef& op_def, + std::unique_ptr<GPUOperation>* ptr) { + Winograd4x4To36 operation; + RETURN_IF_ERROR( + CreateWinograd4x4To36(creation_context, op_def, padding, &operation)); + *ptr = absl::make_unique<Winograd4x4To36>(std::move(operation)); + return OkStatus(); +} + +Status SelectWinograd36To4x4( + const CreationContext& creation_context, const OperationDef& op_def, + const ::tflite::gpu::Tensor<Linear, DataType::FLOAT32>& biases, + std::unique_ptr<GPUOperation>* ptr) { + Winograd36To4x4 operation; + RETURN_IF_ERROR( + CreateWinograd36To4x4(creation_context, op_def, biases, &operation)); + *ptr = absl::make_unique<Winograd36To4x4>(std::move(operation)); + return OkStatus(); +} + } // namespace cl } // namespace gpu } // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.h b/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.h index 5a830994f75..fd29ebc0e91 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.h +++ b/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.h @@ -90,6 +90,16 @@ void SelectTranspose(const TransposeAttributes& attr, const OperationDef& op_def, std::unique_ptr<GPUOperation>* ptr); +Status SelectWinograd4x4To36(const CreationContext& creation_context, + const Padding2D& padding, + const OperationDef& op_def, + std::unique_ptr<GPUOperation>* ptr); + +Status SelectWinograd36To4x4( + const CreationContext& creation_context, const OperationDef& op_def, + const ::tflite::gpu::Tensor<Linear, DataType::FLOAT32>& biases, + std::unique_ptr<GPUOperation>* ptr); + } // namespace cl } // namespace gpu } // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/cl/storage_type_util.cc b/tensorflow/lite/delegates/gpu/cl/storage_type_util.cc new file mode 100644 index 00000000000..26eb3ad3538 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/cl/storage_type_util.cc @@ -0,0 +1,141 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/cl/storage_type_util.h" + +#include "tensorflow/lite/delegates/gpu/cl/cl_context.h" +#include "tensorflow/lite/delegates/gpu/cl/cl_device.h" +#include "tensorflow/lite/delegates/gpu/cl/tensor_type.h" +#include "tensorflow/lite/delegates/gpu/common/data_type.h" +#include "tensorflow/lite/delegates/gpu/common/shape.h" + +namespace tflite { +namespace gpu { +namespace cl { +bool CanCreateTensorWithShape(const CLContext& context, const CLDevice& device, + const BHWDC& shape, + const TensorDescriptor& descriptor) { + const int slices = IntegralDivideRoundUp(shape.c, 4); + switch (descriptor.storage_type) { + case TensorStorageType::BUFFER: { + const int flt4_size = + 4 * (descriptor.data_type == DataType::FLOAT32 ? 4 : 2); + const int buffer_size = + shape.b * shape.w * shape.h * shape.d * slices * flt4_size; + return buffer_size <= device.GetInfo().buffer_max_size; + } + case TensorStorageType::IMAGE_BUFFER: + return shape.b * shape.w * shape.h * shape.d * slices <= + device.GetInfo().image_buffer_max_size; + case TensorStorageType::TEXTURE_3D: + if (device.cl_version() < OpenCLVersion::CL_1_2 && slices == 1) { + // clCreateImage3D (that used in CL 1.0/1.1) can not create image with + // depth = 1 by specification; + return false; + } + return shape.w * shape.b <= device.GetInfo().image3d_max_width && + shape.h <= device.GetInfo().image3d_max_height && + slices * shape.d <= device.GetInfo().image3d_max_depth; + case TensorStorageType::TEXTURE_ARRAY: + // Bug on some Adreno. b/131099086 + if (slices == 1 && !device.SupportsOneLayerTextureArray()) { + return false; + } + return shape.w * shape.b <= device.GetInfo().image2d_max_width && + shape.h <= device.GetInfo().image2d_max_height && + slices * shape.d <= device.GetInfo().image_array_max_layers; + case TensorStorageType::TEXTURE_2D: + return shape.w * shape.b * shape.d <= + device.GetInfo().image2d_max_width && + shape.h * slices <= device.GetInfo().image2d_max_height; + case TensorStorageType::SINGLE_TEXTURE_2D: + return shape.c <= 4 && + context.IsFloatTexture2DSupported(shape.c, descriptor.data_type) && + shape.w * shape.b * shape.d <= + device.GetInfo().image2d_max_width && + shape.h <= device.GetInfo().image2d_max_height; + default: + return false; + } +} + +bool CanCreateTensorWithShape(const CLContext& context, const CLDevice& device, + const BHWC& shape, + const TensorDescriptor& descriptor) { + const BHWDC shape5D(shape.b, shape.h, shape.w, 1, shape.c); + return CanCreateTensorWithShape(context, device, shape5D, descriptor); +} + +TensorStorageType SelectBestStorageType(const CLContext& context, + const CLDevice& device, + const BHWC& shape, + const TensorStorageType& desired, + const DataType& data_type, + const Layout& layout) { + if (CanCreateTensorWithShape(context, device, shape, + TensorDescriptor{data_type, desired, layout})) { + return desired; + } + auto GetBestTypeAfterTextureArray = [&]() { + if (device.SupportsImageBuffer() && + CanCreateTensorWithShape( + context, device, shape, + TensorDescriptor{data_type, TensorStorageType::IMAGE_BUFFER, + layout})) { + return TensorStorageType::IMAGE_BUFFER; + } else { + return TensorStorageType::BUFFER; + } + }; + auto GetBestTypeAfterTexture2D = [&]() { + if (device.SupportsTextureArray() && + CanCreateTensorWithShape( + context, device, shape, + TensorDescriptor{data_type, TensorStorageType::TEXTURE_ARRAY, + layout})) { + return TensorStorageType::TEXTURE_ARRAY; + } else { + return GetBestTypeAfterTextureArray(); + } + }; + auto GetBestTypeAfterTexture3D = [&]() { + if (CanCreateTensorWithShape( + context, device, shape, + TensorDescriptor{data_type, TensorStorageType::TEXTURE_2D, + layout})) { + return TensorStorageType::TEXTURE_2D; + } else { + return GetBestTypeAfterTexture2D(); + } + }; + switch (desired) { + case TensorStorageType::TEXTURE_2D: + case TensorStorageType::SINGLE_TEXTURE_2D: + return GetBestTypeAfterTexture2D(); + case TensorStorageType::TEXTURE_ARRAY: + return GetBestTypeAfterTextureArray(); + case TensorStorageType::TEXTURE_3D: + return GetBestTypeAfterTexture3D(); + case TensorStorageType::IMAGE_BUFFER: + case TensorStorageType::BUFFER: + return TensorStorageType::BUFFER; + default: + return TensorStorageType::BUFFER; + } +} + +} // namespace cl +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/cl/storage_type_util.h b/tensorflow/lite/delegates/gpu/cl/storage_type_util.h new file mode 100644 index 00000000000..87fc2206e81 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/cl/storage_type_util.h @@ -0,0 +1,48 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_STORAGE_TYPE_UTIL_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_CL_STORAGE_TYPE_UTIL_H_ + +#include "tensorflow/lite/delegates/gpu/cl/cl_context.h" +#include "tensorflow/lite/delegates/gpu/cl/cl_device.h" +#include "tensorflow/lite/delegates/gpu/cl/tensor_type.h" +#include "tensorflow/lite/delegates/gpu/common/data_type.h" +#include "tensorflow/lite/delegates/gpu/common/shape.h" + +namespace tflite { +namespace gpu { +namespace cl { + +bool CanCreateTensorWithShape(const CLContext& context, const CLDevice& device, + const BHWDC& shape, + const TensorDescriptor& descriptor); + +bool CanCreateTensorWithShape(const CLContext& context, const CLDevice& device, + const BHWC& shape, + const TensorDescriptor& descriptor); + +TensorStorageType SelectBestStorageType(const CLContext& context, + const CLDevice& device, + const BHWC& shape, + const TensorStorageType& desired, + const DataType& data_type, + const Layout& layout); + +} // namespace cl +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_CL_STORAGE_TYPE_UTIL_H_ diff --git a/tensorflow/lite/delegates/gpu/cl/tensor.cc b/tensorflow/lite/delegates/gpu/cl/tensor.cc index 8423613440e..610ba407eb9 100644 --- a/tensorflow/lite/delegates/gpu/cl/tensor.cc +++ b/tensorflow/lite/delegates/gpu/cl/tensor.cc @@ -331,60 +331,6 @@ Status Tensor::ReadData(CLCommandQueue* queue, Tensor5DFloat32* dst) const { return ReadDataBHWDC(absl::MakeSpan(dst->data), queue); } -bool CanCreateTensorWithShape(const CLContext& context, const CLDevice& device, - const BHWC& shape, - const TensorDescriptor& descriptor) { - const BHWDC shape5D(shape.b, shape.h, shape.w, 1, shape.c); - return CanCreateTensorWithShape(context, device, shape5D, descriptor); -} - -bool CanCreateTensorWithShape(const CLContext& context, const CLDevice& device, - const BHWDC& shape, - const TensorDescriptor& descriptor) { - const int slices = IntegralDivideRoundUp(shape.c, 4); - switch (descriptor.storage_type) { - case TensorStorageType::BUFFER: { - const int flt4_size = - 4 * (descriptor.data_type == DataType::FLOAT32 ? 4 : 2); - const int buffer_size = - shape.b * shape.w * shape.h * shape.d * slices * flt4_size; - return buffer_size <= device.GetInfo().buffer_max_size; - } - case TensorStorageType::IMAGE_BUFFER: - return shape.b * shape.w * shape.h * shape.d * slices <= - device.GetInfo().image_buffer_max_size; - case TensorStorageType::TEXTURE_3D: - if (device.cl_version() < OpenCLVersion::CL_1_2 && slices == 1) { - // clCreateImage3D (that used in CL 1.0/1.1) can not create image with - // depth = 1 by specification; - return false; - } - return shape.w * shape.b <= device.GetInfo().image3d_max_width && - shape.h <= device.GetInfo().image3d_max_height && - slices * shape.d <= device.GetInfo().image3d_max_depth; - case TensorStorageType::TEXTURE_ARRAY: - // Bug on some Adreno. b/131099086 - if (slices == 1 && !device.SupportsOneLayerTextureArray()) { - return false; - } - return shape.w * shape.b <= device.GetInfo().image2d_max_width && - shape.h <= device.GetInfo().image2d_max_height && - slices * shape.d <= device.GetInfo().image_array_max_layers; - case TensorStorageType::TEXTURE_2D: - return shape.w * shape.b * shape.d <= - device.GetInfo().image2d_max_width && - shape.h * slices <= device.GetInfo().image2d_max_height; - case TensorStorageType::SINGLE_TEXTURE_2D: - return shape.c <= 4 && - context.IsFloatTexture2DSupported(shape.c, descriptor.data_type) && - shape.w * shape.b * shape.d <= - device.GetInfo().image2d_max_width && - shape.h <= device.GetInfo().image2d_max_height; - default: - return false; - } -} - Status CreateTensor(const CLContext& context, const CLDevice& device, const BHWC& shape, const TensorDescriptor& descriptor, Tensor* result) { diff --git a/tensorflow/lite/delegates/gpu/cl/tensor.h b/tensorflow/lite/delegates/gpu/cl/tensor.h index efc09480a39..34a45436386 100644 --- a/tensorflow/lite/delegates/gpu/cl/tensor.h +++ b/tensorflow/lite/delegates/gpu/cl/tensor.h @@ -145,14 +145,6 @@ class Tensor { using TensorPtr = std::shared_ptr<Tensor>; -bool CanCreateTensorWithShape(const CLContext& context, const CLDevice& device, - const BHWC& shape, - const TensorDescriptor& descriptor); - -bool CanCreateTensorWithShape(const CLContext& context, const CLDevice& device, - const BHWDC& shape, - const TensorDescriptor& descriptor); - Status AllocateTensorMemory(const CLContext& context, const CLDevice& device, const BHWC& shape, const TensorDescriptor& descriptor, diff --git a/tensorflow/lite/experimental/delegates/hexagon/builders/concat_builder.cc b/tensorflow/lite/experimental/delegates/hexagon/builders/concat_builder.cc index 05f30329215..430aae2443a 100644 --- a/tensorflow/lite/experimental/delegates/hexagon/builders/concat_builder.cc +++ b/tensorflow/lite/experimental/delegates/hexagon/builders/concat_builder.cc @@ -41,6 +41,10 @@ TfLiteStatus ConcatOpBuilder::PopulateSubGraph(const TfLiteIntArray* inputs, int tensor_id; // Input data tensors. + // input_bound_minimum & input_bound_maximum track the minimum & maximum + // min/max bounds across all inputs. + float input_bound_minimum = std::numeric_limits<float>::max(); + float input_bound_maximum = std::numeric_limits<float>::min(); input_minima_.reserve(inputs->size); input_maxima_.reserve(inputs->size); for (int i = 0; i < inputs->size; ++i) { @@ -53,6 +57,8 @@ TfLiteStatus ConcatOpBuilder::PopulateSubGraph(const TfLiteIntArray* inputs, std::numeric_limits<uint8_t>::max())); input_minima_.push_back(data_min); input_maxima_.push_back(data_max); + if (data_min < input_bound_minimum) input_bound_minimum = data_min; + if (data_max > input_bound_maximum) input_bound_maximum = data_max; } // Minima tensors. @@ -96,19 +102,27 @@ TfLiteStatus ConcatOpBuilder::PopulateSubGraph(const TfLiteIntArray* inputs, auto* output_max_const = graph_builder_->AddConstNodeWithData( quant_bound_shape, (char*)&output_max_, sizeof(output_max_)); - auto* requantize_op = graph_builder_->AddNode(GetTFLiteNodeID()); - requantize_op->SetOpType(OP_Requantize_8to8); - requantize_op->AddInput(concat_out); - requantize_op->AddInput(concat_out_min); - requantize_op->AddInput(concat_out_max); - requantize_op->AddInput(TensorID(output_min_const->GetID(), 0)); - requantize_op->AddInput(TensorID(output_max_const->GetID(), 0)); - node_output_ = - requantize_op->AddOutput(sizeof(uint8_t), 4, - {output_batch_size, output_height_size, - output_width_size, output_depth_size}); - requantize_op->AddOutput(sizeof(float), 4, {1, 1, 1, 1}); - requantize_op->AddOutput(sizeof(float), 4, {1, 1, 1, 1}); + if (output_min_ == input_bound_minimum && + output_max_ == input_bound_maximum) { + // If the input min/max (across all tensors) is same as the output min/max, + // Hexagon's Requantize causes errors in InceptionV3. + // TODO(b/150137234): Figure out why this is. + node_output_ = concat_out; + } else { + auto* requantize_op = graph_builder_->AddNode(GetTFLiteNodeID()); + requantize_op->SetOpType(OP_Requantize_8to8); + requantize_op->AddInput(concat_out); + requantize_op->AddInput(concat_out_min); + requantize_op->AddInput(concat_out_max); + requantize_op->AddInput(TensorID(output_min_const->GetID(), 0)); + requantize_op->AddInput(TensorID(output_max_const->GetID(), 0)); + node_output_ = + requantize_op->AddOutput(sizeof(uint8_t), 4, + {output_batch_size, output_height_size, + output_width_size, output_depth_size}); + requantize_op->AddOutput(sizeof(float), 4, {1, 1, 1, 1}); + requantize_op->AddOutput(sizeof(float), 4, {1, 1, 1, 1}); + } return kTfLiteOk; } diff --git a/tensorflow/lite/experimental/delegates/hexagon/builders/tests/concat_test.cc b/tensorflow/lite/experimental/delegates/hexagon/builders/tests/concat_test.cc index bc4026d795b..c66ae17a71a 100644 --- a/tensorflow/lite/experimental/delegates/hexagon/builders/tests/concat_test.cc +++ b/tensorflow/lite/experimental/delegates/hexagon/builders/tests/concat_test.cc @@ -12,12 +12,36 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include <random> + #include <gtest/gtest.h> #include "tensorflow/lite/experimental/delegates/hexagon/builders/tests/hexagon_delegate_op_model.h" namespace tflite { using testing::ElementsAreArray; +void GenerateUniformRandomVector(int size, float min, float max, + std::minstd_rand* random_engine, + std::vector<float>* result) { + // Never use std::uniform_*_distribution in tests, it's + // implementation-defined. Likewise, don't use std::default_random_engine, + // implementation-defined. Implementation-defined is bad because it means that + // any toolchain update or new platform may run into test failures. + // std::minstd_rand is a standard instantiation of + // std::linear_congruential_engine, the cheapest generator in c++11 stdlib, + // it's good enough here. + result->resize(size); + for (int i = 0; i < size; i++) { + // We don't care whether the `max` value may ever be produced exactly. + // It may actually be thanks to rounding, as std::minstd_rand::modulus + // is 2^31 - 1 is greater than the inverse float epsilon. + float random_value_scaled_0_1 = + (*random_engine)() * + (1.0f / static_cast<float>(std::minstd_rand::modulus)); + (*result)[i] = min + (max - min) * random_value_scaled_0_1; + } +} + class QuantizedConcatenationOpModel : public SingleOpModelWithHexagon { public: QuantizedConcatenationOpModel(const std::vector<TensorData>& input_template, @@ -37,7 +61,7 @@ class QuantizedConcatenationOpModel : public SingleOpModelWithHexagon { } template <typename T> - void SetInput(int index, std::initializer_list<float> data) { + void SetInput(int index, std::vector<float> data) { QuantizeAndPopulate<T>(index, data); } @@ -95,4 +119,50 @@ TEST(QuantizedConcatenationOpModel, FourInputsQuantizedMixedRange) { /*max_abs_error=*/0.2))); } +// If the input min/max (across all tensors) is same as the output min/max, +// Hexagon's Requantize causes errors in InceptionV3. +// So, we diable it for that case in the builder. +// This unit test ensures that the math still works. +TEST(QuantizedConcatenationOpModel, FourInputsQuantizedMixedRange_LargeData) { + // Problem specification. + // Adapted from CONCAT node at #15 in Inceptionv3 quantized. + std::vector<float> params1 = {0, 11.30514f}; + std::vector<float> params2 = {0, 10.38416f}; + std::vector<float> params3 = {0, 13.52495f}; + std::vector<float> params4 = {0, 5.883808f}; + std::vector<float> params_output = {0, 13.52495f}; + QuantizedConcatenationOpModel m0( + {{TensorType_UINT8, {1, 35, 35, 64}, params1[0], params1[1]}, + {TensorType_UINT8, {1, 35, 35, 64}, params2[0], params2[1]}, + {TensorType_UINT8, {1, 35, 35, 96}, params3[0], params3[1]}, + {TensorType_UINT8, {1, 35, 35, 32}, params4[0], params4[1]}}, + /*axis=*/3, {TensorType_UINT8, {}, params_output[0], params_output[1]}); + + // Generate random data. + std::minstd_rand random_engine; + std::vector<float> data1, data2, data3, data4; + int num_elements_multiplier = 1 * 35 * 35; + GenerateUniformRandomVector(num_elements_multiplier * 64, params1[0], + params1[1], &random_engine, &data1); + GenerateUniformRandomVector(num_elements_multiplier * 64, params2[0], + params2[1], &random_engine, &data2); + GenerateUniformRandomVector(num_elements_multiplier * 96, params3[0], + params3[1], &random_engine, &data3); + GenerateUniformRandomVector(num_elements_multiplier * 32, params4[0], + params4[1], &random_engine, &data4); + m0.SetInput<uint8_t>(0, data1); + m0.SetInput<uint8_t>(1, data2); + m0.SetInput<uint8_t>(2, data3); + m0.SetInput<uint8_t>(3, data4); + + // Reference output. + m0.Invoke(); + std::vector<float> reference_output = m0.GetDequantizedOutput<uint8_t>(); + + m0.ApplyDelegateAndInvoke(); + EXPECT_THAT(m0.GetDequantizedOutput<uint8_t>(), + ElementsAreArray(ArrayFloatNear(reference_output, + /*max_abs_error=*/0.1))); +} + } // namespace tflite diff --git a/tensorflow/lite/experimental/delegates/hexagon/hexagon_delegate.cc b/tensorflow/lite/experimental/delegates/hexagon/hexagon_delegate.cc index cba74b3df4f..b0507693e35 100644 --- a/tensorflow/lite/experimental/delegates/hexagon/hexagon_delegate.cc +++ b/tensorflow/lite/experimental/delegates/hexagon/hexagon_delegate.cc @@ -132,6 +132,30 @@ class HexagonDelegate : public TfLiteDelegate { if (hexagon_nn == nullptr) { return false; } + if (hexagon_nn->hexagon_nn_version != nullptr && + hexagon_nn->hexagon_nn_hexagon_interface_version) { + int hexagon_nn_version = -1; + int hexagon_interface_version = + hexagon_nn->hexagon_nn_hexagon_interface_version(); + if (hexagon_nn->hexagon_nn_version(&hexagon_nn_version) != 0) { + TFLITE_LOG_PROD(tflite::TFLITE_LOG_WARNING, + "Failed to fetch Hexagon NN version. This might be " + "because you're using incompatible versions of " + "libhexagon_interface and libhexagon_nn_skel. " + "You must use compatible versions. " + "Refer to Tensorflow Lite Hexagon Delegate Guide."); + return false; + } + if (hexagon_nn_version != hexagon_interface_version) { + TFLITE_LOG_PROD( + tflite::TFLITE_LOG_WARNING, + "Incompatible versions between interface library and " + "libhexagon_skel %d vs %d. You must use compatible versions. " + "Refer to Tensorflow Lite Hexagon Delegate Guide.", + hexagon_interface_version, hexagon_nn_version); + return false; + } + } return hexagon_nn->hexagon_nn_is_device_supported && hexagon_nn->hexagon_nn_is_device_supported(); } diff --git a/tensorflow/lite/experimental/delegates/hexagon/hexagon_implementation.cc b/tensorflow/lite/experimental/delegates/hexagon/hexagon_implementation.cc index cd86a95c86f..3cf64496e00 100644 --- a/tensorflow/lite/experimental/delegates/hexagon/hexagon_implementation.cc +++ b/tensorflow/lite/experimental/delegates/hexagon/hexagon_implementation.cc @@ -76,6 +76,8 @@ HexagonNN CreateNewHexagonInterface() { LOAD_FUNCTION(libhexagon_interface, hexagon_nn_is_device_supported, hexagon_nn); LOAD_FUNCTION(libhexagon_interface, hexagon_nn_version, hexagon_nn); + LOAD_FUNCTION(libhexagon_interface, hexagon_nn_hexagon_interface_version, + hexagon_nn); hexagon_nn.interface_loaded = successfully_loaded; return hexagon_nn; } diff --git a/tensorflow/lite/experimental/delegates/hexagon/hexagon_implementation.h b/tensorflow/lite/experimental/delegates/hexagon/hexagon_implementation.h index b41f55b0a42..fb81a650205 100644 --- a/tensorflow/lite/experimental/delegates/hexagon/hexagon_implementation.h +++ b/tensorflow/lite/experimental/delegates/hexagon/hexagon_implementation.h @@ -126,6 +126,9 @@ struct HexagonNN { // Otherwise. hexagon_nn_is_device_supported_fn* hexagon_nn_is_device_supported; + // Returns the version number of the interface library. + hexagon_nn_hexagon_interface_version_fn* hexagon_nn_hexagon_interface_version; + hexagon_nn_version_fn* hexagon_nn_version = nullptr; bool interface_loaded = false; diff --git a/tensorflow/lite/experimental/delegates/hexagon/hexagon_nn/hexagon_nn_init.h b/tensorflow/lite/experimental/delegates/hexagon/hexagon_nn/hexagon_nn_init.h index 812eb792a5c..d142edd4d03 100644 --- a/tensorflow/lite/experimental/delegates/hexagon/hexagon_nn/hexagon_nn_init.h +++ b/tensorflow/lite/experimental/delegates/hexagon/hexagon_nn/hexagon_nn_init.h @@ -21,6 +21,7 @@ extern "C" { void hexagon_nn_global_teardown(void); void hexagon_nn_global_init(void); bool hexagon_nn_is_device_supported(); +int hexagon_nn_hexagon_interface_version(void); #ifdef __cplusplus } #endif diff --git a/tensorflow/lite/experimental/delegates/hexagon/hexagon_nn/version_scripts.lds b/tensorflow/lite/experimental/delegates/hexagon/hexagon_nn/version_scripts.lds index 1e254e7eb0e..7b003afc770 100644 --- a/tensorflow/lite/experimental/delegates/hexagon/hexagon_nn/version_scripts.lds +++ b/tensorflow/lite/experimental/delegates/hexagon/hexagon_nn/version_scripts.lds @@ -19,6 +19,7 @@ VERS_1.0 { hexagon_nn_global_init; hexagon_nn_is_device_supported; hexagon_nn_version; + hexagon_nn_hexagon_interface_version; # Hide everything else. local: diff --git a/tensorflow/lite/experimental/delegates/hexagon/hexagon_nn_interface.h b/tensorflow/lite/experimental/delegates/hexagon/hexagon_nn_interface.h index 74f6ee11de0..cfb61a59182 100644 --- a/tensorflow/lite/experimental/delegates/hexagon/hexagon_nn_interface.h +++ b/tensorflow/lite/experimental/delegates/hexagon/hexagon_nn_interface.h @@ -56,4 +56,7 @@ using hexagon_nn_is_device_supported_fn = using hexagon_nn_version_fn = decltype(hexagon_nn_version); +using hexagon_nn_hexagon_interface_version_fn = + decltype(hexagon_nn_hexagon_interface_version); + #endif // TENSORFLOW_LITE_EXPERIMENTAL_DELEGATES_HEXAGON_HEXAGON_NN_INTERFACE_H_ diff --git a/tensorflow/lite/experimental/support/java/src/java/org/tensorflow/lite/support/common/FileUtil.java b/tensorflow/lite/experimental/support/java/src/java/org/tensorflow/lite/support/common/FileUtil.java index c7662d149e9..4a1e7d4a65e 100644 --- a/tensorflow/lite/experimental/support/java/src/java/org/tensorflow/lite/support/common/FileUtil.java +++ b/tensorflow/lite/experimental/support/java/src/java/org/tensorflow/lite/support/common/FileUtil.java @@ -24,6 +24,7 @@ import java.io.InputStream; import java.io.InputStreamReader; import java.nio.MappedByteBuffer; import java.nio.channels.FileChannel; +import java.nio.charset.Charset; import java.util.ArrayList; import java.util.List; import org.checkerframework.checker.nullness.qual.NonNull; @@ -46,10 +47,28 @@ public class FileUtil { @NonNull public static List<String> loadLabels(@NonNull Context context, @NonNull String filePath) throws IOException { + return loadLabels(context, filePath, Charset.defaultCharset()); + } + + /** + * Loads labels from the label file into a list of strings. + * + * <p>A legal label file is the plain text file whose contents are split into lines, and each line + * is an individual value. The file should be in assets of the context. + * + * @param context The context holds assets. + * @param filePath The path of the label file, relative with assets directory. + * @param cs {@code Charset} to use when decoding content of label file. + * @return a list of labels. + * @throws IOException if error occurs to open or read the file. + */ + @NonNull + public static List<String> loadLabels( + @NonNull Context context, @NonNull String filePath, Charset cs) throws IOException { SupportPreconditions.checkNotNull(context, "Context cannot be null."); SupportPreconditions.checkNotNull(filePath, "File path cannot be null."); InputStream inputStream = context.getAssets().open(filePath); - return loadLabels(inputStream); + return loadLabels(inputStream, cs); } /** @@ -62,8 +81,23 @@ public class FileUtil { */ @NonNull public static List<String> loadLabels(@NonNull InputStream inputStream) throws IOException { + return loadLabels(inputStream, Charset.defaultCharset()); + } + + /** + * Loads labels from an input stream of an opened label file. See details for label files in + * {@link FileUtil#loadLabels(Context, String)}. + * + * @param inputStream the input stream of an opened label file. + * @param cs {@code Charset} to use when decoding content of label file. + * @return a list of labels. + * @throws IOException if error occurs to open or read the file. + */ + @NonNull + public static List<String> loadLabels(@NonNull InputStream inputStream, Charset cs) + throws IOException { List<String> labels = new ArrayList<>(); - BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream)); + BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream, cs)); String line; while ((line = reader.readLine()) != null) { labels.add(line); @@ -72,6 +106,38 @@ public class FileUtil { return labels; } + /** + * Loads a vocabulary file (a single-column text file) into a list of strings. + * + * <p>A vocabulary file is a single-column plain text file whose contents are split into lines, + * and each line is an individual value. The file should be in assets of the context. + * + * @param context The context holds assets. + * @param filePath The path of the vocabulary file, relative with assets directory. + * @return a list of vocabulary words. + * @throws IOException if error occurs to open or read the file. + */ + @NonNull + public static List<String> loadSingleColumnTextFile( + @NonNull Context context, @NonNull String filePath, Charset cs) throws IOException { + return loadLabels(context, filePath, cs); + } + + /** + * Loads vocabulary from an input stream of an opened vocabulary file (which is a single-column + * text file). See details for vocabulary files in {@link FileUtil#loadVocabularyFile(Context, + * String)}. + * + * @param inputStream the input stream of an opened vocabulary file. + * @return a list of vocabulary words. + * @throws IOException if error occurs to open or read the file. + */ + @NonNull + public static List<String> loadSingleColumnTextFile(@NonNull InputStream inputStream, Charset cs) + throws IOException { + return loadLabels(inputStream, cs); + } + /** * Loads a file from the asset folder through memory mapping. * diff --git a/tensorflow/lite/experimental/support/java/src/java/org/tensorflow/lite/support/common/ops/DequantizeOp.java b/tensorflow/lite/experimental/support/java/src/java/org/tensorflow/lite/support/common/ops/DequantizeOp.java index 7840fbfac6a..1881747870b 100644 --- a/tensorflow/lite/experimental/support/java/src/java/org/tensorflow/lite/support/common/ops/DequantizeOp.java +++ b/tensorflow/lite/experimental/support/java/src/java/org/tensorflow/lite/support/common/ops/DequantizeOp.java @@ -24,6 +24,12 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; * <p>Note: The data type of output tensor is always {@code FLOAT32} except when the DequantizeOp is * created effectively as an identity Op such as setting {@code zeroPoint} to 0 and {@code scale} to * 1 (in this case, the output tensor is the same instance as input). + * + * <p>If both {@code zeroPoint} and {@code scale} are 0, the {@link DequantizeOp} will be bypassed, + * which is equivalent to setting {@code zeroPoint} to 0 and {@code scale} to 1. This can be useful + * when passing in the quantization parameters that are extracted directly from the TFLite model + * flatbuffer. If the tensor is not quantized, both {@code zeroPoint} and {@code scale} will be read + * as 0. */ public class DequantizeOp extends NormalizeOp implements TensorOperator { diff --git a/tensorflow/lite/experimental/support/java/src/java/org/tensorflow/lite/support/common/ops/NormalizeOp.java b/tensorflow/lite/experimental/support/java/src/java/org/tensorflow/lite/support/common/ops/NormalizeOp.java index 25db461ede1..8ac57eed286 100644 --- a/tensorflow/lite/experimental/support/java/src/java/org/tensorflow/lite/support/common/ops/NormalizeOp.java +++ b/tensorflow/lite/experimental/support/java/src/java/org/tensorflow/lite/support/common/ops/NormalizeOp.java @@ -41,6 +41,11 @@ public class NormalizeOp implements TensorOperator { * output = (input - mean) / stddev * </pre> * + * <p>In the following two cases, reset {@code mean} to 0 and {@code stddev} to 1 to bypass the + * normalization. <br> + * 1. Both {@code mean} and {code stddev} are 0. <br> + * 2. {@code mean} is 0 and {stddev} is Infinity. + * * <p>Note: If {@code mean} is set to 0 and {@code stddev} is set to 1, no computation will * happen, and original input will be directly returned in execution. * @@ -53,7 +58,28 @@ public class NormalizeOp implements TensorOperator { * @throws IllegalArgumentException if {@code stddev} is zero. */ public NormalizeOp(float mean, float stddev) { - this(new float[] {mean}, new float[] {stddev}); + // Make exceptions to the cases that + // 1. Both mean and stddev are 0.0f. This may happen when reading the normalization parameters + // from a tensor which does not have the values populated in the metadata. The same situation + // may also happen to the quantization parameters. + // 2. mean is 0.0f and stddev is Infinity. This may happen when reading the quantization + // parameters from a tensor which does not have the values populated in the metadata, and then + // passing the parameters into the DequantizeOp. + // Bypass both of the two cases, by reseting stddev to 1.0f. + if (mean == 0.0f && (stddev == 0.0f || Float.isInfinite(stddev))) { + stddev = 1.0f; + } + + SupportPreconditions.checkArgument(stddev != 0.0f, "Stddev cannot be zero."); + boolean meansIsZeroAndDevsIs1 = false; + if (mean == 0.0f && stddev == 1.0f) { + meansIsZeroAndDevsIs1 = true; + } + + this.isIdentityOp = meansIsZeroAndDevsIs1; + this.mean = new float[] {mean}; + this.stddev = new float[] {stddev}; + this.numChannels = 1; } /** diff --git a/tensorflow/lite/experimental/support/java/src/java/org/tensorflow/lite/support/common/ops/QuantizeOp.java b/tensorflow/lite/experimental/support/java/src/java/org/tensorflow/lite/support/common/ops/QuantizeOp.java index 77a6559cb00..8b3e82aee13 100644 --- a/tensorflow/lite/experimental/support/java/src/java/org/tensorflow/lite/support/common/ops/QuantizeOp.java +++ b/tensorflow/lite/experimental/support/java/src/java/org/tensorflow/lite/support/common/ops/QuantizeOp.java @@ -25,6 +25,12 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; * math on top of input. The data type of output tensor is always {@code FLOAT32} except that the Op * is effectively an identity Op (in this case, the output tensor is the same instance as the * input). To connect with quantized model, a {@link CastOp} is probably needed. + * + * <p>If both {@code zeroPoint} and {@code scale} are 0, the {@link QuantizeOp} will be bypassed, + * which is equivalent to setting {@code zeroPoint} to 0 and {@code scale} to 1. This can be useful + * when passing in the quantization parameters that are extracted directly from the TFLite model + * flatbuffer. If the tensor is not quantized, both {@code zeroPoint} and {@code scale} will be read + * as 0. */ public class QuantizeOp extends NormalizeOp implements TensorOperator { diff --git a/tensorflow/lite/java/BUILD b/tensorflow/lite/java/BUILD index cf8e6d40f9f..bfd610b759a 100644 --- a/tensorflow/lite/java/BUILD +++ b/tensorflow/lite/java/BUILD @@ -201,7 +201,6 @@ java_test( "src/testdata/int32.bin", "src/testdata/int64.bin", "src/testdata/invalid_model.bin", - "src/testdata/quantized.bin", "src/testdata/string.bin", "src/testdata/uint8.bin", "src/testdata/with_custom_op.lite", @@ -304,6 +303,7 @@ java_test( "src/testdata/add.bin", "src/testdata/int32.bin", "src/testdata/int64.bin", + "src/testdata/quantized.bin", ], javacopts = JAVACOPTS, tags = [ diff --git a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java index a6285895e4f..ca21ec5c7ea 100644 --- a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java +++ b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java @@ -254,24 +254,6 @@ final class NativeInterpreterWrapper implements AutoCloseable { return (inferenceDurationNanoseconds < 0) ? null : inferenceDurationNanoseconds; } - /** - * Gets the quantization zero point of an output. - * - * @throws IllegalArgumentException if the output index is invalid. - */ - int getOutputQuantizationZeroPoint(int index) { - return getOutputQuantizationZeroPoint(interpreterHandle, index); - } - - /** - * Gets the quantization scale of an output. - * - * @throws IllegalArgumentException if the output index is invalid. - */ - float getOutputQuantizationScale(int index) { - return getOutputQuantizationScale(interpreterHandle, index); - } - /** Gets the number of input tensors. */ int getInputTensorCount() { return inputTensors.length; @@ -374,10 +356,6 @@ final class NativeInterpreterWrapper implements AutoCloseable { private static native int getOutputDataType(long interpreterHandle, int outputIdx); - private static native int getOutputQuantizationZeroPoint(long interpreterHandle, int outputIdx); - - private static native float getOutputQuantizationScale(long interpreterHandle, int outputIdx); - private static final int ERROR_BUFFER_SIZE = 512; private long errorHandle; diff --git a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Tensor.java b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Tensor.java index 5d15b2c9a7e..95b9a41bf32 100644 --- a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Tensor.java +++ b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Tensor.java @@ -36,13 +36,55 @@ public final class Tensor { /** * Creates a Tensor wrapper from the provided interpreter instance and tensor index. * - * <p>The caller is responsible for closing the created wrapper, and ensuring the provided - * native interpreter is valid until the tensor is closed. + * <p>The caller is responsible for closing the created wrapper, and ensuring the provided native + * interpreter is valid until the tensor is closed. */ static Tensor fromIndex(long nativeInterpreterHandle, int tensorIndex) { return new Tensor(create(nativeInterpreterHandle, tensorIndex)); } + /** + * Quantization parameters that corresponds to the table, {@code QuantizationParameters}, in the + * <a + * href="https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/schema/schema.fbs">TFLite + * Model schema file.</a> + * + * <p>Since per-channel quantization does not apply to input and output tensors, {@code scale} and + * {@code zero_point} are both single values instead of arrays. + * + * <p>For tensor that are not quantized, the values of scale and zero_point are both 0. + * + * <p>Given a quantized value q, the corresponding float value f should be: <br> + * f = scale * (q - zero_point) <br> + */ + public static class QuantizationParams { + /** The scale value used in quantization. */ + private final float scale; + /** The zero point value used in quantization. */ + private final int zeroPoint; + + /** + * Creates a {@link QuantizationParams} with {@code scale} and {@code zero_point}. + * + * @param scale The scale value used in quantization. + * @param zeroPoint The zero point value used in quantization. + */ + public QuantizationParams(final float scale, final int zeroPoint) { + this.scale = scale; + this.zeroPoint = zeroPoint; + } + + /** Returns the scale value. */ + public float getScale() { + return scale; + } + + /** Returns the zero point value. */ + public int getZeroPoint() { + return zeroPoint; + } + } + /** Disposes of any resources used by the Tensor wrapper. */ void close() { delete(nativeHandle); @@ -114,6 +156,16 @@ public final class Tensor { return name(nativeHandle); } + /** + * Returns the quantization parameters of the tensor within the owning {@link Interpreter}. + * + * <p>Only quantized tensors have valid {@code QuantizationParameters}. For tensor that are not + * quantized, the values of scale and zero_point are both 0. + */ + public QuantizationParams quantizationParams() { + return quantizationParamsCopy; + } + /** * Copies the contents of the provided {@code src} object to the Tensor. * @@ -376,12 +428,14 @@ public final class Tensor { private final DataType dtype; private int[] shapeCopy; private final int[] shapeSignatureCopy; + private final QuantizationParams quantizationParamsCopy; private Tensor(long nativeHandle) { this.nativeHandle = nativeHandle; this.dtype = DataType.fromC(dtype(nativeHandle)); this.shapeCopy = shape(nativeHandle); this.shapeSignatureCopy = shapeSignature(nativeHandle); + this.quantizationParamsCopy = quantizationParameters(nativeHandle); } private ByteBuffer buffer() { @@ -413,4 +467,6 @@ public final class Tensor { private static native int index(long handle); private static native String name(long handle); + + private static native QuantizationParams quantizationParameters(long handle); } diff --git a/tensorflow/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc b/tensorflow/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc index 3701e07bd82..28d14b6da87 100644 --- a/tensorflow/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc +++ b/tensorflow/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc @@ -458,40 +458,6 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputDataType( return static_cast<jint>(type); } -JNIEXPORT jint JNICALL -Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputQuantizationZeroPoint( - JNIEnv* env, jclass clazz, jlong handle, jint output_idx) { - tflite_api_dispatcher::Interpreter* interpreter = - convertLongToInterpreter(env, handle); - if (interpreter == nullptr) return 0; - const int idx = static_cast<int>(output_idx); - if (output_idx < 0 || output_idx >= interpreter->outputs().size()) { - ThrowException(env, kIllegalArgumentException, - "Failed to get %d-th output out of %d outputs", output_idx, - interpreter->outputs().size()); - return 0; - } - TfLiteTensor* target = interpreter->tensor(interpreter->outputs()[idx]); - return static_cast<jint>(target->params.zero_point); -} - -JNIEXPORT jfloat JNICALL -Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputQuantizationScale( - JNIEnv* env, jclass clazz, jlong handle, jint output_idx) { - tflite_api_dispatcher::Interpreter* interpreter = - convertLongToInterpreter(env, handle); - if (interpreter == nullptr) return 1.0f; - const int idx = static_cast<int>(output_idx); - if (output_idx < 0 || output_idx >= interpreter->outputs().size()) { - ThrowException(env, kIllegalArgumentException, - "Failed to get %d-th output out of %d outputs", output_idx, - interpreter->outputs().size()); - return 1.0f; - } - TfLiteTensor* target = interpreter->tensor(interpreter->outputs()[idx]); - return static_cast<jfloat>(target->params.scale); -} - JNIEXPORT jboolean JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_resizeInput( JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle, diff --git a/tensorflow/lite/java/src/main/native/tensor_jni.cc b/tensorflow/lite/java/src/main/native/tensor_jni.cc index 9a38e85acd1..1a50b3233ca 100644 --- a/tensorflow/lite/java/src/main/native/tensor_jni.cc +++ b/tensorflow/lite/java/src/main/native/tensor_jni.cc @@ -482,6 +482,26 @@ JNIEXPORT jint JNICALL Java_org_tensorflow_lite_Tensor_index(JNIEnv* env, return GetTensorIndexFromHandle(env, handle); } +JNIEXPORT jobject JNICALL +Java_org_tensorflow_lite_Tensor_quantizationParameters(JNIEnv* env, + jclass clazz, + jlong handle) { + const TfLiteTensor* tensor = GetTensorFromHandle(env, handle); + + // For tensor that are not quantized, the values of scale and zero_point are + // both 0. + jfloat scale = static_cast<jfloat>(tensor->params.scale); + jlong zero_point = static_cast<jint>(tensor->params.zero_point); + + jclass quantization_params_class = + env->FindClass("org/tensorflow/lite/Tensor$QuantizationParams"); + jmethodID quantization_params_constructor = + env->GetMethodID(quantization_params_class, "<init>", "(FI)V"); + + return env->NewObject(quantization_params_class, + quantization_params_constructor, scale, zero_point); +} + #ifdef __cplusplus } // extern "C" #endif // __cplusplus diff --git a/tensorflow/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java index 6d522197b2d..bab39793130 100644 --- a/tensorflow/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java +++ b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java @@ -46,9 +46,6 @@ public final class NativeInterpreterWrapperTest { private static final String STRING_MODEL_PATH = "tensorflow/lite/java/src/testdata/string.bin"; - private static final String QUANTIZED_MODEL_PATH = - "tensorflow/lite/java/src/testdata/quantized.bin"; - private static final String INVALID_MODEL_PATH = "tensorflow/lite/java/src/testdata/invalid_model.bin"; @@ -561,16 +558,4 @@ public final class NativeInterpreterWrapperTest { assertThat(wrapper.getInputTensor(0).shape()).isEqualTo(expectedDims); } } - - @Test - public void testGetOutputQuantizationParams() { - try (NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH)) { - assertThat(wrapper.getOutputQuantizationZeroPoint(0)).isEqualTo(0); - assertThat(wrapper.getOutputQuantizationScale(0)).isWithin(1e-6f).of(0.0f); - } - try (NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(QUANTIZED_MODEL_PATH)) { - assertThat(wrapper.getOutputQuantizationZeroPoint(0)).isEqualTo(127); - assertThat(wrapper.getOutputQuantizationScale(0)).isWithin(1e-6f).of(0.25f); - } - } } diff --git a/tensorflow/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java index 09e9b1cbc8f..f828f26f4c5 100644 --- a/tensorflow/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java +++ b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java @@ -31,6 +31,7 @@ import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import org.tensorflow.lite.Tensor.QuantizationParams; /** Unit tests for {@link org.tensorflow.lite.Tensor}. */ @RunWith(JUnit4.class) @@ -45,6 +46,9 @@ public final class TensorTest { private static final String LONG_MODEL_PATH = "tensorflow/lite/java/src/testdata/int64.bin"; + private static final String QUANTIZED_MODEL_PATH = + "tensorflow/lite/java/src/testdata/quantized.bin"; + private NativeInterpreterWrapper wrapper; private Tensor tensor; @@ -451,4 +455,27 @@ public final class TensorTest { // Expected failure. } } + + @Test + public void testQuantizationParameters_floatModel() { + QuantizationParams quantizationParams = tensor.quantizationParams(); + float scale = quantizationParams.getScale(); + long zeroPoint = quantizationParams.getZeroPoint(); + + assertThat(scale).isWithin(1e-6f).of(0.0f); + assertThat(zeroPoint).isEqualTo(0); + } + + @Test + public void testQuantizationParameters_quantizedModel() { + wrapper = new NativeInterpreterWrapper(QUANTIZED_MODEL_PATH); + tensor = wrapper.getOutputTensor(0); + + QuantizationParams quantizationParams = tensor.quantizationParams(); + float scale = quantizationParams.getScale(); + long zeroPoint = quantizationParams.getZeroPoint(); + + assertThat(scale).isWithin(1e-6f).of(0.25f); + assertThat(zeroPoint).isEqualTo(127); + } } diff --git a/tensorflow/lite/kernels/internal/optimized/integer_ops/conv.h b/tensorflow/lite/kernels/internal/optimized/integer_ops/conv.h index 6308131409f..9f967070413 100644 --- a/tensorflow/lite/kernels/internal/optimized/integer_ops/conv.h +++ b/tensorflow/lite/kernels/internal/optimized/integer_ops/conv.h @@ -58,8 +58,6 @@ inline void ConvPerChannel( const bool need_im2col = stride_width != 1 || stride_height != 1 || filter_width != 1 || filter_height != 1; const int8 input_zero_point = -input_offset; - TFLITE_DCHECK_GE(input_zero_point, output_activation_min); - TFLITE_DCHECK_LE(input_zero_point, output_activation_max); const uint8 zero_point_byte = *reinterpret_cast<const uint8*>(&input_zero_point); if (need_dilated_im2col) { diff --git a/tensorflow/lite/kernels/internal/optimized/integer_ops/depthwise_conv.h b/tensorflow/lite/kernels/internal/optimized/integer_ops/depthwise_conv.h index 26d7548e5c3..642d7577a1b 100644 --- a/tensorflow/lite/kernels/internal/optimized/integer_ops/depthwise_conv.h +++ b/tensorflow/lite/kernels/internal/optimized/integer_ops/depthwise_conv.h @@ -1814,8 +1814,9 @@ inline void DepthwiseConvWithRounding( const auto ruy_paths = ruy_context != nullptr ? ruy_context->GetRuntimeEnabledPaths() : ruy::Path::kNone; + // TODO(b/150208140): Re-enable once erroneous activation in test is resolved. const bool has_dot_product_instructions = - (ruy_paths & ruy::Path::kNeonDotprod) != ruy::Path::kNone; + false && (ruy_paths & ruy::Path::kNeonDotprod) != ruy::Path::kNone; // Dispatch to dot-product 3x3 kernels when supported. if (has_dot_product_instructions) { diff --git a/tensorflow/lite/kernels/internal/reference/strided_slice.h b/tensorflow/lite/kernels/internal/reference/strided_slice.h index 921c49ea77b..ba6d4c22554 100644 --- a/tensorflow/lite/kernels/internal/reference/strided_slice.h +++ b/tensorflow/lite/kernels/internal/reference/strided_slice.h @@ -18,7 +18,6 @@ limitations under the License. #include "tensorflow/lite/kernels/internal/common.h" #include "tensorflow/lite/kernels/internal/strided_slice_logic.h" #include "tensorflow/lite/kernels/internal/types.h" - namespace tflite { namespace reference_ops { @@ -28,47 +27,60 @@ inline void StridedSlice(const tflite::StridedSliceParams& op_params, const T* input_data, const RuntimeShape& unextended_output_shape, T* output_data) { + using strided_slice::LoopCondition; + using strided_slice::StartForAxis; + using strided_slice::StopForAxis; // Note that the output_shape is not used herein. tflite::StridedSliceParams params_copy = op_params; - TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4); - TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4); + TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 5); + TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 5); const RuntimeShape input_shape = - RuntimeShape::ExtendedShape(4, unextended_input_shape); + RuntimeShape::ExtendedShape(5, unextended_input_shape); const RuntimeShape output_shape = - RuntimeShape::ExtendedShape(4, unextended_output_shape); + RuntimeShape::ExtendedShape(5, unextended_output_shape); - // Reverse and pad to 4 dimensions because that is what the runtime code - // requires (ie. all shapes must be 4D and are given backwards). - strided_slice::StridedSlicePadIndices(¶ms_copy, 4); + // Reverse and pad to 5 dimensions because that is what the runtime code + // requires (ie. all shapes must be 5D and are given backwards). + strided_slice::StridedSlicePadIndices(¶ms_copy, 5); - const int start_b = strided_slice::StartForAxis(params_copy, input_shape, 0); - const int stop_b = - strided_slice::StopForAxis(params_copy, input_shape, 0, start_b); - const int start_h = strided_slice::StartForAxis(params_copy, input_shape, 1); - const int stop_h = - strided_slice::StopForAxis(params_copy, input_shape, 1, start_h); - const int start_w = strided_slice::StartForAxis(params_copy, input_shape, 2); - const int stop_w = - strided_slice::StopForAxis(params_copy, input_shape, 2, start_w); - const int start_d = strided_slice::StartForAxis(params_copy, input_shape, 3); - const int stop_d = - strided_slice::StopForAxis(params_copy, input_shape, 3, start_d); + const int start_0 = StartForAxis(params_copy, input_shape, 0); + const int stop_0 = StopForAxis(params_copy, input_shape, 0, start_0); + const int start_1 = StartForAxis(params_copy, input_shape, 1); + const int stop_1 = StopForAxis(params_copy, input_shape, 1, start_1); + const int start_2 = StartForAxis(params_copy, input_shape, 2); + const int stop_2 = StopForAxis(params_copy, input_shape, 2, start_2); + const int start_3 = StartForAxis(params_copy, input_shape, 3); + const int stop_3 = StopForAxis(params_copy, input_shape, 3, start_3); + const int start_4 = StartForAxis(params_copy, input_shape, 4); + const int stop_4 = StopForAxis(params_copy, input_shape, 4, start_4); T* out_ptr = output_data; - for (int in_b = start_b; - !strided_slice::LoopCondition(in_b, stop_b, params_copy.strides[0]); - in_b += params_copy.strides[0]) { - for (int in_h = start_h; - !strided_slice::LoopCondition(in_h, stop_h, params_copy.strides[1]); - in_h += params_copy.strides[1]) { - for (int in_w = start_w; - !strided_slice::LoopCondition(in_w, stop_w, params_copy.strides[2]); - in_w += params_copy.strides[2]) { - for (int in_d = start_d; !strided_slice::LoopCondition( - in_d, stop_d, params_copy.strides[3]); - in_d += params_copy.strides[3]) { - *out_ptr++ = input_data[Offset(input_shape, in_b, in_h, in_w, in_d)]; + for (int offset_0 = start_0 * input_shape.Dims(1), + end_0 = stop_0 * input_shape.Dims(1), + step_0 = params_copy.strides[0] * input_shape.Dims(1); + !LoopCondition(offset_0, end_0, params_copy.strides[0]); + offset_0 += step_0) { + for (int offset_1 = (offset_0 + start_1) * input_shape.Dims(2), + end_1 = (offset_0 + stop_1) * input_shape.Dims(2), + step_1 = params_copy.strides[1] * input_shape.Dims(2); + !LoopCondition(offset_1, end_1, params_copy.strides[1]); + offset_1 += step_1) { + for (int offset_2 = (offset_1 + start_2) * input_shape.Dims(3), + end_2 = (offset_1 + stop_2) * input_shape.Dims(3), + step_2 = params_copy.strides[2] * input_shape.Dims(3); + !LoopCondition(offset_2, end_2, params_copy.strides[2]); + offset_2 += step_2) { + for (int offset_3 = (offset_2 + start_3) * input_shape.Dims(4), + end_3 = (offset_2 + stop_3) * input_shape.Dims(4), + step_3 = params_copy.strides[3] * input_shape.Dims(4); + !LoopCondition(offset_3, end_3, params_copy.strides[3]); + offset_3 += step_3) { + for (int offset_4 = offset_3 + start_4, end_4 = offset_3 + stop_4; + !LoopCondition(offset_4, end_4, params_copy.strides[4]); + offset_4 += params_copy.strides[4]) { + *out_ptr++ = input_data[offset_4]; + } } } } diff --git a/tensorflow/lite/kernels/internal/strided_slice_logic.h b/tensorflow/lite/kernels/internal/strided_slice_logic.h index 3022ac7b8e9..12dd33d3296 100644 --- a/tensorflow/lite/kernels/internal/strided_slice_logic.h +++ b/tensorflow/lite/kernels/internal/strided_slice_logic.h @@ -35,7 +35,7 @@ inline int Clamp(const int v, const int lo, const int hi) { inline void StridedSlicePadIndices(tflite::StridedSliceParams* p, int dim_count) { // Add indices and mask bits to fully include extra dimensions - TFLITE_CHECK_LE(dim_count, 4); + TFLITE_CHECK_LE(dim_count, 5); TFLITE_CHECK_GE(dim_count, p->start_indices_count); TFLITE_CHECK_EQ(p->start_indices_count, p->stop_indices_count); TFLITE_CHECK_EQ(p->stop_indices_count, p->strides_count); diff --git a/tensorflow/lite/kernels/internal/types.h b/tensorflow/lite/kernels/internal/types.h index e96e50209bb..f4d2b5c0dad 100644 --- a/tensorflow/lite/kernels/internal/types.h +++ b/tensorflow/lite/kernels/internal/types.h @@ -128,9 +128,9 @@ struct Dims { class RuntimeShape { public: - // Shapes with dimensions up to 4 are stored directly in the structure, while + // Shapes with dimensions up to 5 are stored directly in the structure, while // larger shapes are separately allocated. - static constexpr int kMaxSmallSize = 4; + static constexpr int kMaxSmallSize = 5; RuntimeShape& operator=(RuntimeShape const&) = delete; @@ -207,8 +207,8 @@ class RuntimeShape { inline const int32* DimsData() const { return size_ > kMaxSmallSize ? dims_pointer_ : dims_; } - // The caller must ensure that the shape is no bigger than 4-D. - inline const int32* DimsDataUpTo4D() const { return dims_; } + // The caller must ensure that the shape is no bigger than 5-D. + inline const int32* DimsDataUpTo5D() const { return dims_; } inline void Resize(int dimensions_count) { if (size_ > kMaxSmallSize) { @@ -378,7 +378,7 @@ inline size_t ReducedOutputOffset(const int num_dims, const int* dims, inline int Offset(const RuntimeShape& shape, int i0, int i1, int i2, int i3) { TFLITE_DCHECK_EQ(shape.DimensionsCount(), 4); - const int* dims_data = reinterpret_cast<const int*>(shape.DimsDataUpTo4D()); + const int* dims_data = reinterpret_cast<const int*>(shape.DimsDataUpTo5D()); TFLITE_DCHECK(i0 >= 0 && i0 < dims_data[0]); TFLITE_DCHECK(i1 >= 0 && i1 < dims_data[1]); TFLITE_DCHECK(i2 >= 0 && i2 < dims_data[2]); @@ -1049,11 +1049,11 @@ struct SqueezeParams { struct StridedSliceParams { int8 start_indices_count; - int32 start_indices[4]; + int32 start_indices[5]; int8 stop_indices_count; - int32 stop_indices[4]; + int32 stop_indices[5]; int8 strides_count; - int32 strides[4]; + int32 strides[5]; int16 begin_mask; int16 ellipsis_mask; diff --git a/tensorflow/lite/kernels/register.cc b/tensorflow/lite/kernels/register.cc index e8eebd81025..51534375e5f 100644 --- a/tensorflow/lite/kernels/register.cc +++ b/tensorflow/lite/kernels/register.cc @@ -156,7 +156,7 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_SQUEEZE, Register_SQUEEZE()); AddBuiltin(BuiltinOperator_STRIDED_SLICE, Register_STRIDED_SLICE(), /* min_version */ 1, - /* max_version */ 3); + /* max_version */ 4); AddBuiltin(BuiltinOperator_EXP, Register_EXP()); AddBuiltin(BuiltinOperator_TOPK_V2, Register_TOPK_V2(), /* min_version */ 1, diff --git a/tensorflow/lite/kernels/strided_slice.cc b/tensorflow/lite/kernels/strided_slice.cc index ba39b016624..e2ca812d193 100644 --- a/tensorflow/lite/kernels/strided_slice.cc +++ b/tensorflow/lite/kernels/strided_slice.cc @@ -142,8 +142,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, op_context.begin->type, kTfLiteInt32); TF_LITE_ENSURE_EQ(context, op_context.end->type, kTfLiteInt32); TF_LITE_ENSURE_EQ(context, op_context.strides->type, kTfLiteInt32); - TF_LITE_ENSURE_MSG(context, op_context.dims <= 4, - "StridedSlice op only supports 1D-4D input arrays."); + TF_LITE_ENSURE_MSG(context, op_context.dims <= 5, + "StridedSlice op only supports 1D-5D input arrays."); // TODO(b/138098220): Remove when bug is resolved. // Currently, working on using the compiler to cannonize strided_slice, diff --git a/tensorflow/lite/kernels/strided_slice_test.cc b/tensorflow/lite/kernels/strided_slice_test.cc index 83093a09eed..8db98dba0e9 100644 --- a/tensorflow/lite/kernels/strided_slice_test.cc +++ b/tensorflow/lite/kernels/strided_slice_test.cc @@ -82,9 +82,9 @@ TYPED_TEST_SUITE(StridedSliceOpTest, DataTypes); #ifdef GTEST_HAS_DEATH_TEST TYPED_TEST(StridedSliceOpTest, UnsupportedInputSize) { - EXPECT_DEATH(StridedSliceOpModel<TypeParam>({2, 2, 2, 2, 2}, {5}, {5}, {5}, 0, - 0, 0, 0, 0), - "StridedSlice op only supports 1D-4D input arrays."); + EXPECT_DEATH(StridedSliceOpModel<TypeParam>({2, 2, 2, 2, 2, 2}, {5}, {5}, {5}, + 0, 0, 0, 0, 0), + "StridedSlice op only supports 1D-5D input arrays."); } TYPED_TEST(StridedSliceOpTest, UnssupportedArgs) { @@ -612,5 +612,29 @@ TYPED_TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis1int8) { EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 2})); EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6})); } + +TYPED_TEST(StridedSliceOpTest, In5D_Identity) { + StridedSliceOpModel<TypeParam> m({2, 2, 2, 1, 2}, {5}, {5}, {5}, 0, 0, 0, 0, + 0); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); + m.SetBegin({0, 0, 0, 0, 0}); + m.SetEnd({2, 1, 2, 1, 2}); + m.SetStrides({1, 1, 1, 1, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 1, 2, 1, 2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 9, 10, 11, 12})); +} + +TYPED_TEST(StridedSliceOpTest, In5D_IdentityShrinkAxis1) { + StridedSliceOpModel<TypeParam> m({2, 2, 2, 1, 2}, {5}, {5}, {5}, 0, 0, 0, 0, + 1); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); + m.SetBegin({0, 0, 0, 0, 0}); + m.SetEnd({2, 1, 2, 1, 2}); + m.SetStrides({1, 1, 1, 1, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2, 1, 2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4})); +} } // namespace } // namespace tflite diff --git a/tensorflow/lite/micro/examples/micro_speech/esp/Makefile.inc b/tensorflow/lite/micro/examples/micro_speech/esp/Makefile.inc index f03f199d215..a658c2fad94 100644 --- a/tensorflow/lite/micro/examples/micro_speech/esp/Makefile.inc +++ b/tensorflow/lite/micro/examples/micro_speech/esp/Makefile.inc @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -ifeq ($(TARGET), "esp") +ifeq ($(TARGET), esp) # Adding some esp specific files in the main CMakeLists.txt ESP_MICRO_SPEECH_SRCS := \ @@ -26,34 +26,6 @@ CXXFLAGS += -Wno-return-type -Wno-strict-aliasing -Wno-ignored-qualifiers MICRO_SPEECH_SRCS += $(ESP_MICRO_SPEECH_SRCS) MICRO_SPEECH_HDRS += $(ESP_MICRO_SPEECH_HDRS) MAIN_SRCS += $(ESP_MICRO_SPEECH_SRCS) -# Adding the microfrontend lib in the CMakeLists.txt of tfmicro -PROJECT_INCLUDES += \ -tensorflow/lite/experimental/microfrontend/lib - -MICRO_FRONTEND_LIB_SRCS := \ -tensorflow/lite/experimental/microfrontend/lib/fft.cc \ -tensorflow/lite/experimental/microfrontend/lib/fft_util.cc \ -tensorflow/lite/experimental/microfrontend/lib/filterbank.c \ -tensorflow/lite/experimental/microfrontend/lib/filterbank_util.c \ -tensorflow/lite/experimental/microfrontend/lib/frontend.c \ -tensorflow/lite/experimental/microfrontend/lib/frontend_util.c \ -tensorflow/lite/experimental/microfrontend/lib/log_lut.c \ -tensorflow/lite/experimental/microfrontend/lib/log_scale.c \ -tensorflow/lite/experimental/microfrontend/lib/log_scale_util.c \ -tensorflow/lite/experimental/microfrontend/lib/noise_reduction.c \ -tensorflow/lite/experimental/microfrontend/lib/noise_reduction_util.c \ -tensorflow/lite/experimental/microfrontend/lib/pcan_gain_control.c \ -tensorflow/lite/experimental/microfrontend/lib/pcan_gain_control_util.c \ -tensorflow/lite/experimental/microfrontend/lib/window.c \ -tensorflow/lite/experimental/microfrontend/lib/window_util.c - -# Adding the micro frontend lib srcs into the CMakeLists.txt of tfmicro -MICROLITE_CC_SRCS += $(MICRO_FRONTEND_LIB_SRCS) -# Adding the kissfft srcs in the CMakeLists.txt of tfmicro -THIRD_PARTY_CC_SRCS += $(KISSFFT_LIB_SRCS) -# stopping microfrontend srcs from being included in the main srcs -MICRO_SPEECH_SRCS := $(filter-out $(MICRO_FRONTEND_LIB_SRCS), $(MICRO_SPEECH_SRCS)) -MICRO_SPEECH_SRCS := $(filter-out $(KISSFFT_LIB_SRCS), $(MICRO_SPEECH_SRCS)) ESP_PROJECT_FILES += \ sdkconfig.defaults diff --git a/tensorflow/lite/micro/examples/micro_speech/esp/audio_provider.cc b/tensorflow/lite/micro/examples/micro_speech/esp/audio_provider.cc index 5c23381c3cc..3596246d1e3 100644 --- a/tensorflow/lite/micro/examples/micro_speech/esp/audio_provider.cc +++ b/tensorflow/lite/micro/examples/micro_speech/esp/audio_provider.cc @@ -18,12 +18,17 @@ limitations under the License. #include <cstdlib> #include <cstring> +// FreeRTOS.h must be included before some of the following dependencies. +// Solves b/150260343. +// clang-format off +#include "freertos/FreeRTOS.h" +// clang-format on + #include "driver/i2s.h" #include "esp_log.h" #include "esp_spi_flash.h" #include "esp_system.h" #include "esp_timer.h" -#include "freertos/FreeRTOS.h" #include "freertos/task.h" #include "ringbuf.h" #include "tensorflow/lite/micro/examples/micro_speech/micro_features/micro_model_settings.h" diff --git a/tensorflow/lite/micro/tools/make/Makefile b/tensorflow/lite/micro/tools/make/Makefile index 1dc45f88cb9..5546dfc95ab 100644 --- a/tensorflow/lite/micro/tools/make/Makefile +++ b/tensorflow/lite/micro/tools/make/Makefile @@ -1,4 +1,3 @@ - ifneq (3.82,$(firstword $(sort $(MAKE_VERSION) 3.82))) $(error "Requires make version 3.82 or later (current is $(MAKE_VERSION))") endif @@ -65,18 +64,24 @@ TEST_SCRIPT := tensorflow/lite/micro/testing/test_linux_binary.sh MICROLITE_LIBS := -lm -# There are no rules for compiling objects for the host system (since we don't -# generate things like the protobuf compiler that require that), so all of -# these settings are for the target compiler. -CXXFLAGS := -O3 -CXXFLAGS += -std=c++11 -g -DTF_LITE_STATIC_MEMORY -CXXFLAGS += -fno-rtti -CCFLAGS := -g -DTF_LITE_STATIC_MEMORY -LDOPTS := -L/usr/local/lib +# TODO(b/150240249): Add in -fno-rtti once that works for the Xtensa toolchain. +CXXFLAGS := -std=c++11 -DTF_LITE_STATIC_MEMORY +CCFLAGS := -std=c11 -DTF_LITE_STATIC_MEMORY ARFLAGS := -r TARGET_TOOLCHAIN_PREFIX := CC_PREFIX := +ifeq ($(BUILD_TYPE), debug) + CXXFLAGS += -DDEBUG -g + CCFLAGS += -DDEBUG -g +else ifeq ($(BUILD_TYPE), release) + CXXFLAGS += -DNDEBUG -O3 -DTF_LITE_STRIP_ERROR_STRINGS + CCFLAGS += -DNDEBUG -O3 -DTF_LITE_STRIP_ERROR_STRINGS +else + CXXFLAGS += -O3 + CCFLAGS += -O3 +endif + # This library is the main target for this makefile. It will contain a minimal # runtime that can be linked in to other programs. MICROLITE_LIB_NAME := libtensorflow-microlite.a diff --git a/tensorflow/lite/micro/tools/make/helper_functions.inc b/tensorflow/lite/micro/tools/make/helper_functions.inc index be6d943f4e9..40c8c1c0aeb 100644 --- a/tensorflow/lite/micro/tools/make/helper_functions.inc +++ b/tensorflow/lite/micro/tools/make/helper_functions.inc @@ -326,10 +326,15 @@ $(PRJDIR)$(2)/esp-idf/sdkconfig.defaults: tensorflow/lite/micro/examples/$(2)/es @cp $$< $$@ $(PRJDIR)$(2)/esp-idf/%: tensorflow/lite/micro/tools/make/templates/esp/%.tpl - $(eval MAIN_SRCS_RELATIVE := $(patsubst tensorflow/lite/micro/examples/$(2)/%,%,$(5))) + # Split the sources into 2 components: + # - Main component contains only the example's sources, relative from its dir. + $(eval MAIN_SRCS := $(filter tensorflow/lite/micro/examples/%,$(5))) + $(eval MAIN_SRCS_RELATIVE := $(patsubst tensorflow/lite/micro/examples/$(2)/%,%,$(MAIN_SRCS))) + # - TFL Micro component contains everything but the example sources. + $(eval TFLM_SRCS := $(filter-out tensorflow/lite/micro/examples/%,$(5)) $(3)) @mkdir -p $$(dir $$@) - @sed -E 's#\%\{COMPONENT_SRCS\}\%#$(3)#g' $$< | \ + @sed -E 's#\%\{COMPONENT_SRCS\}\%#$(TFLM_SRCS)#g' $$< | \ sed -E 's#\%\{MAIN_SRCS\}\%#$(MAIN_SRCS_RELATIVE)#g' | \ sed -E 's#\%\{EXECUTABLE\}\%#$(2)#g' | \ sed -E 's#\%\{COMPONENT_INCLUDES\}\%#$(10)#g' | \ diff --git a/tensorflow/lite/micro/tools/make/targets/xtensa_xpg_makefile.inc b/tensorflow/lite/micro/tools/make/targets/xtensa_xpg_makefile.inc index 0ccad72692d..5836aea417d 100644 --- a/tensorflow/lite/micro/tools/make/targets/xtensa_xpg_makefile.inc +++ b/tensorflow/lite/micro/tools/make/targets/xtensa_xpg_makefile.inc @@ -3,17 +3,14 @@ # - RI2019.2 Toolkit (for xt-clang/xt-clang++). # - XTENSA_CORE: The name of the core to use, will cause a compiler exception # without providing a core. + ifeq ($(TARGET), xtensa-xpg) TARGET_ARCH := xtensa-xpg PLATFORM_ARGS = \ - -DTF_LITE_STATIC_MEMORY \ - -DTF_LITE_STRIP_ERROR_STRINGS \ - -DNDEBUG \ -DTF_LITE_MCU_DEBUG_LOG \ --xtensa-core=$(XTENSA_CORE) \ -mcoproc \ - -O3 \ -DXTENSA -DMAX_RFFT_PWR=9 -DMIN_RFFT_PWR=MAX_RFFT_PWR \ -fdata-sections \ -ffunction-sections \ @@ -25,18 +22,10 @@ ifeq ($(TARGET), xtensa-xpg) CXX_TOOL := clang++ CC_TOOL := clang - CXXFLAGS = $(PLATFORM_ARGS) -std=c++11 - CCFLAGS = $(PLATFORM_ARGS) -std=c11 + CXXFLAGS += $(PLATFORM_ARGS) + CCFLAGS += $(PLATFORM_ARGS) LDFLAGS += -Wl,-gc-sections TEST_SCRIPT := tensorflow/lite/micro/testing/test_xtensa_xpg_binary.sh - - # These are microcontroller-specific rules for converting the ELF output - # of the linker into a binary image that can be loaded directly. - OBJCOPY := $(TARGET_TOOLCHAIN_PREFIX)objcopy - - $(BINDIR)/%.bin: $(BINDIR)/% - @mkdir -p $(dir $@) - $(OBJCOPY) $< $@ -O binary endif diff --git a/tensorflow/lite/micro/xtensa-xpg/debug_log.cc b/tensorflow/lite/micro/xtensa-xpg/debug_log.cc index a95a084978b..ca8b251c069 100644 --- a/tensorflow/lite/micro/xtensa-xpg/debug_log.cc +++ b/tensorflow/lite/micro/xtensa-xpg/debug_log.cc @@ -39,7 +39,10 @@ limitations under the License. #include <cstdio> extern "C" void DebugLog(const char* s) { -#ifndef NDEBUG +#ifndef TF_LITE_STRIP_ERROR_STRINGS + // Reusing TF_LITE_STRIP_ERROR_STRINGS to disable DebugLog completely to get + // maximum reduction in binary size. This is because we have DebugLog calls + // via TF_LITE_CHECK that are not stubbed out by TF_LITE_REPORT_ERROR. fprintf(stderr, "%s", s); #endif } diff --git a/tensorflow/lite/testing/op_tests/strided_slice_np_style.py b/tensorflow/lite/testing/op_tests/strided_slice_np_style.py index 95f7acabdf7..45f2e4b867a 100644 --- a/tensorflow/lite/testing/op_tests/strided_slice_np_style.py +++ b/tensorflow/lite/testing/op_tests/strided_slice_np_style.py @@ -68,6 +68,18 @@ def make_strided_slice_np_style_tests(options): [slice(1, 11, 3), Ellipsis, slice(3, 7, 2)]], }, + # Ellipsis 5d. + { + "dtype": [tf.float32], + "shape": [[11, 21, 15, 7, 9]], + "spec": [[ + slice(3, 7, 2), + slice(None), + slice(None), + slice(None), + slice(None) + ], [Ellipsis, slice(3, 7, 2)]], + }, # All combinations. { "dtype": [tf.float32], diff --git a/tensorflow/lite/toco/tflite/operator.cc b/tensorflow/lite/toco/tflite/operator.cc index 3238d8ef032..0c310d15020 100644 --- a/tensorflow/lite/toco/tflite/operator.cc +++ b/tensorflow/lite/toco/tflite/operator.cc @@ -1165,6 +1165,15 @@ class StridedSlice op->new_axis_mask = options.new_axis_mask(); op->shrink_axis_mask = options.shrink_axis_mask(); } + + int GetVersion(const OperatorSignature& op_signature) const override { + const auto& ss_op = + static_cast<const StridedSliceOperator&>(*op_signature.op); + ::tflite::OpSignature op_sig = + GetVersioningOpSig(builtin_op(), op_signature); + op_sig.options.strided_slice.num_dims = ss_op.start_indices.size(); + return ::tflite::GetBuiltinOperatorVersion(op_sig); + } }; class TopK_V2 : public BuiltinOperator<TopKV2Operator, ::tflite::TopKV2Options, diff --git a/tensorflow/lite/tools/benchmark/README.md b/tensorflow/lite/tools/benchmark/README.md index 89a64ab2751..286ddf69cab 100644 --- a/tensorflow/lite/tools/benchmark/README.md +++ b/tensorflow/lite/tools/benchmark/README.md @@ -65,6 +65,14 @@ and the following optional parameters: This is available on recent Android devices. Note that some Android P devices will fail to use NNAPI for models in `/data/local/tmp/` and this benchmark tool will not correctly use NNAPI. +* `max_delegated_partitions`: `int` (default=0, i.e. no limit) \ + The maximum number of partitions that will be delegated. \ + Currently supported only by the NNAPI Delegate and it won't work \ + if `use_legacy_nnapi` has been selected. +* `disable_nnapi_cpu`: `bool` (default=false) \ + Excludes the [NNAPI CPU reference implementation](https://developer.android.com/ndk/guides/neuralnetworks#device-assignment) + from the possible devices to be used by NNAPI to execute the model. + This option is ignored if `nnapi_accelerator_name` is specified. * `use_gpu`: `bool` (default=false) \ Whether to use the [GPU accelerator delegate](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/delegates/gpu). This option is currently only available on Android and iOS devices. diff --git a/tensorflow/lite/tools/benchmark/benchmark_performance_options.cc b/tensorflow/lite/tools/benchmark/benchmark_performance_options.cc index eccd389af6f..46620ae3372 100644 --- a/tensorflow/lite/tools/benchmark/benchmark_performance_options.cc +++ b/tensorflow/lite/tools/benchmark/benchmark_performance_options.cc @@ -240,6 +240,8 @@ void BenchmarkPerformanceOptions::ResetPerformanceOptions() { single_option_run_params_->Set<bool>("gpu_precision_loss_allowed", true); single_option_run_params_->Set<bool>("use_nnapi", false); single_option_run_params_->Set<std::string>("nnapi_accelerator_name", ""); + single_option_run_params_->Set<bool>("disable_nnapi_cpu", false); + single_option_run_params_->Set<int>("max_delegated_partitions", 0); #endif #if defined(TFLITE_ENABLE_HEXAGON) single_option_run_params_->Set<bool>("use_hexagon", false); @@ -305,6 +307,10 @@ void BenchmarkPerformanceOptions::CreatePerformanceOptions() { params.AddParam("use_nnapi", BenchmarkParam::Create<bool>(true)); params.AddParam("nnapi_accelerator_name", BenchmarkParam::Create<std::string>(name)); + params.AddParam("disable_nnapi_cpu", + BenchmarkParam::Create<bool>(false)); + params.AddParam("max_delegated_partitions", + BenchmarkParam::Create<int>(0)); all_run_params_.emplace_back(std::move(params)); } } diff --git a/tensorflow/lite/tools/benchmark/benchmark_test.cc b/tensorflow/lite/tools/benchmark/benchmark_test.cc index 81b27c092ec..0d49fc8baec 100644 --- a/tensorflow/lite/tools/benchmark/benchmark_test.cc +++ b/tensorflow/lite/tools/benchmark/benchmark_test.cc @@ -77,6 +77,8 @@ BenchmarkParams CreateParams(int32_t num_runs, float min_secs, float max_secs, BenchmarkParam::Create<std::string>("")); params.AddParam("nnapi_execution_preference", BenchmarkParam::Create<std::string>("")); + params.AddParam("disable_nnapi_cpu", BenchmarkParam::Create<bool>(false)); + params.AddParam("max_delegated_partitions", BenchmarkParam::Create<int>(0)); params.AddParam("profiling_output_csv_file", BenchmarkParam::Create<std::string>("")); return params; diff --git a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc index 6b1e9819312..d563724efe4 100644 --- a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc +++ b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc @@ -215,6 +215,8 @@ BenchmarkParams BenchmarkTfLiteModel::DefaultParams() { BenchmarkParam::Create<int32_t>(1024)); default_params.AddParam("profiling_output_csv_file", BenchmarkParam::Create<std::string>("")); + default_params.AddParam("max_delegated_partitions", + BenchmarkParam::Create<int32_t>(0)); for (const auto& delegate_util : GetRegisteredDelegateProviders()) { delegate_util->AddParams(&default_params); @@ -257,7 +259,9 @@ std::vector<Flag> BenchmarkTfLiteModel::GetFlags() { CreateFlag<std::string>( "profiling_output_csv_file", ¶ms_, "File path to export profile data as CSV, if not set " - "prints to stdout.")}; + "prints to stdout."), + CreateFlag<int>("max_delegated_partitions", ¶ms_, + "Max partitions to be delegated.")}; flags.insert(flags.end(), specific_flags.begin(), specific_flags.end()); @@ -295,6 +299,8 @@ void BenchmarkTfLiteModel::LogParams() { TFLITE_LOG(INFO) << "CSV File to export profiling data to: [" << params_.Get<std::string>("profiling_output_csv_file") << "]"; + TFLITE_LOG(INFO) << "Max number of delegated partitions : [" + << params_.Get<int32_t>("max_delegated_partitions") << "]"; for (const auto& delegate_util : GetRegisteredDelegateProviders()) { delegate_util->LogParams(params_); diff --git a/tensorflow/lite/tools/benchmark/nnapi_delegate_provider.cc b/tensorflow/lite/tools/benchmark/nnapi_delegate_provider.cc index 4ac50b9771f..3f87de863e7 100644 --- a/tensorflow/lite/tools/benchmark/nnapi_delegate_provider.cc +++ b/tensorflow/lite/tools/benchmark/nnapi_delegate_provider.cc @@ -50,7 +50,10 @@ std::vector<Flag> NnapiDelegateProvider::CreateFlags( "sustained_speed, low_power, undefined"), CreateFlag<std::string>( "nnapi_accelerator_name", params, - "the name of the nnapi accelerator to use (requires Android Q+)")}; + "the name of the nnapi accelerator to use (requires Android Q+)"), + CreateFlag<bool>("disable_nnapi_cpu", params, + "Disable the NNAPI CPU device")}; + return flags; } @@ -60,17 +63,18 @@ void NnapiDelegateProvider::AddParams(BenchmarkParams* params) const { BenchmarkParam::Create<std::string>("")); params->AddParam("nnapi_accelerator_name", BenchmarkParam::Create<std::string>("")); + params->AddParam("disable_nnapi_cpu", BenchmarkParam::Create<bool>(false)); } void NnapiDelegateProvider::LogParams(const BenchmarkParams& params) const { #if defined(__ANDROID__) TFLITE_LOG(INFO) << "Use nnapi : [" << params.Get<bool>("use_nnapi") << "]"; - if (!params.Get<std::string>("nnapi_execution_preference").empty()) { - TFLITE_LOG(INFO) << "nnapi execution preference: [" - << params.Get<std::string>("nnapi_execution_preference") - << "]"; - } if (params.Get<bool>("use_nnapi")) { + if (!params.Get<std::string>("nnapi_execution_preference").empty()) { + TFLITE_LOG(INFO) << "nnapi execution preference: [" + << params.Get<std::string>("nnapi_execution_preference") + << "]"; + } std::string log_string = "nnapi accelerator name: [" + params.Get<std::string>("nnapi_accelerator_name") + "]"; @@ -80,6 +84,10 @@ void NnapiDelegateProvider::LogParams(const BenchmarkParams& params) const { log_string += " (Available: " + string_device_names_list + ")"; } TFLITE_LOG(INFO) << log_string; + if (params.Get<bool>("disable_nnapi_cpu")) { + TFLITE_LOG(INFO) << "disable_nnapi_cpu: [" + << params.Get<bool>("disable_nnapi_cpu") << "]"; + } } #endif } @@ -93,6 +101,8 @@ TfLiteDelegatePtr NnapiDelegateProvider::CreateTfLiteDelegate( params.Get<std::string>("nnapi_accelerator_name"); if (!accelerator_name.empty()) { options.accelerator_name = accelerator_name.c_str(); + } else if (params.Get<bool>("disable_nnapi_cpu")) { + options.disallow_nnapi_cpu = true; } std::string string_execution_preference = params.Get<std::string>("nnapi_execution_preference"); @@ -121,6 +131,10 @@ TfLiteDelegatePtr NnapiDelegateProvider::CreateTfLiteDelegate( } options.execution_preference = execution_preference; } + int max_delegated_partitions = params.Get<int>("max_delegated_partitions"); + if (max_delegated_partitions > 0) { + options.max_number_delegated_partitions = max_delegated_partitions; + } delegate = evaluation::CreateNNAPIDelegate(options); if (!delegate.get()) { TFLITE_LOG(WARN) << "NNAPI acceleration is unsupported on this platform."; diff --git a/tensorflow/lite/tools/make/Makefile b/tensorflow/lite/tools/make/Makefile index 1597b8e7f2e..b78fb14b785 100644 --- a/tensorflow/lite/tools/make/Makefile +++ b/tensorflow/lite/tools/make/Makefile @@ -56,9 +56,8 @@ LIBS := \ # There are no rules for compiling objects for the host system (since we don't # generate things like the protobuf compiler that require that), so all of # these settings are for the target compiler. -CXXFLAGS := -O3 -DNDEBUG -fPIC -CXXFLAGS += $(EXTRA_CXXFLAGS) -CFLAGS := ${CXXFLAGS} +CFLAGS := -O3 -DNDEBUG -fPIC +CXXFLAGS := $(CFLAGS) --std=c++11 $(EXTRA_CXXFLAGS) LDOPTS := -L/usr/local/lib ARFLAGS := -r TARGET_TOOLCHAIN_PREFIX := @@ -68,10 +67,6 @@ ifeq ($(HOST_OS),windows) CXXFLAGS += -fext-numeric-literals -D__LITTLE_ENDIAN__ endif -ifeq ($(TARGET),ios) -CXXFLAGS += --std=c++11 -endif - # Auto-detect optimization opportunity if building natively. ifeq ($(HOST_OS),$(TARGET)) ifeq ($(HOST_ARCH),$(TARGET_ARCH)) diff --git a/tensorflow/lite/tools/pip_package/Makefile b/tensorflow/lite/tools/pip_package/Makefile index eaca6e131b3..ff18303284c 100644 --- a/tensorflow/lite/tools/pip_package/Makefile +++ b/tensorflow/lite/tools/pip_package/Makefile @@ -46,9 +46,10 @@ docker-build: docker-image --volume $(TENSORFLOW_DIR):/tensorflow \ --volume $(OUT_DIR):/out \ $(TAG_IMAGE) \ - /bin/bash -c "tensorflow/tensorflow/lite/tools/pip_package/build_pip_package.sh && \ - (cp ${MAKEFILE_DIR}/gen/tflite_pip/*.deb ${MAKEFILE_DIR}/gen/tflite_pip/python3/dist/{*.whl,*.tar.gz} /out 2>/dev/null || true)" + /bin/bash -c "/tensorflow/tensorflow/lite/tools/pip_package/build_pip_package.sh && \ + (cp /tensorflow/tensorflow/lite/tools/pip_package/gen/tflite_pip/*.deb \ + /tensorflow/tensorflow/lite/tools/pip_package/gen/tflite_pip/${PYTHON}/dist/{*.whl,*.tar.gz} \ + /out 2>/dev/null || true)" clean: rm -rf $(CURDIR)/out - diff --git a/tensorflow/lite/tools/pip_package/build_pip_package.sh b/tensorflow/lite/tools/pip_package/build_pip_package.sh index 925c6142be0..5ba2cf954f6 100755 --- a/tensorflow/lite/tools/pip_package/build_pip_package.sh +++ b/tensorflow/lite/tools/pip_package/build_pip_package.sh @@ -13,8 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -set -e -set -x +set -ex SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" PYTHON="${PYTHON:-python3}" @@ -23,7 +22,7 @@ export TENSORFLOW_DIR="${SCRIPT_DIR}/../../../.." TENSORFLOW_LITE_DIR="${TENSORFLOW_DIR}/tensorflow/lite" TENSORFLOW_VERSION=$(grep "_VERSION = " "${TENSORFLOW_DIR}/tensorflow/tools/pip_package/setup.py" | cut -d= -f2 | sed "s/[ '-]//g") export PACKAGE_VERSION="${TENSORFLOW_VERSION}${VERSION_SUFFIX}" -BUILD_DIR="${SCRIPT_DIR}/gen/tflite_pip/python3" +BUILD_DIR="${SCRIPT_DIR}/gen/tflite_pip/${PYTHON}" # Build source tree. rm -rf "${BUILD_DIR}" && mkdir -p "${BUILD_DIR}/tflite_runtime" diff --git a/tensorflow/lite/tools/pip_package/setup.py b/tensorflow/lite/tools/pip_package/setup.py index 19c9993e5fa..f99a5b043dc 100644 --- a/tensorflow/lite/tools/pip_package/setup.py +++ b/tensorflow/lite/tools/pip_package/setup.py @@ -39,47 +39,30 @@ PACKAGE_NAME = 'tflite_runtime' PACKAGE_VERSION = os.environ['PACKAGE_VERSION'] DOCLINES = __doc__.split('\n') TENSORFLOW_DIR = os.environ['TENSORFLOW_DIR'] - -# Setup cross compiling -TARGET = os.environ.get('TENSORFLOW_TARGET', None) -if TARGET == 'rpi': - os.environ['CXX'] = 'arm-linux-gnueabihf-g++' - os.environ['CC'] = 'arm-linux-gnueabihf-gcc' -elif TARGET == 'aarch64': - os.environ['CXX'] = 'aarch64-linux-gnu-g++' - os.environ['CC'] = 'aarch64-linux-gnu-gcc' -MAKE_CROSS_OPTIONS = ['TARGET=%s' % TARGET] if TARGET else [] - -TARGET_ARCH = ( - os.environ['TENSORFLOW_TARGET_ARCH'] \ - if 'TENSORFLOW_TARGET_ARCH' in os.environ - else None) -MAKE_CROSS_OPTIONS += ['TARGET_ARCH=%s' % TARGET_ARCH] \ - if TARGET_ARCH else [] - -CC_PREFIX = ( - os.environ['TENSORFLOW_CC_PREFIX'] \ - if 'TENSORFLOW_CC_PREFIX' in os.environ - else None) -MAKE_CROSS_OPTIONS += ['CC_PREFIX=%s' % CC_PREFIX] \ - if CC_PREFIX else [] - -EXTRA_CXXFLAGS = ( - os.environ['TENSORFLOW_EXTRA_CXXFLAGS'] \ - if 'TENSORFLOW_EXTRA_CXXFLAGS' in os.environ - else None) -MAKE_CROSS_OPTIONS += ['EXTRA_CXXFLAGS=%s' % EXTRA_CXXFLAGS] \ - if EXTRA_CXXFLAGS else [] - RELATIVE_MAKE_DIR = os.path.join('tensorflow', 'lite', 'tools', 'make') MAKE_DIR = os.path.join(TENSORFLOW_DIR, RELATIVE_MAKE_DIR) DOWNLOADS_DIR = os.path.join(MAKE_DIR, 'downloads') RELATIVE_MAKEFILE_PATH = os.path.join(RELATIVE_MAKE_DIR, 'Makefile') DOWNLOAD_SCRIPT_PATH = os.path.join(MAKE_DIR, 'download_dependencies.sh') +# Setup cross compiling +TARGET = os.environ.get('TENSORFLOW_TARGET') +if TARGET == 'rpi': + os.environ['CXX'] = 'arm-linux-gnueabihf-g++' + os.environ['CC'] = 'arm-linux-gnueabihf-gcc' +elif TARGET == 'aarch64': + os.environ['CXX'] = 'aarch64-linux-gnu-g++' + os.environ['CC'] = 'aarch64-linux-gnu-gcc' + +MAKE_CROSS_OPTIONS = [] +for name in ['TARGET', 'TARGET_ARCH', 'CC_PREFIX', 'EXTRA_CXXFLAGS']: + value = os.environ.get('TENSORFLOW_%s' % name) + if value: + MAKE_CROSS_OPTIONS.append('%s=%s' % (name, value)) + # Check physical memory and if we are on a reasonable non small SOC machine -# with more than 4GB, use all the CPUs, otherwisxe only 1. +# with more than 4GB, use all the CPUs, otherwise only 1. def get_build_cpus(): physical_bytes = os.sysconf('SC_PAGESIZE') * os.sysconf('SC_PHYS_PAGES') if physical_bytes < (1<<30) * 4: @@ -156,6 +139,7 @@ ext = Extension( 'interpreter_wrapper/numpy.cc', 'interpreter_wrapper/python_error_reporter.cc', 'interpreter_wrapper/python_utils.cc'], + extra_compile_args=['--std=c++11'], swig_opts=['-c++', '-I%s' % TENSORFLOW_DIR, '-module', 'interpreter_wrapper', diff --git a/tensorflow/lite/tools/versioning/op_version.cc b/tensorflow/lite/tools/versioning/op_version.cc index b699f0dbc9b..b3ef46503b3 100644 --- a/tensorflow/lite/tools/versioning/op_version.cc +++ b/tensorflow/lite/tools/versioning/op_version.cc @@ -266,6 +266,9 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) { } return 1; case BuiltinOperator_STRIDED_SLICE: + if (op_sig.options.strided_slice.num_dims > 4) { + return 4; + } // If the op takes bool input, it is version 3. if (op_sig.input_types.at(0) == TensorType_BOOL) { return 3; @@ -431,6 +434,11 @@ OpSignature GetOpSignature(const OperatorCode* op_code, const Operator* op, resize_bilinear_option->half_pixel_centers(); } } break; + // TODO(b/150176627): Add tests for GetOpSignature. + case BuiltinOperator_STRIDED_SLICE: { + op_sig.options.strided_slice.num_dims = + subgraph->tensors()->Get(op->inputs()->Get(0))->shape()->size(); + } break; default: break; diff --git a/tensorflow/lite/tools/versioning/op_version.h b/tensorflow/lite/tools/versioning/op_version.h index 7fbc5a056e5..364d1a299cc 100644 --- a/tensorflow/lite/tools/versioning/op_version.h +++ b/tensorflow/lite/tools/versioning/op_version.h @@ -49,6 +49,9 @@ typedef struct { struct { bool half_pixel_centers; } resize_bilinear; + struct { + int32_t num_dims; + } strided_slice; } options; } OpSignature; diff --git a/tensorflow/lite/tools/visualize.py b/tensorflow/lite/tools/visualize.py index b78695be5a5..1f89f9c5448 100644 --- a/tensorflow/lite/tools/visualize.py +++ b/tensorflow/lite/tools/visualize.py @@ -265,7 +265,9 @@ class TensorMapper(object): html += str(i) + " " html += NameListToString(tensor["name"]) + " " html += TensorTypeToName(tensor["type"]) + " " - html += (repr(tensor["shape"]) if "shape" in tensor else "[]") + "<br>" + html += (repr(tensor["shape"]) if "shape" in tensor else "[]") + html += (repr(tensor["shape_signature"]) + if "shape_signature" in tensor else "[]") + "<br>" html += "</span>" html += repr(x) html += "</span>" @@ -447,9 +449,9 @@ def CreateHtmlFile(tflite_input, html_output): ("builtin_options", None), ("opcode_index", opcode_mapper)] tensor_keys_to_display = [("name", NameListToString), - ("type", TensorTypeToName), - ("shape", None), - ("buffer", None), ("quantization", None)] + ("type", TensorTypeToName), ("shape", None), + ("shape_signature", None), ("buffer", None), + ("quantization", None)] html += "<h2>Subgraph %d</h2>\n" % subgraph_idx diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 3b99884c603..6f7ec6389a3 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -623,9 +623,10 @@ tf_python_pybind_extension( "@pybind11", "//third_party/python_runtime:headers", "//tensorflow/core:protos_all_cc", - "//tensorflow/core:framework_headers_lib", + "//tensorflow/core:framework", "//tensorflow/core:core_cpu_headers_lib", - "//tensorflow/core:lib_headers_for_pybind", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", "@com_google_absl//absl/types:optional", ] + if_static( extra_deps = [ @@ -1642,6 +1643,7 @@ py_library( srcs = ["framework/auto_control_deps.py"], srcs_version = "PY2AND3", deps = [ + ":auto_control_deps_utils", ":control_flow_ops", ":framework_ops", ":sparse_tensor", @@ -1650,6 +1652,15 @@ py_library( ], ) +py_library( + name = "auto_control_deps_utils", + srcs = ["framework/auto_control_deps_utils.py"], + srcs_version = "PY2AND3", + deps = [ + ":dtypes", + ], +) + tf_py_test( name = "auto_control_deps_test", size = "small", @@ -2155,6 +2166,7 @@ py_library( srcs_version = "PY2AND3", deps = [ ":array_ops", + ":auto_control_deps_utils", ":constant_op", ":control_flow_ops", ":framework_ops", @@ -2882,6 +2894,7 @@ tf_gen_op_wrapper_private_py( name = "resource_variable_ops_gen", visibility = [ "//tensorflow/compiler/tf2xla:internal", + "//tensorflow/python/distribute:__pkg__", ], ) @@ -3352,6 +3365,7 @@ py_library( srcs_version = "PY2AND3", deps = [ ":array_ops", + ":auto_control_deps_utils", ":c_api_util", ":control_flow_util_v2", ":framework_ops", @@ -3378,6 +3392,7 @@ py_library( srcs_version = "PY2AND3", deps = [ ":array_ops", + ":auto_control_deps_utils", ":constant_op", ":control_flow_ops", ":control_flow_util", @@ -3983,6 +3998,7 @@ py_library( deps = [ ":array_ops", ":array_ops_gen", + ":auto_control_deps_utils", ":dtypes", ":framework_ops", ":pywrap_tf_session", @@ -4179,6 +4195,7 @@ cuda_py_test( srcs = ["ops/stateful_random_ops_test.py"], python_version = "PY3", xla_enable_strict_auto_jit = False, + xla_enabled = True, deps = [ ":client_testlib", ":config", @@ -7495,7 +7512,7 @@ tf_python_pybind_extension( ":pybind11_status", "@pybind11", "//tensorflow/core:core_cpu_headers_lib", - "//tensorflow/core:framework_headers_lib", + "//tensorflow/core:framework", "//tensorflow/core:gpu_id", "//tensorflow/core:protos_all_cc", ] + if_not_windows(["//tensorflow/core/grappler/costs:graph_properties"]), # b/148556093, @@ -7631,9 +7648,9 @@ tf_python_pybind_extension( deps = [ ":pybind11_status", "//tensorflow/core:core_cpu_headers_lib", - "//tensorflow/core:framework_headers_lib", + "//tensorflow/core:framework", "//tensorflow/core:gpu_id", - "//tensorflow/core:lib_headers_for_pybind", + "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "@pybind11", ], diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index e376480d188..5e0355d644f 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -31,7 +31,7 @@ from tensorflow.python.util.tf_export import tf_export # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2020, 2, 24) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2020, 2, 26) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None diff --git a/tensorflow/python/data/experimental/kernel_tests/snapshot_test.py b/tensorflow/python/data/experimental/kernel_tests/snapshot_test.py index 96b3b764864..535cf884dc6 100644 --- a/tensorflow/python/data/experimental/kernel_tests/snapshot_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/snapshot_test.py @@ -161,17 +161,49 @@ class SnapshotDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase, self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1) - @combinations.generate(test_base.default_test_combinations()) - def testWriteSnapshotRepeatAfterwards(self): + @combinations.generate( + combinations.times( + test_base.default_test_combinations(), + combinations.combine(compression=[ + snapshot.COMPRESSION_NONE, snapshot.COMPRESSION_GZIP, + snapshot.COMPRESSION_SNAPPY + ]))) + def testWriteSnapshotRepeatAfterwards(self, compression): tmpdir = self.snapshot_dir dataset = dataset_ops.Dataset.range(10) - dataset = dataset.apply(snapshot.snapshot(tmpdir)) + dataset = dataset.apply(snapshot.snapshot(tmpdir, compression=compression)) dataset = dataset.repeat(10) self.assertDatasetProduces(dataset, list(range(10)) * 10) self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1) + @combinations.generate( + combinations.times( + test_base.default_test_combinations(), + combinations.combine(compression=[ + snapshot.COMPRESSION_NONE, snapshot.COMPRESSION_GZIP, + snapshot.COMPRESSION_SNAPPY + ]))) + def testWriteSnapshotMixTypes(self, compression): + tmpdir = self.snapshot_dir + + dataset = dataset_ops.Dataset.range(10) + + def map_fn(x): + return (x, string_ops.as_string(x), string_ops.as_string(2 * x), 2 * x) + + dataset = dataset.map(map_fn) + dataset = dataset.apply(snapshot.snapshot(tmpdir, compression=compression)) + dataset = dataset.repeat(10) + + expected = [] + for i in range(10): + expected.append((i, str(i), str(2 * i), 2 * i)) + self.assertDatasetProduces(dataset, expected * 10) + + self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1) + @combinations.generate(test_base.default_test_combinations()) def testSpecifySnapshotNameWriteAndRead(self): tmpdir = self.snapshot_dir @@ -365,8 +397,14 @@ class SnapshotDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase, res3 = self.evaluate(next3()) self.assertEqual(res2, res3) - @combinations.generate(test_base.default_test_combinations()) - def testReadSnapshotParallelAfterWrite(self): + @combinations.generate( + combinations.times( + test_base.default_test_combinations(), + combinations.combine(compression=[ + snapshot.COMPRESSION_NONE, snapshot.COMPRESSION_GZIP, + snapshot.COMPRESSION_SNAPPY + ]))) + def testReadSnapshotParallelAfterWrite(self, compression): self.setUpTFRecord(10, 4000) filenames = self.test_filenames @@ -383,7 +421,8 @@ class SnapshotDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase, tmpdir, shard_size_bytes=1024 * 1024, num_reader_threads=2, - reader_buffer_size=10)) + reader_buffer_size=10, + compression=compression)) self.assertDatasetProduces(dataset, expected, assert_items_equal=True) # remove the original files and try to read the data back only from @@ -396,7 +435,8 @@ class SnapshotDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase, tmpdir, shard_size_bytes=1024 * 1024, num_reader_threads=2, - reader_buffer_size=10)) + reader_buffer_size=10, + compression=compression)) self.assertDatasetProduces(dataset2, expected, assert_items_equal=True) # Not testing Snappy here because Snappy reads currently require a lot of @@ -514,21 +554,31 @@ class SnapshotDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase, self.evaluate(next2()) self.assertSnapshotDirectoryContains(tmpdir, 2, 1, 1) - @combinations.generate(test_base.default_test_combinations()) - def testSpecifyShardSize(self): + @combinations.generate( + combinations.times( + test_base.default_test_combinations(), + combinations.combine(compression=[ + snapshot.COMPRESSION_NONE, snapshot.COMPRESSION_GZIP, + snapshot.COMPRESSION_SNAPPY + ]))) + def testSpecifyShardSize(self, compression): tmpdir = self.snapshot_dir dataset = dataset_ops.Dataset.from_tensor_slices([1.0]) dataset = dataset.map(lambda x: gen_array_ops.broadcast_to(x, [1024, 1024])) dataset = dataset.repeat(10) dataset = dataset.apply( - snapshot.snapshot(tmpdir, shard_size_bytes=10 * 1024 * 1024)) + snapshot.snapshot( + tmpdir, shard_size_bytes=10 * 1024 * 1024, compression=compression)) next_fn = self.getNext(dataset) for _ in range(10): self.evaluate(next_fn()) - self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 3) + num_files = 1 + if compression == snapshot.COMPRESSION_NONE: + num_files = 3 + self.assertSnapshotDirectoryContains(tmpdir, 1, 1, num_files) @combinations.generate(test_base.default_test_combinations()) def testAdditionalOperationsAfterReadBack(self): diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index 2bd34b195e4..6596bfa8607 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -44,6 +44,7 @@ from tensorflow.python.data.util import traverse from tensorflow.python.eager import context from tensorflow.python.eager import function as eager_function from tensorflow.python.framework import auto_control_deps +from tensorflow.python.framework import auto_control_deps_utils as acd_utils from tensorflow.python.framework import composite_tensor from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -4469,43 +4470,67 @@ def _collect_resource_inputs(op): """Collects resource inputs for the given ops (and its variant inputs).""" def _process(op_queue, seen_ops): - """Processes the next element of the op queue.""" + """Processes the next element of the op queue. - result = [] + Args: + op_queue: Queue of Dataset operations to process. + seen_ops: Already processed set of Operations. + + Returns: + A 2-tuple containing sets of resource handles. The first tuple entry + contains read-only handles and the second entry contains read-write + handles. + """ + + reads = [] + writes = [] op = op_queue.pop() if op in seen_ops: - return result + return reads, writes seen_ops.add(op) for t in op.inputs: if t.dtype == dtypes.variant: # Conservatively assume that any variant inputs are datasets. op_queue.append(t.op) elif t.dtype == dtypes.resource: - result.append(t) - return result + # TODO(b/150139257): This always returns True right now since we have + # not updated the functional ops to set the special attribute that ACD + # uses to figure out which of the op's inputs are read-only. + if acd_utils.op_writes_to_resource(t, op): + writes.append(t) + else: + reads.append(t) + return reads, writes op_queue = [op] seen_ops = set() - resource_inputs = [] + all_reads = [] + all_writes = [] while op_queue: - resource_inputs.extend(_process(op_queue, seen_ops)) + reads, writes = _process(op_queue, seen_ops) + all_reads.extend(reads) + all_writes.extend(writes) - return resource_inputs + return all_reads, all_writes @auto_control_deps.register_acd_resource_resolver -def _resource_resolver(op, resource_inputs): +def _resource_resolver(op, resource_reads, resource_writes): """Updates resource inputs for tf.data ops with indirect dependencies.""" updated = False if op.type in [ "DatasetToSingleElement", "DatasetToTFRecord", "ReduceDataset" ]: - indirect_resource_inputs = _collect_resource_inputs(op) - for inp in indirect_resource_inputs: - if inp not in resource_inputs: + reads, writes = _collect_resource_inputs(op) + for inp in reads: + if inp not in resource_reads: updated = True - resource_inputs.add(inp) + resource_reads.add(inp) + for inp in writes: + if inp not in resource_writes: + updated = True + resource_writes.add(inp) if op.type in [ "IteratorGetNext", "IteratorGetNextSync", "IteratorGetNextAsOptional" @@ -4516,10 +4541,14 @@ def _resource_resolver(op, resource_inputs): ] if len(make_iterator_ops) == 1: - indirect_resource_inputs = _collect_resource_inputs(make_iterator_ops[0]) - for inp in indirect_resource_inputs: - if inp not in resource_inputs: + reads, writes = _collect_resource_inputs(make_iterator_ops[0]) + for inp in reads: + if inp not in resource_reads: updated = True - resource_inputs.add(inp) + resource_reads.add(inp) + for inp in writes: + if inp not in resource_writes: + updated = True + resource_writes.add(inp) return updated diff --git a/tensorflow/python/debug/lib/dumping_callback_test.py b/tensorflow/python/debug/lib/dumping_callback_test.py index 5382965ebc4..b76077e8def 100644 --- a/tensorflow/python/debug/lib/dumping_callback_test.py +++ b/tensorflow/python/debug/lib/dumping_callback_test.py @@ -1128,7 +1128,7 @@ class TracingCallbackTest( # 1st element: tensor ID; 2nd element: 0 indicating no inf or nan. self.assertAllClose(trace.debug_tensor_value, [tensor_id, 0]) elif tensor_debug_mode == "CONCISE_HEALTH": - for tensor_value in tensor_values: + for trace in graph_exec_traces: tensor_id = reader.graph_execution_trace_to_tensor_id(trace) # 1st element: tensor ID. # 2nd element: element count. Remaining elements: all zero because there diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD index 7f4bc9641ae..efa10d6ee45 100644 --- a/tensorflow/python/distribute/BUILD +++ b/tensorflow/python/distribute/BUILD @@ -63,9 +63,11 @@ py_library( srcs = ["cross_device_ops.py"], srcs_version = "PY2AND3", deps = [ + ":collective_util", ":cross_device_utils", ":device_util", ":reduce_util", + ":tpu_values", ":values", "//tensorflow/python:array_ops", "//tensorflow/python:device_lib", @@ -96,6 +98,7 @@ py_library( "//tensorflow/python:gradients", "//tensorflow/python:math_ops", "//tensorflow/python:nccl_ops", + "//tensorflow/python:platform", ], ) @@ -144,6 +147,7 @@ py_library( ], srcs_version = "PY2AND3", deps = [ + ":collective_util", ":device_util", ":numpy_dataset", ":reduce_util", @@ -531,6 +535,7 @@ py_library( ":input_lib", ":numpy_dataset", ":reduce_util", + ":tpu_values", ":values", "//tensorflow/compiler/xla/experimental/xla_sharding", "//tensorflow/python:array_ops", @@ -578,6 +583,15 @@ py_library( ], ) +py_library( + name = "collective_util", + srcs = ["collective_util.py"], + deps = [ + "//tensorflow/python:util", + "//tensorflow/python:variable_scope", + ], +) + py_library( name = "shared_variable_creator", srcs = ["shared_variable_creator.py"], @@ -612,18 +626,36 @@ py_library( deps = [ ":device_util", ":distribute_lib", + ":reduce_util", "//tensorflow/python:array_ops", + "//tensorflow/python:composite_tensor", "//tensorflow/python:control_flow_ops", - "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", - "//tensorflow/python:resource_variable_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:tensor_util", + "//tensorflow/python:type_spec", "//tensorflow/python:util", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", "//tensorflow/python/eager:context", - "//tensorflow/python/tpu:tpu_lib", + "//tensorflow/python/eager:tape", "//tensorflow/python/training/saving:saveable_object", "//tensorflow/python/training/saving:saveable_object_util", "//tensorflow/python/training/tracking:base", - "@six_archive//:six", + ], +) + +py_library( + name = "tpu_values", + srcs = ["tpu_values.py"], + deps = [ + ":values", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:resource_variable_ops_gen", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:tape", + "//tensorflow/python/tpu:tpu_lib", ], ) @@ -775,7 +807,9 @@ cuda_py_test( name = "cross_device_utils_test", srcs = ["cross_device_utils_test.py"], deps = [ + "//tensorflow/python:array_ops", "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", "//tensorflow/python:math_ops", "//tensorflow/python/distribute:combinations", @@ -795,6 +829,7 @@ cuda_py_test( ], deps = [ ":collective_all_reduce_strategy", + ":collective_util", ":mirrored_strategy", "//tensorflow/python:array_ops", "//tensorflow/python:constant_op", @@ -883,7 +918,7 @@ distribute_py_test( srcs = ["values_test.py"], main = "values_test.py", tags = [ - "no_oss", # http://b/119349471 + "multi_and_single_gpu", ], deps = [ ":mirrored_strategy", diff --git a/tensorflow/python/distribute/collective_all_reduce_strategy.py b/tensorflow/python/distribute/collective_all_reduce_strategy.py index e77fce0f2a5..ab9721e1bfb 100644 --- a/tensorflow/python/distribute/collective_all_reduce_strategy.py +++ b/tensorflow/python/distribute/collective_all_reduce_strategy.py @@ -95,6 +95,7 @@ class CollectiveAllReduceStrategy(distribute_lib.Strategy): TFConfigClusterResolver which is instantiated from the TF_CONFIG env var. """ + # TODO(b/150151677): consider move communication to CollectiveHints. super(CollectiveAllReduceStrategy, self).__init__( CollectiveAllReduceExtended( self, @@ -505,7 +506,7 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): return updated_config - def _reduce_to(self, reduce_op, value, destinations): + def _reduce_to(self, reduce_op, value, destinations, experimental_hints): if (isinstance(value, values.Mirrored) and reduce_op == reduce_util.ReduceOp.MEAN): return value @@ -526,7 +527,10 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): return cross_device_ops_lib.reduce_non_distributed_value( reduce_op, value, destinations, len(self.worker_devices)) return self._get_cross_device_ops().reduce( - reduce_op, value, destinations=destinations) + reduce_op, + value, + destinations=destinations, + experimental_hints=experimental_hints) def _warn_nccl_no_gpu(self): if ((self._communication == diff --git a/tensorflow/python/distribute/collective_util.py b/tensorflow/python/distribute/collective_util.py new file mode 100644 index 00000000000..fb7008d1636 --- /dev/null +++ b/tensorflow/python/distribute/collective_util.py @@ -0,0 +1,63 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities for collectives.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.util.tf_export import tf_export + + +@tf_export("distribute.experimental.CollectiveHints") +class Hints(object): + """Hints for collective operations like AllReduce. + + This can be passed to methods like + `tf.distribute.get_replica_context().all_reduce()` to optimize collective + operation performance. Note that these are only hints, which may or may not + change the actual behavior. Some options only apply to certain strategy and + are ignored by others. + + One common optimization is to break gradients all-reduce into multiple packs + so that weight updates can overlap with gradient all-reduce. + + Example: + + ```python + hints = tf.distribute.experimental.CollectiveHints( + bytes_per_pack=50 * 1024 * 1024) + grads = tf.distribute.get_replica_context().all_reduce( + 'sum', grads, experimental_hints=hints) + optimizer.apply_gradients(zip(grads, vars), all_reduce_sum_gradients=False) + ``` + + """ + + def __init__(self, bytes_per_pack=0): + """Creates a CollectiveHints. + + Args: + bytes_per_pack: A non-negative integer. Breaks collective operations into + packs of certain size. If it's zero, the value is determined + automatically. This only applies to all-reduce with + `MultiWorkerMirroredStrategy` currently. + + Raises: + ValueError: When arguments have invalid value. + """ + if bytes_per_pack < 0: + raise ValueError("bytes_per_pack must be non-negative") + self.bytes_per_pack = bytes_per_pack diff --git a/tensorflow/python/distribute/cross_device_ops.py b/tensorflow/python/distribute/cross_device_ops.py index eae15cf292c..eeb45428bab 100644 --- a/tensorflow/python/distribute/cross_device_ops.py +++ b/tensorflow/python/distribute/cross_device_ops.py @@ -19,14 +19,16 @@ from __future__ import division from __future__ import print_function import collections -import enum +import enum import six from tensorflow.python.client import device_lib +from tensorflow.python.distribute import collective_util from tensorflow.python.distribute import cross_device_utils from tensorflow.python.distribute import device_util from tensorflow.python.distribute import reduce_util +from tensorflow.python.distribute import tpu_values from tensorflow.python.distribute import values as value_lib from tensorflow.python.eager import context from tensorflow.python.eager import def_function @@ -62,8 +64,8 @@ def validate_destinations(destinations): if not isinstance( destinations, (value_lib.DistributedValues, ops.Tensor, value_lib.AggregatingVariable, - six.string_types, value_lib.TPUMirroredVariable) - ) and not resource_variable_ops.is_resource_variable(destinations): + six.string_types, tpu_values.TPUMirroredVariable + )) and not resource_variable_ops.is_resource_variable(destinations): raise ValueError("destinations must be one of a `DistributedValues` object," " a tf.Variable object, or a device string.") @@ -221,7 +223,11 @@ class CrossDeviceOps(object): # Returns 1 by default, the value may be overridden by sub classes. return 1 - def reduce(self, reduce_op, per_replica_value, destinations): + def reduce(self, + reduce_op, + per_replica_value, + destinations, + experimental_hints=None): """Reduce `per_replica_value` to `destinations`. It runs the reduction operation defined by `reduce_op` and put the @@ -230,8 +236,10 @@ class CrossDeviceOps(object): Args: reduce_op: An instance of `tf.distribute.ReduceOp` that indicates how per_replica_value will be reduced. - per_replica_value: a PerReplica object or a tensor with device set. + per_replica_value: A PerReplica object or a tensor with device set. destinations: the reduction destinations. + experimental_hints: A `tf.distrbute.experimental.CollectiveHints`. Hints + to perform collective operations. Returns: a Mirrored object. @@ -253,10 +261,15 @@ class CrossDeviceOps(object): per_replica_value.values, wrap_class=value_lib.Mirrored) + if experimental_hints is None: + experimental_hints = collective_util.Hints() return self.reduce_implementation(reduce_op, per_replica_value, - destinations) + destinations, experimental_hints) - def batch_reduce(self, reduce_op, value_destination_pairs): + def batch_reduce(self, + reduce_op, + value_destination_pairs, + experimental_hints=None): """Reduce PerReplica objects in a batch. Reduce each first element in `value_destination_pairs` to each second @@ -266,10 +279,12 @@ class CrossDeviceOps(object): fuse several tensors into one or multiple packs before reduction. Args: - reduce_op: An instance of `tf.distribute.ReduceOp` that indicates how - the `per_replica_value` will be reduced. - value_destination_pairs: a list or a tuple of PerReplica objects - (or tensors with device set if there is one device) and destinations. + reduce_op: An instance of `tf.distribute.ReduceOp` that indicates how the + `per_replica_value` will be reduced. + value_destination_pairs: A list or a tuple of PerReplica objects (or + tensors with device set if there is one device) and destinations. + experimental_hints: A `tf.distrbute.experimental.CollectiveHints`. Hints + to perform collective operations. Returns: a list of Mirrored objects. @@ -298,7 +313,10 @@ class CrossDeviceOps(object): for v, _ in value_destination_pairs ] - return self.batch_reduce_implementation(reduce_op, value_destination_pairs) + if experimental_hints is None: + experimental_hints = collective_util.Hints() + return self.batch_reduce_implementation(reduce_op, value_destination_pairs, + experimental_hints) def broadcast(self, tensor, destinations): """Broadcast the `tensor` to destinations. @@ -314,7 +332,8 @@ class CrossDeviceOps(object): return self.broadcast_implementation(tensor, destinations) @doc_controls.for_subclass_implementers - def reduce_implementation(self, reduce_op, per_replica_value, destinations): + def reduce_implementation(self, reduce_op, per_replica_value, destinations, + experimental_hints): """The implementation of reduce of `per_replica_value` to `destinations`. Overriding this method is useful for subclass implementers. @@ -325,8 +344,10 @@ class CrossDeviceOps(object): Args: reduce_op: An instance `tf.distribute.ReduceOp` that indicates of how per_replica_value will be reduced. - per_replica_value: a PerReplica object or a tensor with device set. + per_replica_value: A PerReplica object or a tensor with device set. destinations: the reduction destinations. + experimental_hints: A `tf.distrbute.experimental.CollectiveHints`. Hints + to perform collective operations. Returns: a Mirrored object. @@ -339,7 +360,8 @@ class CrossDeviceOps(object): "_reduce method must be implemented in descendants.") @doc_controls.for_subclass_implementers - def batch_reduce_implementation(self, reduce_op, value_destination_pairs): + def batch_reduce_implementation(self, reduce_op, value_destination_pairs, + experimental_hints): """Implementation of reduce PerReplica objects in a batch. Overriding this method is useful for subclass implementers. @@ -350,8 +372,10 @@ class CrossDeviceOps(object): Args: reduce_op: An instance of `tf.distribute.ReduceOp` that indicates how per_replica_value will be reduced. - value_destination_pairs: an iterable of tuples of PerReplica objects + value_destination_pairs: An iterable of tuples of PerReplica objects (or tensors with device set if there is one device) and destinations. + experimental_hints: A `tf.distrbute.experimental.CollectiveHints`. Hints + to perform collective operations. Returns: a list of Mirrored objects. @@ -361,7 +385,8 @@ class CrossDeviceOps(object): tuples of PerReplica objects and destinations """ raise NotImplementedError( - "_batch_reduce method must be implemented in descendants.") + "batch_reduce_implementation method must be implemented in descendants." + ) @doc_controls.for_subclass_implementers def broadcast_implementation(self, tensor, destinations): @@ -402,7 +427,9 @@ class ReductionToOneDevice(CrossDeviceOps): self.accumulation_fn = accumulation_fn or math_ops.add_n super(ReductionToOneDevice, self).__init__() - def reduce_implementation(self, reduce_op, per_replica_value, destinations): + def reduce_implementation(self, reduce_op, per_replica_value, destinations, + experimental_hints): + del experimental_hints # Unused. if check_destinations(destinations): devices = get_devices_from(destinations) else: @@ -415,9 +442,11 @@ class ReductionToOneDevice(CrossDeviceOps): self.accumulation_fn, reduce_op) return self.broadcast(reduced, destinations) - def batch_reduce_implementation(self, reduce_op, value_destination_pairs): + def batch_reduce_implementation(self, reduce_op, value_destination_pairs, + experimental_hints): return [ - self.reduce_implementation(reduce_op, t, destinations=v) + self.reduce_implementation( + reduce_op, t, destinations=v, experimental_hints=experimental_hints) for t, v in value_destination_pairs ] @@ -625,21 +654,24 @@ class AllReduceCrossDeviceOps(CrossDeviceOps): self._simple_cross_replica_ops = ReductionToOneDevice() super(AllReduceCrossDeviceOps, self).__init__() - def reduce_implementation(self, reduce_op, per_replica_value, destinations): + def reduce_implementation(self, reduce_op, per_replica_value, destinations, + experimental_hints): + del experimental_hints # Unused. if _devices_match(per_replica_value, destinations): return self._batch_all_reduce(reduce_op, [per_replica_value])[0] else: return self._simple_cross_replica_ops.reduce(reduce_op, per_replica_value, destinations) - def batch_reduce_implementation(self, reduce_op, value_destination_pairs): + def batch_reduce_implementation(self, reduce_op, value_destination_pairs, + experimental_hints): if _all_devices_match(value_destination_pairs): return self._batch_all_reduce(reduce_op, [v[0] for v in value_destination_pairs]) else: return [ - self.reduce_implementation(reduce_op, t, destinations=v) - for t, v in value_destination_pairs + self.reduce_implementation(reduce_op, value, dest, experimental_hints) + for value, dest in value_destination_pairs ] def _batch_all_reduce(self, reduce_op, per_replica_values): @@ -903,7 +935,6 @@ class CollectiveAllReduce(CrossDeviceOps): def __init__(self, num_workers=1, num_gpus_per_worker=0, - num_packs=1, collective_keys=None, communication=CollectiveCommunication.AUTO): """Initializes the object. @@ -911,13 +942,11 @@ class CollectiveAllReduce(CrossDeviceOps): Args: num_workers: number of workers in the between-graph replicated training. num_gpus_per_worker: number of GPUs per worker. - num_packs: gradients will be packed into `num_packs` chunks. collective_keys: an optional CollectiveKey object. communication: indicates which collective communication to use. """ self._num_workers = num_workers self._num_gpus_per_worker = num_gpus_per_worker - self._num_packs = num_packs self._collective_keys = (collective_keys or cross_device_utils.CollectiveKeys()) self._communication = communication @@ -927,8 +956,10 @@ class CollectiveAllReduce(CrossDeviceOps): def _num_between_graph_workers(self): return self._num_workers - def reduce_implementation(self, reduce_op, per_replica_value, destinations): - all_reduced = self._batch_all_reduce(reduce_op, [per_replica_value])[0] + def reduce_implementation(self, reduce_op, per_replica_value, destinations, + experimental_hints): + all_reduced = self._batch_all_reduce(reduce_op, [per_replica_value], + experimental_hints)[0] devices = get_devices_from(destinations) if (isinstance(all_reduced, value_lib.Mirrored) and @@ -957,11 +988,13 @@ class CollectiveAllReduce(CrossDeviceOps): index.append(array_ops.identity(all_reduced._primary)) # pylint: disable=protected-access return value_lib.regroup(index, wrap_class=value_lib.Mirrored) - def batch_reduce_implementation(self, reduce_op, value_destination_pairs): + def batch_reduce_implementation(self, reduce_op, value_destination_pairs, + experimental_hints): all_devices_match = _all_devices_match(value_destination_pairs) if all_devices_match: return self._batch_all_reduce(reduce_op, - [v[0] for v in value_destination_pairs]) + [v[0] for v in value_destination_pairs], + experimental_hints) else: if not all_devices_match: logging.log_first_n( @@ -969,47 +1002,18 @@ class CollectiveAllReduce(CrossDeviceOps): "destinations are different.", 10) return [ - self.reduce_implementation(reduce_op, t, destinations=v) - for t, v in value_destination_pairs + self.reduce_implementation(reduce_op, value, dest, experimental_hints) + for value, dest in value_destination_pairs ] - def _make_gradient_chunks(self, per_replica_values, num_packs): - """Make `per_replica_values` into chunks.""" - chunked_by_device = _group_value_by_device(per_replica_values) - chunked_by_var = list(zip(*chunked_by_device)) - # chunked_by_var is chunked by variables and takes the following format: - # [((grad0_gpu0, v0_gpu0), (grad0_gpu1, v0_gpu1), (grad0_gpu2, v0_gpu2) ..), - # ((grad1_gpu0, v1_gpu0), (grad1_gpu1, v1_gpu1), (grad1_gpu0, v1_gpu2) ..), - # ((grad2_gpu0, v2_gpu0), (grad2_gpu1, v2_gpu1), (grad2_gpu0, v2_gpu2) ..), - # ... - # ] - - # No chunking if number of variables is fewer than number of packs. - if len(chunked_by_var) < num_packs: - return [chunked_by_var] - - # First n-1 chunks get `chunk_size` grads, last chunk gets leftover grads. - # This strategy can cause the last chunk to have larger size compared to the - # first n-1 chunks. Alternatively, we can increment chunk_size by 1 to get - # slightly larger first n-1 chunks and smaller last chunk. - # TODO(ayushd): compare different packing strategies. - chunk_size = len(chunked_by_var) // num_packs - leftover_size = len(chunked_by_var) - chunk_size * (num_packs - 1) - assert leftover_size > 0 - chunked_gv = [ - chunked_by_var[x:x + chunk_size] - for x in range(0, len(chunked_by_var) - leftover_size, chunk_size) - ] - chunked_gv.append(chunked_by_var[-leftover_size:]) - - return chunked_gv - - def _batch_all_reduce(self, reduce_op, per_replica_values): + def _batch_all_reduce(self, reduce_op, per_replica_values, + experimental_hints): """All reduce algorithm in a batch.""" dense_values, dense_indices, sparse_values, sparse_indices = ( cross_device_utils.split_by_sparsity(per_replica_values)) if dense_values: - dense_results = self._do_batch_all_reduce_dense(reduce_op, dense_values) + dense_results = self._do_batch_all_reduce_dense(reduce_op, dense_values, + experimental_hints) else: dense_results = [] if sparse_values: @@ -1017,83 +1021,84 @@ class CollectiveAllReduce(CrossDeviceOps): sparse_values) else: sparse_results = [] - return cross_device_utils.stitch_values(((dense_results, dense_indices), - (sparse_results, sparse_indices))) + return cross_device_utils.stitch_values( + ((dense_results, dense_indices), (sparse_results, sparse_indices))) - def _do_batch_all_reduce_dense(self, reduce_op, per_replica_values): + def _do_batch_all_reduce_dense(self, reduce_op, per_replica_values, + experimental_hints): """All-reduce across all workers in a batch.""" - chunked_gv = self._make_gradient_chunks(per_replica_values, self._num_packs) - # Actual number of packs may be different from `self._num_packs`. e.g. if - # there are fewer tensors than `self._num_packs`. - num_actual_packs = len(chunked_gv) - batch_size = len(per_replica_values) # Pass self._communication to the runtime as a communication hint. - communication_hint = self._communication.value + communication = self._communication.value # For now, we use NCCL only when batch_size > 1. # TODO(b/132575814): switch to NCCL for all collectives when communication # is NCCL. if self._communication == CollectiveCommunication.NCCL and batch_size == 1: - communication_hint = CollectiveCommunication.AUTO.value + communication = CollectiveCommunication.AUTO.value + + # Reverse the lists so that there's better chance that values follows + # the order in which they are calculated (e.g. when they're gradients), so + # as to overlap calculation with communication. However, this may not be + # optimal for cases like gradients of complicated non-sequential models. + # + # Note that we reverse the list before packing so that the first pack won't + # be too small, since it's more likely for first few packs to have long + # queuing time due to concurrent intense computation. + # + # TODO(b/147393503): explore solutions for optimal ordering. + packs = cross_device_utils.pack_by_size( + list(reversed(per_replica_values)), experimental_hints.bytes_per_pack) if batch_size > 1: logging.info( "Collective batch_all_reduce: %d all-reduces, num_workers = %d, " - "communication_hint = %s, num_packs = %d" % - (batch_size, self._num_workers, communication_hint, num_actual_packs)) + "communication_hint = %s, num_packs = %d", batch_size, + self._num_workers, communication, len(packs)) else: logging.log_first_n( logging.INFO, "Collective batch_all_reduce: %d all-reduces, " "num_workers = %d, communication_hint = %s, num_packs = %d" % - (batch_size, self._num_workers, communication_hint, num_actual_packs), - 10) + (batch_size, self._num_workers, communication, len(packs)), 10) def batch_fn(): """Wrapper function around batched all-reduce calls.""" - reduced_gv_list = [] - # Reverse the gradient lists so that the gradient grouping roughly follows - # the order in which gradients are calculated in backprop. This should - # enable overlapping gradient all-reduce with backprop for most models. - # However, it is likely that for some complicated non-sequential models - # this grouping is not optimal. - # - # TODO(b/147393503): explore solutions for optimal gradient grouping. - for chunk in reversed(chunked_gv): - # By placing all CollectiveReduce ops in a chunk under single name - # scope, we ensure they will be picked up by the `ScopedAllocator` - # grappler optimizer and packed into a single all-reduce. + reduced_values = [] + for pack in packs: + # By placing all CollectiveReduce ops in a pack under single name scope, + # we ensure they will be picked up by the `ScopedAllocator` grappler + # optimizer and packed into a single all-reduce. with ops.name_scope("allreduce"): - for grad_and_vars in reversed(chunk): - # Gradients for the same variable but from different devices. - grads = [g for g, _ in grad_and_vars] + for per_replica in pack: # Add control dependencies per device from the last gradients to the # current set, in order to serialize NCCL launches. - if (communication_hint == CollectiveCommunication.NCCL.value and - reduced_gv_list): - control_input_grads = [g for g, _ in reduced_gv_list[-1]] + if (communication == CollectiveCommunication.NCCL.value and + reduced_values): + control_inputs = [g for g in reduced_values[-1]] else: - control_input_grads = None - collective_reduced = cross_device_utils.build_collective_reduce( - grads, self._num_workers, self._collective_keys, "Add", "Id", - communication_hint, control_input_grads) - result = [] - for (_, v), g in zip(grad_and_vars, collective_reduced): - result.append([g, v]) - reduced_gv_list.append(result) - # Reverse the batch reduced gradients to (approximately) recover the order - # in the input per_replica_values. - reduced_gv_list.reverse() - return reduced_gv_list + control_inputs = None + reduced_values.append( + cross_device_utils.build_collective_reduce( + per_replica.values, self._num_workers, + self._collective_keys, "Add", "Id", communication, + control_inputs)) + return reduced_values + if context.executing_eagerly(): batch_fn = def_function.function(batch_fn) - new_device_grads = [list(x) for x in zip(*batch_fn())] - return _ungroup_and_make_mirrored( - new_device_grads, - per_replica_values[0], - reduce_op, - num_between_graph_workers=self._num_workers) + reduced_values = batch_fn() + mirrored = [] + # Reverse the order of reduced value to recover the order in the input. + for value in reversed(reduced_values): + if reduce_op == reduce_util.ReduceOp.MEAN: + # Assume each worker has the same number of replicas. + num_replicas = len(value) * self._num_workers + for i, v in enumerate(value): + with ops.device(v.device): + value[i] = v / num_replicas + mirrored.append(value_lib.regroup(value, wrap_class=value_lib.Mirrored)) + return mirrored def _do_batch_all_reduce_sparse(self, reduce_op, per_replica_values): """All-reduce IndexedSlices across all workers in a batch.""" @@ -1105,48 +1110,31 @@ class CollectiveAllReduce(CrossDeviceOps): # Pass self._communication to the runtime as a communication hint. communication_hint = self._communication.value - # For now, we use NCCL only when batch_size > 1 and num_packs is 1. - # TODO(b/132575814): Enable NCCL if num_packs > 1. - # TODO(b/132575814): Switch to NCCL for all collectives when communication + # For now, we use NCCL only when batch_size > 1. + # TODO(b/132575814): switch to NCCL for all collectives when communication # is NCCL. - if self._communication == CollectiveCommunication.NCCL and ( - len(per_replica_values) == 1 or self._num_packs != 1): + if self._communication == CollectiveCommunication.NCCL and len( + per_replica_values) == 1: communication_hint = CollectiveCommunication.AUTO.value - chunked_gv = self._make_gradient_chunks(per_replica_values, self._num_packs) + gathered_values = [] + with ops.name_scope("allreduce"): + for per_replica in per_replica_values: + gathered_values.append( + cross_device_utils.build_collective_gather_indexed_slices( + per_replica.values, self._num_workers, self._collective_keys, + communication_hint)) - reduced_gv_list = [] - for chunk in chunked_gv: - # By placing all CollectiveReduce ops in a chunk under single name scope, - # we ensure they will be picked up by the `ScopedAllocator` grappler - # optimizer and packed into a single all-reduce. - with ops.name_scope("allreduce"): - for grad_and_vars in chunk: - grads = [g for g, _ in grad_and_vars] - - # Add control dependencies per device from the last gradients to the - # current set, in order to serialize NCCL launches. - if (communication_hint == CollectiveCommunication.NCCL.value and - reduced_gv_list): - control_input_grads = [g for g, _ in reduced_gv_list[-1]] - else: - control_input_grads = None - - collective_reduced = ( - cross_device_utils.build_collective_gather_indexed_slices( - grads, self._num_workers, self._collective_keys, - communication_hint, control_input_grads)) - result = [] - for (_, v), g in zip(grad_and_vars, collective_reduced): - result.append([g, v]) - reduced_gv_list.append(result) - - new_device_grads = [list(x) for x in zip(*reduced_gv_list)] - return _ungroup_and_make_mirrored( - new_device_grads, - per_replica_values[0], - reduce_op, - num_between_graph_workers=self._num_workers) + mirrored = [] + for value in gathered_values: + if reduce_op == reduce_util.ReduceOp.MEAN: + # Assume each worker has the same number of replicas. + num_replicas = len(value) * self._num_workers + for i, v in enumerate(value): + with ops.device(v.device): + value[i].values = value[i].values / num_replicas + mirrored.append(value_lib.regroup(value, wrap_class=value_lib.Mirrored)) + return mirrored def choose_the_best(devices, session_config=None): diff --git a/tensorflow/python/distribute/cross_device_ops_test.py b/tensorflow/python/distribute/cross_device_ops_test.py index 3a9d7b7ec44..aba5316f7b5 100644 --- a/tensorflow/python/distribute/cross_device_ops_test.py +++ b/tensorflow/python/distribute/cross_device_ops_test.py @@ -24,6 +24,7 @@ from absl.testing import parameterized import numpy as np from tensorflow.core.protobuf import config_pb2 from tensorflow.python.distribute import collective_all_reduce_strategy +from tensorflow.python.distribute import collective_util from tensorflow.python.distribute import combinations from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib from tensorflow.python.distribute import cross_device_utils @@ -463,8 +464,7 @@ class CollectiveAllReduceTest(multi_worker_test_base.MultiWorkerTestBase, num_gpus=0, communication=CollectiveCommunication.AUTO, use_strategy_object=False, - local_mode=False, - num_packs=1): + local_mode=False): collective_keys = cross_device_utils.CollectiveKeys( group_key_start=10 + CollectiveAllReduceTest.collective_key_base, op_instance_key_start=100 + CollectiveAllReduceTest.collective_key_base, @@ -487,7 +487,6 @@ class CollectiveAllReduceTest(multi_worker_test_base.MultiWorkerTestBase, 1, num_gpus, collective_keys=collective_keys, - num_packs=num_packs, communication=communication) return collective_all_reduce_ops, devices, "" else: @@ -520,7 +519,6 @@ class CollectiveAllReduceTest(multi_worker_test_base.MultiWorkerTestBase, NUM_WORKERS, num_gpus, collective_keys=collective_keys, - num_packs=num_packs, communication=communication) return (collective_all_reduce_ops, devices, "grpc://" + self._cluster_spec[task_type][task_id]) @@ -532,15 +530,14 @@ class CollectiveAllReduceTest(multi_worker_test_base.MultiWorkerTestBase, communication, use_strategy_object=False, local_mode=False, - num_packs=1): + hints=None): collective_all_reduce, devices, master_target = self._get_test_objects( task_type, task_id, num_gpus, communication=communication, use_strategy_object=use_strategy_object, - local_mode=local_mode, - num_packs=num_packs) + local_mode=local_mode) if local_mode: num_workers = 1 worker_device = None @@ -553,17 +550,19 @@ class CollectiveAllReduceTest(multi_worker_test_base.MultiWorkerTestBase, if use_strategy_object: with test_object.scope(): return test_object.extended.reduce_to(reduce_op, per_replica, - destinations) + destinations, hints) else: - return test_object.reduce(reduce_op, per_replica, destinations) + return test_object.reduce(reduce_op, per_replica, destinations, hints) def _batch_reduce(test_object, reduce_op, value_destination_pairs): if use_strategy_object: with test_object.scope(): return test_object.extended.batch_reduce_to(reduce_op, - value_destination_pairs) + value_destination_pairs, + hints) else: - return test_object.batch_reduce(reduce_op, value_destination_pairs) + return test_object.batch_reduce(reduce_op, value_destination_pairs, + hints) with ops.Graph().as_default(), \ ops.device(worker_device), \ @@ -724,16 +723,17 @@ class CollectiveAllReduceTest(multi_worker_test_base.MultiWorkerTestBase, mode=["graph"], required_gpus=[0, 1, 2], use_strategy_object=[True, False], - num_packs=[1, 2])) + bytes_per_pack=[0, 1, 4])) def testReductionDistributed(self, required_gpus, use_strategy_object, - num_packs): + bytes_per_pack): + hints = collective_util.Hints(bytes_per_pack=bytes_per_pack) self._run_between_graph_clients( self._test_reduction, self._cluster_spec, required_gpus, communication=CollectiveCommunication.RING, use_strategy_object=use_strategy_object, - num_packs=num_packs) + hints=hints) @combinations.generate( combinations.combine( diff --git a/tensorflow/python/distribute/cross_device_utils.py b/tensorflow/python/distribute/cross_device_utils.py index fa6f612af17..f9917385b59 100644 --- a/tensorflow/python/distribute/cross_device_utils.py +++ b/tensorflow/python/distribute/cross_device_utils.py @@ -33,6 +33,7 @@ from tensorflow.python.ops import collective_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nccl_ops +from tensorflow.python.platform import tf_logging as logging OP_INSTANCE_KEY_START_NUMBER = 100 @@ -896,6 +897,67 @@ def stitch_values(values_and_indices_list): return result +def per_replica_num_elements(per_replica): + """Returns the static number of elements of one replica. + + Args: + per_replica: A PerReplica of Tensor or IndexedSlices. + + Returns: + Number of elements. None if some replica has a different or unknown shape. + """ + + values = per_replica._values # pylint: disable=protected-access + s0 = values[0].shape + for v in values: + assert not isinstance(v, ops.IndexedSlices) + if v.shape != s0: + return None + return s0.num_elements() + + +def pack_by_size(per_replica_list, bytes_per_pack): + """Packs `per_replica_list` into chunks of `bytes_per_pack`. + + The method preserves the original order of `per_replica_list`. The packing is + best effort, each pack could have more or less bytes than `bytes_per_pack`. + It only packs values with known shape. Note that, the usage is different from + `cross_device_ops._pack_tensors`, this function is intended to work with the + ScopeAllocator style batching used in `CollectiveAllReduce`. + + Args: + per_replica_list: A list of PerReplica. + bytes_per_pack: Bytes per pack. + + Returns: + A list of packs of PerReplica. All values are packed into one pack if + `bytes_per_pack` is zero or any of the value has unknown shape. + """ + + if bytes_per_pack == 0: + return [per_replica_list] + packs = [] + last_pack_size = 0 + for value in per_replica_list: + num_elements = per_replica_num_elements(value) + if num_elements is None: + # Can't pack values with unknown shape. + logging.warning( + 'not packing values due to the unknown or inconsistent shape of %s', + value) + return [per_replica_list] + size = num_elements * value._primary.dtype.size # pylint: disable=protected-access + # Try to keep each pack as close to bytes_per_pack as possible, while each + # pack is at least bytes_per_pack large. I.E. we err on the side of having + # few but large packs. + if not packs or last_pack_size > bytes_per_pack: + packs.append([]) + last_pack_size = 0 + packs[-1].append(value) + last_pack_size += size + return packs + + def _control_input(inputs, control_inputs, idx): """Returns the `idx`-th item in control_inputs to be used in ops.control_dependencies. diff --git a/tensorflow/python/distribute/cross_device_utils_test.py b/tensorflow/python/distribute/cross_device_utils_test.py index 217883ea21b..ae0f3b8fe76 100644 --- a/tensorflow/python/distribute/cross_device_utils_test.py +++ b/tensorflow/python/distribute/cross_device_utils_test.py @@ -26,8 +26,11 @@ from tensorflow.python.distribute import device_util from tensorflow.python.distribute import values as value_lib from tensorflow.python.eager import test from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import test_util +from tensorflow.python.keras.engine import input_layer +from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops @@ -133,8 +136,86 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase): self.assertIsInstance(result, ops.IndexedSlices) self._assert_values_equal(t, result) - self.assertEqual(device_util.resolve(destination), - device_util.resolve(result.device)) + self.assertEqual( + device_util.resolve(destination), device_util.resolve(result.device)) + + +class PackBySizeTest(test.TestCase): + + def assertShape(self, per_replica, shape): + for v in per_replica._values: # pylint: disable=protected-access + self.assertEqual(v.shape, shape) + + def testPreferLargerPack(self): + # Each packs except the last one should be equal or larger than + # bytes_per_pack. + values = [ + # size = 2 * 4 * 4 * 4 = 128 + array_ops.ones([2, 4, 4], dtype=dtypes.float32), + # size = 8 * 4 = 32 + array_ops.ones([8], dtype=dtypes.int32), + # size = 10 * 10 * 8 = 800 + array_ops.ones([10, 10], dtype=dtypes.int64), + # size = 1 * 4 = 4 + array_ops.ones([1], dtype=dtypes.int32), + ] + per_replica_values = [value_lib.PerReplica([v, v]) for v in values] + packs = cross_device_utils.pack_by_size( + per_replica_values, bytes_per_pack=200) + self.assertLen(packs, 2) + self.assertLen(packs[0], 3) + self.assertShape(packs[0][0], [2, 4, 4]) + self.assertShape(packs[0][1], [8]) + self.assertShape(packs[0][2], [10, 10]) + self.assertLen(packs[1], 1) + self.assertShape(packs[1][0], [1]) + + def testZeroBytesPerPack(self): + values = [ + array_ops.ones([1], dtype=dtypes.float32), + array_ops.ones([2], dtype=dtypes.float32), + ] + per_replica_values = [value_lib.PerReplica([v, v]) for v in values] + packs = cross_device_utils.pack_by_size( + per_replica_values, bytes_per_pack=0) + self.assertLen(packs, 1) + self.assertLen(packs[0], 2) + self.assertShape(packs[0][0], [1]) + self.assertShape(packs[0][1], [2]) + + def testUnknownShape(self): + per_replica_values = [ + value_lib.PerReplica([ + array_ops.ones([10, 10], dtype=dtypes.float32), + array_ops.ones([10, 10], dtype=dtypes.float32), + ]), + value_lib.PerReplica([ + array_ops.ones([10, 10], dtype=dtypes.float32), + input_layer.Input( + shape=(10), batch_size=None, dtype=dtypes.float32), + ]), + ] + packs = cross_device_utils.pack_by_size( + per_replica_values, bytes_per_pack=1) + self.assertLen(packs, 1) + self.assertEqual(packs[0], per_replica_values) + + def testInconsistentShape(self): + per_replica_values = [ + value_lib.PerReplica([ + array_ops.ones([10, 10], dtype=dtypes.float32), + array_ops.ones([10, 10], dtype=dtypes.float32), + ]), + value_lib.PerReplica([ + array_ops.ones([10, 10], dtype=dtypes.float32), + input_layer.Input( + shape=(10), batch_size=None, dtype=dtypes.float32), + ]), + ] + packs = cross_device_utils.pack_by_size( + per_replica_values, bytes_per_pack=1) + self.assertLen(packs, 1) + self.assertEqual(packs[0], per_replica_values) if __name__ == "__main__": diff --git a/tensorflow/python/distribute/distribute_lib.py b/tensorflow/python/distribute/distribute_lib.py index 7fac3fa3d9e..1d1e44f97c9 100644 --- a/tensorflow/python/distribute/distribute_lib.py +++ b/tensorflow/python/distribute/distribute_lib.py @@ -108,6 +108,7 @@ import six from tensorflow.python.autograph.core import ag_ctx as autograph_ctx from tensorflow.python.autograph.impl import api as autograph from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.distribute import collective_util from tensorflow.python.distribute import device_util from tensorflow.python.distribute import distribution_strategy_context from tensorflow.python.distribute import numpy_dataset @@ -1719,10 +1720,10 @@ class StrategyExtendedV2(object): def _reduce(self, reduce_op, value): # Default implementation until we have an implementation for each strategy. return self._local_results( - self._reduce_to(reduce_op, value, - device_util.current() or "/device:CPU:0"))[0] + self.reduce_to(reduce_op, value, + device_util.current() or "/device:CPU:0"))[0] - def reduce_to(self, reduce_op, value, destinations): + def reduce_to(self, reduce_op, value, destinations, experimental_hints=None): """Combine (via e.g. sum or mean) values across replicas. Args: @@ -1732,6 +1733,8 @@ class StrategyExtendedV2(object): string. The return value will be copied to all destination devices (or all the devices where the `destinations` value resides). To perform an all-reduction, pass `value` to `destinations`. + experimental_hints: A `tf.distrbute.experimental.CollectiveHints`. Hints + to perform collective operations. Returns: A tensor or value mirrored to `destinations`. @@ -1744,18 +1747,25 @@ class StrategyExtendedV2(object): reduce_op = reduce_util.ReduceOp(reduce_op.upper()) assert (reduce_op == reduce_util.ReduceOp.SUM or reduce_op == reduce_util.ReduceOp.MEAN) - return self._reduce_to(reduce_op, value, destinations) + if experimental_hints is None: + experimental_hints = collective_util.Hints() + return self._reduce_to(reduce_op, value, destinations, experimental_hints) - def _reduce_to(self, reduce_op, value, destinations): + def _reduce_to(self, reduce_op, value, destinations, experimental_hints): raise NotImplementedError("must be implemented in descendants") - def batch_reduce_to(self, reduce_op, value_destination_pairs): + def batch_reduce_to(self, + reduce_op, + value_destination_pairs, + experimental_hints=None): """Combine multiple `reduce_to` calls into one for faster execution. Args: reduce_op: Reduction type, an instance of `tf.distribute.ReduceOp` enum. - value_destination_pairs: A sequence of (value, destinations) - pairs. See `reduce_to()` for a description. + value_destination_pairs: A sequence of (value, destinations) pairs. See + `reduce_to()` for a description. + experimental_hints: A `tf.distrbute.experimental.CollectiveHints`. Hints + to perform collective operations. Returns: A list of mirrored values, one per pair in `value_destination_pairs`. @@ -1765,11 +1775,16 @@ class StrategyExtendedV2(object): assert not isinstance(reduce_op, variable_scope.VariableAggregation) if isinstance(reduce_op, six.string_types): reduce_op = reduce_util.ReduceOp(reduce_op.upper()) - return self._batch_reduce_to(reduce_op, value_destination_pairs) + if experimental_hints is None: + experimental_hints = collective_util.Hints() + return self._batch_reduce_to(reduce_op, value_destination_pairs, + experimental_hints) - def _batch_reduce_to(self, reduce_op, value_destination_pairs): + def _batch_reduce_to(self, reduce_op, value_destination_pairs, + experimental_hints): return [ - self.reduce_to(reduce_op, t, destinations=v) + self.reduce_to( + reduce_op, t, destinations=v, experimental_hints=experimental_hints) for t, v in value_destination_pairs ] @@ -2267,7 +2282,7 @@ class ReplicaContext(object): require_replica_context(self) return (device_util.current(),) - def all_reduce(self, reduce_op, value): + def all_reduce(self, reduce_op, value, experimental_hints=None): """All-reduces the given `value Tensor` nest across replicas. If `all_reduce` is called in any replica, it must be called in all replicas. @@ -2289,16 +2304,21 @@ class ReplicaContext(object): reduce_op: Reduction type, an instance of `tf.distribute.ReduceOp` enum. value: The nested structure of `Tensor`s to all-reduce. The structure must be compatible with `tf.nest`. + experimental_hints: A `tf.distrbute.experimental.CollectiveHints`. Hints + to perform collective operations. Returns: A `Tensor` nest with the reduced `value`s from each replica. """ if isinstance(reduce_op, six.string_types): reduce_op = reduce_util.ReduceOp(reduce_op.upper()) + if experimental_hints is None: + experimental_hints = collective_util.Hints() def batch_all_reduce(strategy, *value_flat): return strategy.extended.batch_reduce_to( - reduce_op, [(v, _batch_reduce_destination(v)) for v in value_flat]) + reduce_op, [(v, _batch_reduce_destination(v)) for v in value_flat], + experimental_hints) if reduce_op in [reduce_util.ReduceOp.SUM, reduce_util.ReduceOp.MEAN]: # TODO(cjfj): Work out why `batch_reduce` doesn't return the correct grad. @@ -2449,9 +2469,9 @@ class _DefaultDistributionExtended(StrategyExtendedV1): replica_id_in_sync_group=constant_op.constant(0, dtypes.int32)): return fn(*args, **kwargs) - def _reduce_to(self, reduce_op, value, destinations): + def _reduce_to(self, reduce_op, value, destinations, experimental_hints): # TODO(josh11b): Use destinations? - del reduce_op, destinations + del reduce_op, destinations, experimental_hints return value def _update(self, var, fn, args, kwargs, group): diff --git a/tensorflow/python/distribute/distribute_lib_test.py b/tensorflow/python/distribute/distribute_lib_test.py index 8c7ad0ae40d..bac623ada52 100644 --- a/tensorflow/python/distribute/distribute_lib_test.py +++ b/tensorflow/python/distribute/distribute_lib_test.py @@ -95,8 +95,8 @@ class _TestExtended(distribute_lib.StrategyExtendedV1): def _local_results(self, value): return (value,) - def _reduce_to(self, reduce_op, value, destinations): - del reduce_op, destinations + def _reduce_to(self, reduce_op, value, destinations, experimental_hints): + del reduce_op, destinations, experimental_hints return value def _experimental_make_numpy_dataset(self, numpy_input, session): diff --git a/tensorflow/python/distribute/distributed_file_utils.py b/tensorflow/python/distribute/distributed_file_utils.py index 9e4a2b202f1..b8e3fd8a8c9 100644 --- a/tensorflow/python/distribute/distributed_file_utils.py +++ b/tensorflow/python/distribute/distributed_file_utils.py @@ -50,7 +50,6 @@ from __future__ import print_function import os from tensorflow.python.distribute import distribution_strategy_context -from tensorflow.python.distribute import multi_worker_util from tensorflow.python.lib.io import file_io @@ -80,7 +79,7 @@ def write_dirpath(dirpath, strategy=None): return dirpath if not strategy.extended._in_multi_worker_mode(): # pylint: disable=protected-access return dirpath - if multi_worker_util.is_chief(): + if strategy.extended.should_checkpoint: return dirpath # If this worker is not chief and hence should not save file, save it to a # temporary directory to be removed later. @@ -96,8 +95,10 @@ def remove_temp_dirpath(dirpath, strategy=None): # If strategy is still not available, this is not in distributed training. # Fallback to no-op. return - if strategy.extended._in_multi_worker_mode(): # pylint: disable=protected-access - if not multi_worker_util.is_chief(): + # TODO(anjalisridhar): Consider removing the check for multi worker mode since + # it is redundant when used with the should_checkpoint property. + if (strategy.extended._in_multi_worker_mode() and # pylint: disable=protected-access + not strategy.extended.should_checkpoint): # If this worker is not chief and hence should not save file, remove # the temporary directory. - file_io.delete_recursively(_get_temp_dir(dirpath, strategy)) + file_io.delete_recursively(_get_temp_dir(dirpath, strategy)) diff --git a/tensorflow/python/distribute/distribution_strategy_context.py b/tensorflow/python/distribute/distribution_strategy_context.py index 24d6a67187f..29593d65c5d 100644 --- a/tensorflow/python/distribute/distribution_strategy_context.py +++ b/tensorflow/python/distribute/distribution_strategy_context.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import contextlib import threading from tensorflow.python.framework import ops @@ -266,6 +267,20 @@ def experimental_set_strategy(strategy): ops.get_default_graph()._global_distribute_strategy_scope = new_scope # pylint: disable=protected-access +# ------------------------------------------------------------------------------ +# Internal helpers. + + +@contextlib.contextmanager +def enter_or_assert_strategy(strategy): + if not has_strategy(): + with strategy.scope(): + yield + else: + _assert_strategy(strategy) + yield + + # ------------------------------------------------------------------------------ # Defaults that are used when no tf.distribute.Strategy is explicitly created. # We create them lazily in a function so that we can workaround the circular @@ -284,6 +299,17 @@ _default_replica_context_lock = threading.Lock() _default_replica_mode_lock = threading.Lock() +def _assert_strategy(strategy): + if not has_strategy(): + raise RuntimeError('Need to be inside "with strategy.scope()" for %s' % + (strategy,)) + current_strategy = get_strategy() + if current_strategy is not strategy: + raise RuntimeError( + "Mixing different tf.distribute.Strategy objects: %s is not %s" % + (current_strategy, strategy)) + + def _get_default_strategy(): if _defaults["strategy"] is None: # Avoid race condition causing two defaults to be created diff --git a/tensorflow/python/distribute/mirrored_strategy.py b/tensorflow/python/distribute/mirrored_strategy.py index e57c656139a..baa6e1ac76e 100644 --- a/tensorflow/python/distribute/mirrored_strategy.py +++ b/tensorflow/python/distribute/mirrored_strategy.py @@ -788,7 +788,7 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1): def _get_cross_device_ops(self): return self._cross_device_ops or self._inferred_cross_device_ops - def _reduce_to(self, reduce_op, value, destinations): + def _reduce_to(self, reduce_op, value, destinations, experimental_hints): if (isinstance(value, values.Mirrored) and reduce_op == reduce_util.ReduceOp.MEAN): return value @@ -801,11 +801,16 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1): return cross_device_ops_lib.reduce_non_distributed_value( reduce_op, value, destinations, self._num_replicas_in_sync) return self._get_cross_device_ops().reduce( - reduce_op, value, destinations=destinations) + reduce_op, + value, + destinations=destinations, + experimental_hints=experimental_hints) - def _batch_reduce_to(self, reduce_op, value_destination_pairs): + def _batch_reduce_to(self, reduce_op, value_destination_pairs, + experimental_hints): return self._get_cross_device_ops().batch_reduce(reduce_op, - value_destination_pairs) + value_destination_pairs, + experimental_hints) def _update(self, var, fn, args, kwargs, group): # TODO(josh11b): In eager mode, use one thread per device. diff --git a/tensorflow/python/distribute/mirrored_strategy_test.py b/tensorflow/python/distribute/mirrored_strategy_test.py index 9f4b07a3e75..0ab4018ce13 100644 --- a/tensorflow/python/distribute/mirrored_strategy_test.py +++ b/tensorflow/python/distribute/mirrored_strategy_test.py @@ -1356,14 +1356,14 @@ class FunctionTest(test.TestCase, parameterized.TestCase): def forward(x, w, b): return x * w + b - x = constant_op.constant([1.0], name="x_useless") + x = array_ops.identity([1.0], name="x_useless") concrete_forward = forward.get_concrete_function(x, w._primary, b._primary) with distribution.scope(): def replica_fn(): with backprop.GradientTape() as t: - x = constant_op.constant([1.0], name="x") + x = array_ops.identity([1.0], name="x") loss = concrete_forward(x, w._get(), b._get()) - [1.0] return t.gradient(loss, [w, b]) diff --git a/tensorflow/python/distribute/one_device_strategy.py b/tensorflow/python/distribute/one_device_strategy.py index 1fb3538b517..5a6973a699b 100644 --- a/tensorflow/python/distribute/one_device_strategy.py +++ b/tensorflow/python/distribute/one_device_strategy.py @@ -356,8 +356,8 @@ class OneDeviceExtended(distribute_lib.StrategyExtendedV1): with ops.device(self._device), _OneDeviceReplicaContext(strategy): return fn(*args, **kwargs) - def _reduce_to(self, reduce_op, value, destinations): - del reduce_op, destinations + def _reduce_to(self, reduce_op, value, destinations, experimental_hints): + del reduce_op, destinations, experimental_hints return value def _update(self, var, fn, args, kwargs, group): diff --git a/tensorflow/python/distribute/parameter_server_strategy.py b/tensorflow/python/distribute/parameter_server_strategy.py index d27bacf6be7..f785a0c6266 100644 --- a/tensorflow/python/distribute/parameter_server_strategy.py +++ b/tensorflow/python/distribute/parameter_server_strategy.py @@ -466,20 +466,25 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1): "Cannot reduce to another worker: %r, current worker is %r" % (d, self._input_workers.worker_devices[0])) - def _reduce_to(self, reduce_op, value, destinations): + def _reduce_to(self, reduce_op, value, destinations, experimental_hints): self._verify_destinations_not_different_worker(destinations) if not isinstance(value, values.DistributedValues): # pylint: disable=protected-access return cross_device_ops_lib.reduce_non_distributed_value( reduce_op, value, destinations, self._num_replicas_in_sync) return self._cross_device_ops.reduce( - reduce_op, value, destinations=destinations) + reduce_op, + value, + destinations=destinations, + experimental_hints=experimental_hints) - def _batch_reduce_to(self, reduce_op, value_destination_pairs): + def _batch_reduce_to(self, reduce_op, value_destination_pairs, + experimental_hints): for _, destinations in value_destination_pairs: self._verify_destinations_not_different_worker(destinations) return self._cross_device_ops.batch_reduce(reduce_op, - value_destination_pairs) + value_destination_pairs, + experimental_hints) def _select_single_value(self, structured): """Select any single value in `structured`.""" diff --git a/tensorflow/python/distribute/strategy_test_lib.py b/tensorflow/python/distribute/strategy_test_lib.py index c889484ae68..00730959d4e 100644 --- a/tensorflow/python/distribute/strategy_test_lib.py +++ b/tensorflow/python/distribute/strategy_test_lib.py @@ -688,9 +688,9 @@ class RemoteSingleWorkerMirroredStrategyBase(DistributionTestBase): def _testDeviceScope(self, distribution): with distribution.scope(): - a = constant_op.constant(1.) + a = array_ops.identity(1.) with ops.device("/cpu:0"): - b = constant_op.constant(1.) + b = array_ops.identity(1.) if context.executing_eagerly(): device = "/job:worker/replica:0/task:0/device:CPU:0" else: diff --git a/tensorflow/python/distribute/tpu_strategy.py b/tensorflow/python/distribute/tpu_strategy.py index 54e2028ccaf..7176a2e2dc9 100644 --- a/tensorflow/python/distribute/tpu_strategy.py +++ b/tensorflow/python/distribute/tpu_strategy.py @@ -34,6 +34,7 @@ from tensorflow.python.distribute import distribute_lib from tensorflow.python.distribute import input_lib from tensorflow.python.distribute import numpy_dataset from tensorflow.python.distribute import reduce_util +from tensorflow.python.distribute import tpu_values from tensorflow.python.distribute import values from tensorflow.python.distribute.cluster_resolver import TPUClusterResolver from tensorflow.python.eager import context @@ -543,7 +544,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1): self._logical_device_stack.append(logical_device_id) try: - if values._enclosing_tpu_context() is None: # pylint: disable=protected-access + if tpu_values.enclosing_tpu_context() is None: yield else: with ops.device(tpu.core(logical_device_id)): @@ -648,20 +649,20 @@ class TPUExtended(distribute_lib.StrategyExtendedV1): with context.device_policy(context.DEVICE_PLACEMENT_SILENT): v = next_creator(**kwargs) - assert not isinstance(v, values.TPUMirroredVariable) + assert not isinstance(v, tpu_values.TPUMirroredVariable) value_list.append(v) return value_list return values.create_mirrored_variable(self._container_strategy(), _real_mirrored_creator, - values.TPUMirroredVariable, - values.TPUSyncOnReadVariable, + tpu_values.TPUMirroredVariable, + tpu_values.TPUSyncOnReadVariable, **kwargs) - def _reduce_to(self, reduce_op, value, destinations): + def _reduce_to(self, reduce_op, value, destinations, experimental_hints): if (isinstance(value, values.DistributedValues) or tensor_util.is_tensor(value) - ) and values._enclosing_tpu_context() is not None: # pylint: disable=protected-access + ) and tpu_values.enclosing_tpu_context() is not None: if reduce_op == reduce_util.ReduceOp.MEAN: # TODO(jhseu): Revisit once we support model-parallelism. value *= (1. / self._num_replicas_in_sync) @@ -701,9 +702,9 @@ class TPUExtended(distribute_lib.StrategyExtendedV1): return output def _update(self, var, fn, args, kwargs, group): - assert isinstance(var, values.TPUVariableMixin) or isinstance( + assert isinstance(var, tpu_values.TPUVariableMixin) or isinstance( var, resource_variable_ops.BaseResourceVariable) - if values._enclosing_tpu_context() is not None: # pylint: disable=protected-access + if tpu_values.enclosing_tpu_context() is not None: if group: return fn(var, *args, **kwargs) else: @@ -724,7 +725,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1): return values.update_regroup(self, updates, group) def read_var(self, var): - assert isinstance(var, values.TPUVariableMixin) or isinstance( + assert isinstance(var, tpu_values.TPUVariableMixin) or isinstance( var, resource_variable_ops.BaseResourceVariable) return var.read_value() @@ -745,7 +746,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1): # since the `1` gets broadcast as an int32 but global_step is int64. if isinstance(tensor, (float, int)): return tensor - if values._enclosing_tpu_context() is not None: # pylint: disable=protected-access + if tpu_values.enclosing_tpu_context() is not None: broadcast_tensor = [tensor for _ in range(self._num_replicas_in_sync)] result = tpu_ops.all_to_all( broadcast_tensor, diff --git a/tensorflow/python/distribute/tpu_values.py b/tensorflow/python/distribute/tpu_values.py new file mode 100644 index 00000000000..871c85405e2 --- /dev/null +++ b/tensorflow/python/distribute/tpu_values.py @@ -0,0 +1,245 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Various classes representing TPU distributed values. + +Note that the tests are in values_test.py . + +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import contextlib + +from tensorflow.python.distribute import distribution_strategy_context as ds_context +from tensorflow.python.distribute import values +from tensorflow.python.eager import context +from tensorflow.python.eager import tape +from tensorflow.python.framework import ops +from tensorflow.python.ops import gen_resource_variable_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.tpu import tpu + + +@contextlib.contextmanager +def _maybe_enter_graph(tensor): + # Note: might have an eager tensor but not be executing eagerly when + # building functions. + if (context.executing_eagerly() or isinstance(tensor, ops.EagerTensor) or + ops.has_default_graph()): + yield + else: + with tensor.graph.as_default(): + yield + + +def _make_raw_assign_fn(raw_assign_fn): # pylint: disable=missing-docstring + + def assign_fn(var, value, use_locking=False, name=None, read_value=True): # pylint: disable=missing-docstring + del use_locking # Unused. + + with _maybe_enter_graph(var.handle): + op = raw_assign_fn( + var.handle, ops.convert_to_tensor(value, dtype=var.dtype), name=name) + + with ops.control_dependencies([op]): + return var._read_variable_op() if read_value else op # pylint: disable=protected-access + + return assign_fn + + +class TPUVariableMixin(object): + """Mixin for TPU variables.""" + + def __init__(self, *args, **kwargs): + super(TPUVariableMixin, self).__init__(*args, **kwargs) + + # Handle ID is needed for `get_replicated_var_handle` to cache the variables + # correctly since in eager mode different variables can have the same name. + if ops.executing_eagerly_outside_functions(): + self._handle_id = self._common_name + "_" + str(id(self._primary)) + else: + self._handle_id = self._common_name + + def __getattr__(self, name): + if enclosing_tpu_context() is None: + return super(TPUVariableMixin, self).__getattr__(name) + else: + raise AttributeError( + "'{}' not accessible within a TPU context.".format(name)) + + def get(self): + if enclosing_tpu_context() is None: + return super(TPUVariableMixin, self).get() + else: + raise NotImplementedError( + "`TPUVariableMixin.get()` is not supported within a TPU context.") + + def _get_as_operand(self): + return self.read_value() + + def _get_closest(self): + if enclosing_tpu_context() is None: + return super(TPUVariableMixin, self)._get_closest() + else: + return self._primary + + def numpy(self): + if context.executing_eagerly(): + return self.read_value().numpy() + else: + raise NotImplementedError( + "numpy() is only available when eager execution is enabled.") + + def _is_mirrored(self): + raise NotImplementedError( + "`TPUVariableMixin._is_mirrored()` must be implemented by subclasses.") + + @property + def handle(self): + # If we're in a tpu.rewrite(), return the replicated handle. + tpu_context = enclosing_tpu_context() + if tpu_context is None: + return self._get_closest().handle + else: + return tpu_context.get_replicated_var_handle(self._handle_id, + self._values, + self._is_mirrored()) + + @property + def device(self): + return self.handle.device + + def _read_variable_op(self): + if self.trainable: + tape.variable_accessed(self) + return gen_resource_variable_ops.read_variable_op(self.handle, self.dtype) + + def read_value(self): + if enclosing_tpu_context() is None: + return super(TPUVariableMixin, self).read_value() + else: + return self._read_variable_op() + + def value(self): + if enclosing_tpu_context() is None: + return super(TPUVariableMixin, self).value() + else: + return self._read_variable_op() + + def _as_graph_element(self): + if enclosing_tpu_context() is None: + return super(TPUVariableMixin, self)._as_graph_element() # pylint: disable=protected-access + else: + return None + + @property + def op(self): + return values.DistributedVarOp(self._primary.op.name, + self._primary.op.graph, + self._primary.op.traceback, + self._primary.op.type) + + def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False): + """Converts a variable to a tensor.""" + # pylint: disable=protected-access + if enclosing_tpu_context() is None: + return super(TPUVariableMixin, self)._dense_var_to_tensor( + dtype=dtype, name=name, as_ref=as_ref) + # pylint: enable=protected-access + elif dtype is not None and dtype != self.dtype: + return math_ops.cast(self.read_value(), dtype) + else: + return self.handle if as_ref else self.read_value() + + +def enclosing_tpu_context(): + """Returns the TPUReplicateContext, which exists inside a tpu.rewrite().""" + graph = ops.get_default_graph() + while graph is not None: + # pylint: disable=protected-access + context_ = graph._get_control_flow_context() + # pylint: enable=protected-access + while context_ is not None: + if isinstance(context_, tpu.TPUReplicateContext): + return context_ + context_ = context_.outer_context + # This may be a FuncGraph due to defuns or v2 control flow. We need to + # find the original graph with the XLAControlFlowContext. + graph = getattr(graph, "outer_graph", None) + return None + + +class TPUMirroredVariable(TPUVariableMixin, values.MirroredVariable): + """Holds a map from replica to TPU variables whose values are kept in sync.""" + + def _assign_func(self, *args, **kwargs): + with ds_context.enter_or_assert_strategy(self._distribute_strategy): + if (ds_context.in_cross_replica_context() and + (enclosing_tpu_context() is not None)): + f = kwargs.pop("f") + return self._distribute_strategy.extended.update( + self, f, args=args, kwargs=kwargs) + else: + return values.MirroredVariable._assign_func(self, *args, **kwargs) + + def assign_sub(self, *args, **kwargs): + assign_sub_fn = _make_raw_assign_fn( + gen_resource_variable_ops.assign_sub_variable_op) + return self._assign_func(f=assign_sub_fn, *args, **kwargs) + + def assign_add(self, *args, **kwargs): + assign_add_fn = _make_raw_assign_fn( + gen_resource_variable_ops.assign_add_variable_op) + return self._assign_func(f=assign_add_fn, *args, **kwargs) + + def assign(self, *args, **kwargs): + assign_fn = _make_raw_assign_fn( + gen_resource_variable_ops.assign_variable_op) + return self._assign_func(f=assign_fn, *args, **kwargs) + + def _is_mirrored(self): + return True + + +class TPUSyncOnReadVariable(TPUVariableMixin, values.SyncOnReadVariable): + """Holds a map from replica to variables whose values are reduced on save.""" + + def assign_sub(self, *args, **kwargs): + if enclosing_tpu_context() is None: + return values.SyncOnReadVariable.assign_sub(self, *args, **kwargs) + else: + return _make_raw_assign_fn( + gen_resource_variable_ops.assign_sub_variable_op)(self, *args, + **kwargs) + + def assign_add(self, *args, **kwargs): + if enclosing_tpu_context() is None: + return values.SyncOnReadVariable.assign_add(self, *args, **kwargs) + else: + return _make_raw_assign_fn( + gen_resource_variable_ops.assign_add_variable_op)(self, *args, + **kwargs) + + def assign(self, *args, **kwargs): + if enclosing_tpu_context() is None: + return values.SyncOnReadVariable.assign(self, *args, **kwargs) + else: + return _make_raw_assign_fn(gen_resource_variable_ops.assign_variable_op)( + self, *args, **kwargs) + + def _is_mirrored(self): + return False diff --git a/tensorflow/python/distribute/values.py b/tensorflow/python/distribute/values.py index 48adb5159bf..c23819cde11 100644 --- a/tensorflow/python/distribute/values.py +++ b/tensorflow/python/distribute/values.py @@ -19,12 +19,11 @@ from __future__ import division from __future__ import print_function import collections -import contextlib import weakref from tensorflow.python.distribute import device_util from tensorflow.python.distribute import distribute_lib -from tensorflow.python.distribute import distribution_strategy_context +from tensorflow.python.distribute import distribution_strategy_context as ds_context from tensorflow.python.distribute import reduce_util from tensorflow.python.eager import context from tensorflow.python.eager import tape @@ -34,11 +33,9 @@ from tensorflow.python.framework import tensor_util from tensorflow.python.framework import type_spec from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import gen_resource_variable_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope as vs from tensorflow.python.ops import variables as variables_lib -from tensorflow.python.tpu import tpu from tensorflow.python.training.saving import saveable_object from tensorflow.python.training.saving import saveable_object_util from tensorflow.python.training.tracking import base as trackable @@ -48,7 +45,7 @@ from tensorflow.python.util.tf_export import tf_export def _get_current_replica_id_as_int(): """Returns the current replica ID as an integer, or `None`.""" - replica_context = distribution_strategy_context.get_replica_context() + replica_context = ds_context.get_replica_context() if replica_context: replica_id = replica_context.replica_id_in_sync_group if not isinstance(replica_id, int): @@ -362,7 +359,7 @@ class PerReplicaSpec(type_spec.TypeSpec): return self._value_specs def _to_components(self, value): - replica_context = distribution_strategy_context.get_replica_context() + replica_context = ds_context.get_replica_context() if replica_context is not None and replica_context.num_replicas_in_sync > 1: raise ValueError( "Flattening a PerReplica to components is not supported in replica " @@ -405,27 +402,6 @@ def _assign_sub_on_device(device, variable, tensor): return variable.assign_sub(tensor) -def _assert_strategy(strategy): - if not distribution_strategy_context.has_strategy(): - raise RuntimeError('Need to be inside "with strategy.scope()" for %s' % - (strategy,)) - current_strategy = distribution_strategy_context.get_strategy() - if current_strategy is not strategy: - raise RuntimeError( - "Mixing different tf.distribute.Strategy objects: %s is not %s" % - (current_strategy, strategy)) - - -@contextlib.contextmanager -def _enter_or_assert_strategy(strategy): - if not distribution_strategy_context.has_strategy(): - with strategy.scope(): - yield - else: - _assert_strategy(strategy) - yield - - DistributedVarOp = collections.namedtuple( "DistributedVarOp", ["name", "graph", "traceback", "type"]) @@ -578,7 +554,7 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable): # We want cross-replica code that does some var.op.X calls # to work (even if the current device isn't in self._devices), but # other uses of var.op in a cross-replica context to fail. - if distribution_strategy_context.in_cross_replica_context(): + if ds_context.in_cross_replica_context(): return DistributedVarOp(self._primary.op.name, self._primary.op.graph, self._primary.op.traceback, self._primary.op.type) return self._get().op @@ -588,7 +564,7 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable): return self._primary._in_graph_mode # pylint: disable=protected-access def read_value(self): - with _enter_or_assert_strategy(self._distribute_strategy): + with ds_context.enter_or_assert_strategy(self._distribute_strategy): return array_ops.identity(self._get()) def value(self): @@ -602,135 +578,6 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable): ops.register_dense_tensor_like_type(DistributedVariable) -@contextlib.contextmanager -def _maybe_enter_graph(tensor): - # Note: might have an eager tensor but not be executing eagerly when - # building functions. - if (context.executing_eagerly() or isinstance(tensor, ops.EagerTensor) or - ops.has_default_graph()): - yield - else: - with tensor.graph.as_default(): - yield - - -def _make_raw_assign_fn(raw_assign_fn): # pylint: disable=missing-docstring - - def assign_fn(var, value, use_locking=False, name=None, read_value=True): # pylint: disable=missing-docstring - del use_locking # Unused. - - with _maybe_enter_graph(var.handle): - op = raw_assign_fn( - var.handle, ops.convert_to_tensor(value, dtype=var.dtype), name=name) - - with ops.control_dependencies([op]): - return var._read_variable_op() if read_value else op # pylint: disable=protected-access - - return assign_fn - - -class TPUVariableMixin(object): - """Mixin for TPU variables.""" - - def __init__(self, *args, **kwargs): - super(TPUVariableMixin, self).__init__(*args, **kwargs) - - # Handle ID is needed for `get_replicated_var_handle` to cache the variables - # correctly since in eager mode different variables can have the same name. - if ops.executing_eagerly_outside_functions(): - self._handle_id = self._common_name + "_" + str(id(self._primary)) - else: - self._handle_id = self._common_name - - def __getattr__(self, name): - if _enclosing_tpu_context() is None: - return super(TPUVariableMixin, self).__getattr__(name) - else: - raise AttributeError( - "'{}' not accessible within a TPU context.".format(name)) - - def get(self): - if _enclosing_tpu_context() is None: - return super(TPUVariableMixin, self).get() - else: - raise NotImplementedError( - "`TPUVariableMixin.get()` is not supported within a TPU context.") - - def _get_as_operand(self): - return self.read_value() - - def _get_closest(self): - if _enclosing_tpu_context() is None: - return super(TPUVariableMixin, self)._get_closest() - else: - return self._primary - - def numpy(self): - if context.executing_eagerly(): - return self.read_value().numpy() - else: - raise NotImplementedError( - "numpy() is only available when eager execution is enabled.") - - def _is_mirrored(self): - raise NotImplementedError( - "`TPUVariableMixin._is_mirrored()` must be implemented by subclasses.") - - @property - def handle(self): - # If we're in a tpu.rewrite(), return the replicated handle. - tpu_context = _enclosing_tpu_context() - if tpu_context is None: - return self._get_closest().handle - else: - return tpu_context.get_replicated_var_handle( - self._handle_id, self._values, self._is_mirrored()) - - @property - def device(self): - return self.handle.device - - def _read_variable_op(self): - if self.trainable: - tape.variable_accessed(self) - return gen_resource_variable_ops.read_variable_op(self.handle, self.dtype) - - def read_value(self): - if _enclosing_tpu_context() is None: - return super(TPUVariableMixin, self).read_value() - else: - return self._read_variable_op() - - def value(self): - if _enclosing_tpu_context() is None: - return super(TPUVariableMixin, self).value() - else: - return self._read_variable_op() - - def _as_graph_element(self): - if _enclosing_tpu_context() is None: - return super(TPUVariableMixin, self)._as_graph_element() # pylint: disable=protected-access - else: - return None - - @property - def op(self): - return DistributedVarOp(self._primary.op.name, self._primary.op.graph, - self._primary.op.traceback, self._primary.op.type) - - def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False): - """Converts a variable to a tensor.""" - # pylint: disable=protected-access - if _enclosing_tpu_context() is None: - return super(TPUVariableMixin, self)._dense_var_to_tensor( - dtype=dtype, name=name, as_ref=as_ref) - # pylint: enable=protected-access - elif dtype is not None and dtype != self.dtype: - return math_ops.cast(self.read_value(), dtype) - else: - return self.handle if as_ref else self.read_value() - - def _validate_colocate_extended(v, extended): variable_strategy = v._distribute_strategy # pylint: disable=protected-access if variable_strategy.extended is not extended: @@ -888,9 +735,9 @@ class MirroredVariable(DistributedVariable, Mirrored): # update_non_slot() function (like OptimizerV2._finish), which can # update several non-slot variables in one call. def _assign_func(self, *args, **kwargs): - with _enter_or_assert_strategy(self._distribute_strategy): + with ds_context.enter_or_assert_strategy(self._distribute_strategy): f = kwargs.pop("f") - if distribution_strategy_context.in_cross_replica_context(): + if ds_context.in_cross_replica_context(): update_replica_id = distribute_lib.get_update_replica_id() if update_replica_id is not None: # We are calling an assign function on the mirrored variable in an @@ -933,7 +780,7 @@ class MirroredVariable(DistributedVariable, Mirrored): return strategy.extended.update( self, f, args=(v,) + other_args, kwargs=other_kwargs) - return distribution_strategy_context.get_replica_context().merge_call( + return ds_context.get_replica_context().merge_call( merge_fn, args=args, kwargs=kwargs) def assign_sub(self, *args, **kwargs): @@ -1003,60 +850,11 @@ ops.register_tensor_conversion_function(Mirrored, _tensor_conversion_mirrored_val) -def _enclosing_tpu_context(): - """Returns the TPUReplicateContext, which exists inside a tpu.rewrite().""" - graph = ops.get_default_graph() - while graph is not None: - # pylint: disable=protected-access - context_ = graph._get_control_flow_context() - # pylint: enable=protected-access - while context_ is not None: - if isinstance(context_, tpu.TPUReplicateContext): - return context_ - context_ = context_.outer_context - # This may be a FuncGraph due to defuns or v2 control flow. We need to - # find the original graph with the XLAControlFlowContext. - graph = getattr(graph, "outer_graph", None) - return None - - def is_distributed_variable(v): """Determine if a variable is ds variable or TPU mirrored variable.""" return isinstance(v, DistributedVariable) -class TPUMirroredVariable(TPUVariableMixin, MirroredVariable): - """Holds a map from replica to TPU variables whose values are kept in sync.""" - - def _assign_func(self, *args, **kwargs): - with _enter_or_assert_strategy(self._distribute_strategy): - if (distribution_strategy_context.in_cross_replica_context() and - (_enclosing_tpu_context() is not None)): - f = kwargs.pop("f") - return self._distribute_strategy.extended.update( - self, f, args=args, kwargs=kwargs) - else: - return MirroredVariable._assign_func(self, *args, **kwargs) - - def assign_sub(self, *args, **kwargs): - assign_sub_fn = _make_raw_assign_fn( - gen_resource_variable_ops.assign_sub_variable_op) - return self._assign_func(f=assign_sub_fn, *args, **kwargs) - - def assign_add(self, *args, **kwargs): - assign_add_fn = _make_raw_assign_fn( - gen_resource_variable_ops.assign_add_variable_op) - return self._assign_func(f=assign_add_fn, *args, **kwargs) - - def assign(self, *args, **kwargs): - assign_fn = _make_raw_assign_fn( - gen_resource_variable_ops.assign_variable_op) - return self._assign_func(f=assign_fn, *args, **kwargs) - - def _is_mirrored(self): - return True - - class _SyncOnReadSaveable(saveable_object.SaveableObject): """Class for defining how to restore a SyncOnReadVariable.""" @@ -1094,7 +892,7 @@ class _SyncOnReadSaveable(saveable_object.SaveableObject): def _assert_replica_context(strategy): - replica_context = distribution_strategy_context.get_replica_context() + replica_context = ds_context.get_replica_context() if not replica_context: raise RuntimeError( "Replica-local variables may only be assigned in a replica context.") @@ -1111,8 +909,8 @@ class SyncOnReadVariable(DistributedVariable): self._aggregation = aggregation def assign_sub(self, *args, **kwargs): - with _enter_or_assert_strategy(self._distribute_strategy): - if distribution_strategy_context.in_cross_replica_context(): + with ds_context.enter_or_assert_strategy(self._distribute_strategy): + if ds_context.in_cross_replica_context(): if self._aggregation == vs.VariableAggregation.SUM: raise ValueError( "SyncOnReadVariable does not support `assign_sub` in " @@ -1126,8 +924,8 @@ class SyncOnReadVariable(DistributedVariable): return self._get().assign_sub(*args, **kwargs) def assign_add(self, *args, **kwargs): - with _enter_or_assert_strategy(self._distribute_strategy): - if distribution_strategy_context.in_cross_replica_context(): + with ds_context.enter_or_assert_strategy(self._distribute_strategy): + if ds_context.in_cross_replica_context(): if self._aggregation == vs.VariableAggregation.SUM: raise ValueError( "SyncOnReadVariable does not support `assign_add` in " @@ -1141,8 +939,8 @@ class SyncOnReadVariable(DistributedVariable): return self._get().assign_add(*args, **kwargs) def assign(self, *args, **kwargs): - with _enter_or_assert_strategy(self._distribute_strategy): - if distribution_strategy_context.in_cross_replica_context(): + with ds_context.enter_or_assert_strategy(self._distribute_strategy): + if ds_context.in_cross_replica_context(): # To preserve the sum across save and restore, we have to divide the # total across all devices when restoring a variable that was summed # when saving. @@ -1155,8 +953,8 @@ class SyncOnReadVariable(DistributedVariable): return self._get().assign(*args, **kwargs) def value(self): - with _enter_or_assert_strategy(self._distribute_strategy): - if distribution_strategy_context.in_cross_replica_context(): + with ds_context.enter_or_assert_strategy(self._distribute_strategy): + if ds_context.in_cross_replica_context(): return self._get_cross_replica() else: # _get_closest() returns a Variable. @@ -1177,7 +975,7 @@ class SyncOnReadVariable(DistributedVariable): if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA: return self._primary - with _enter_or_assert_strategy(self._distribute_strategy): + with ds_context.enter_or_assert_strategy(self._distribute_strategy): return self._distribute_strategy.reduce( reduce_util.ReduceOp.from_variable_aggregation(self.aggregation), self, @@ -1185,8 +983,8 @@ class SyncOnReadVariable(DistributedVariable): def _as_graph_element(self): # pylint: disable=protected-access - with _enter_or_assert_strategy(self._distribute_strategy): - if distribution_strategy_context.in_cross_replica_context(): + with ds_context.enter_or_assert_strategy(self._distribute_strategy): + if ds_context.in_cross_replica_context(): return ops.convert_to_tensor(self._get_cross_replica()) return self._get()._as_graph_element() @@ -1207,7 +1005,7 @@ class SyncOnReadVariable(DistributedVariable): def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False): """Converts a variable to a tensor.""" - with _enter_or_assert_strategy(self._distribute_strategy): + with ds_context.enter_or_assert_strategy(self._distribute_strategy): return ops.convert_to_tensor( self._get(), dtype=dtype, name=name, as_ref=as_ref) @@ -1222,36 +1020,6 @@ ops.register_tensor_conversion_function(SyncOnReadVariable, _tensor_conversion_sync_on_read) -class TPUSyncOnReadVariable(TPUVariableMixin, SyncOnReadVariable): - """Holds a map from replica to variables whose values are reduced on save.""" - - def assign_sub(self, *args, **kwargs): - if _enclosing_tpu_context() is None: - return SyncOnReadVariable.assign_sub(self, *args, **kwargs) - else: - return _make_raw_assign_fn( - gen_resource_variable_ops.assign_sub_variable_op)(self, *args, - **kwargs) - - def assign_add(self, *args, **kwargs): - if _enclosing_tpu_context() is None: - return SyncOnReadVariable.assign_add(self, *args, **kwargs) - else: - return _make_raw_assign_fn( - gen_resource_variable_ops.assign_add_variable_op)(self, *args, - **kwargs) - - def assign(self, *args, **kwargs): - if _enclosing_tpu_context() is None: - return SyncOnReadVariable.assign(self, *args, **kwargs) - else: - return _make_raw_assign_fn(gen_resource_variable_ops.assign_variable_op)( - self, *args, **kwargs) - - def _is_mirrored(self): - return False - - def regroup(values, wrap_class=PerReplica): """Makes a nest per-replica into a nest of PerReplica/Mirrored values.""" v0 = values[0] @@ -1444,9 +1212,9 @@ class AggregatingVariable(variables_lib.Variable): return getattr(self._v, name) def _assign_func(self, *args, **kwargs): - with _enter_or_assert_strategy(self._distribute_strategy): + with ds_context.enter_or_assert_strategy(self._distribute_strategy): f = kwargs.pop("f") - if distribution_strategy_context.in_cross_replica_context(): + if ds_context.in_cross_replica_context(): if distribute_lib.get_update_replica_id() is not None: # We are calling an assign function in an update context. return f(self._v, *args, **kwargs) @@ -1456,7 +1224,7 @@ class AggregatingVariable(variables_lib.Variable): return self._distribute_strategy.extended.update( self, f, args=args, kwargs=kwargs) else: - replica_context = distribution_strategy_context.get_replica_context() + replica_context = ds_context.get_replica_context() assert replica_context # We are calling an assign function in replica context. # We reduce the value we want to assign/add/sub. More details about how diff --git a/tensorflow/python/distribute/values_test.py b/tensorflow/python/distribute/values_test.py index f2922e6e53a..d66726424a1 100644 --- a/tensorflow/python/distribute/values_test.py +++ b/tensorflow/python/distribute/values_test.py @@ -29,6 +29,7 @@ from tensorflow.python.distribute import distribute_lib from tensorflow.python.distribute import distribution_strategy_context from tensorflow.python.distribute import strategy_combinations from tensorflow.python.distribute import tpu_strategy +from tensorflow.python.distribute import tpu_values from tensorflow.python.distribute import values from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver from tensorflow.python.eager import context @@ -824,7 +825,7 @@ def _make_replica_local(method, strategy=None): name=n, initializer=init, use_resource=True)) if (strategy is not None) and isinstance(strategy, _TPU_STRATEGIES): - var_cls = values.TPUSyncOnReadVariable + var_cls = tpu_values.TPUSyncOnReadVariable else: var_cls = values.SyncOnReadVariable replica_local = var_cls(strategy, v, method) diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD index 7aef5da11f2..a6bc0b7ec56 100644 --- a/tensorflow/python/eager/BUILD +++ b/tensorflow/python/eager/BUILD @@ -878,6 +878,24 @@ tpu_py_test( ], ) +tpu_py_test( + name = "remote_cloud_tpu_pod_test", + srcs = ["remote_cloud_tpu_test.py"], + args = ["--num_tpu_devices=32"], + main = "remote_cloud_tpu_test.py", + python_version = "PY3", + tags = [ + "notap", + "tpu_pod", + ], + deps = [ + ":context", + ":remote", + "//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib", + "//tensorflow/python/tpu:tpu_strategy_util", + ], +) + cuda_py_test( name = "device_placement_test", size = "small", diff --git a/tensorflow/python/eager/core_test.py b/tensorflow/python/eager/core_test.py index 273777090c6..47b3966827f 100644 --- a/tensorflow/python/eager/core_test.py +++ b/tensorflow/python/eager/core_test.py @@ -79,7 +79,6 @@ class TFETest(test_util.TensorFlowTestCase): def setUp(self): super(TFETest, self).setUp() - ops.device(None).__enter__() context._reset_context() configure_virtual_cpus() @@ -395,7 +394,7 @@ class TFETest(test_util.TensorFlowTestCase): def testMultiCpuPlacement(self): with ops.device('cpu:1'): - x = constant_op.constant(1.0) + x = array_ops.identity(1.0) with ops.device('cpu:0'): y = array_ops.identity(x) self.assertEqual(x.device, '/job:localhost/replica:0/task:0/device:CPU:1') @@ -1064,7 +1063,6 @@ class SendRecvTest(test_util.TensorFlowTestCase): def setUp(self): super(SendRecvTest, self).setUp() - ops.device(None).__enter__() context._reset_context() configure_virtual_cpus() @@ -1084,7 +1082,7 @@ class SendRecvTest(test_util.TensorFlowTestCase): def testLocalCrossDevice(self): gpu_device_name = '/job:localhost/replica:0/task:0/device:GPU:0' with ops.device('GPU:0'): - t0 = constant_op.constant(1.0) + t0 = array_ops.identity(1.0) self._send(t0, 't0', self.cpu_device) with ops.device('cpu:0'): self.assertAllEqual( @@ -1101,7 +1099,6 @@ class EagerTensorCacheTest(test_util.TensorFlowTestCase): def setUp(self): super(EagerTensorCacheTest, self).setUp() - ops.device(None).__enter__() context._reset_context() configure_virtual_cpus() @@ -1115,4 +1112,5 @@ class EagerTensorCacheTest(test_util.TensorFlowTestCase): if __name__ == '__main__': + context.set_log_device_placement(True) test.main() diff --git a/tensorflow/python/eager/device_placement_test.py b/tensorflow/python/eager/device_placement_test.py index 32ca6d3a826..af6c68243b4 100644 --- a/tensorflow/python/eager/device_placement_test.py +++ b/tensorflow/python/eager/device_placement_test.py @@ -18,24 +18,26 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized + from tensorflow.python.eager import context from tensorflow.python.eager import def_function from tensorflow.python.eager import remote from tensorflow.python.eager import test from tensorflow.python.framework import config from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops -class SoftDevicePlacementTest(test.TestCase): +class SoftDevicePlacementTest(test.TestCase, parameterized.TestCase): def setUp(self): super(SoftDevicePlacementTest, self).setUp() - context._context = None - ops.enable_eager_execution_internal() + context._reset_context() config.set_soft_device_placement(enabled=True) context.context().log_device_placement = True @@ -90,13 +92,21 @@ class SoftDevicePlacementTest(test.TestCase): # We don't support nested device placement right now. self.assertIn('GPU:0', c.device) + @parameterized.named_parameters(('float', 1.0, None), + ('int32', [1], dtypes.int32), + ('string', ['a'], None)) + def testSoftPlacedCPUConstant(self, value, dtype): + with ops.device('GPU:0'): + a = constant_op.constant(value, dtype=dtype) + self.assertIn('CPU:0', a.device) + self.assertIn('CPU:0', a.backing_device) -class HardDevicePlacementTest(test.TestCase): + +class HardDevicePlacementTest(test.TestCase, parameterized.TestCase): def setUp(self): super(HardDevicePlacementTest, self).setUp() - context._context = None - ops.enable_eager_execution_internal() + context._reset_context() config.set_soft_device_placement(enabled=False) context.context().log_device_placement = True self.assertEqual(config.get_soft_device_placement(), False) @@ -114,13 +124,27 @@ class HardDevicePlacementTest(test.TestCase): self.assertIn('GPU:0', y.device) self.assertIn('GPU:0', y.backing_device) + @parameterized.named_parameters(('float_cpu0', 'CPU:0', 1.0, None), + ('int32_cpu0', 'CPU:0', [1], dtypes.int32), + ('string_cpu0', 'CPU:0', ['a'], None), + ('float_gpu0', 'GPU:0', 1.0, None), + ('int32_gpu0', 'GPU:0', [1], dtypes.int32), + ('string_gpu0', 'GPU:0', ['a'], None), + ('float_gpu99', 'GPU:99', 1.0, None), + ('int32_gpu99', 'GPU:99', [1], dtypes.int32), + ('string_gpu99', 'GPU:99', ['a'], None)) + def testHardPlacedCPUConstant(self, device, value, dtype): + with ops.device(device): + a = constant_op.constant(value, dtype=dtype) + self.assertIn('CPU:0', a.device) + self.assertIn('CPU:0', a.backing_device) + class ClusterPlacementTest(test.TestCase): def setUp(self): super(ClusterPlacementTest, self).setUp() - context._context = None - ops.enable_eager_execution_internal() + context._reset_context() config.set_soft_device_placement(enabled=True) context.context().log_device_placement = True workers, _ = test_util.create_local_cluster(2, 0) diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index 0a34d4a3852..7b599a995e2 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -1570,10 +1570,10 @@ class FunctionTest(test.TestCase, parameterized.TestCase): def testColocateWithRespected(self): # TODO(b/113291792): Use multiple CPUs instead of a GPU. with ops.device('cpu:0'): - x = constant_op.constant(1.0) + x = array_ops.identity(1.0) with ops.device('gpu:0'): - y = constant_op.constant(1.0) + y = array_ops.identity(1.0) @def_function.function def foo(): @@ -3239,9 +3239,9 @@ class MultiDeviceTest(test.TestCase, parameterized.TestCase): return b, a with ops.device('/device:CPU:0'): - a = constant_op.constant(3.0) + a = array_ops.identity(3.0) with ops.device('/device:GPU:0'): - b = constant_op.constant(5.0) + b = array_ops.identity(5.0) m1, m2 = func(a, b) self.assertAllEqual(m1.numpy(), 5.0) @@ -3306,9 +3306,9 @@ class MultiDeviceTest(test.TestCase, parameterized.TestCase): devices = ['/device:CPU:0', '/device:GPU:0'] for dev1, dev2 in itertools.product(devices, devices): with ops.device(dev1): - a = constant_op.constant(1.0) + a = array_ops.identity(1.0) with ops.device(dev2): - b = constant_op.constant(10.0) + b = array_ops.identity(10.0) ra, rb = func(a, b) self.assertEqual(ra.numpy(), 2.0) @@ -3469,13 +3469,13 @@ class MultiDeviceTest(test.TestCase, parameterized.TestCase): with ops.device('/device:CPU:0'): rc0 = resource_variable_ops.ResourceVariable(2.0) rc1 = resource_variable_ops.ResourceVariable(3.0) - cc0 = constant_op.constant(5.0) - cc1 = constant_op.constant(7.0) + cc0 = array_ops.identity(5.0) + cc1 = array_ops.identity(7.0) with ops.device('/device:GPU:0'): rg0 = resource_variable_ops.ResourceVariable(11.0) rg1 = resource_variable_ops.ResourceVariable(13.0) - cg0 = constant_op.constant(17.0) - cg1 = constant_op.constant(19.0) + cg0 = array_ops.identity(17.0) + cg1 = array_ops.identity(19.0) # Make sure tensors are on expected devices. for tensor in [cc0, cc1]: diff --git a/tensorflow/python/eager/pywrap_tensor.cc b/tensorflow/python/eager/pywrap_tensor.cc index b5c9bfb6824..f8e1fb568ac 100644 --- a/tensorflow/python/eager/pywrap_tensor.cc +++ b/tensorflow/python/eager/pywrap_tensor.cc @@ -278,39 +278,13 @@ TFE_TensorHandle* ConvertToEagerTensorUncached(TFE_Context* ctx, } } - // Almost all TensorFlow kernels for GPU devices keep int32 tensors in host - // memory. We approximate the same behavior for eager execution - keeping - // int32 tensors in host memory. + // We always generate CPU:0 tensors, but we may need to change the device + // slightly, as for example from /job:localhost/... to /job:worker/... // - // We do so to preclude the need for callers into such kernels from having to - // explicitly place the int32 tensors in host memory. For example, without - // this, one needed: - // - // with tf.device('/gpu:0'): - // ...// code here - // with tf.device('/cpu:0'): - // shape = tf.constant(...) - // y = tf.random_uniform(shape) - // - // Without the CPU device block, tfe.ops.random_uniform would fail since the - // kernel expects the shape in host memory. - // - // With this support, we simplify the code: - // - // with tf.device('/gpu:0'): - // y = tf.random_uniform(...) - // - // The approximation is not exact there are GPU kernels which do not require - // host memory for int32 tensors. This will lead to a discrepancy between - // eager and graph execution. - // - // To support remote execution copy int32 tensors to another CPU device. - // TODO(ashankar): Fix this. + // Note that this is a shallow copy and will share the underlying buffer, + // because we are copying to the same device. if (device_name != nullptr && - (TFE_TensorHandleDataType(handle.get()) != TF_INT32 || - strstr(device_name, "/device:CPU:0") != nullptr)) { - // Note that this is a shallow copy and will share the underlying buffer - // if copying to the same device. + strstr(device_name, "/device:CPU:0") != nullptr) { handle = make_safe(TFE_TensorHandleCopyToDevice(handle.get(), ctx, device_name, status.get())); if (MaybeRaiseExceptionFromTFStatus(status.get(), PyExc_RuntimeError)) { @@ -318,6 +292,15 @@ TFE_TensorHandle* ConvertToEagerTensorUncached(TFE_Context* ctx, } } + // We always enable implicit mirroring for constants. Without this, code + // written previously under the assumption that + // + // with tf.device('GPU:0'): x = tf.constant(1.0) + // + // will be placed in the GPU will suffer a non-trivial performance regression + // (measured at ~20% for certain benchmarks). + handle->handle->EnableImplicitMirroring(); + return handle.release(); } diff --git a/tensorflow/python/eager/remote_cloud_tpu_test.py b/tensorflow/python/eager/remote_cloud_tpu_test.py index d63a8924bc8..8ba11a3e6ac 100644 --- a/tensorflow/python/eager/remote_cloud_tpu_test.py +++ b/tensorflow/python/eager/remote_cloud_tpu_test.py @@ -31,24 +31,25 @@ flags.DEFINE_string('tpu', '', 'Name of TPU to connect to.') flags.DEFINE_string('project', None, 'Name of GCP project with TPU.') flags.DEFINE_string('zone', None, 'Name of GCP zone with TPU.') +flags.DEFINE_integer('num_tpu_devices', 8, 'The expected number of TPUs.') +DEVICES_PER_TASK = 8 + EXPECTED_DEVICES_PRE_CONNECT = [ '/device:CPU:0', '/device:XLA_CPU:0', ] -EXPECTED_DEVICES_AFTER_CONNECT = [ - '/device:CPU:0', - '/device:XLA_CPU:0', - '/job:worker/replica:0/task:0/device:CPU:0', - '/job:worker/replica:0/task:0/device:XLA_CPU:0', - '/job:worker/replica:0/task:0/device:TPU_SYSTEM:0', - '/job:worker/replica:0/task:0/device:TPU:0', - '/job:worker/replica:0/task:0/device:TPU:1', - '/job:worker/replica:0/task:0/device:TPU:2', - '/job:worker/replica:0/task:0/device:TPU:3', - '/job:worker/replica:0/task:0/device:TPU:4', - '/job:worker/replica:0/task:0/device:TPU:5', - '/job:worker/replica:0/task:0/device:TPU:6', - '/job:worker/replica:0/task:0/device:TPU:7', +EXPECTED_NEW_DEVICES_AFTER_CONNECT_TEMPLATES = [ + '/job:worker/replica:0/task:{task}/device:CPU:0', + '/job:worker/replica:0/task:{task}/device:XLA_CPU:0', + '/job:worker/replica:0/task:{task}/device:TPU_SYSTEM:0', + '/job:worker/replica:0/task:{task}/device:TPU:0', + '/job:worker/replica:0/task:{task}/device:TPU:1', + '/job:worker/replica:0/task:{task}/device:TPU:2', + '/job:worker/replica:0/task:{task}/device:TPU:3', + '/job:worker/replica:0/task:{task}/device:TPU:4', + '/job:worker/replica:0/task:{task}/device:TPU:5', + '/job:worker/replica:0/task:{task}/device:TPU:6', + '/job:worker/replica:0/task:{task}/device:TPU:7', ] @@ -56,6 +57,9 @@ class RemoteCloudTPUTest(absltest.TestCase): """Test that we can connect to a real Cloud TPU.""" def test_connect(self): + # Log full diff on failure. + self.maxDiff = None # pylint:disable=invalid-name + self.assertCountEqual( EXPECTED_DEVICES_PRE_CONNECT, [device.name for device in config.list_logical_devices()]) @@ -65,8 +69,15 @@ class RemoteCloudTPUTest(absltest.TestCase): ) remote.connect_to_cluster(resolver) + expected_devices = EXPECTED_DEVICES_PRE_CONNECT + for task in range(FLAGS.num_tpu_devices // DEVICES_PER_TASK): + expected_devices.extend([ + template.format(task=task) + for template in EXPECTED_NEW_DEVICES_AFTER_CONNECT_TEMPLATES + ]) + self.assertCountEqual( - EXPECTED_DEVICES_AFTER_CONNECT, + expected_devices, [device.name for device in config.list_logical_devices()]) tpu_strategy_util.initialize_tpu_system(resolver) diff --git a/tensorflow/python/eager/remote_test.py b/tensorflow/python/eager/remote_test.py index d9ed93fc662..44af62666ee 100644 --- a/tensorflow/python/eager/remote_test.py +++ b/tensorflow/python/eager/remote_test.py @@ -264,6 +264,42 @@ class MultiWorkersTest(test.TestCase, parameterized.TestCase): self.assertAllEqual(remote_function(constant_op.constant([1.0])), [3.0]) + @test_util.eager_lazy_remote_copy_on_and_off + def testMultiDeviceFunctionOnRemoteDeviceWithWait(self): + with ops.device('/job:worker/replica:0/task:1'): + variable_b = variables.Variable([1.0]) + + @def_function.function + def remote_function(i): + x = array_ops.ones([1000, 1000]) + for _ in range(1, 1000): + x = x * x + variable_b.assign_add(i) + a = 1.0 + variable_b + return a + + @def_function.function + def remote_function2(i): + variable_b.assign_add(i) + a = 1.0 + variable_b + return a + + # Runs first function: + # - on remote device + # - needs remote input + # - is side impacting + # - runs much slower + with ops.device('/job:worker/replica:0/task:0'): + remote_function(constant_op.constant([2.0])) + + # Runs second function: + # - on remote device + # - is side impacting + # There should be a sync point here and the next function will be executed + # only after the first function has completed. + with ops.device('/job:worker/replica:0/task:2'): + self.assertAllEqual(remote_function2(constant_op.constant([3.0])), [7.0]) + @test_util.eager_lazy_remote_copy_on_and_off def testMultiDeviceFunctionOnRemoteDevice(self): with ops.device('/job:worker/replica:0/task:1'): diff --git a/tensorflow/python/eager/tensor_test.py b/tensorflow/python/eager/tensor_test.py index dd1f049cdcc..fe4a7933a32 100644 --- a/tensorflow/python/eager/tensor_test.py +++ b/tensorflow/python/eager/tensor_test.py @@ -281,9 +281,9 @@ class TFETensorTest(test_util.TensorFlowTestCase): @test_util.run_gpu_only def testStringTensorOnGPU(self): with ops.device("/device:GPU:0"): - with self.assertRaisesRegexp( - RuntimeError, "Can't copy Tensor with type string to device"): - _create_tensor("test string") + t = _create_tensor("test string") + self.assertIn("CPU", t.device) + self.assertIn("CPU", t.backing_device) def testInvalidUTF8ProducesReasonableError(self): if sys.version_info[0] < 3: diff --git a/tensorflow/python/framework/auto_control_deps.py b/tensorflow/python/framework/auto_control_deps.py index aef685e2390..2b18eb4ede0 100644 --- a/tensorflow/python/framework/auto_control_deps.py +++ b/tensorflow/python/framework/auto_control_deps.py @@ -18,7 +18,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections +import enum + from tensorflow.python.eager import context +from tensorflow.python.framework import auto_control_deps_utils as utils from tensorflow.python.framework import dtypes as dtypes_module from tensorflow.python.framework import op_def_registry from tensorflow.python.framework import ops @@ -120,6 +124,11 @@ def op_is_stateful(op): op.type in _WHITELIST_STATELESS_OPS) +class ResourceType(enum.Enum): + READ_ONLY = "read-only" + READ_WRITE = "read-write" + + class AutomaticControlDependencies(object): """Context manager to automatically add control dependencies. @@ -197,7 +206,7 @@ class AutomaticControlDependencies(object): return self def _process_switch(self, switch_op, ops_which_must_run, - last_op_using_resource_tensor, merge_for_resource): + last_write_to_resource, merge_for_resource): """Processes a switch node for a resource input. When tensorflow creates a cond, it creates a control flow context for each @@ -227,7 +236,7 @@ class AutomaticControlDependencies(object): Args: switch_op: the switch op to be processed ops_which_must_run: the set of ops which must run - last_op_using_resource_tensor: map from resource tensor to last op using + last_write_to_resource: map from resource tensor to last op updating it merge_for_resource: map from resource tensor to merge which must follow all usages of it. @@ -235,8 +244,8 @@ class AutomaticControlDependencies(object): inp = switch_op.inputs[0] input_id = ops.tensor_id(inp) if inp.dtype == dtypes_module.resource and inp.op.type == "Switch": - self._process_switch(inp.op, ops_which_must_run, - last_op_using_resource_tensor, merge_for_resource) + self._process_switch(inp.op, ops_which_must_run, last_write_to_resource, + merge_for_resource) output = switch_op.outputs[0] output_id = ops.tensor_id(output) if output_id in merge_for_resource: @@ -247,11 +256,11 @@ class AutomaticControlDependencies(object): switch_op._control_flow_context.outer_context) # pylint: disable=protected-access # Ensures the merge always runs ops_which_must_run.add(new_merge[0].op) - if input_id in last_op_using_resource_tensor: + if input_id in last_write_to_resource: # Ensures the switch executes after the previous op using the resource. - switch_op._add_control_input(last_op_using_resource_tensor[input_id]) # pylint: disable=protected-access + switch_op._add_control_input(last_write_to_resource[input_id]) # pylint: disable=protected-access # Ensure the next op outside the cond happens after the merge. - last_op_using_resource_tensor[input_id] = new_merge[0].op + last_write_to_resource[input_id] = new_merge[0].op if input_id in merge_for_resource: merge_for_resource[input_id]._add_control_input(new_merge[0].op) # pylint: disable=protected-access for o in switch_op.outputs: @@ -274,8 +283,11 @@ class AutomaticControlDependencies(object): self._graph._add_control_dependencies = False # pylint: enable=protected-access - # map from resource tensor to the last op which used it - last_op_using_resource_tensor = {} + # map from resource tensor to the last op which wrote to it + last_write_to_resource = {} + # map from resource tensor to the list of reads from it since the last + # write or since the beginning of the function. + reads_since_last_write_to_resource = collections.defaultdict(list) # set of conditional and loop exits ops_which_must_run = set() # merge which must depend on ops which use this resource @@ -285,7 +297,7 @@ class AutomaticControlDependencies(object): # Ensures that uses of resource tensors get serialized properly and all # execute. This is done by keeping a map from resource tensor to the last op - # in graph-construction order which used it (last_op_using_resource_tensor). + # in graph-construction order which used it (last_write_to_resource). # # Conditionals are written in TensorFlow such that every external tensor # accessed in the conditional goes through a switch op and every return @@ -317,7 +329,10 @@ class AutomaticControlDependencies(object): continue control_inputs = set() # Ensure stateful ops run - if op_def_registry.get(op.type) is None or op_is_stateful(op): + if (op_def_registry.get(op.type) is None or + (op_is_stateful(op) and op.type not in utils.RESOURCE_READ_OPS)): + # TODO(srbs): Do not add functional ops to `ops_which_must_run` if + # they only have variable reads and are otherwise stateless. ops_which_must_run.add(op) # Ignore switches (they're handled separately) if op.type == "Switch" and op.inputs[0].dtype == dtypes_module.resource: @@ -328,15 +343,16 @@ class AutomaticControlDependencies(object): op._add_control_input(o) # pylint: disable=protected-access for inp in o.inputs: input_id = ops.tensor_id(inp) - if input_id in last_op_using_resource_tensor: - last_op_using_resource_tensor[input_id] = op + if input_id in last_write_to_resource: + last_write_to_resource[input_id] = op ops_which_must_run = set([op]) continue resource_inputs = set() # Check for any resource inputs. If we find any, we update control_inputs - # and last_op_using_resource_tensor. - for inp in _get_resource_inputs(op): + # and last_write_to_resource. + for inp, resource_type in _get_resource_inputs(op): + is_read = resource_type == ResourceType.READ_ONLY input_id = ops.tensor_id(inp) # If the op receives the same resource tensor twice as an input, we skip @@ -348,25 +364,29 @@ class AutomaticControlDependencies(object): # Deal with switches, finally. if inp.op.type == "Switch": self._process_switch(inp.op, ops_which_must_run, - last_op_using_resource_tensor, - merge_for_resource) + last_write_to_resource, merge_for_resource) is_building_function = op.graph.building_function # Ensure uses of resources are serialized - if input_id in last_op_using_resource_tensor: + if input_id in last_write_to_resource: if is_building_function or ( - last_op_using_resource_tensor[input_id]._control_flow_context # pylint: disable=protected-access + last_write_to_resource[input_id]._control_flow_context # pylint: disable=protected-access is op._control_flow_context): # pylint: disable=protected-access - control_inputs.add(last_op_using_resource_tensor[input_id]) + control_inputs.add(last_write_to_resource[input_id]) # Ensure merges happen after the closing of a cond block if input_id in merge_for_resource: merge_for_resource[input_id]._add_control_input(op) # pylint: disable=protected-access - last_op_using_resource_tensor[input_id] = op + if is_read: + reads_since_last_write_to_resource[input_id].append(op) + else: + control_inputs.update(reads_since_last_write_to_resource[input_id]) + reads_since_last_write_to_resource[input_id] = [] + last_write_to_resource[input_id] = op if (op_is_stateful(op) and not resource_inputs and op._control_flow_context is None): # pylint: disable=protected-access - if None in last_op_using_resource_tensor: - op._add_control_input(last_op_using_resource_tensor[None]) # pylint: disable=protected-access - last_op_using_resource_tensor[None] = op + if None in last_write_to_resource: + op._add_control_input(last_write_to_resource[None]) # pylint: disable=protected-access + last_write_to_resource[None] = op control_inputs = [ c for c in control_inputs if is_building_function or (c._control_flow_context is op._control_flow_context)] # pylint: disable=protected-access @@ -384,34 +404,43 @@ class AutomaticControlDependencies(object): ]) -_acd_resource_resolvers_registry = registry.Registry("acd_resouce_resolvers") +_acd_resource_resolvers_registry = registry.Registry("acd_resource_resolvers") def register_acd_resource_resolver(f): """Register a function for resolving resources touched by an op. + `f` is called for every Operation added in the ACD context with the op's + original resource reads and writes. `f` is expected to update the sets of + resource reads and writes in-place and return True if it updated either of the + sets, False otherwise. + Example: @register_acd_resource_resolver - def ResolveIdentity(op, resource_inputs): + def ResolveIdentity(op, resource_reads, resource_writes): # op: The `Operation` being processed by ACD currently. - # resource_inputs: An `ObjectIdentitySet` that can be updated in-place. - if not resource_inputs: + # resource_reads: An `ObjectIdentitySet` of read-only resources. + # resource_writes: An `ObjectIdentitySet` of read-write resources. + if not resource_reads or resource_writes: return False - to_add = [] - to_remove = [] - for t in resource_inputs: - if t.op.type == "Identity": - to_remove.append(t) - to_add.append(t.op.inputs[0]) - if not to_add and not to_remove: - return False - for t in to_remove: - resource_inputs.discard(t) - resource_inputs.update(to_add) - return True # `resource_inputs` was updated. + def update(resource_inputs): + to_add = [] + to_remove = [] + for t in resource_inputs: + if t.op.type == "Identity": + to_remove.append(t) + to_add.append(t.op.inputs[0]) + if not to_add and not to_remove: + return False + for t in to_remove: + resource_inputs.discard(t) + resource_inputs.update(to_add) + return True + return update(resource_reads) or update(resource_writes) Args: - f: Python function + f: Python function with signature + (Operation, ObjectIdentitySet, ObjectIdentitySet) -> bool Returns: The function `f` after adding it to the registry. @@ -422,8 +451,14 @@ def register_acd_resource_resolver(f): def _get_resource_inputs(op): """Returns an iterable of resources touched by this `op`.""" - resource_inputs = object_identity.ObjectIdentitySet( - t for t in op.inputs if t.dtype == dtypes_module.resource) + reads = object_identity.ObjectIdentitySet() + writes = object_identity.ObjectIdentitySet() + for t in op.inputs: + if t.dtype == dtypes_module.resource: + if utils.op_writes_to_resource(t, op): + writes.add(t) + else: + reads.add(t) saturated = False while not saturated: saturated = True @@ -432,10 +467,16 @@ def _get_resource_inputs(op): # resource_inputs. # TODO(srbs): An alternate would be to just compare the old and new set # but that may not be as fast. - updated = _acd_resource_resolvers_registry.lookup(key)(op, - resource_inputs) + updated = _acd_resource_resolvers_registry.lookup(key)(op, reads, writes) + if updated: + # Conservatively remove any resources from `reads` that are also writes. + reads = reads.difference(writes) saturated = saturated and not updated - return resource_inputs + + # Note: A resource handle that is not written to is treated as read-only. We + # don't have a special way of denoting an unused resource. + return ([(t, ResourceType.READ_ONLY) for t in reads] + + [(t, ResourceType.READ_WRITE) for t in writes]) def automatic_control_dependencies(f): diff --git a/tensorflow/python/framework/auto_control_deps_test.py b/tensorflow/python/framework/auto_control_deps_test.py index 19c0606eb42..a655cadbc10 100644 --- a/tensorflow/python/framework/auto_control_deps_test.py +++ b/tensorflow/python/framework/auto_control_deps_test.py @@ -17,6 +17,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import itertools + from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.eager import def_function @@ -30,6 +32,7 @@ from tensorflow.python.framework import test_util from tensorflow.python.keras.layers import core as keras_core from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import gen_resource_variable_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variables from tensorflow.python.platform import test @@ -50,6 +53,509 @@ class AutomaticControlDependenciesTest(test.TestCase): val = c.mark_as_return(val) self.assertAllEqual(val.eval(), 4.0) + def testNoControlDepsBetweenVariableReads(self): + with context.graph_mode(), self.cached_session(): + v = resource_variable_ops.ResourceVariable(1.0) + self.evaluate(variables.global_variables_initializer()) + with acd.AutomaticControlDependencies(): + read_op1 = gen_resource_variable_ops.read_variable_op( + v.handle, v.dtype).op + read_op2 = gen_resource_variable_ops.read_variable_op( + v.handle, v.dtype).op + gen_resource_variable_ops.assign_variable_op(v.handle, v + 1) + self.assertNotIn(read_op1, read_op2.control_inputs) + self.assertNotIn(read_op2, read_op1.control_inputs) + + def testVariableReadThenWrite(self): + with context.graph_mode(), self.cached_session(): + v = resource_variable_ops.ResourceVariable(1.0) + self.evaluate(variables.global_variables_initializer()) + with acd.AutomaticControlDependencies(): + read_op1 = gen_resource_variable_ops.read_variable_op( + v.handle, v.dtype).op + read_op2 = gen_resource_variable_ops.read_variable_op( + v.handle, v.dtype).op + assign_op = gen_resource_variable_ops.assign_variable_op( + v.handle, v + 1) + # Writes should have control deps from "all" reads since last write + # or start of the code block. + self.assertIn(read_op1, assign_op.control_inputs) + self.assertIn(read_op2, assign_op.control_inputs) + # There should be no control deps between reads. + self.assertNotIn(read_op1, read_op2.control_inputs) + self.assertNotIn(read_op2, read_op1.control_inputs) + + def testVariableWriteThenRead(self): + with context.graph_mode(), self.cached_session(): + v = resource_variable_ops.ResourceVariable(1.0) + self.evaluate(variables.global_variables_initializer()) + with acd.AutomaticControlDependencies(): + assign_op = gen_resource_variable_ops.assign_variable_op( + v.handle, v + 1) + read_op1 = gen_resource_variable_ops.read_variable_op( + v.handle, v.dtype).op + read_op2 = gen_resource_variable_ops.read_variable_op( + v.handle, v.dtype).op + # Reads should have a control dep from the last write. + self.assertIn(assign_op, read_op1.control_inputs) + self.assertIn(assign_op, read_op2.control_inputs) + # There should be no control deps between reads. + self.assertNotIn(read_op1, read_op2.control_inputs) + self.assertNotIn(read_op2, read_op1.control_inputs) + + def testVariableReadsNotInOpsWithMustRun(self): + with context.graph_mode(), self.cached_session(): + v = resource_variable_ops.ResourceVariable(1.0) + self.evaluate(variables.global_variables_initializer()) + with acd.AutomaticControlDependencies() as c: + read_op1 = gen_resource_variable_ops.read_variable_op( + v.handle, v.dtype).op + read_op2 = gen_resource_variable_ops.read_variable_op( + v.handle, v.dtype).op + assign_op = gen_resource_variable_ops.assign_variable_op( + v.handle, v + 1) + # Reads must not be in `ops_which_must_run` since those get added to the + # `control_outputs`. + self.assertNotIn(read_op1, c.ops_which_must_run) + self.assertNotIn(read_op2, c.ops_which_must_run) + # Last write must be in `ops_which_must_run`. + self.assertIn(assign_op, c.ops_which_must_run) + + def testVariableMultipleReadsAndWrites(self): + with context.graph_mode(), self.cached_session(): + v = resource_variable_ops.ResourceVariable(1.0) + self.evaluate(variables.global_variables_initializer()) + with acd.AutomaticControlDependencies() as c: + # 2 reads -> 2 writes -> 2 reads -> 2 writes. + read_op1 = gen_resource_variable_ops.read_variable_op( + v.handle, v.dtype).op + read_op2 = gen_resource_variable_ops.read_variable_op( + v.handle, v.dtype).op + assign_op1 = gen_resource_variable_ops.assign_variable_op( + v.handle, v + 1) + assign_op2 = gen_resource_variable_ops.assign_variable_op( + v.handle, v + 1) + read_op3 = gen_resource_variable_ops.read_variable_op( + v.handle, v.dtype).op + read_op4 = gen_resource_variable_ops.read_variable_op( + v.handle, v.dtype).op + assign_op3 = gen_resource_variable_ops.assign_variable_op( + v.handle, v + 1) + assign_op4 = gen_resource_variable_ops.assign_variable_op( + v.handle, v + 1) + + # Verify the control edges. + self.assertIn(read_op1, assign_op1.control_inputs) + self.assertIn(read_op2, assign_op1.control_inputs) + self.assertIn(assign_op1, assign_op2.control_inputs) + self.assertIn(assign_op2, read_op3.control_inputs) + self.assertIn(assign_op2, read_op4.control_inputs) + self.assertIn(read_op3, assign_op3.control_inputs) + self.assertIn(read_op4, assign_op3.control_inputs) + self.assertIn(assign_op3, assign_op4.control_inputs) + + # There should be no control deps between reads. + read_ops = [read_op1, read_op2, read_op3, read_op4] + for src_op, tgt_op in itertools.product(read_ops, read_ops): + self.assertNotIn(src_op, tgt_op.control_inputs) + + # Reads must not be in `ops_which_must_run`. + self.assertNotIn(read_op1, c.ops_which_must_run) + self.assertNotIn(read_op2, c.ops_which_must_run) + self.assertNotIn(read_op3, c.ops_which_must_run) + self.assertNotIn(read_op4, c.ops_which_must_run) + # Last write must be in `ops_which_must_run`. + self.assertIn(assign_op4, c.ops_which_must_run) + + def _testVariableReadInFunctionalOp(self, build_functional_op, op_type): + v = resource_variable_ops.ResourceVariable(1.0) + self.evaluate(variables.global_variables_initializer()) + + @def_function.function + def read_var_in_while(): + gen_resource_variable_ops.read_variable_op( + v.handle, v.dtype, name="read1") + + result = build_functional_op(v) + gen_resource_variable_ops.read_variable_op( + v.handle, v.dtype, name="read2") + gen_resource_variable_ops.assign_variable_op(v.handle, v + 1) + return result + + func_graph = read_var_in_while.get_concrete_function().graph + assert len(func_graph.inputs) == 1 + + def get_op(op_type, sub_name): + operations = [ + op for op in func_graph.get_operations() + if op.type == op_type and sub_name in op.name + ] + assert len(operations) == 1 + return operations[0] + + read1 = get_op("ReadVariableOp", "read1") + functional_op = get_op(op_type, "") + read2 = get_op("ReadVariableOp", "read2") + assign_op = get_op("AssignVariableOp", "") + # Since the functional op only has reads, previous reads e.g. read1 do not\ + # have a control edge to it and next future reads e.g. read2 do not have a + # control edge from it. + self.assertNotIn(read1, functional_op.control_inputs) + self.assertNotIn(functional_op, read2.control_inputs) + self.assertIn(read1, assign_op.control_inputs) + self.assertIn(read2, assign_op.control_inputs) + self.assertIn(functional_op, assign_op.control_inputs) + + def testVariableReadInWhileLoop(self): + + def build_functional_op(v): + + def body(_): + return gen_resource_variable_ops.read_variable_op(v.handle, v.dtype) + + return control_flow_ops.while_loop( + lambda i: True, body, [0.0], maximum_iterations=1) + + self._testVariableReadInFunctionalOp(build_functional_op, "While") + + def testVariableReadInCondTrueBranch(self): + + def build_functional_op(v): + + def then_branch(): + return gen_resource_variable_ops.read_variable_op(v.handle, v.dtype) + + def else_branch(): + return array_ops.zeros([], v.dtype) + + return control_flow_ops.cond( + constant_op.constant(True), then_branch, else_branch) + + self._testVariableReadInFunctionalOp(build_functional_op, "If") + + def testVariableReadInCondFalseBranch(self): + + def build_functional_op(v): + + def then_branch(): + return array_ops.zeros([], v.dtype) + + def else_branch(): + return gen_resource_variable_ops.read_variable_op(v.handle, v.dtype) + + return control_flow_ops.cond( + constant_op.constant(False), then_branch, else_branch) + + self._testVariableReadInFunctionalOp(build_functional_op, "If") + + def testVariableReadInCaseBranch0(self): + + def build_functional_op(v): + + def branch0(): + return gen_resource_variable_ops.read_variable_op(v.handle, v.dtype) + + def branch1(): + return array_ops.zeros([], v.dtype) + + return control_flow_ops.switch_case( + constant_op.constant(0), [branch0, branch1]) + + self._testVariableReadInFunctionalOp(build_functional_op, "Case") + + def testVariableReadInCaseBranch1(self): + + def build_functional_op(v): + + def branch0(): + return array_ops.zeros([], v.dtype) + + def branch1(): + return gen_resource_variable_ops.read_variable_op(v.handle, v.dtype) + + return control_flow_ops.switch_case( + constant_op.constant(0), [branch0, branch1]) + + self._testVariableReadInFunctionalOp(build_functional_op, "Case") + + def testVariableReadInFunction(self): + + def build_functional_op(v): + + @def_function.function + def fn_with_read(): + return gen_resource_variable_ops.read_variable_op(v.handle, v.dtype) + + return fn_with_read() + + self._testVariableReadInFunctionalOp(build_functional_op, + "StatefulPartitionedCall") + + def testVariableReadInNestedFunction(self): + + def build_functional_op(v): + + @def_function.function + def fn_with_read(): + + @def_function.function + def inner_fn(): + return gen_resource_variable_ops.read_variable_op(v.handle, v.dtype) + + return inner_fn() + + return fn_with_read() + + self._testVariableReadInFunctionalOp(build_functional_op, + "StatefulPartitionedCall") + + def testVariableReadInWhileInInnerFunc(self): + + def build_functional_op(v): + + @def_function.function + def fn_with_read(): + + @def_function.function + def inner_fn(): + + def body(_): + return gen_resource_variable_ops.read_variable_op(v.handle, v.dtype) + + return control_flow_ops.while_loop( + lambda i: True, body, [0.0], maximum_iterations=1) + + return inner_fn() + + return fn_with_read() + + self._testVariableReadInFunctionalOp(build_functional_op, + "StatefulPartitionedCall") + + def testVariableReadInCondInInnerFunc(self): + + def build_functional_op(v): + + @def_function.function + def fn_with_read(): + + @def_function.function + def inner_fn(): + + def then_branch(): + return gen_resource_variable_ops.read_variable_op(v.handle, v.dtype) + + def else_branch(): + return array_ops.zeros([], v.dtype) + + return control_flow_ops.cond( + constant_op.constant(True), then_branch, else_branch) + + return inner_fn() + + return fn_with_read() + + self._testVariableReadInFunctionalOp(build_functional_op, + "StatefulPartitionedCall") + + def _testVariableWriteInFunctionalOp(self, build_functional_op, op_type): + v = resource_variable_ops.ResourceVariable(1.0) + self.evaluate(variables.global_variables_initializer()) + + @def_function.function + def write_var_in_while(): + gen_resource_variable_ops.read_variable_op( + v.handle, v.dtype, name="read1") + + result = build_functional_op(v) + gen_resource_variable_ops.read_variable_op( + v.handle, v.dtype, name="read2") + gen_resource_variable_ops.assign_variable_op(v.handle, v + 1) + return result + + func_graph = write_var_in_while.get_concrete_function().graph + assert len(func_graph.inputs) == 1 + + def get_op(op_type, sub_name): + operations = [ + op for op in func_graph.get_operations() + if op.type == op_type and sub_name in op.name + ] + assert len(operations) == 1 + return operations[0] + + read1 = get_op("ReadVariableOp", "read1") + functional_op = get_op(op_type, "") + read2 = get_op("ReadVariableOp", "read2") + assign_op = get_op("AssignVariableOp", "") + # Since the While has writes, it has control edges from previous reads + # e.g. `read1` and to future reads(`read2`) and writes(`assign_op`). + self.assertIn(read1, functional_op.control_inputs) + self.assertIn(functional_op, read2.control_inputs) + self.assertIn(read2, assign_op.control_inputs) + self.assertIn(functional_op, assign_op.control_inputs) + + def testVariableWriteInWhileLoop(self): + + def build_functional_op(v): + + def body(_): + gen_resource_variable_ops.assign_variable_op(v.handle, v + 1) + return gen_resource_variable_ops.read_variable_op(v.handle, v.dtype) + + return control_flow_ops.while_loop( + lambda i: True, body, [0.0], maximum_iterations=1) + + self._testVariableWriteInFunctionalOp(build_functional_op, "While") + + def testVariableWriteInCondTrueBranch(self): + + def build_functional_op(v): + + def then_branch(): + gen_resource_variable_ops.assign_variable_op(v.handle, v + 1) + return gen_resource_variable_ops.read_variable_op(v.handle, v.dtype) + + def else_branch(): + return array_ops.zeros([], v.dtype) + + return control_flow_ops.cond( + constant_op.constant(True), then_branch, else_branch) + + self._testVariableWriteInFunctionalOp(build_functional_op, "If") + + def testVariableWriteInCondFalseBranch(self): + + def build_functional_op(v): + + def then_branch(): + return array_ops.zeros([], v.dtype) + + def else_branch(): + gen_resource_variable_ops.assign_variable_op(v.handle, v + 1) + return gen_resource_variable_ops.read_variable_op(v.handle, v.dtype) + + return control_flow_ops.cond( + constant_op.constant(False), then_branch, else_branch) + + self._testVariableWriteInFunctionalOp(build_functional_op, "If") + + def testVariableWriteInCaseBranch0(self): + + def build_functional_op(v): + + def branch0(): + gen_resource_variable_ops.assign_variable_op(v.handle, v + 1) + return gen_resource_variable_ops.read_variable_op(v.handle, v.dtype) + + def branch1(): + return array_ops.zeros([], v.dtype) + + return control_flow_ops.switch_case( + constant_op.constant(0), [branch0, branch1]) + + self._testVariableWriteInFunctionalOp(build_functional_op, "Case") + + def testVariableWriteInCaseBranch1(self): + + def build_functional_op(v): + + def branch0(): + return array_ops.zeros([], v.dtype) + + def branch1(): + gen_resource_variable_ops.assign_variable_op(v.handle, v + 1) + return gen_resource_variable_ops.read_variable_op(v.handle, v.dtype) + + return control_flow_ops.switch_case( + constant_op.constant(0), [branch0, branch1]) + + self._testVariableWriteInFunctionalOp(build_functional_op, "Case") + + def testVariableWriteInFunction(self): + + def build_functional_op(v): + + @def_function.function + def fn_with_write(): + gen_resource_variable_ops.assign_variable_op(v.handle, v + 1) + return gen_resource_variable_ops.read_variable_op(v.handle, v.dtype) + + return fn_with_write() + + self._testVariableWriteInFunctionalOp(build_functional_op, + "StatefulPartitionedCall") + + def testVariableWriteInNestedFunction(self): + + def build_functional_op(v): + + @def_function.function + def fn_with_write(): + + @def_function.function + def inner_fn(): + gen_resource_variable_ops.assign_variable_op(v.handle, v + 1) + return gen_resource_variable_ops.read_variable_op(v.handle, v.dtype) + + return inner_fn() + + return fn_with_write() + + self._testVariableWriteInFunctionalOp(build_functional_op, + "StatefulPartitionedCall") + + def testVariableWriteInWhileInInnerFunc(self): + + def build_functional_op(v): + + @def_function.function + def fn_with_write(): + + @def_function.function + def inner_fn(): + + def body(_): + gen_resource_variable_ops.assign_variable_op(v.handle, v + 1) + return gen_resource_variable_ops.read_variable_op(v.handle, v.dtype) + + return control_flow_ops.while_loop( + lambda i: True, body, [0.0], maximum_iterations=1) + + return inner_fn() + + return fn_with_write() + + self._testVariableWriteInFunctionalOp(build_functional_op, + "StatefulPartitionedCall") + + def testVariableWriteInCondInInnerFunc(self): + + def build_functional_op(v): + + @def_function.function + def fn_with_write(): + + @def_function.function + def inner_fn(): + + def then_branch(): + gen_resource_variable_ops.assign_variable_op(v.handle, v + 1) + return gen_resource_variable_ops.read_variable_op(v.handle, v.dtype) + + def else_branch(): + return array_ops.zeros([], v.dtype) + + return control_flow_ops.cond( + constant_op.constant(True), then_branch, else_branch) + + return inner_fn() + + return fn_with_write() + + self._testVariableWriteInFunctionalOp(build_functional_op, + "StatefulPartitionedCall") + @test_util.run_v1_only("b/120545219") def testCondMustRun(self): with context.graph_mode(), self.cached_session(): diff --git a/tensorflow/python/framework/auto_control_deps_utils.py b/tensorflow/python/framework/auto_control_deps_utils.py new file mode 100644 index 00000000000..d8b4656e3d9 --- /dev/null +++ b/tensorflow/python/framework/auto_control_deps_utils.py @@ -0,0 +1,91 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities for AutomaticControlDependencies.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import dtypes + +READ_ONLY_RESOURCE_INPUTS_ATTR = "_read_only_resource_inputs" +RESOURCE_READ_OPS = set() + + +def register_read_only_resource_op(op_type): + """Declares that `op_type` does not update its touched resource.""" + RESOURCE_READ_OPS.add(op_type) + + +def resource_has_writes(handle): + """Returns whether any of the consumers of handle write to it. + + Args: + handle: Tensor of type DT_RESOURCE. + + Returns: + Returns True if at least one consumer of `handle` writes to it. + Returns False if all consumers of `handle` do not write to it or if the + `handle` has no consumers. + """ + assert handle.dtype == dtypes.resource + return any(op_writes_to_resource(handle, op) for op in handle.consumers()) + + +def op_writes_to_resource(handle, op): + """Returns whether op writes to resource handle. + + Args: + handle: Resource handle. Must be an input of `op`. + op: Operation. + + Returns: + Returns False if op is a read-only op registered using + `register_read_only_resource_op` or if `handle` is an input at one of + the indices in the `READ_ONLY_RESOURCE_INPUTS_ATTR` attr of the op, True + otherwise. + + Raises: + ValueError: if `handle` is not an input of `op`. + """ + if op.type in RESOURCE_READ_OPS: + return False + input_index = _input_index(op, handle) + try: + read_only_input_indices = op.get_attr(READ_ONLY_RESOURCE_INPUTS_ATTR) + except ValueError: + # Attr was not set. Conservatively assume that the resource is written to. + return True + return input_index not in read_only_input_indices + + +def _input_index(op, handle): + """Returns the index of `handle` in `op.inputs`. + + Args: + op: Operation. + handle: Resource handle. + + Returns: + Index in `op.inputs` receiving the resource `handle`. + + Raises: + ValueError: If handle and its replicated input are both not found in + `op.inputs`. + """ + for i, t in enumerate(op.inputs): + if handle is t: + return i + raise ValueError("%s not in list" % str(handle)) diff --git a/tensorflow/python/framework/config_test.py b/tensorflow/python/framework/config_test.py index 72612a21cbf..2ef7d737d73 100644 --- a/tensorflow/python/framework/config_test.py +++ b/tensorflow/python/framework/config_test.py @@ -33,6 +33,7 @@ from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import test_ops from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.platform import test from tensorflow.python.util import compat @@ -380,10 +381,10 @@ class DeviceTest(test.TestCase): with ops.device('/device:CPU:1'): b = constant_op.constant(1.0) self.evaluate(b) - with self.assertRaisesRegexp(RuntimeError, 'unknown device'): - with ops.device('/device:CPU:2'): - c = constant_op.constant(1.0) - self.evaluate(c) + with ops.device('/device:CPU:2'): + c = constant_op.constant(1.0) + self.evaluate(c) + self.assertIn('CPU:0', c.device) # Ensure we can place ops on each of the device names for vcpu in vcpus: @@ -408,6 +409,7 @@ class DeviceTest(test.TestCase): @test_util.run_gpu_only @reset_eager def testGpuNone(self): + config.set_soft_device_placement(False) gpus = config.list_physical_devices('GPU') self.assertGreater(len(gpus), 0) @@ -427,14 +429,16 @@ class DeviceTest(test.TestCase): self.assertEqual(len(config.get_visible_devices('GPU')), 0) self.assertEqual(len(config.list_logical_devices('XLA_GPU')), 0) - with self.assertRaisesRegexp(RuntimeError, 'unknown device'): + with self.assertRaisesRegexp(errors.InvalidArgumentError, + 'Could not satisfy'): with ops.device('/device:GPU:0'): - a = constant_op.constant(1.0) + a = array_ops.identity(1.0) self.evaluate(a) - with self.assertRaisesRegexp(RuntimeError, 'unknown device'): + with self.assertRaisesRegexp(errors.InvalidArgumentError, + 'Could not satisfy'): with ops.device('/device:XLA_GPU:0'): - a = constant_op.constant(1.0) + a = array_ops.identity(1.0) self.evaluate(a) # Modifying the visible devices is not supported @@ -465,6 +469,7 @@ class DeviceTest(test.TestCase): @test_util.run_gpu_only @reset_eager def testVirtualGpu(self): + config.set_soft_device_placement(False) gpus = config.list_physical_devices('GPU') self.assertNotEqual(len(gpus), 0) @@ -479,12 +484,13 @@ class DeviceTest(test.TestCase): self.assertTrue(len(logical_gpus), len(gpus) + 1) for i in range(0, len(logical_gpus)): with ops.device('/device:GPU:' + str(i)): - a = constant_op.constant(1.0) + a = array_ops.identity(1.0) self.evaluate(a) - with self.assertRaisesRegexp(RuntimeError, 'unknown device'): + with self.assertRaisesRegexp(errors.InvalidArgumentError, + 'Could not satisfy'): with ops.device('/device:GPU:' + str(len(logical_gpus))): - a = constant_op.constant(1.0) + a = array_ops.identity(1.0) self.evaluate(a) # Modifying the GPU configuration is not supported diff --git a/tensorflow/python/framework/constant_op.py b/tensorflow/python/framework/constant_op.py index 4d9aa29ad60..9736bb8b78b 100644 --- a/tensorflow/python/framework/constant_op.py +++ b/tensorflow/python/framework/constant_op.py @@ -224,6 +224,10 @@ def constant(value, dtype=None, shape=None, name="Const"): ... NotImplementedError: ... + `tf.constant` will _always_ create CPU (host) tensors. In order to create + tensors on other devices, use `tf.identity`. (If the `value` is an eager + Tensor, however, the tensor will be returned unmodified as mentioned above.) + Related Ops: * `tf.convert_to_tensor` is similar but: diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index f716dfa33dd..d3df8cb973d 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -3171,6 +3171,8 @@ class Graph(object): Raises: ValueError: if another function is defined with the same name. """ + self._check_not_finalized() + name = function.name # Sanity checks on gradient definition. if (function.grad_func_name is not None) and (function.python_grad_func is @@ -3455,6 +3457,8 @@ class Graph(object): Returns: A list of the new `Operation` objects. """ + self._check_not_finalized() + # Create all Operation objects before accessing their inputs since an op may # be created before its inputs. new_ops = [ @@ -6738,3 +6742,9 @@ class _TensorIterator(object): return result next = __next__ # python2.x compatibility. + + +def set_int_list_attr(op, attr_name, ints): + """TF internal method used to set a list(int) attribute in the node_def.""" + ints_list = attr_value_pb2.AttrValue.ListValue(i=ints) + op._set_attr(attr_name, attr_value_pb2.AttrValue(list=ints_list)) # pylint:disable=protected-access diff --git a/tensorflow/python/keras/applications/inception_resnet_v2.py b/tensorflow/python/keras/applications/inception_resnet_v2.py index ab8ab71e3b0..7408b314bb6 100644 --- a/tensorflow/python/keras/applications/inception_resnet_v2.py +++ b/tensorflow/python/keras/applications/inception_resnet_v2.py @@ -370,9 +370,34 @@ def inception_resnet_block(x, scale, block_type, block_idx, activation='relu'): @keras_export('keras.applications.inception_resnet_v2.preprocess_input') def preprocess_input(x, data_format=None): + """Preprocesses a numpy array encoding a batch of images. + + Arguments + x: A 4D numpy array consists of RGB values within [0, 255]. + + Returns + Preprocessed array. + + Raises + ValueError: In case of unknown `data_format` argument. + """ return imagenet_utils.preprocess_input(x, data_format=data_format, mode='tf') @keras_export('keras.applications.inception_resnet_v2.decode_predictions') def decode_predictions(preds, top=5): + """Decodes the prediction result from the model. + + Arguments + preds: Numpy tensor encoding a batch of predictions. + top: Integer, how many top-guesses to return. + + Returns + A list of lists of top class prediction tuples + `(class_name, class_description, score)`. + One list of tuples per sample in batch input. + + Raises + ValueError: In case of invalid shape of the `preds` array (must be 2D). + """ return imagenet_utils.decode_predictions(preds, top=top) diff --git a/tensorflow/python/keras/applications/mobilenet.py b/tensorflow/python/keras/applications/mobilenet.py index 224e8c84496..128553f0d39 100644 --- a/tensorflow/python/keras/applications/mobilenet.py +++ b/tensorflow/python/keras/applications/mobilenet.py @@ -436,9 +436,34 @@ def _depthwise_conv_block(inputs, @keras_export('keras.applications.mobilenet.preprocess_input') def preprocess_input(x, data_format=None): + """Preprocesses a numpy array encoding a batch of images. + + Arguments + x: A 4D numpy array consists of RGB values within [0, 255]. + + Returns + Preprocessed array. + + Raises + ValueError: In case of unknown `data_format` argument. + """ return imagenet_utils.preprocess_input(x, data_format=data_format, mode='tf') @keras_export('keras.applications.mobilenet.decode_predictions') def decode_predictions(preds, top=5): + """Decodes the prediction result from the model. + + Arguments + preds: Numpy tensor encoding a batch of predictions. + top: Integer, how many top-guesses to return. + + Returns + A list of lists of top class prediction tuples + `(class_name, class_description, score)`. + One list of tuples per sample in batch input. + + Raises + ValueError: In case of invalid shape of the `preds` array (must be 2D). + """ return imagenet_utils.decode_predictions(preds, top=top) diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py index 8651cf27375..aea77ca2e47 100644 --- a/tensorflow/python/keras/callbacks.py +++ b/tensorflow/python/keras/callbacks.py @@ -894,9 +894,12 @@ class History(Callback): gets returned by the `fit` method of models. """ + def __init__(self): + super(History, self).__init__() + self.history = {} + def on_train_begin(self, logs=None): self.epoch = [] - self.history = {} def on_epoch_end(self, epoch, logs=None): logs = logs or {} @@ -904,6 +907,10 @@ class History(Callback): for k, v in logs.items(): self.history.setdefault(k, []).append(v) + # Set the history attribute on the model after the epoch ends. This will + # make sure that the state which is set is the latest one. + self.model.history = self + @keras_export('keras.callbacks.ModelCheckpoint') class ModelCheckpoint(Callback): @@ -1672,6 +1679,8 @@ class TensorBoard(Callback): self._writers = {} self._start_batch, self._stop_batch = self._init_profile_batch( profile_batch) + if self._start_batch > 0: + profiler.warmup() # Improve the profiling accuracy. # True when a trace is running. self._is_tracing = False diff --git a/tensorflow/python/keras/engine/network.py b/tensorflow/python/keras/engine/network.py index e13ab8f0b92..148df242e48 100644 --- a/tensorflow/python/keras/engine/network.py +++ b/tensorflow/python/keras/engine/network.py @@ -51,6 +51,7 @@ from tensorflow.python.keras.utils import generic_utils from tensorflow.python.keras.utils import layer_utils from tensorflow.python.keras.utils import tf_utils from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite +from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.platform import tf_logging as logging @@ -238,6 +239,9 @@ class Network(base_layer.Layer): outputs = outputs[0] self._nested_outputs = outputs self._nested_inputs = inputs + self._nested_inputs_are_flat_list = ( + isinstance(self._nested_inputs, (list, tuple)) and + not any(nest.is_sequence(t) for t in self._nested_inputs)) self.inputs = nest.flatten(inputs) self.outputs = nest.flatten(outputs) @@ -814,40 +818,18 @@ class Network(base_layer.Layer): # use masking, it does not interfere with regular behavior at all and you # can ignore it. - if isinstance(inputs, dict) and isinstance(self._nested_inputs, - (list, tuple)): - # Backwards compat: Allows passing a dict to a Model constructed with a - # list. Matches dict keys to input names. - inputs = [ - inputs[inp._keras_history.layer.name] for inp in self._nested_inputs - ] - else: - inputs = nest.flatten(inputs) - + inputs = self._flatten_to_reference_inputs(inputs) if mask is None: masks = [None for _ in range(len(inputs))] else: - masks = nest.flatten(mask) - + masks = self._flatten_to_reference_inputs(mask) for input_t, mask in zip(inputs, masks): input_t._keras_mask = mask # Dictionary mapping reference tensors to computed tensors. tensor_dict = {} - for x, y in zip(self.inputs, inputs): - # Set shape and dtype based on `keras.Input`s. - if isinstance(x, ops.Tensor) and isinstance(y, ops.Tensor): - try: - y.set_shape(y.shape.merge_with(x.shape)) - except ValueError: - logging.warning( - 'Model was constructed with shape {} for input {}, but it was ' - 're-called on a Tensor with incompatible shape {}.' - .format(x, x.shape, y.shape)) - if isinstance(x, (ops.Tensor, composite_tensor.CompositeTensor)): - y = math_ops.cast(y, dtype=x.dtype) - + y = self._conform_to_reference_input(y, ref_input=x) x_id = str(id(x)) tensor_dict[x_id] = [y] * self._tensor_usage_count[x_id] @@ -925,6 +907,53 @@ class Network(base_layer.Layer): output_tensors = nest.pack_sequence_as(self._nested_outputs, output_tensors) return output_tensors + def _flatten_to_reference_inputs(self, tensors): + """Maps `tensors` to their respective `keras.Input`.""" + if self._nested_inputs_are_flat_list and isinstance(tensors, dict): + # Backwards compat: Allows passing a dict to a Model constructed with a + # list. Matches dict keys to input names. + tensors = [ + tensors[inp._keras_history.layer.name] for inp in self._nested_inputs + ] + else: + # Otherwise both self.inputs and tensors will be flattened in same order. + tensors = nest.flatten(tensors) + return tensors + + def _conform_to_reference_input(self, tensor, ref_input): + """Set shape and dtype based on `keras.Input`s.""" + # Shape handling (only for non-CompositeTensors). + if isinstance(tensor, ops.Tensor) and isinstance(ref_input, ops.Tensor): + # Allow (None,) and (None, 1) Tensors to be passed interchangably. Use the + # shape specified by the `keras.Input`. + if tensor.shape.rank is not None and ref_input.shape.rank is not None: + should_squeeze_last_dim = ( + tensor.shape.rank == ref_input.shape.rank + 1 and + tensor.shape[-1] == 1) + should_expand_last_dim = ( + tensor.shape.rank == ref_input.shape.rank - 1 and + ref_input.shape[-1] == 1) + if should_squeeze_last_dim: + tensor = array_ops.squeeze_v2(tensor, axis=-1) + elif should_expand_last_dim: + tensor = array_ops.expand_dims_v2(tensor, axis=-1) + + # Add shape hints to Tensors that might have None shape dims but have + # shapes defined by the `keras.Input`. + try: + tensor.set_shape(tensor.shape.merge_with(ref_input.shape)) + except ValueError: + logging.warning( + 'Model was constructed with shape {} for input {}, but it was ' + 'called on an input with incompatible shape {}.'.format( + ref_input.shape, ref_input, tensor.shape)) + + # Dtype handling. + if isinstance(ref_input, (ops.Tensor, composite_tensor.CompositeTensor)): + tensor = math_ops.cast(tensor, dtype=ref_input.dtype) + + return tensor + def get_config(self): if not self._is_graph_network: raise NotImplementedError diff --git a/tensorflow/python/keras/engine/network_test.py b/tensorflow/python/keras/engine/network_test.py index 17f08889936..d890cc118ae 100644 --- a/tensorflow/python/keras/engine/network_test.py +++ b/tensorflow/python/keras/engine/network_test.py @@ -1879,6 +1879,24 @@ class CacheCorrectnessTest(keras_parameterized.TestCase): for i in range(999, 1024): self.assertEqual(network.compute_output_shape((1, i, 32)), (1, i, 2)) + def test_2d_inputs_squeezed_to_1d(self): + input_1d = input_layer_lib.Input(shape=()) + outputs = input_1d * 2. + net = network_lib.Network(input_1d, outputs) + + x = np.ones((10, 1)) + y = net(x) + self.assertEqual(y.shape.rank, 1) + + def test_1d_inputs_expanded_to_2d(self): + input_1d = input_layer_lib.Input(shape=(1,)) + outputs = input_1d * 2. + net = network_lib.Network(input_1d, outputs) + + x = np.ones((10,)) + y = net(x) + self.assertEqual(y.shape.rank, 2) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py index f9ec6f37b45..c487ba6a322 100644 --- a/tensorflow/python/keras/engine/training.py +++ b/tensorflow/python/keras/engine/training.py @@ -174,6 +174,7 @@ class Model(network.Network, version_utils.ModelVersionSelector): # Fault-tolerance handler. Set in `ModelCheckpoint`. self._training_state = None + self.history = None def get_weights(self): """Retrieves the weights of the model. @@ -768,7 +769,11 @@ class Model(network.Network, version_utils.ModelVersionSelector): step_num=step, batch_size=batch_size): callbacks.on_train_batch_begin(step) - logs = train_function(iterator) + tmp_logs = train_function(iterator) + # Catch possible OutOfRangeError here. + # TODO(b/150292341): Allow multiple async steps. + context.async_wait() + logs = tmp_logs callbacks.on_train_batch_end(step, logs) epoch_logs = {m.name: m.result() for m in self.metrics} @@ -995,7 +1000,9 @@ class Model(network.Network, version_utils.ModelVersionSelector): graph_type='test', step_num=step): callbacks.on_test_batch_begin(step) - logs = test_function(iterator) + tmp_logs = test_function(iterator) + context.async_wait() # Possible OutOfRangeError here. + logs = tmp_logs callbacks.on_test_batch_end(step, logs) callbacks.on_test_end() @@ -1175,7 +1182,9 @@ class Model(network.Network, version_utils.ModelVersionSelector): with data_handler.catch_stop_iteration(): for step in data_handler.steps(): callbacks.on_predict_batch_begin(step) - batch_outputs = predict_function(iterator) + tmp_batch_outputs = predict_function(iterator) + context.async_wait() # Possible OutOfRangeError here. + batch_outputs = tmp_batch_outputs if outputs is None: outputs = nest.map_structure(lambda batch_output: [batch_output], batch_outputs) diff --git a/tensorflow/python/keras/layers/normalization.py b/tensorflow/python/keras/layers/normalization.py index 8e0590b9f80..ca4734ef8cc 100644 --- a/tensorflow/python/keras/layers/normalization.py +++ b/tensorflow/python/keras/layers/normalization.py @@ -675,8 +675,10 @@ class BatchNormalizationBase(Layer): training = bool(training) if base_layer_utils.is_in_keras_graph(): training = math_ops.logical_and(training, self._get_trainable_var()) - else: - training = math_ops.logical_and(training, self.trainable) + elif not self.trainable: + # When the layer is not trainable, it overrides the value passed from + # model. + training = self.trainable return training def call(self, inputs, training=None): diff --git a/tensorflow/python/keras/metrics.py b/tensorflow/python/keras/metrics.py index 1c851581a05..e5610309bec 100644 --- a/tensorflow/python/keras/metrics.py +++ b/tensorflow/python/keras/metrics.py @@ -32,6 +32,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import tensor_spec from tensorflow.python.keras import backend as K from tensorflow.python.keras.engine import base_layer from tensorflow.python.keras.engine import base_layer_utils @@ -48,6 +49,7 @@ from tensorflow.python.keras.losses import mean_squared_logarithmic_error from tensorflow.python.keras.losses import poisson from tensorflow.python.keras.losses import sparse_categorical_crossentropy from tensorflow.python.keras.losses import squared_hinge +from tensorflow.python.keras.saving.saved_model import metric_serialization from tensorflow.python.keras.utils import metrics_utils from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object from tensorflow.python.keras.utils.generic_utils import serialize_keras_object @@ -63,7 +65,9 @@ from tensorflow.python.ops import nn from tensorflow.python.ops import variables as tf_variables from tensorflow.python.ops import weights_broadcast_ops from tensorflow.python.ops.losses import util as tf_losses_utils +from tensorflow.python.training.tracking import base as trackable from tensorflow.python.util import nest +from tensorflow.python.util import tf_inspect from tensorflow.python.util.tf_export import keras_export from tensorflow.tools.docs import doc_controls @@ -154,10 +158,13 @@ class Metric(base_layer.Layer): # custom metrics in v1 need not worry about control dependencies and # return ops. if (base_layer_utils.is_in_eager_or_tf_function() or - cls.__module__ == Metric.__module__): + is_built_in(cls)): update_state_fn = obj.update_state else: - update_state_fn = def_function.function(obj.update_state) + if isinstance(obj.update_state, def_function.Function): + update_state_fn = obj.update_state + else: + update_state_fn = def_function.function(obj.update_state) obj.update_state = types.MethodType( metrics_utils.update_state_wrapper(update_state_fn), obj) @@ -278,6 +285,10 @@ class Metric(base_layer.Layer): ### End: For use by subclasses ### + @property + def _trackable_saved_model_saver(self): + return metric_serialization.MetricSavedModelSaver(self) + class Reduce(Metric): """Encapsulates metrics that perform a reduce operation on the values.""" @@ -595,11 +606,27 @@ class MeanMetricWrapper(Mean): def get_config(self): config = {} + + if type(self) is MeanMetricWrapper: # pylint: disable=unidiomatic-typecheck + # Only include function argument when the object is a MeanMetricWrapper + # and not a subclass. + config['fn'] = self._fn + for k, v in six.iteritems(self._fn_kwargs): config[k] = K.eval(v) if is_tensor_or_variable(v) else v base_config = super(MeanMetricWrapper, self).get_config() return dict(list(base_config.items()) + list(config.items())) + @classmethod + def from_config(cls, config): + # Note that while MeanMetricWrapper itself isn't public, objects of this + # class may be created and added to the model by calling model.compile. + if cls is MeanMetricWrapper: + fn = get(config.pop('fn')) + return cls(fn, **config) + + return super(MeanMetricWrapper, cls).from_config(config) + @keras_export('keras.metrics.Accuracy') class Accuracy(MeanMetricWrapper): @@ -1883,7 +1910,7 @@ class AUC(Metric): else: variable_shape = tensor_shape.TensorShape( [tensor_shape.Dimension(self.num_thresholds)]) - + self._build_input_shape = shape # Create metric variables self.true_positives = self.add_weight( 'true_positives', @@ -1927,7 +1954,7 @@ class AUC(Metric): """ deps = [] if not self._built: - self._build(y_pred.shape) + self._build(tensor_shape.TensorShape(y_pred.shape)) if self.multi_label or (self.label_weights is not None): # y_true should have shape (number of examples, number of labels). @@ -2735,6 +2762,7 @@ class MeanTensor(Metric): def _build(self, shape): self._shape = tensor_shape.TensorShape(shape) + self._build_input_shape = self._shape # Create new state variables self._total = self.add_weight( 'total', shape=shape, initializer=init_ops.zeros_initializer) @@ -3251,3 +3279,7 @@ def get(identifier): error_msg = 'Could not interpret metric function identifier: {}'.format( identifier) raise ValueError(error_msg) + + +def is_built_in(cls): + return cls.__module__ == Metric.__module__ diff --git a/tensorflow/python/keras/preprocessing/__init__.py b/tensorflow/python/keras/preprocessing/__init__.py index 58b670d0b0e..1542dab6d35 100644 --- a/tensorflow/python/keras/preprocessing/__init__.py +++ b/tensorflow/python/keras/preprocessing/__init__.py @@ -23,12 +23,14 @@ from __future__ import print_function import keras_preprocessing from tensorflow.python.keras import backend +from tensorflow.python.keras.preprocessing import image +from tensorflow.python.keras.preprocessing import sequence +from tensorflow.python.keras.preprocessing import text from tensorflow.python.keras.utils import all_utils as utils # This exists for compatibility with prior version of keras_preprocessing. keras_preprocessing.set_keras_submodules(backend=backend, utils=utils) - del absolute_import del division del print_function diff --git a/tensorflow/python/keras/preprocessing/sequence.py b/tensorflow/python/keras/preprocessing/sequence.py index 41a3ab6c1be..5ba2e2b47d5 100644 --- a/tensorflow/python/keras/preprocessing/sequence.py +++ b/tensorflow/python/keras/preprocessing/sequence.py @@ -24,7 +24,6 @@ from keras_preprocessing import sequence from tensorflow.python.keras.utils import data_utils from tensorflow.python.util.tf_export import keras_export -pad_sequences = sequence.pad_sequences make_sampling_table = sequence.make_sampling_table skipgrams = sequence.skipgrams # TODO(fchollet): consider making `_remove_long_seq` public. @@ -34,6 +33,7 @@ _remove_long_seq = sequence._remove_long_seq # pylint: disable=protected-access @keras_export('keras.preprocessing.sequence.TimeseriesGenerator') class TimeseriesGenerator(sequence.TimeseriesGenerator, data_utils.Sequence): """Utility class for generating batches of temporal data. + This class takes in a sequence of data-points gathered at equal intervals, along with time series parameters such as stride, length of history, etc., to produce batches for @@ -89,7 +89,74 @@ class TimeseriesGenerator(sequence.TimeseriesGenerator, data_utils.Sequence): pass -keras_export('keras.preprocessing.sequence.pad_sequences')(pad_sequences) +@keras_export('keras.preprocessing.sequence.pad_sequences') +def pad_sequences(sequences, maxlen=None, dtype='int32', + padding='pre', truncating='pre', value=0.): + """Pads sequences to the same length. + + This function transforms a list (of length `num_samples`) + of sequences (lists of integers) + into a 2D Numpy array of shape `(num_samples, num_timesteps)`. + `num_timesteps` is either the `maxlen` argument if provided, + or the length of the longest sequence in the list. + + Sequences that are shorter than `num_timesteps` + are padded with `value` until they are `num_timesteps` long. + + Sequences longer than `num_timesteps` are truncated + so that they fit the desired length. + + The position where padding or truncation happens is determined by + the arguments `padding` and `truncating`, respectively. + Pre-padding or removing values from the beginning of the sequence is the + default. + + >>> sequence = [[1], [2, 3], [4, 5, 6]] + >>> tf.keras.preprocessing.sequence.pad_sequences(sequence) + array([[0, 0, 1], + [0, 2, 3], + [4, 5, 6]], dtype=int32) + + >>> tf.keras.preprocessing.sequence.pad_sequences(sequence, value=-1) + array([[-1, -1, 1], + [-1, 2, 3], + [ 4, 5, 6]], dtype=int32) + + >>> tf.keras.preprocessing.sequence.pad_sequences(sequence, padding='post') + array([[1, 0, 0], + [2, 3, 0], + [4, 5, 6]], dtype=int32) + + >>> tf.keras.preprocessing.sequence.pad_sequences(sequence, maxlen=2) + array([[0, 1], + [2, 3], + [5, 6]], dtype=int32) + + Arguments: + sequences: List of sequences (each sequence is a list of integers). + maxlen: Optional Int, maximum length of all sequences. If not provided, + sequences will be padded to the length of the longest individual + sequence. + dtype: (Optional, defaults to int32). Type of the output sequences. + To pad sequences with variable length strings, you can use `object`. + padding: String, 'pre' or 'post' (optional, defaults to 'pre'): + pad either before or after each sequence. + truncating: String, 'pre' or 'post' (optional, defaults to 'pre'): + remove values from sequences larger than + `maxlen`, either at the beginning or at the end of the sequences. + value: Float or String, padding value. (Optional, defaults to 0.) + + Returns: + Numpy array with shape `(len(sequences), maxlen)` + + Raises: + ValueError: In case of invalid values for `truncating` or `padding`, + or in case of invalid shape for a `sequences` entry. + """ + return sequence.pad_sequences( + sequences, maxlen=maxlen, dtype=dtype, + padding=padding, truncating=truncating, value=value) + keras_export( 'keras.preprocessing.sequence.make_sampling_table')(make_sampling_table) keras_export('keras.preprocessing.sequence.skipgrams')(skipgrams) diff --git a/tensorflow/python/keras/saving/BUILD b/tensorflow/python/keras/saving/BUILD index 7ab6639d118..eda5776c7a6 100644 --- a/tensorflow/python/keras/saving/BUILD +++ b/tensorflow/python/keras/saving/BUILD @@ -22,6 +22,7 @@ py_library( "saved_model/json_utils.py", "saved_model/layer_serialization.py", "saved_model/load.py", + "saved_model/metric_serialization.py", "saved_model/model_serialization.py", "saved_model/network_serialization.py", "saved_model/save.py", diff --git a/tensorflow/python/keras/saving/saved_model/layer_serialization.py b/tensorflow/python/keras/saving/saved_model/layer_serialization.py index 6dffcc65c7e..66a70e55c3b 100644 --- a/tensorflow/python/keras/saving/saved_model/layer_serialization.py +++ b/tensorflow/python/keras/saving/saved_model/layer_serialization.py @@ -52,13 +52,7 @@ class LayerSavedModelSaver(base_serialization.SavedModelSaver): dtype=policy.serialize(self.obj._dtype_policy), # pylint: disable=protected-access batch_input_shape=getattr(self.obj, '_batch_input_shape', None)) - with generic_utils.skip_failed_serialization(): - # Store the config dictionary, which may be used when reviving the object. - # When loading, the program will attempt to revive the object from config, - # and if that fails, the object will be revived from the SavedModel. - config = generic_utils.serialize_keras_object(self.obj)['config'] - if config is not None: - metadata['config'] = config + metadata.update(get_config(self.obj)) if self.obj.input_spec is not None: # Layer's input_spec has already been type-checked in the property setter. metadata['input_spec'] = nest.map_structure( @@ -109,6 +103,20 @@ class LayerSavedModelSaver(base_serialization.SavedModelSaver): return objects, functions +# TODO(kathywu): Move serialization utils (and related utils from +# generic_utils.py) to a separate file. +def get_config(obj): + with generic_utils.skip_failed_serialization(): + # Store the config dictionary, which may be used when reviving the object. + # When loading, the program will attempt to revive the object from config, + # and if that fails, the object will be revived from the SavedModel. + config = generic_utils.serialize_keras_object(obj)['config'] + + if config is not None: + return {'config': config} + return {} + + class InputLayerSavedModelSaver(base_serialization.SavedModelSaver): """InputLayer serialization.""" diff --git a/tensorflow/python/keras/saving/saved_model/load.py b/tensorflow/python/keras/saving/saved_model/load.py index af511e6586a..fddd6c46759 100644 --- a/tensorflow/python/keras/saving/saved_model/load.py +++ b/tensorflow/python/keras/saving/saved_model/load.py @@ -18,6 +18,7 @@ from __future__ import division from __future__ import print_function import re +import types from tensorflow.python.eager import context from tensorflow.python.eager import function as defun @@ -32,6 +33,7 @@ from tensorflow.python.keras.saving.saved_model import json_utils from tensorflow.python.keras.saving.saved_model import utils from tensorflow.python.keras.saving.saved_model.serialized_attributes import CommonEndpoints from tensorflow.python.keras.utils import generic_utils +from tensorflow.python.keras.utils import metrics_utils from tensorflow.python.platform import tf_logging as logging from tensorflow.python.saved_model import load as tf_load from tensorflow.python.saved_model import nested_structure_coder @@ -69,6 +71,8 @@ training_lib = LazyLoader( training_lib_v1 = LazyLoader( "training_lib_v1", globals(), "tensorflow.python.keras.engine.training_v1") +metrics = LazyLoader("metrics", globals(), + "tensorflow.python.keras.metrics") # pylint:enable=g-inconsistent-quotes @@ -77,9 +81,9 @@ PUBLIC_ATTRIBUTES = CommonEndpoints.all_functions.union( PUBLIC_ATTRIBUTES.add(constants.KERAS_ATTR) -KERAS_OBJECT_IDENTIFIERS = ( - '_tf_keras_layer', '_tf_keras_input_layer', '_tf_keras_network', - '_tf_keras_model', '_tf_keras_sequential') +KERAS_OBJECT_IDENTIFIERS = ('_tf_keras_layer', '_tf_keras_input_layer', + '_tf_keras_network', '_tf_keras_model', + '_tf_keras_sequential', '_tf_keras_metric') def load(path, compile=True): # pylint: disable=redefined-builtin @@ -179,6 +183,20 @@ class KerasObjectLoader(tf_load.Loader): super(KerasObjectLoader, self).__init__(*args, **kwargs) + # Now that the node object has been fully loaded, and the checkpoint has + # been restored, the object no longer needs to track objects added from + # SerializedAttributes. (Note that saving a training checkpoint still + # functions correctly, because layers and variables are tracked separately + # by the Layer object.) + # TODO(kathywu): Instead of outright deleting these nodes (which would + # make restoring from a different checkpoint tricky), mark them as extra + # dependencies that are OK to overwrite. + for node in self._nodes: + if not isinstance(node, base_layer.Layer): + continue + for name in PUBLIC_ATTRIBUTES: + delete_tracking(node, name) + def _load_all(self): """Reconstruct the object graph from the SavedModel.""" # Load layer and model objects from either config or SavedModel. The objects @@ -192,19 +210,6 @@ class KerasObjectLoader(tf_load.Loader): # Finish setting up layers and models. See function docstring for more info. self._finalize_objects() - # Now that the node object has been fully loaded, the object no longer needs - # to track objects added from SerializedAttributes. (Note that saving a - # training checkpoint still functions correctly, because layers and - # variables are tracked separately by the Layer object.) - # TODO(kathywu): Instead of outright deleting these nodes (which would - # make restoring from a different checkpoint tricky), mark them as extra - # dependencies that are OK to overwrite. - for node in self._nodes: - if not isinstance(node, base_layer.Layer): - continue - for name in PUBLIC_ATTRIBUTES: - delete_tracking(node, name) - @property def _expect_partial_checkpoint(self): return True @@ -230,10 +235,30 @@ class KerasObjectLoader(tf_load.Loader): return self._traversed_nodes_from_config.append(node_id) obj._maybe_initialize_trackable() + if isinstance(obj, base_layer.Layer) and not obj.built: + metadata = json_utils.decode(proto.user_object.metadata) + self._try_build_layer(obj, node_id, metadata.get('build_input_shape')) + # Create list of all possible children + children = [] + # Look for direct children for reference in proto.children: obj_child = obj._lookup_dependency(reference.local_name) - child_id = reference.node_id + children.append((obj_child, reference.node_id)) + + # Add metrics that may have been added to the layer._metrics list. + # This is stored in the SavedModel as layer.keras_api.layer_metrics in + # SavedModels created after Tf 2.2. + metric_list_node_id = self._search_for_child_node( + node_id, [constants.KERAS_ATTR, 'layer_metrics'], raise_error=False) + if metric_list_node_id is not None and hasattr(obj, '_metrics'): + obj_metrics = {m.name: m for m in obj._metrics} + for reference in self._proto.nodes[metric_list_node_id].children: + metric = obj_metrics.get(reference.local_name) + if metric is not None: + children.append((metric, reference.node_id)) + + for (obj_child, child_id) in children: child_proto = self._proto.nodes[child_id] if not isinstance(obj_child, trackable.Trackable): @@ -263,10 +288,23 @@ class KerasObjectLoader(tf_load.Loader): def _load_layers(self): layers = {} + + # Load metrics after models and layers, since it's likely that models + # and layers will create the metric when initialized (this avoids wasting + # time by creating objects multiple times). + metric_list = [] for node_id, proto in enumerate(self._proto.nodes): - if (proto.WhichOneof('kind') == 'user_object' and - proto.user_object.identifier in KERAS_OBJECT_IDENTIFIERS): - layers[node_id] = self._load_layer(proto.user_object, node_id) + if (proto.WhichOneof('kind') != 'user_object' or + proto.user_object.identifier not in KERAS_OBJECT_IDENTIFIERS): + continue + if proto.user_object.identifier == '_tf_keras_metric': + metric_list.append((node_id, proto)) + continue + + layers[node_id] = self._load_layer(proto.user_object, node_id) + + for node_id, proto in metric_list: + layers[node_id] = self._load_layer(proto.user_object, node_id) return layers def _load_layer(self, proto, node_id): @@ -277,8 +315,6 @@ class KerasObjectLoader(tf_load.Loader): if node_id in self._nodes_recreated_from_config: node, setter = self._nodes_recreated_from_config[node_id] - self._try_build_layer(node, node_id, metadata.get('build_input_shape')) - # Revive setter requires the object to have a `_serialized_attributes` # property. Add it here. _maybe_add_serialized_attributes(node, metadata) @@ -291,7 +327,7 @@ class KerasObjectLoader(tf_load.Loader): # Detect whether this object can be revived from the config. If not, then # revive from the SavedModel instead. - obj, setter = self._revive_from_config(metadata, node_id) + obj, setter = self._revive_from_config(proto.identifier, metadata, node_id) if obj is None: obj, setter = revive_custom_object(proto.identifier, metadata) @@ -302,10 +338,15 @@ class KerasObjectLoader(tf_load.Loader): _maybe_add_serialized_attributes(obj, metadata) return obj, setter - def _revive_from_config(self, metadata, node_id): + def _revive_from_config(self, identifier, metadata, node_id): """Revives a layer/model from config, or returns None.""" - obj = (self._revive_graph_network(metadata, node_id) or - self._revive_layer_from_config(metadata, node_id)) + if identifier == '_tf_keras_metric': + obj = self._revive_metric_from_config(metadata, node_id) + else: + obj = ( + self._revive_graph_network(metadata, node_id) or + self._revive_layer_from_config(metadata, node_id)) + if obj is None: return None, None @@ -382,6 +423,25 @@ class KerasObjectLoader(tf_load.Loader): return obj + def _revive_metric_from_config(self, metadata, node_id): + class_name = compat.as_str(metadata['class_name']) + config = metadata.get('config') + + if not generic_utils.validate_config(config): + return None + + try: + obj = metrics.deserialize( + generic_utils.serialize_keras_class_and_config(class_name, config)) + except ValueError: + return None + + build_input_shape = metadata.get('build_input_shape') + if build_input_shape is not None and hasattr(obj, '_build'): + obj._build(build_input_shape) # pylint: disable=protected-access + + return obj + def _try_build_layer(self, obj, node_id, build_input_shape): """Attempts to build the layer.""" if obj.built or hasattr(obj.build, '_is_default'): @@ -424,16 +484,18 @@ class KerasObjectLoader(tf_load.Loader): node_id in self.model_layer_dependencies): continue - # No need to apply the finalizing steps to input layers. + self._unblock_model_reconstruction(node_id, node) + if isinstance(node, input_layer.InputLayer): - self._unblock_model_reconstruction(node_id, node) + continue + elif isinstance(node, metrics.Metric): continue if node_id in self._nodes_recreated_from_config: layers_revived_from_config.append(node) else: layers_revived_from_saved_model.append(node) - self._unblock_model_reconstruction(node_id, node) + _finalize_saved_model_layers(layers_revived_from_saved_model) _finalize_config_layers(layers_revived_from_config) @@ -503,29 +565,63 @@ class KerasObjectLoader(tf_load.Loader): self._unblock_model_reconstruction(model_id, model) def _get_child_layer_node_ids(self, node_id, name): - # First, retrieve the node.keras_api.layers attribute, which is a list of - # all the layers in the node. - keras_attr = self._search_for_child_node(node_id, constants.KERAS_ATTR, - name) - layers_node = self._search_for_child_node(keras_attr, 'layers', name) - return [node.node_id for node in self._proto.nodes[layers_node].children] + """Returns the node ids of the children layers of a node.""" + # Retrieve the node id of layer.keras_api.layers. + layer_list = self._search_for_child_node( + node_id, [constants.KERAS_ATTR, 'layers'], name) + return [node.node_id for node in self._proto.nodes[layer_list].children] - def _search_for_child_node(self, node_id, child_name, debugging_name): - for child in self._proto.nodes[node_id].children: - if child.local_name == child_name: - return child.node_id - raise ValueError( - 'Error when loading {}: could not find attribute {}.\n' - 'Most likely this object was serialized incorrectly.' - .format(debugging_name, child_name)) + def _search_for_child_node( + self, parent_id, path_to_child, debugging_name=None, raise_error=True): + """Returns node id of child node. + + A helper method for traversing the object graph proto. + + As an example, say that the object graph proto in the SavedModel contains an + object with the following child and grandchild attributes: + + `parent.child_a.child_b` + + This method can be used to retrieve the node id of `child_b` using the + parent's node id by calling: + + `_search_for_child_node(parent_id, ['child_a', 'child_b'])`. + + Args: + parent_id: node id of parent node + path_to_child: list of children names. + debugging_name: the name to print out when raising an error. + raise_error: Whether to raise an error if the child isn't found. + + Returns: + node_id of child, or None if child isn't found. + + Raises: + ValueError: if child isn't found and raise_error is True. + """ + if not path_to_child: + return parent_id + + for child in self._proto.nodes[parent_id].children: + if child.local_name == path_to_child[0]: + return self._search_for_child_node(child.node_id, path_to_child[1:], + debugging_name, raise_error) + + if raise_error: + raise ValueError( + 'Error when loading {}: could not find attribute {}.\n' + 'Most likely this object was serialized incorrectly.' + .format(debugging_name or path_to_child[0], path_to_child[0])) + else: + return None def _infer_inputs(self, layer_node_id, convert_to_shapes=False): """Infers input shape of layer from SavedModel functions.""" coder = nested_structure_coder.StructureCoder() - try: - call_fn_id = self._search_for_child_node( - layer_node_id, 'call_and_return_all_conditional_losses', None) - except ValueError: + call_fn_id = self._search_for_child_node( + layer_node_id, ['call_and_return_all_conditional_losses'], None, + raise_error=False) + if call_fn_id is None: return None concrete_functions = ( @@ -579,6 +675,10 @@ def _finalize_saved_model_layers(layers): # 3. Add losses that aren't generated by the layer.call function. _restore_layer_unconditional_losses(layer) _restore_layer_activation_loss(layer) + + # 4. Restore metrics list + _restore_layer_metrics(layer) + # pylint: enable=protected-access @@ -600,6 +700,15 @@ def _finalize_config_layers(layers): # loading behavior between HDF5 and SavedModel. _restore_layer_activation_loss(layer) + # Restore metrics list. + _restore_layer_metrics(layer) + + +def _finalize_metric(metric): + metric.update_state = types.MethodType(metrics_utils.update_state_wrapper( + metric.keras_api.update_state), metric) + metric.result = metric.keras_api.result + def _restore_layer_unconditional_losses(layer): """Restore unconditional losses from SavedModel.""" @@ -641,7 +750,7 @@ def revive_custom_object(identifier, metadata): '_tf_keras_input_layer': (RevivedInputLayer, input_layer.InputLayer), '_tf_keras_network': (RevivedNetwork, network_lib.Network), '_tf_keras_model': (RevivedNetwork, model_class), - '_tf_keras_sequential': (RevivedNetwork, models_lib.Sequential) + '_tf_keras_sequential': (RevivedNetwork, models_lib.Sequential), } parent_classes = revived_classes.get(identifier, None) @@ -651,6 +760,21 @@ def revive_custom_object(identifier, metadata): revived_cls = type( compat.as_str(metadata['class_name']), parent_classes, {}) return revived_cls._init_from_metadata(metadata) # pylint: disable=protected-access + else: + raise ValueError('Unable to restore custom object of type {} currently. ' + 'Please make sure that the layer implements `get_config`' + 'and `from_config` when saving. In addition, please use ' + 'the `custom_objects` arg when calling `load_model()`.' + .format(identifier)) + + +def _restore_layer_metrics(layer): + metrics_list = getattr(_get_keras_attr(layer), 'layer_metrics', {}) + layer_metrics = {m.name: m for m in layer._metrics} # pylint: disable=protected-access + for name, metric in metrics_list.items(): + if name not in layer_metrics: + # Metrics may be added during initialization/building of custom layers. + layer._metrics.append(metric) # pylint: disable=protected-access # TODO(kathywu): Centrally define keys and functions for both serialization and diff --git a/tensorflow/python/keras/saving/saved_model/metric_serialization.py b/tensorflow/python/keras/saving/saved_model/metric_serialization.py new file mode 100644 index 00000000000..efe977ec55f --- /dev/null +++ b/tensorflow/python/keras/saving/saved_model/metric_serialization.py @@ -0,0 +1,45 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Classes and functions implementing Metrics SavedModel serialization.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.keras.saving.saved_model import layer_serialization +from tensorflow.python.training.tracking import data_structures + + +class MetricSavedModelSaver(layer_serialization.LayerSavedModelSaver): + """Metric serialization.""" + + @property + def object_identifier(self): + return '_tf_keras_metric' + + def _python_properties_internal(self): + metadata = dict( + class_name=type(self.obj).__name__, + name=self.obj.name, + dtype=self.obj.dtype) + metadata.update(layer_serialization.get_config(self.obj)) + if self.obj._build_input_shape is not None: # pylint: disable=protected-access + metadata['build_input_shape'] = self.obj._build_input_shape # pylint: disable=protected-access + return metadata + + def _get_serialized_attributes_internal(self, unused_serialization_cache): + return (dict(variables=data_structures.ListWrapper(self.obj.variables)), + dict()) # TODO(b/135550038): save functions to enable saving + # custom metrics. diff --git a/tensorflow/python/keras/saving/saved_model/revive_test.py b/tensorflow/python/keras/saving/saved_model/revive_test.py index ca3ecfc5a77..575bf0fc4bc 100644 --- a/tensorflow/python/keras/saving/saved_model/revive_test.py +++ b/tensorflow/python/keras/saving/saved_model/revive_test.py @@ -90,6 +90,8 @@ class CustomLayerNoConfig(keras.layers.Layer): def a_regularizer(): return self.a * 2 self.add_loss(a_regularizer) + self.sum_metric = keras.metrics.Sum(name='inputs_sum') + self.unused_metric = keras.metrics.Sum(name='not_added_to_metrics') def build(self, input_shape): self.c = variables.Variable( @@ -97,6 +99,9 @@ class CustomLayerNoConfig(keras.layers.Layer): def call(self, inputs): self.add_loss(math_ops.reduce_sum(inputs), inputs) + self.add_metric(self.sum_metric(inputs)) + self.add_metric(inputs, aggregation='mean', name='mean') + return inputs + self.c @@ -143,6 +148,9 @@ class TestModelRevive(keras_parameterized.TestCase): self.assertAllClose(model(input_arr), revived(input_arr)) self.assertAllClose(sum(model.losses), sum(revived.losses)) self.assertAllClose(len(model.losses), len(revived.losses)) + self.assertEqual(len(model.metrics), len(revived.metrics)) + self.assertAllClose([m.result() for m in model.metrics], + [m.result() for m in revived.metrics]) model_layers = {layer.name: layer for layer in model.layers} revived_layers = {layer.name: layer for layer in revived.layers} self.assertAllEqual(model_layers.keys(), revived_layers.keys()) diff --git a/tensorflow/python/keras/saving/saved_model/save_impl.py b/tensorflow/python/keras/saving/saved_model/save_impl.py index 7bd2b52fe84..9d9ff2b0ed2 100644 --- a/tensorflow/python/keras/saving/saved_model/save_impl.py +++ b/tensorflow/python/keras/saving/saved_model/save_impl.py @@ -53,6 +53,8 @@ from tensorflow.python.util.lazy_loader import LazyLoader base_layer = LazyLoader( "base_layer", globals(), "tensorflow.python.keras.engine.base_layer") +metrics = LazyLoader("metrics", globals(), + "tensorflow.python.keras.metrics") input_layer = LazyLoader( "input_layer", globals(), "tensorflow.python.keras.engine.input_layer") @@ -108,6 +110,9 @@ def wrap_layer_objects(layer, serialization_cache): wrapped_loss_functions.append(wrapped_loss) wrapped_layer_losses = [keras_loss_cache[fn] for fn in layer._callable_losses[:]] # pylint: disable=protected-access + + layer_metrics = data_structures._DictWrapper( # pylint: disable=protected-access + {m.name: m for m in layer._metrics}) # pylint: disable=protected-access return dict( variables=data_structures.ListWrapper(layer.variables), trainable_variables=data_structures.ListWrapper( @@ -119,7 +124,9 @@ def wrap_layer_objects(layer, serialization_cache): regularization_losses=data_structures.ListWrapper( wrapped_loss_functions), layer_regularization_losses=data_structures.ListWrapper( - wrapped_layer_losses)) + wrapped_layer_losses), + layer_metrics=layer_metrics) + # pylint: disable=protected-access def wrap_layer_functions(layer, serialization_cache): @@ -218,24 +225,55 @@ def _replace_child_layer_functions(layer, serialization_cache): { Child layer 1: { 'losses': Original losses, 'call': Original call function - 'activity_regularizer': Original activity regularizer}, + '_activity_regularizer': Original activity regularizer}, Child layer 2: ... } """ # pylint: disable=protected-access original_fns = {} + + def replace_layer_functions(child_layer, serialized_fns): + """Replaces layer call and activity regularizer with wrapped functions.""" + original_fns[child_layer] = { + 'call': child_layer.call, + '_activity_regularizer': child_layer._activity_regularizer + } + with trackable.no_automatic_dependency_tracking_scope(child_layer): + try: + child_layer._activity_regularizer = serialized_fns.get( + 'activity_regularizer_fn') + except AttributeError: + # Some layers have an unsettable activity regularizer. + pass + child_layer.call = utils.use_wrapped_call( + child_layer, + serialized_fns['call_and_return_conditional_losses'], + default_training_value=False) + + def replace_metric_functions(child_layer, serialized_fns): + """Replaces metric functions with wrapped functions.""" + original_fns[child_layer] = { + '__call__': child_layer.__call__, + 'result': child_layer.result, + 'update_state': child_layer.update_state + } + with trackable.no_automatic_dependency_tracking_scope(child_layer): + child_layer.__call__ = serialized_fns['__call__'] + child_layer.result = serialized_fns['result'] + child_layer.update_state = serialized_fns['update_state'] + for child_layer in utils.list_all_layers(layer): if isinstance(child_layer, input_layer.InputLayer): continue if child_layer not in serialization_cache[constants.KERAS_CACHE_KEY]: - layer_fns = ( + serialized_functions = ( child_layer._trackable_saved_model_saver._get_serialized_attributes( serialization_cache).functions) else: - layer_fns = ( + serialized_functions = ( serialization_cache[constants.KERAS_CACHE_KEY][child_layer].functions) - if not layer_fns: + if not serialized_functions: # This indicates either: # - circular dependency, which means the current layer's functions # should be wrapped first. @@ -243,20 +281,12 @@ def _replace_child_layer_functions(layer, serialization_cache): # wrapped. In this case, no replacement is necessary so move on to the # next child. continue - original_fns[child_layer] = { - 'call': child_layer.call, - 'activity_regularizer': child_layer._activity_regularizer - } - with trackable.no_automatic_dependency_tracking_scope(child_layer): - try: - child_layer._activity_regularizer = layer_fns.get( - 'activity_regularizer_fn') - except AttributeError: - # Some layers have an unsettable activity regularizer. - pass - child_layer.call = utils.use_wrapped_call( - child_layer, layer_fns['call_and_return_conditional_losses'], - default_training_value=False) + + if isinstance(child_layer, metrics.Metric): + replace_metric_functions(child_layer, serialized_functions) + else: + replace_layer_functions(child_layer, serialized_functions) + return original_fns # pylint: enable=protected-access @@ -265,11 +295,12 @@ def _restore_child_layer_functions(original_fns): """Restores attributes replaced with `_replace_child_layer_functions`.""" for child_layer, fns in original_fns.items(): with trackable.no_automatic_dependency_tracking_scope(child_layer): - child_layer.call = fns['call'] - try: - child_layer._activity_regularizer = fns['activity_regularizer'] # pylint: disable=protected-access - except AttributeError: - pass + for fn_name, fn in fns.items(): + try: + setattr(child_layer, fn_name, fn) # pylint: disable=protected-access + except AttributeError: + pass # In the case of _activity_regularizer, setting the attribute + # may be disallowed. # pylint: disable=protected-access diff --git a/tensorflow/python/keras/saving/saved_model/saved_model_test.py b/tensorflow/python/keras/saving/saved_model/saved_model_test.py index da86a7cdac1..63eeceb2a6b 100644 --- a/tensorflow/python/keras/saving/saved_model/saved_model_test.py +++ b/tensorflow/python/keras/saving/saved_model/saved_model_test.py @@ -27,6 +27,7 @@ from __future__ import print_function import os import shutil +from absl.testing import parameterized import numpy as np from tensorflow.core.example import example_pb2 @@ -49,6 +50,7 @@ from tensorflow.python.keras import regularizers from tensorflow.python.keras import testing_utils from tensorflow.python.keras.saving.saved_model import load as keras_load from tensorflow.python.keras.saving.saved_model import save_impl as keras_save +from tensorflow.python.keras.utils import generic_utils from tensorflow.python.keras.utils import tf_utils from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops @@ -746,5 +748,128 @@ class TestLayerCallTracing(test.TestCase): self.assertAllEqual(previous_losses, layer.losses) +@test_util.run_all_in_graph_and_eager_modes +class MetricTest(test.TestCase, parameterized.TestCase): + + def _save_model_dir(self, dirname='saved_model'): + temp_dir = self.get_temp_dir() + self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True) + return os.path.join(temp_dir, dirname) + + def generate_inputs(self, num_tensor_args, shape=(1, 5)): + return [ + np.random.uniform(0, 1, shape).astype('float32') + for _ in range(num_tensor_args) + ] + + def _test_metric_save_and_load(self, + metric, + save_dir, + num_tensor_args, + shape=(1, 5), + test_sample_weight=True): + tf_save.save(metric, save_dir) + loaded = keras_load.load(save_dir) + self.evaluate([v.initializer for v in loaded.variables]) + self.assertEqual(metric.name, loaded.name) + self.assertEqual(metric.dtype, loaded.dtype) + + inputs = self.generate_inputs(num_tensor_args, shape) + actual = self.evaluate(metric(*inputs)) + self.assertAllClose(actual, loaded(*inputs)) + self.assertAllClose(metric.variables, loaded.variables) + + # Test with separate calls to update state and result. + inputs = self.generate_inputs(num_tensor_args, shape) + self.evaluate(metric.update_state(*inputs)) + self.evaluate(loaded.update_state(*inputs)) + actual = self.evaluate(metric.result()) + self.assertAllClose(actual, loaded.result()) + + if test_sample_weight: + # Test with sample weights input. + inputs = self.generate_inputs(num_tensor_args, shape) + sample_weight = self.generate_inputs(1, [])[0] + inputs.append(sample_weight) + + actual = self.evaluate(metric(*inputs)) + self.assertAllClose(actual, loaded(*inputs)) + return loaded + + @parameterized.named_parameters([ + ('mean', keras.metrics.Mean, 1, (1, 5)), + ('false_positives', keras.metrics.FalsePositives, 2, (1, 5)), + ('precision_at_top_k', keras.metrics.Precision, 2, (2, 3, 4), { + 'top_k': 2, + 'class_id': 1 + }), + ('precision_at_recall', keras.metrics.PrecisionAtRecall, 2, (1, 5), { + 'recall': .8 + }), ('auc', keras.metrics.AUC, 2, (1, 5), { + 'multi_label': True + }), ('cosine_similarity', keras.metrics.CosineSimilarity, 2, (2, 3, 1)) + ]) + def test_metric(self, metric_cls, num_tensor_args, shape, init_kwargs=None): + init_kwargs = init_kwargs or {} + metric = metric_cls(**init_kwargs) + metric(*self.generate_inputs(num_tensor_args, shape)) + self.evaluate([v.initializer for v in metric.variables]) + loaded = self._test_metric_save_and_load(metric, self._save_model_dir(), + num_tensor_args, shape) + self.assertEqual(type(loaded), type(metric)) + + @parameterized.named_parameters([ + ('mean', keras.metrics.Mean, 1, False), + ('auc', keras.metrics.AUC, 2, False), + ('mean_tensor', keras.metrics.MeanTensor, 1, True)]) + def test_custom_metric(self, base_cls, num_tensor_args, requires_build): + + class CustomMetric(base_cls): + + def update_state(self, *args): # pylint: disable=useless-super-delegation + # Sometimes built-in metrics return an op in update_state. Custom + # metrics don't support returning ops, so wrap the update_state method + # while returning nothing. + super(CustomMetric, self).update_state(*args) + + metric = CustomMetric() + save_dir = self._save_model_dir('first_save') + + if requires_build: + metric(*self.generate_inputs(num_tensor_args)) # pylint: disable=not-callable + + self.evaluate([v.initializer for v in metric.variables]) + + with self.assertRaisesRegexp(ValueError, 'Unable to restore custom object'): + self._test_metric_save_and_load(metric, save_dir, num_tensor_args) + with generic_utils.CustomObjectScope({'CustomMetric': CustomMetric}): + loaded = self._test_metric_save_and_load( + metric, + save_dir, + num_tensor_args, + test_sample_weight=False) + + self._test_metric_save_and_load( + loaded, + self._save_model_dir('second_save'), + num_tensor_args, + test_sample_weight=False) + + def test_custom_metric_wrapped_call(self): + + class NegativeMean(keras.metrics.Mean): + + @def_function.function( + input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)]) + def update_state(self, value): + super(NegativeMean, self).update_state(-value) + + metric = NegativeMean() + self.evaluate([v.initializer for v in metric.variables]) + with generic_utils.CustomObjectScope({'NegativeMean': NegativeMean}): + self._test_metric_save_and_load( + metric, self._save_model_dir(), 1, test_sample_weight=False) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/keras/saving/saved_model/serialized_attributes.py b/tensorflow/python/keras/saving/saved_model/serialized_attributes.py index c62dbb697a5..c87a0d64cf7 100644 --- a/tensorflow/python/keras/saving/saved_model/serialized_attributes.py +++ b/tensorflow/python/keras/saving/saved_model/serialized_attributes.py @@ -34,6 +34,8 @@ base_layer = LazyLoader( training_lib = LazyLoader( "training_lib", globals(), "tensorflow.python.keras.engine.training") +metrics = LazyLoader("metrics", globals(), + "tensorflow.python.keras.metrics") # pylint:enable=g-inconsistent-quotes @@ -138,6 +140,8 @@ class SerializedAttributes(object): def new(obj): if isinstance(obj, training_lib.Model): return ModelAttributes() + elif isinstance(obj, metrics.Metric): + return MetricAttributes() elif isinstance(obj, base_layer.Layer): return LayerAttributes() else: @@ -203,7 +207,8 @@ class SerializedAttributes(object): self._object_dict[key] = object_dict[key] setattr(self._keras_trackable, key, object_dict[key]) else: - raise ValueError('Object {} missing from serialized object dict.') + raise ValueError( + 'Object {} missing from serialized object dict.'.format(key)) return self.checkpointable_objects @@ -233,7 +238,7 @@ class CommonEndpoints(SerializedAttributes.with_attributes( class LayerAttributes(SerializedAttributes.with_attributes( 'LayerAttributes', checkpointable_objects=['non_trainable_variables', 'layers', 'metrics', - 'layer_regularization_losses'], + 'layer_regularization_losses', 'layer_metrics'], functions=['call_and_return_conditional_losses', 'activity_regularizer_fn'], copy_from=[CommonEndpoints] )): @@ -252,6 +257,7 @@ class LayerAttributes(SerializedAttributes.with_attributes( activity regularizer. activity_regularizer_fn: Callable that returns the activity regularizer loss layer_regularization_losses: List of losses owned only by this layer. + layer_metrics: List of metrics owned by this layer. """ @@ -265,3 +271,17 @@ class ModelAttributes(SerializedAttributes.with_attributes( """ # TODO(kathywu): Add attributes `compile_losses` and `compile_metrics`, which # list all losses and metrics defined by `model.compile`. + + +class MetricAttributes( + SerializedAttributes.with_attributes( + 'MetricAttributes', + checkpointable_objects=['variables'], + functions=[], + )): + """Attributes that are added to Metric objects when saved to SavedModel. + + List of all attributes: + variables: list of all variables + """ + pass diff --git a/tensorflow/python/kernel_tests/cond_v2_test.py b/tensorflow/python/kernel_tests/cond_v2_test.py index 60d34a0e299..94a0e73e64f 100644 --- a/tensorflow/python/kernel_tests/cond_v2_test.py +++ b/tensorflow/python/kernel_tests/cond_v2_test.py @@ -780,21 +780,29 @@ class CondV2Test(test.TestCase): self.evaluate(variables.global_variables_initializer()) + def update_v1(): + v1.assign(v1) + return v1 + + def update_v2(): + v2.assign(v2) + return v2 + @def_function.function def fn_with_cond(): cond_v2.cond_v2( constant_op.constant(True), - lambda: v1, + update_v1, lambda: constant_op.constant(0.), name="cond_1") cond_2 = cond_v2.cond_v2( constant_op.constant(False), lambda: constant_op.constant(0.), - lambda: v1, + update_v1, name="cond_2") cond_v2.cond_v2( constant_op.constant(True), - lambda: v2, + update_v2, lambda: constant_op.constant(0.), name="cond_3") cond_4 = cond_v2.cond_v2( @@ -841,24 +849,34 @@ class CondV2Test(test.TestCase): @def_function.function def fn_with_cond(): + + def update_v1(): + v1.assign(v1) + return v1 + + def update_v2(): + v2.assign(v2) + return v2 + cond_v2.cond_v2( constant_op.constant(True), - lambda: v1, + update_v1, lambda: constant_op.constant(0.), name="cond_1") cond_2 = cond_v2.cond_v2( constant_op.constant(False), lambda: constant_op.constant(0.), - lambda: v1, + update_v1, name="cond_2") cond_v2.cond_v2( constant_op.constant(True), - lambda: v2, + update_v2, lambda: constant_op.constant(0.), name="cond_3") @def_function.function def cond_4_false_branch(): + v2.assign(v2) return v2 cond_4 = cond_v2.cond_v2( diff --git a/tensorflow/python/kernel_tests/while_v2_test.py b/tensorflow/python/kernel_tests/while_v2_test.py index 15a389294e5..8f131723cd2 100644 --- a/tensorflow/python/kernel_tests/while_v2_test.py +++ b/tensorflow/python/kernel_tests/while_v2_test.py @@ -340,14 +340,24 @@ class WhileV2Test(test.TestCase, parameterized.TestCase): @def_function.function def Fn(): + + def Body1(v): + x.assign(x) + return v * x + ret1 = while_loop_v2( lambda v: v < 4., - lambda v: v * x, [c], + Body1, [c], return_same_structure=False, name="while_1") # 2x + + def Body2(v): + x.assign(x) + return v * x * x + ret2 = while_loop_v2( lambda v: v < 16., - lambda v: v * x * x, [c], + Body2, [c], return_same_structure=False, name="while_2") # 4x return ret1, ret2 @@ -368,24 +378,44 @@ class WhileV2Test(test.TestCase, parameterized.TestCase): @def_function.function def Fn(): + + def Body1(v): + x1.assign(x1) + return v * x1 + ret1 = while_loop_v2( lambda v: v < 4., - lambda v: v * x1, [c], + Body1, [c], return_same_structure=False, name="while_1") # 2x + + def Body2(v): + x1.assign(x1) + return v * x1 * x1 + ret2 = while_loop_v2( lambda v: v < 16., - lambda v: v * x1 * x1, [c], + Body2, [c], return_same_structure=False, name="while_2") # 4x + + def Body3(v): + x2.assign(x2) + return v * x2 + ret3 = while_loop_v2( lambda v: v < 4., - lambda v: v * x2, [c], + Body3, [c], return_same_structure=False, name="while_3") # 3x + + def Body4(v): + x2.assign(x2) + return v * x2 * x2 + ret4 = while_loop_v2( lambda v: v < 16., - lambda v: v * x2 * x2, [c], + Body4, [c], return_same_structure=False, name="while_4") # 9x ret5 = while_loop_v2( diff --git a/tensorflow/python/module/BUILD b/tensorflow/python/module/BUILD index 4585d39e592..fea6fe123ad 100644 --- a/tensorflow/python/module/BUILD +++ b/tensorflow/python/module/BUILD @@ -28,6 +28,7 @@ tf_py_test( "//tensorflow/python:client_testlib", "//tensorflow/python:variables", "//tensorflow/python/compat:v2_compat", + "//tensorflow/python/distribute:tpu_values", "@absl_py//absl/testing:parameterized", ], ) diff --git a/tensorflow/python/module/module_test.py b/tensorflow/python/module/module_test.py index 7fa4fc14d7f..b2fc4ff9645 100644 --- a/tensorflow/python/module/module_test.py +++ b/tensorflow/python/module/module_test.py @@ -26,6 +26,7 @@ from absl.testing import parameterized import six from tensorflow.python import tf2 +from tensorflow.python.distribute import tpu_values from tensorflow.python.distribute import values as distributed_values from tensorflow.python.eager import context from tensorflow.python.eager import def_function @@ -249,10 +250,8 @@ class VariableTrackingTest(test_util.TensorFlowTestCase): def test_supports_distributed_variables(self): mirrored = distributed_values.MirroredVariable( None, [variables.Variable(1.)], variables.VariableAggregation.SUM) - tpu = distributed_values.TPUMirroredVariable( - strategy=None, - values=[variables.Variable(42.)], - aggregation=None) + tpu = tpu_values.TPUMirroredVariable( + strategy=None, values=[variables.Variable(42.)], aggregation=None) aggregating = distributed_values.AggregatingVariable( strategy=None, v=variables.Variable(1.), aggregation=None) diff --git a/tensorflow/python/ops/cond_v2.py b/tensorflow/python/ops/cond_v2.py index fd0102328d1..19f777115db 100644 --- a/tensorflow/python/ops/cond_v2.py +++ b/tensorflow/python/ops/cond_v2.py @@ -26,6 +26,7 @@ from __future__ import print_function import collections from tensorflow.python.eager import backprop_util +from tensorflow.python.framework import auto_control_deps_utils as acd from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import func_graph as func_graph_module @@ -279,6 +280,7 @@ def _build_cond(pred, if_op._false_graph = false_graph util.maybe_set_lowering_attr(if_op) util.maybe_propagate_compile_time_consts_in_xla(if_op) + _set_read_only_resource_inputs_attr(if_op, [true_graph, false_graph]) # Prevent fetching since the variant outputs can't be fetched directly. if_op.graph.prevent_fetching(if_op) @@ -1105,6 +1107,7 @@ def _build_case(branch_index, branch_graphs, branch_inputs, name=None): if case_op is not None: util.maybe_set_lowering_attr(case_op) util.maybe_propagate_compile_time_consts_in_xla(case_op) + _set_read_only_resource_inputs_attr(case_op, branch_graphs) # Prevent fetching since the variant outputs can't be fetched directly. case_op.graph.prevent_fetching(case_op) @@ -1119,3 +1122,28 @@ def _build_case(branch_index, branch_graphs, branch_inputs, name=None): tensors = [array_ops.identity(t) for t in tensors] return _pack_sequence_as(branch_graphs[0].structured_outputs, tensors) + + +def _set_read_only_resource_inputs_attr(op, branch_graphs): + """Sets the list of resource inputs which are read-only. + + This is used by AutomaticControlDependencies. + + Args: + op: If or Case Operation. + branch_graphs: List of branch FuncGraphs. + """ + read_only_indices = [] + for i in range(1, len(op.inputs)): + if op.inputs[i].dtype != dtypes.resource: + continue + has_write = False + for branch_graph in branch_graphs: + handle = branch_graph.inputs[i - 1] + if acd.resource_has_writes(handle): + has_write = True + break + if not has_write: + read_only_indices.append(i) + ops.set_int_list_attr(op, acd.READ_ONLY_RESOURCE_INPUTS_ATTR, + read_only_indices) diff --git a/tensorflow/python/ops/functional_ops.py b/tensorflow/python/ops/functional_ops.py index 4698f870785..a43cfc0be65 100644 --- a/tensorflow/python/ops/functional_ops.py +++ b/tensorflow/python/ops/functional_ops.py @@ -20,6 +20,7 @@ from __future__ import print_function from tensorflow.core.framework import attr_value_pb2 from tensorflow.python.eager import context +from tensorflow.python.framework import auto_control_deps_utils as acd from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import function @@ -1178,4 +1179,25 @@ def partitioned_call(args, op_attrs[xla_compile_attr] = f.definition.attr[xla_compile_attr] op = graph.create_op(op_name, args, tout, name=op_name, attrs=op_attrs) outputs = op.outputs + if hasattr(f, "graph"): + _set_read_only_resource_inputs_attr(op, f.graph) return outputs if outputs else op + + +def _set_read_only_resource_inputs_attr(op, func_graph): + """Sets the list of resource inputs which are read-only. + + This is used by AutomaticControlDependencies. + + Args: + op: PartitionedCall Operation. + func_graph: FuncGraph. + """ + read_only_indices = [] + for i in range(len(op.inputs)): + handle = func_graph.inputs[i] + if handle.dtype != dtypes.resource or acd.resource_has_writes(handle): + continue + read_only_indices.append(i) + ops.set_int_list_attr(op, acd.READ_ONLY_RESOURCE_INPUTS_ATTR, + read_only_indices) diff --git a/tensorflow/python/ops/parallel_for/control_flow_ops_test.py b/tensorflow/python/ops/parallel_for/control_flow_ops_test.py index cab7573256b..e6a67efa301 100644 --- a/tensorflow/python/ops/parallel_for/control_flow_ops_test.py +++ b/tensorflow/python/ops/parallel_for/control_flow_ops_test.py @@ -1735,7 +1735,8 @@ class SpectralTest(PForTestCase, parameterized.TestCase): (fft_ops.irfft2d,), (fft_ops.irfft3d,), ) - def test_irfft(self, op_func): + # TODO(agarwal): Reenable this once the test flaky is fixed. + def disabled_test_irfft(self, op_func): for dtype in (dtypes.complex64, dtypes.complex128): shape = [2, 3, 4, 3, 4] x = np.random.uniform(size=shape) + 1j * np.random.uniform(size=shape) diff --git a/tensorflow/python/ops/ragged/ragged_string_ops.py b/tensorflow/python/ops/ragged/ragged_string_ops.py index d5a508a08be..493e5b97cd6 100755 --- a/tensorflow/python/ops/ragged/ragged_string_ops.py +++ b/tensorflow/python/ops/ragged/ragged_string_ops.py @@ -457,8 +457,8 @@ def string_split_v2(input, sep=None, maxsplit=-1, name=None): # pylint: disable """Split elements of `input` based on `sep` into a `RaggedTensor`. Let N be the size of `input` (typically N will be the batch size). Split each - element of `input` based on `sep` and return a `SparseTensor` or - `RaggedTensor` containing the split tokens. Empty tokens are ignored. + element of `input` based on `sep` and return a `RaggedTensor` containing the + split tokens. Empty tokens are ignored. Example: diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py index 6e02531fccb..20e2b87c69b 100644 --- a/tensorflow/python/ops/resource_variable_ops.py +++ b/tensorflow/python/ops/resource_variable_ops.py @@ -29,6 +29,7 @@ from tensorflow.python import _pywrap_utils from tensorflow.python.client import pywrap_tf_session from tensorflow.python.eager import context from tensorflow.python.eager import tape +from tensorflow.python.framework import auto_control_deps_utils as acd from tensorflow.python.framework import constant_op from tensorflow.python.framework import cpp_shape_inference_pb2 from tensorflow.python.framework import dtypes @@ -53,6 +54,13 @@ from tensorflow.python.util.deprecation import deprecated from tensorflow.python.util.deprecation import deprecated_args +acd.register_read_only_resource_op("ReadVariableOp") +acd.register_read_only_resource_op("VariableShape") +acd.register_read_only_resource_op("ResourceGather") +acd.register_read_only_resource_op("ResourceGatherNd") +acd.register_read_only_resource_op("_ReadVariablesOp") + + def get_resource_handle_data(graph_op): assert type(graph_op) == ops.Tensor # pylint: disable=unidiomatic-typecheck diff --git a/tensorflow/python/ops/while_v2.py b/tensorflow/python/ops/while_v2.py index 4b8f45e2617..c7ec0668ef0 100644 --- a/tensorflow/python/ops/while_v2.py +++ b/tensorflow/python/ops/while_v2.py @@ -26,6 +26,7 @@ from __future__ import print_function from tensorflow.core.framework import attr_value_pb2 from tensorflow.python.client import pywrap_tf_session as c_api from tensorflow.python.eager import backprop_util +from tensorflow.python.framework import auto_control_deps_utils as acd from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import func_graph as func_graph_module @@ -275,13 +276,8 @@ def while_loop(cond, # This is needed so we do not compute derivative wrt these extra outputs. outputs[0].op._set_attr("_num_original_outputs", attr_value_pb2.AttrValue(i=num_original_outputs)) - outputs[0].op._cond_graph = cond_graph outputs[0].op._body_graph = body_graph - _copy_handle_data(body_graph.outputs, outputs) - util.maybe_set_lowering_attr(outputs[0].op) - util.maybe_propagate_compile_time_consts_in_xla(outputs[0].op) - if not ops.get_default_graph().building_function: # In V1 graph mode, return identities for each output of the While op, # rather than the output of the While op directly. This makes pruning work @@ -401,11 +397,6 @@ def _WhileGrad(op, *grads): # pylint: disable=invalid-name output_shapes=[t.shape for t in body_grad_graph.outputs], parallel_iterations=parallel_iterations, name="%s_grad" % while_op.name) - grad_op = outputs[0].op - - _copy_handle_data(body_grad_graph.outputs, outputs) - util.maybe_set_lowering_attr(grad_op) - util.maybe_propagate_compile_time_consts_in_xla(grad_op) # See comment in while_loop. outputs = [array_ops.identity(t) for t in outputs] @@ -426,13 +417,19 @@ def _build_while_op(loop_vars, cond_graph, body_graph, output_shapes, else: op_fn = gen_functional_ops.stateless_while - return op_fn( + outputs = op_fn( loop_vars, util.create_new_tf_function(cond_graph), util.create_new_tf_function(body_graph), output_shapes=output_shapes, parallel_iterations=parallel_iterations, name=name) + while_op = outputs[0].op + _copy_handle_data(body_graph.outputs, outputs) + util.maybe_set_lowering_attr(while_op) + util.maybe_propagate_compile_time_consts_in_xla(while_op) + _set_read_only_resource_inputs_attr(while_op, [cond_graph, body_graph]) + return outputs def _get_intermediates(func_graph): @@ -1294,4 +1291,28 @@ class _OperationWithOutputs(ops.Operation): self._is_stateful = False +def _set_read_only_resource_inputs_attr(op, branch_graphs): + """Sets the list of resource inputs which are read-only. + + This is used by AutomaticControlDependencies. + + Args: + op: While Operation. + branch_graphs: List of branch FuncGraphs. + """ + read_only_indices = [] + for i in range(len(op.inputs)): + if op.inputs[i].dtype != dtypes.resource: + continue + has_write = False + for branch_graph in branch_graphs: + handle = branch_graph.inputs[i] + if acd.resource_has_writes(handle): + has_write = True + break + if not has_write: + read_only_indices.append(i) + ops.set_int_list_attr(op, acd.READ_ONLY_RESOURCE_INPUTS_ATTR, + read_only_indices) + # pylint: enable=protected-access diff --git a/tensorflow/python/profiler/internal/profiler_wrapper.cc b/tensorflow/python/profiler/internal/profiler_wrapper.cc index 5c11fbb1cff..ef1a1c032f7 100644 --- a/tensorflow/python/profiler/internal/profiler_wrapper.cc +++ b/tensorflow/python/profiler/internal/profiler_wrapper.cc @@ -40,6 +40,7 @@ tensorflow::ProfileRequest MakeProfileRequest() { tensorflow::ProfileRequest request; request.add_tools("overview_page"); request.add_tools("input_pipeline"); + request.add_tools("kernel_stats"); request.add_tools("tensorflow_stats"); return request; } diff --git a/tensorflow/python/profiler/profiler_v2.py b/tensorflow/python/profiler/profiler_v2.py index afbe1ec5881..ba65aea7621 100644 --- a/tensorflow/python/profiler/profiler_v2.py +++ b/tensorflow/python/profiler/profiler_v2.py @@ -106,6 +106,18 @@ def stop(save=True): _profiler = None +def warmup(): + """Warm-up the profiler session. + + The profiler session will set up profiling context, including loading CUPTI + library for GPU profiling. This is used for improving the accuracy of + the profiling results. + + """ + start('') + stop(save=False) + + @tf_export('profiler.experimental.server.start', v1=[]) def start_server(port): """Start a profiler grpc server that listens to given port. diff --git a/tensorflow/python/tpu/tpu.py b/tensorflow/python/tpu/tpu.py index 96789d2cea5..7c4f40c6b66 100644 --- a/tensorflow/python/tpu/tpu.py +++ b/tensorflow/python/tpu/tpu.py @@ -208,13 +208,14 @@ def _enclosing_tpu_device_assignment(): @auto_control_deps.register_acd_resource_resolver -def tpu_replicated_input_resolver(op, resource_inputs): +def tpu_replicated_input_resolver(op, resource_reads, resource_writes): """Replaces TPUReplicatedInput outputs with its inputs in resource_inputs.""" # Ignore TPUReplicatedInput for ACD purposes since we will be directly adding # control deps on the replicated inputs. if op.type == "TPUReplicatedInput": - if resource_inputs: - resource_inputs.clear() + if resource_reads or resource_writes: + resource_reads.clear() + resource_writes.clear() return True else: return False @@ -222,18 +223,21 @@ def tpu_replicated_input_resolver(op, resource_inputs): # with the actual replicated inputs. This allows ACD to correct add control # deps when there are multiple calls to `experimental_run_v2` in a # `tf.function`. - to_remove = [] - to_add = [] - for resource in resource_inputs: - if resource.op.type == "TPUReplicatedInput": - to_remove.append(resource) - to_add.extend(resource.op.inputs) - if not to_add and not to_remove: - return False - for t in to_remove: - resource_inputs.discard(t) - resource_inputs.update(to_add) - return True + def replace_with_unreplicated_resources(resource_inputs): + """Replaces handles in `resource_inputs` with their unreplicated inputs.""" + to_remove = [] + to_add = [] + for resource in resource_inputs: + if resource.op.type == "TPUReplicatedInput": + to_remove.append(resource) + to_add.extend(resource.op.inputs) + for t in to_remove: + resource_inputs.discard(t) + resource_inputs.update(to_add) + return to_add or to_remove + + return (replace_with_unreplicated_resources(resource_reads) or + replace_with_unreplicated_resources(resource_writes)) class TPUReplicateContext(control_flow_ops.XLAControlFlowContext): diff --git a/tensorflow/python/util/deprecation_test.py b/tensorflow/python/util/deprecation_test.py index 035c416d793..a6ca3c6fda8 100644 --- a/tensorflow/python/util/deprecation_test.py +++ b/tensorflow/python/util/deprecation_test.py @@ -761,7 +761,7 @@ class DeprecatedArgValuesTest(test.TestCase): deprecation.deprecated_arg_values(date, None, deprecated=True) with self.assertRaisesRegexp(ValueError, "instructions"): deprecation.deprecated_arg_values(date, "", deprecated=True) - with self.assertRaisesRegexp(ValueError, "argument", deprecated=True): + with self.assertRaisesRegexp(ValueError, "argument"): deprecation.deprecated_arg_values(date, instructions) @test.mock.patch.object(logging, "warning", autospec=True) diff --git a/tensorflow/stream_executor/allocator_stats.cc b/tensorflow/stream_executor/allocator_stats.cc index 440d6f46a3c..8a45efdef83 100644 --- a/tensorflow/stream_executor/allocator_stats.cc +++ b/tensorflow/stream_executor/allocator_stats.cc @@ -18,7 +18,7 @@ limitations under the License. namespace stream_executor { -string AllocatorStats::DebugString() const { +std::string AllocatorStats::DebugString() const { return absl::StrFormat( "Limit: %20lld\n" "InUse: %20lld\n" diff --git a/tensorflow/stream_executor/allocator_stats.h b/tensorflow/stream_executor/allocator_stats.h index 62edfff3c1b..9a99c1099c9 100644 --- a/tensorflow/stream_executor/allocator_stats.h +++ b/tensorflow/stream_executor/allocator_stats.h @@ -51,7 +51,7 @@ struct AllocatorStats { bytes_reserved(0), peak_bytes_reserved(0) {} - string DebugString() const; + std::string DebugString() const; }; } // namespace stream_executor diff --git a/tensorflow/stream_executor/blas.cc b/tensorflow/stream_executor/blas.cc index 9b8fdf1efe8..f499b3003d0 100644 --- a/tensorflow/stream_executor/blas.cc +++ b/tensorflow/stream_executor/blas.cc @@ -20,7 +20,7 @@ limitations under the License. namespace stream_executor { namespace blas { -string TransposeString(Transpose t) { +std::string TransposeString(Transpose t) { switch (t) { case Transpose::kNoTranspose: return "NoTranspose"; @@ -33,7 +33,7 @@ string TransposeString(Transpose t) { } } -string UpperLowerString(UpperLower ul) { +std::string UpperLowerString(UpperLower ul) { switch (ul) { case UpperLower::kUpper: return "Upper"; @@ -44,7 +44,7 @@ string UpperLowerString(UpperLower ul) { } } -string DiagonalString(Diagonal d) { +std::string DiagonalString(Diagonal d) { switch (d) { case Diagonal::kUnit: return "Unit"; @@ -55,7 +55,7 @@ string DiagonalString(Diagonal d) { } } -string SideString(Side s) { +std::string SideString(Side s) { switch (s) { case Side::kLeft: return "Left"; @@ -68,9 +68,11 @@ string SideString(Side s) { // -- AlgorithmConfig -string AlgorithmConfig::ToString() const { return absl::StrCat(algorithm_); } +std::string AlgorithmConfig::ToString() const { + return absl::StrCat(algorithm_); +} -string ComputationTypeString(ComputationType ty) { +std::string ComputationTypeString(ComputationType ty) { switch (ty) { case ComputationType::kF16: return "f16"; diff --git a/tensorflow/stream_executor/blas.h b/tensorflow/stream_executor/blas.h index d361343c381..5018d487ed1 100644 --- a/tensorflow/stream_executor/blas.h +++ b/tensorflow/stream_executor/blas.h @@ -67,27 +67,27 @@ namespace blas { enum class Transpose { kNoTranspose, kTranspose, kConjugateTranspose }; // Returns a name for t. -string TransposeString(Transpose t); +std::string TransposeString(Transpose t); // Specifies whether the upper or lower triangular part of a // symmetric/Hermitian matrix is used. enum class UpperLower { kUpper, kLower }; // Returns a name for ul. -string UpperLowerString(UpperLower ul); +std::string UpperLowerString(UpperLower ul); // Specifies whether a matrix is unit triangular. enum class Diagonal { kUnit, kNonUnit }; // Returns a name for d. -string DiagonalString(Diagonal d); +std::string DiagonalString(Diagonal d); // Specifies whether a Hermitian matrix appears on the left or right in // operation. enum class Side { kLeft, kRight }; // Returns a name for s. -string SideString(Side s); +std::string SideString(Side s); // Type with which intermediate computations of a blas routine are performed. // @@ -104,7 +104,7 @@ enum class ComputationType { }; // Converts a ComputationType to a string. -string ComputationTypeString(ComputationType ty); +std::string ComputationTypeString(ComputationType ty); std::ostream &operator<<(std::ostream &os, ComputationType ty); @@ -157,7 +157,7 @@ class AlgorithmConfig { bool operator!=(const AlgorithmConfig &other) const { return !(*this == other); } - string ToString() const; + std::string ToString() const; private: AlgorithmType algorithm_; @@ -1383,7 +1383,7 @@ class BlasSupport { const DeviceMemory<std::complex<double>> &a, int lda, DeviceMemory<std::complex<double>> *b, int ldb) = 0; - virtual port::Status GetVersion(string *version) = 0; + virtual port::Status GetVersion(std::string *version) = 0; protected: BlasSupport() {} @@ -2196,7 +2196,7 @@ class BlasSupport { uint64 n, std::complex<double> alpha, \ const DeviceMemory<std::complex<double>> &a, int lda, \ DeviceMemory<std::complex<double>> *b, int ldb) override; \ - port::Status GetVersion(string *version) override; + port::Status GetVersion(std::string *version) override; } // namespace blas } // namespace stream_executor diff --git a/tensorflow/stream_executor/device_description.cc b/tensorflow/stream_executor/device_description.cc index 5bdfb7ef1d0..9ee6e6837d7 100644 --- a/tensorflow/stream_executor/device_description.cc +++ b/tensorflow/stream_executor/device_description.cc @@ -55,10 +55,11 @@ DeviceDescription::DeviceDescription() core_count_(-1), ecc_enabled_(false) {} -std::unique_ptr<std::map<string, string>> DeviceDescription::ToMap() const { - std::unique_ptr<std::map<string, string>> owned_result{ - new std::map<string, string>}; - std::map<string, string> &result = *owned_result; +std::unique_ptr<std::map<std::string, std::string>> DeviceDescription::ToMap() + const { + std::unique_ptr<std::map<std::string, std::string>> owned_result{ + new std::map<std::string, std::string>}; + std::map<std::string, std::string> &result = *owned_result; result["Device Vendor"] = device_vendor(); result["Platform Version"] = platform_version(); result["Driver Version"] = driver_version(); diff --git a/tensorflow/stream_executor/device_description.h b/tensorflow/stream_executor/device_description.h index db14b516a19..fa7426eb04b 100644 --- a/tensorflow/stream_executor/device_description.h +++ b/tensorflow/stream_executor/device_description.h @@ -42,22 +42,22 @@ class DeviceDescription { // Returns the platform being run on; this value is primarily intended for // printing, and comes out something like "OpenCL 1.2" or "Compute Capability // 3.5". - const string &platform_version() const { return platform_version_; } + const std::string &platform_version() const { return platform_version_; } // Returns the driver version interfacing with the underlying platform. Vendor // dependent format. - const string &driver_version() const { return driver_version_; } + const std::string &driver_version() const { return driver_version_; } // Return the runtime version, if one is provided by the underlying platform. // Vendor dependent format / usefulness. - const string &runtime_version() const { return runtime_version_; } + const std::string &runtime_version() const { return runtime_version_; } // Returns the name that the device reports. Vendor dependent. - const string &name() const { return name_; } + const std::string &name() const { return name_; } // Returns the PCI bus identifier for this device, of the form // [domain]:[bus]:[device].[function] - const string &pci_bus_id() const { return pci_bus_id_; } + const std::string &pci_bus_id() const { return pci_bus_id_; } // Returns the NUMA node associated with this device, for use in // determining socket locality. If the NUMA node could not be determined, -1 @@ -126,7 +126,7 @@ class DeviceDescription { // Returns the device vendor string, e.g., "NVIDIA Corporation", "Advanced // Micro Devices, Inc.", or "GenuineIntel". - const string &device_vendor() const { return device_vendor_; } + const std::string &device_vendor() const { return device_vendor_; } // Returns the CUDA compute capability if we're running on the CUDA platform. // If a CUDA compute capability is not available, the major version will be @@ -150,7 +150,7 @@ class DeviceDescription { // TODO(leary): resident blocks per core will be useful. // Convenience typedef for the string-based DeviceDescription mapping. - typedef std::map<string, string> Map; + typedef std::map<std::string, std::string> Map; // Returns a mapping from readable names to readable values that describe the // device. This is useful for things like printing. @@ -169,12 +169,12 @@ class DeviceDescription { // above. // // N.B. If another field is added, update ToMap() above. - string device_vendor_; - string platform_version_; - string driver_version_; - string runtime_version_; - string pci_bus_id_; - string name_; + std::string device_vendor_; + std::string platform_version_; + std::string driver_version_; + std::string runtime_version_; + std::string pci_bus_id_; + std::string name_; ThreadDim thread_dim_limit_; BlockDim block_dim_limit_; @@ -221,22 +221,24 @@ class DeviceDescriptionBuilder { // For descriptions of the following fields, see comments on the corresponding // DeviceDescription::* accessors above. - void set_device_vendor(const string &value) { + void set_device_vendor(const std::string &value) { device_description_->device_vendor_ = value; } - void set_platform_version(const string &value) { + void set_platform_version(const std::string &value) { device_description_->platform_version_ = value; } - void set_driver_version(const string &value) { + void set_driver_version(const std::string &value) { device_description_->driver_version_ = value; } - void set_runtime_version(const string &value) { + void set_runtime_version(const std::string &value) { device_description_->runtime_version_ = value; } - void set_pci_bus_id(const string &value) { + void set_pci_bus_id(const std::string &value) { device_description_->pci_bus_id_ = value; } - void set_name(const string &value) { device_description_->name_ = value; } + void set_name(const std::string &value) { + device_description_->name_ = value; + } void set_thread_dim_limit(const ThreadDim &value) { device_description_->thread_dim_limit_ = value; diff --git a/tensorflow/stream_executor/device_options.h b/tensorflow/stream_executor/device_options.h index 2646950f42e..b195bc84e14 100644 --- a/tensorflow/stream_executor/device_options.h +++ b/tensorflow/stream_executor/device_options.h @@ -71,13 +71,13 @@ struct DeviceOptions { return !(*this == other); } - string ToString() { + std::string ToString() { return flags_ == 0 ? "none" : "kDoNotReclaimStackAllocation"; } // Platform-specific device options. Expressed as key-value pairs to avoid // DeviceOptions subclass proliferation. - std::map<string, string> non_portable_tags; + std::map<std::string, std::string> non_portable_tags; private: unsigned flags_; diff --git a/tensorflow/stream_executor/dnn.cc b/tensorflow/stream_executor/dnn.cc index e1289165252..567cca5f6a2 100644 --- a/tensorflow/stream_executor/dnn.cc +++ b/tensorflow/stream_executor/dnn.cc @@ -33,7 +33,7 @@ uint64 AlgorithmDesc::hash() const { return absl::Hash<decltype(p)>()(p); } -string AlgorithmDesc::ToString() const { +std::string AlgorithmDesc::ToString() const { if (tensor_ops_enabled()) { return absl::StrCat(algo_id(), "#TC"); } else { @@ -74,7 +74,7 @@ bool DnnSupport::GetConvolveBackwardFilterAlgorithms( return false; } -string QuantizedActivationModeString(QuantizedActivationMode mode) { +std::string QuantizedActivationModeString(QuantizedActivationMode mode) { switch (mode) { case dnn::QuantizedActivationMode::k8Bit: return "uint8"; @@ -89,7 +89,7 @@ string QuantizedActivationModeString(QuantizedActivationMode mode) { return "unknown quantized_activation_mode"; } -string ActivationModeString(ActivationMode mode) { +std::string ActivationModeString(ActivationMode mode) { switch (mode) { case ActivationMode::kSigmoid: return "sigmoid"; @@ -109,7 +109,7 @@ string ActivationModeString(ActivationMode mode) { return "unknown activation_mode"; } -string ElementwiseOperationString(ElementwiseOperation op) { +std::string ElementwiseOperationString(ElementwiseOperation op) { switch (op) { case ElementwiseOperation::kAdd: return "add"; @@ -121,7 +121,7 @@ string ElementwiseOperationString(ElementwiseOperation op) { return "unknown element wise op"; } -string DataLayoutString(DataLayout layout) { +std::string DataLayoutString(DataLayout layout) { switch (layout) { case DataLayout::kYXDepthBatch: return "YXDepthBatch"; @@ -139,7 +139,7 @@ string DataLayoutString(DataLayout layout) { return "unknown data layout"; } -string FilterLayoutString(FilterLayout layout) { +std::string FilterLayoutString(FilterLayout layout) { switch (layout) { case FilterLayout::kOutputInputYX: return "OutputInputYX"; @@ -157,7 +157,7 @@ string FilterLayoutString(FilterLayout layout) { return "unknown filter layout"; } -string PadAlignmentString(PadAlignment alignment) { +std::string PadAlignmentString(PadAlignment alignment) { switch (alignment) { case PadAlignment::kDefault: return "default"; @@ -173,7 +173,7 @@ std::ostream& operator<<(std::ostream& str, dnn::PadAlignment alignment) { return str << PadAlignmentString(alignment); } -string ShortPoolingModeString(PoolingMode mode) { +std::string ShortPoolingModeString(PoolingMode mode) { switch (mode) { case PoolingMode::kMaximum: return "Max"; @@ -247,12 +247,12 @@ std::vector<int64> ReorderDims(const std::vector<int64>& input, // -- AlgorithmConfig -string AlgorithmConfig::ToString() const { - string algo = "none"; +std::string AlgorithmConfig::ToString() const { + std::string algo = "none"; if (algorithm().has_value()) { algo = algorithm()->ToString(); } - string algo_no_scratch = "none"; + std::string algo_no_scratch = "none"; if (algorithm_no_scratch().has_value()) { algo_no_scratch = algorithm_no_scratch()->ToString(); } @@ -306,8 +306,8 @@ void BatchDescriptor::CloneFrom(const BatchDescriptor& other) { quantized_activation_mode_ = other.quantized_activation_mode_; } -string BatchDescriptor::ToString() const { - string spatial; +std::string BatchDescriptor::ToString() const { + std::string spatial; for (int i = 0; i < ndims(); i++) { absl::StrAppendFormat(&spatial, "%d ", spatial_size()[i]); } @@ -318,19 +318,19 @@ string BatchDescriptor::ToString() const { DataLayoutString(layout())); } -string BatchDescriptor::ToShortString() const { +std::string BatchDescriptor::ToShortString() const { // All the constituent strings are less than 15 characters, so the // small string optimization ensures that there will be at most one // heap memory allocation. - string depth = absl::StrCat("d", feature_map_count()); - string batch = absl::StrCat("b", count()); + std::string depth = absl::StrCat("d", feature_map_count()); + std::string batch = absl::StrCat("b", count()); - string spatial = "s"; + std::string spatial = "s"; for (int i = 0; i < ndims(); i++) { absl::StrAppendFormat(&spatial, "%d ", spatial_size()[i]); } - string suffix; + std::string suffix; if (value_min() != value_max()) { absl::StrAppend(&suffix, "[", value_min(), ";", value_max(), "]"); } @@ -419,8 +419,8 @@ void FilterDescriptor::CloneFrom(const FilterDescriptor& other) { tensor_ = other.tensor_; } -string FilterDescriptor::ToString() const { - string desc = absl::StrFormat( +std::string FilterDescriptor::ToString() const { + std::string desc = absl::StrFormat( "{output_feature_map_count: %d input_feature_map_count: %d " "layout: %s shape: ", output_feature_map_count(), input_feature_map_count(), @@ -433,14 +433,14 @@ string FilterDescriptor::ToString() const { return desc; } -string FilterDescriptor::ToShortString() const { +std::string FilterDescriptor::ToShortString() const { // All the constituent strings are less than 15 characters, so the // small string optimization ensures that there will be at most one // heap memory allocation. - string od = absl::StrCat("od", output_feature_map_count()); - string id = absl::StrCat("id", input_feature_map_count()); + std::string od = absl::StrCat("od", output_feature_map_count()); + std::string id = absl::StrCat("id", input_feature_map_count()); - string spatial = "s"; + std::string spatial = "s"; for (int i = 0; i < ndims(); i++) { absl::StrAppendFormat(&spatial, "%d ", input_filter_dims()[i]); } @@ -491,10 +491,10 @@ ConvolutionDescriptor::ConvolutionDescriptor() ConvolutionDescriptor::~ConvolutionDescriptor() {} -string ConvolutionDescriptor::ToString() const { - string padding; - string strides; - string dilations; +std::string ConvolutionDescriptor::ToString() const { + std::string padding; + std::string strides; + std::string dilations; for (int i = 0; i < ndims(); i++) { absl::StrAppendFormat(&padding, "%d ", this->padding()[i]); absl::StrAppendFormat(&strides, "%d ", this->strides()[i]); @@ -507,8 +507,8 @@ string ConvolutionDescriptor::ToString() const { padding, PadAlignmentString(pad_alignment()), strides, dilations); } -string ConvolutionDescriptor::ToShortString() const { - string desc; +std::string ConvolutionDescriptor::ToShortString() const { + std::string desc; for (int i = 0; i < ndims(); i++) { if (i > 0) absl::StrAppend(&desc, "_"); absl::StrAppendFormat(&desc, "p%d:%d", i, padding()[i]); @@ -543,11 +543,11 @@ void PoolingDescriptor::CloneFrom(const PoolingDescriptor& other) { propagate_nans_ = other.propagate_nans_; } -string PoolingDescriptor::ToString() const { +std::string PoolingDescriptor::ToString() const { const char* mode_string = mode_ == dnn::PoolingMode::kMaximum ? "kMaximum" : "kAverage"; - string window, strides, padding; + std::string window, strides, padding; for (int i = 0; i < ndims_; i++) { absl::StrAppendFormat(&window, "%d ", window_[i]); absl::StrAppendFormat(&strides, "%d ", strides_[i]); @@ -561,8 +561,8 @@ string PoolingDescriptor::ToString() const { mode_string, window, strides, padding, propagate_string); } -string PoolingDescriptor::ToShortString() const { - string window, strides, padding; +std::string PoolingDescriptor::ToShortString() const { + std::string window, strides, padding; for (int i = 0; i < ndims_; i++) { absl::StrAppendFormat(&window, "_w%d:%d", i, window_[i]); absl::StrAppendFormat(&strides, "_s%d:%d", i, strides_[i]); @@ -592,14 +592,14 @@ void NormalizeDescriptor::CloneFrom(const NormalizeDescriptor& other) { segment_size_ = other.segment_size_; } -string NormalizeDescriptor::ToString() const { +std::string NormalizeDescriptor::ToString() const { return absl::StrFormat( "{bias: %f range: %d alpha: %f beta: %f wrap_around: %d " "segment_size: %d}", bias_, range_, alpha_, beta_, wrap_around_, segment_size_); } -string NormalizeDescriptor::ToShortString() const { +std::string NormalizeDescriptor::ToShortString() const { return absl::StrCat("bias:", bias_, "_range:", range_, "_alpha:", alpha_, "_beta:", beta_, "_wrap:", wrap_around_, "_size:", segment_size_); diff --git a/tensorflow/stream_executor/dnn.h b/tensorflow/stream_executor/dnn.h index 8885e568ed1..771dc908a12 100644 --- a/tensorflow/stream_executor/dnn.h +++ b/tensorflow/stream_executor/dnn.h @@ -101,7 +101,7 @@ inline absl::Span<int64> AsInt64Slice(T* repeated_field) { } // Returns a string representation of the given data layout. -string DataLayoutString(DataLayout layout); +std::string DataLayoutString(DataLayout layout); // Specifies a quantization for activations in a given BatchDescriptor. enum class QuantizedActivationMode { @@ -209,7 +209,7 @@ class RnnStateTensorDescriptor { }; // Returns a string representation of the given quantization mode. -string QuantizedActivationModeString(QuantizedActivationMode mode); +std::string QuantizedActivationModeString(QuantizedActivationMode mode); // Describes the dimensions that a layer consumes/produces. // @@ -260,8 +260,8 @@ class BatchDescriptor { // Clones values from 'other' for initialization. void CloneFrom(const BatchDescriptor& other); - string ToString() const; - string ToShortString() const; + std::string ToString() const; + std::string ToShortString() const; // Pre-condition: // value_max_ == 0 @@ -374,7 +374,7 @@ class BatchDescriptor { }; // Returns a string representation of the given filter layout. -string FilterLayoutString(FilterLayout layout); +std::string FilterLayoutString(FilterLayout layout); // Describes a filter for the convolution. This is the "window" from // height-by-width patches of each of the feature maps in the input layer to the @@ -439,8 +439,8 @@ class FilterDescriptor { void CloneFrom(const FilterDescriptor& other); - string ToString() const; - string ToShortString() const; + std::string ToString() const; + std::string ToShortString() const; TensorDescriptorProto ToProto(DataType data_type) const; // Returns the number of weights required as parameters for a convolution @@ -486,7 +486,7 @@ enum class PadAlignment : int64 { }; // Returns a string representation of the given padding alignment. -string PadAlignmentString(PadAlignment alignment); +std::string PadAlignmentString(PadAlignment alignment); // Print alignment to str. Needed to use CHECK_EQ between two PadAlignments. std::ostream& operator<<(std::ostream& str, dnn::PadAlignment alignment); @@ -529,8 +529,8 @@ class ConvolutionDescriptor { explicit ConvolutionDescriptor(int ndims); ~ConvolutionDescriptor(); - string ToString() const; - string ToShortString() const; + std::string ToString() const; + std::string ToShortString() const; ConvolutionDescriptorProto ToProto() const { return proto_; } ConvolutionDescriptor& set_zero_padding_height(int64 value) { @@ -578,7 +578,7 @@ class ConvolutionDescriptor { : ConvolutionMode::CROSS_CORRELATION); return *this; } - ConvolutionDescriptor& set_name(const string& name) { + ConvolutionDescriptor& set_name(const std::string& name) { proto_.set_name(name); return *this; } @@ -621,7 +621,7 @@ class ConvolutionDescriptor { return AsInt64Slice(proto_.paddings()); } - string name() const { return proto_.name(); } + std::string name() const { return proto_.name(); } private: absl::Span<int64> strides() { return AsInt64Slice(proto_.mutable_strides()); } @@ -658,7 +658,7 @@ enum class SpaceConcatenateMode : int64 { }; // Returns a short name for the pooling mode, e.g. "Avg". -string ShortPoolingModeString(PoolingMode mode); +std::string ShortPoolingModeString(PoolingMode mode); // Describes a pooling operation to be enqueued onto a stream via a platform's // DnnSupport. @@ -722,7 +722,7 @@ class PoolingDescriptor { propagate_nans_ = value; return *this; } - PoolingDescriptor& set_name(const string& name) { + PoolingDescriptor& set_name(const std::string& name) { name_ = name; return *this; } @@ -730,8 +730,8 @@ class PoolingDescriptor { int ndims() const { return ndims_; } void CloneFrom(const PoolingDescriptor& other); - string ToString() const; - string ToShortString() const; + std::string ToString() const; + std::string ToShortString() const; PoolingMode mode() const { return mode_; } int64 window_height() const { return GetDim(window_, DimIndex::Y); } @@ -747,13 +747,13 @@ class PoolingDescriptor { absl::Span<const int64> padding() const { return padding_; } absl::Span<const int64> strides() const { return strides_; } bool propagate_nans() const { return propagate_nans_; } - string name() const { return name_; } + std::string name() const { return name_; } private: PoolingMode mode_; int ndims_; bool propagate_nans_; - string name_; // Name as in Tensorflow NodeDef, for debugging purposes. + std::string name_; // Name as in Tensorflow NodeDef, for debugging purposes. // Stored as: ..., y, x. std::vector<int64> window_; @@ -783,7 +783,7 @@ class AlgorithmDesc { AlgorithmProto ToProto() const { return proto_; } - string ToString() const; + std::string ToString() const; private: AlgorithmProto proto_; @@ -860,7 +860,7 @@ class AlgorithmConfig { bool operator!=(const AlgorithmConfig& other) const { return !(*this == other); } - string ToString() const; + std::string ToString() const; private: absl::optional<AlgorithmDesc> algorithm_; @@ -927,8 +927,8 @@ class NormalizeDescriptor { void CloneFrom(const NormalizeDescriptor& other); - string ToString() const; - string ToShortString() const; + std::string ToString() const; + std::string ToShortString() const; float bias() const { return bias_; } int32 range() const { return range_; } @@ -947,13 +947,13 @@ class NormalizeDescriptor { }; // Returns a string representation of the given activation mode. -string ActivationModeString(ActivationMode mode); +std::string ActivationModeString(ActivationMode mode); // Describes the operation that DoElementwiseOperation should perform on its // inputs. enum class ElementwiseOperation { kAdd, kMultiply }; -string ElementwiseOperationString(ElementwiseOperation op); +std::string ElementwiseOperationString(ElementwiseOperation op); // A simple class representing the version of the backing library, to // workaround the "too perfect forwarding" issue in gcc6+ compilers. diff --git a/tensorflow/stream_executor/kernel.cc b/tensorflow/stream_executor/kernel.cc index 2aee9617bff..ec78c120965 100644 --- a/tensorflow/stream_executor/kernel.cc +++ b/tensorflow/stream_executor/kernel.cc @@ -91,7 +91,7 @@ KernelCacheConfig KernelBase::GetPreferredCacheConfig() const { } void KernelBase::set_name(absl::string_view name) { - name_ = string(name); + name_ = std::string(name); // CUDA splitter prefixes stub functions with __device_stub_. demangled_name_ = diff --git a/tensorflow/stream_executor/kernel.h b/tensorflow/stream_executor/kernel.h index 1e4f375073e..d3d386bca4a 100644 --- a/tensorflow/stream_executor/kernel.h +++ b/tensorflow/stream_executor/kernel.h @@ -178,8 +178,8 @@ class KernelBase { KernelCacheConfig GetPreferredCacheConfig() const; void set_name(absl::string_view name); - const string &name() const { return name_; } - const string &demangled_name() const { return demangled_name_; } + const std::string &name() const { return name_; } + const std::string &demangled_name() const { return demangled_name_; } private: // The StreamExecutor that loads this kernel object. @@ -188,8 +188,8 @@ class KernelBase { // Implementation delegated to for platform-specific functionality. std::unique_ptr<internal::KernelInterface> implementation_; - string name_; - string demangled_name_; + std::string name_; + std::string demangled_name_; KernelMetadata metadata_; diff --git a/tensorflow/stream_executor/kernel_spec.cc b/tensorflow/stream_executor/kernel_spec.cc index d7e00205103..de10d8ec9f1 100644 --- a/tensorflow/stream_executor/kernel_spec.cc +++ b/tensorflow/stream_executor/kernel_spec.cc @@ -19,11 +19,11 @@ limitations under the License. namespace stream_executor { KernelLoaderSpec::KernelLoaderSpec(absl::string_view kernelname) - : kernelname_(string(kernelname)) {} + : kernelname_(std::string(kernelname)) {} OnDiskKernelLoaderSpec::OnDiskKernelLoaderSpec(absl::string_view filename, absl::string_view kernelname) - : KernelLoaderSpec(kernelname), filename_(string(filename)) {} + : KernelLoaderSpec(kernelname), filename_(std::string(filename)) {} CudaPtxOnDisk::CudaPtxOnDisk(absl::string_view filename, absl::string_view kernelname) @@ -77,13 +77,13 @@ CudaPtxInMemory::CudaPtxInMemory( } } -string CudaPtxInMemory::DecompressPtx(const char *ptx) { +std::string CudaPtxInMemory::DecompressPtx(const char *ptx) { // Get the length of the PTX string from the beginning of the buffer. uint64 ptx_length = *reinterpret_cast<const uint64 *>(ptx); // Get the PTX string from the buffer with offset and length. - string compressed_ptx(ptx + sizeof(uint64), - ptx + sizeof(uint64) + ptx_length); - string decompressed_ptx; + std::string compressed_ptx(ptx + sizeof(uint64), + ptx + sizeof(uint64) + ptx_length); + std::string decompressed_ptx; // Decompress the PTX string with bzip2. LOG(FATAL) << "bzip2 decompression is not supported yet."; return decompressed_ptx; diff --git a/tensorflow/stream_executor/kernel_spec.h b/tensorflow/stream_executor/kernel_spec.h index 7199f60e4ca..5138e3b3a2c 100644 --- a/tensorflow/stream_executor/kernel_spec.h +++ b/tensorflow/stream_executor/kernel_spec.h @@ -73,7 +73,7 @@ class KernelLoaderSpec { virtual ~KernelLoaderSpec() {} // Returns the kernel name to load out of the program. - const string &kernelname() const { return kernelname_; } + const std::string &kernelname() const { return kernelname_; } protected: explicit KernelLoaderSpec(absl::string_view kernelname); @@ -81,7 +81,7 @@ class KernelLoaderSpec { private: // The kernel name that should be loaded out of the program description given // above. - string kernelname_; + std::string kernelname_; SE_DISALLOW_COPY_AND_ASSIGN(KernelLoaderSpec); }; @@ -94,7 +94,7 @@ class OnDiskKernelLoaderSpec : public KernelLoaderSpec { ~OnDiskKernelLoaderSpec() override {} // Returns the path to the on-disk loadable kernel file. - const string &filename() const { return filename_; } + const std::string &filename() const { return filename_; } // Returns the canonical suffix for this on-disk kernel loader spec format; // e.g. PTX files on disk have a canonical suffix of ".ptx". @@ -104,7 +104,7 @@ class OnDiskKernelLoaderSpec : public KernelLoaderSpec { OnDiskKernelLoaderSpec(absl::string_view filename, absl::string_view kernelname); - string filename_; + std::string filename_; private: SE_DISALLOW_COPY_AND_ASSIGN(OnDiskKernelLoaderSpec); @@ -128,12 +128,12 @@ class CudaCubinOnDisk : public OnDiskKernelLoaderSpec { CudaCubinOnDisk(absl::string_view filename, absl::string_view kernelname); ~CudaCubinOnDisk() override {} - const string &filename() const { return filename_; } + const std::string &filename() const { return filename_; } const char *CanonicalSuffix() const override { return ".cubin"; } private: - string filename_; + std::string filename_; SE_DISALLOW_COPY_AND_ASSIGN(CudaCubinOnDisk); }; @@ -192,7 +192,7 @@ class CudaPtxInMemory : public KernelLoaderSpec { int compute_capability_minor) const; // Decompresses the PTX string using bzip2. - static string DecompressPtx(const char *ptx); + static std::string DecompressPtx(const char *ptx); private: // PTX translation unit text contents in memory. The key is of as a tuple @@ -205,7 +205,7 @@ class CudaPtxInMemory : public KernelLoaderSpec { // Stores all decompressed ptx strings, with original ptx string as keys. // It is marked as mutable for lazy decompression. - mutable std::map<const char *, string> decompressed_ptx_; + mutable std::map<const char *, std::string> decompressed_ptx_; mutable absl::Mutex mu_; // Defines the minimum compute capability possible. Used when PTX has no @@ -246,11 +246,11 @@ class OpenCLTextInMemory : public KernelLoaderSpec { ~OpenCLTextInMemory() override {} // Returns the OpenCL text contents. - const string &text() const { return text_; } + const std::string &text() const { return text_; } private: // OpenCL translation unit text contents in memory. - string text_; + std::string text_; SE_DISALLOW_COPY_AND_ASSIGN(OpenCLTextInMemory); }; diff --git a/tensorflow/stream_executor/launch_dim.h b/tensorflow/stream_executor/launch_dim.h index 4a3c882d9f7..95643ecf9bd 100644 --- a/tensorflow/stream_executor/launch_dim.h +++ b/tensorflow/stream_executor/launch_dim.h @@ -56,7 +56,7 @@ struct ThreadDim : public Dim3D { : Dim3D(x, y, z) {} // Returns a string representation of the thread dimensionality. - string ToString() const { + std::string ToString() const { return absl::StrCat("ThreadDim{", x, ", ", y, ", ", z, "}"); } }; @@ -68,7 +68,7 @@ struct BlockDim : public Dim3D { : Dim3D(x, y, z) {} // Returns a string representation of the block dimensionality. - string ToString() const { + std::string ToString() const { return absl::StrCat("BlockDim{", x, ", ", y, ", ", z, "}"); } }; diff --git a/tensorflow/stream_executor/lib/demangle.cc b/tensorflow/stream_executor/lib/demangle.cc index adb6b4f2d11..8fb5db0777e 100644 --- a/tensorflow/stream_executor/lib/demangle.cc +++ b/tensorflow/stream_executor/lib/demangle.cc @@ -33,8 +33,8 @@ namespace port { // The API reference of abi::__cxa_demangle() can be found in // libstdc++'s manual. // https://gcc.gnu.org/onlinedocs/libstdc++/libstdc++-html-USERS-4.3/a01696.html -string Demangle(const char *mangled) { - string demangled; +std::string Demangle(const char *mangled) { + std::string demangled; int status = 0; char *result = nullptr; #if HAS_CXA_DEMANGLE diff --git a/tensorflow/stream_executor/lib/demangle.h b/tensorflow/stream_executor/lib/demangle.h index af16fa7d8cb..4c1007a1947 100644 --- a/tensorflow/stream_executor/lib/demangle.h +++ b/tensorflow/stream_executor/lib/demangle.h @@ -21,7 +21,7 @@ limitations under the License. namespace stream_executor { namespace port { -string Demangle(const char* mangled); +std::string Demangle(const char* mangled); } // namespace port } // namespace stream_executor diff --git a/tensorflow/stream_executor/lib/env.h b/tensorflow/stream_executor/lib/env.h index a5eb8ef1d43..f8ccac75e98 100644 --- a/tensorflow/stream_executor/lib/env.h +++ b/tensorflow/stream_executor/lib/env.h @@ -27,12 +27,12 @@ namespace port { using tensorflow::Env; using tensorflow::Thread; -inline Status FileExists(const string& filename) { +inline Status FileExists(const std::string& filename) { return Env::Default()->FileExists(filename); } inline Status FileExists(const absl::string_view& filename) { - return Env::Default()->FileExists(string(filename)); + return Env::Default()->FileExists(std::string(filename)); } } // namespace port diff --git a/tensorflow/stream_executor/lib/human_readable.h b/tensorflow/stream_executor/lib/human_readable.h index 5e5525e6b5b..a8ede952fd1 100644 --- a/tensorflow/stream_executor/lib/human_readable.h +++ b/tensorflow/stream_executor/lib/human_readable.h @@ -28,7 +28,7 @@ namespace port { class HumanReadableNumBytes { public: - static string ToString(int64 num_bytes) { + static std::string ToString(int64 num_bytes) { if (num_bytes == std::numeric_limits<int64>::min()) { // Special case for number with not representable nagation. return "-8E"; diff --git a/tensorflow/stream_executor/lib/numbers.cc b/tensorflow/stream_executor/lib/numbers.cc index b670c42ec84..96bd9c3868e 100644 --- a/tensorflow/stream_executor/lib/numbers.cc +++ b/tensorflow/stream_executor/lib/numbers.cc @@ -32,7 +32,7 @@ bool safe_strto32(const char* str, int32* value) { // Convert strings to floating point values. // Leading and trailing spaces are allowed. // Values may be rounded on over- and underflow. -bool safe_strto32(const string& str, int32* value) { +bool safe_strto32(const std::string& str, int32* value) { return port::safe_strto32(str.c_str(), value); } diff --git a/tensorflow/stream_executor/lib/numbers.h b/tensorflow/stream_executor/lib/numbers.h index 2f48281d2d6..15fecfbfca9 100644 --- a/tensorflow/stream_executor/lib/numbers.h +++ b/tensorflow/stream_executor/lib/numbers.h @@ -24,7 +24,7 @@ namespace port { // Convert strings to floating point values. // Leading and trailing spaces are allowed. // Values may be rounded on over- and underflow. -bool safe_strto32(const string& str, int32* value); +bool safe_strto32(const std::string& str, int32* value); } // namespace port } // namespace stream_executor diff --git a/tensorflow/stream_executor/lib/path.cc b/tensorflow/stream_executor/lib/path.cc index 47eedbc6a16..3388843cfa2 100644 --- a/tensorflow/stream_executor/lib/path.cc +++ b/tensorflow/stream_executor/lib/path.cc @@ -27,14 +27,14 @@ static bool IsAbsolutePath(absl::string_view path) { // For an array of paths of length count, append them all together, // ensuring that the proper path separators are inserted between them. -string JoinPathImpl(std::initializer_list<absl::string_view> paths) { - string result; +std::string JoinPathImpl(std::initializer_list<absl::string_view> paths) { + std::string result; for (absl::string_view path : paths) { if (path.empty()) continue; if (result.empty()) { - result = string(path); + result = std::string(path); continue; } diff --git a/tensorflow/stream_executor/lib/path.h b/tensorflow/stream_executor/lib/path.h index 902331b4273..17ede80d25c 100644 --- a/tensorflow/stream_executor/lib/path.h +++ b/tensorflow/stream_executor/lib/path.h @@ -25,7 +25,7 @@ namespace port { namespace internal { // TODO(rspringer): Move to cc/implementation file. // Not part of the public API. -string JoinPathImpl(std::initializer_list<absl::string_view> paths); +std::string JoinPathImpl(std::initializer_list<absl::string_view> paths); } // namespace internal // Join multiple paths together. @@ -47,7 +47,7 @@ string JoinPathImpl(std::initializer_list<absl::string_view> paths); // string path = file::JoinPath("/var/log", dirname, filename); // string path = file::JoinPath(FLAGS_test_srcdir, filename); template <typename... T> -inline string JoinPath(const T&... args) { +inline std::string JoinPath(const T&... args) { return internal::JoinPathImpl({args...}); } diff --git a/tensorflow/stream_executor/lib/process_state.cc b/tensorflow/stream_executor/lib/process_state.cc index 5a351e7a8d5..3bf4c5aecf4 100644 --- a/tensorflow/stream_executor/lib/process_state.cc +++ b/tensorflow/stream_executor/lib/process_state.cc @@ -29,14 +29,14 @@ limitations under the License. namespace stream_executor { namespace port { -string Hostname() { +std::string Hostname() { char hostname[1024]; gethostname(hostname, sizeof hostname); hostname[sizeof hostname - 1] = 0; return std::string(hostname); } -bool GetCurrentDirectory(string* dir) { +bool GetCurrentDirectory(std::string* dir) { size_t len = 128; std::unique_ptr<char[]> a(new char[len]); for (;;) { diff --git a/tensorflow/stream_executor/lib/process_state.h b/tensorflow/stream_executor/lib/process_state.h index 248218c759e..68f06f969e8 100644 --- a/tensorflow/stream_executor/lib/process_state.h +++ b/tensorflow/stream_executor/lib/process_state.h @@ -21,8 +21,8 @@ limitations under the License. namespace stream_executor { namespace port { -string Hostname(); -bool GetCurrentDirectory(string* dir); +std::string Hostname(); +bool GetCurrentDirectory(std::string* dir); } // namespace port } // namespace stream_executor diff --git a/tensorflow/stream_executor/lib/statusor_test.cc b/tensorflow/stream_executor/lib/statusor_test.cc index 56584e18920..16480b30789 100644 --- a/tensorflow/stream_executor/lib/statusor_test.cc +++ b/tensorflow/stream_executor/lib/statusor_test.cc @@ -147,19 +147,19 @@ TEST(StatusOr, TestMoveOnlyVector) { } TEST(StatusOr, TestMoveWithValuesAndErrors) { - StatusOr<string> status_or(string(1000, '0')); - StatusOr<string> value1(string(1000, '1')); - StatusOr<string> value2(string(1000, '2')); - StatusOr<string> error1(Status(tensorflow::error::UNKNOWN, "error1")); - StatusOr<string> error2(Status(tensorflow::error::UNKNOWN, "error2")); + StatusOr<std::string> status_or(std::string(1000, '0')); + StatusOr<std::string> value1(std::string(1000, '1')); + StatusOr<std::string> value2(std::string(1000, '2')); + StatusOr<std::string> error1(Status(tensorflow::error::UNKNOWN, "error1")); + StatusOr<std::string> error2(Status(tensorflow::error::UNKNOWN, "error2")); ASSERT_TRUE(status_or.ok()); - EXPECT_EQ(string(1000, '0'), status_or.ValueOrDie()); + EXPECT_EQ(std::string(1000, '0'), status_or.ValueOrDie()); // Overwrite the value in status_or with another value. status_or = std::move(value1); ASSERT_TRUE(status_or.ok()); - EXPECT_EQ(string(1000, '1'), status_or.ValueOrDie()); + EXPECT_EQ(std::string(1000, '1'), status_or.ValueOrDie()); // Overwrite the value in status_or with an error. status_or = std::move(error1); @@ -174,23 +174,23 @@ TEST(StatusOr, TestMoveWithValuesAndErrors) { // Overwrite the error with a value. status_or = std::move(value2); ASSERT_TRUE(status_or.ok()); - EXPECT_EQ(string(1000, '2'), status_or.ValueOrDie()); + EXPECT_EQ(std::string(1000, '2'), status_or.ValueOrDie()); } TEST(StatusOr, TestCopyWithValuesAndErrors) { - StatusOr<string> status_or(string(1000, '0')); - StatusOr<string> value1(string(1000, '1')); - StatusOr<string> value2(string(1000, '2')); - StatusOr<string> error1(Status(tensorflow::error::UNKNOWN, "error1")); - StatusOr<string> error2(Status(tensorflow::error::UNKNOWN, "error2")); + StatusOr<std::string> status_or(std::string(1000, '0')); + StatusOr<std::string> value1(std::string(1000, '1')); + StatusOr<std::string> value2(std::string(1000, '2')); + StatusOr<std::string> error1(Status(tensorflow::error::UNKNOWN, "error1")); + StatusOr<std::string> error2(Status(tensorflow::error::UNKNOWN, "error2")); ASSERT_TRUE(status_or.ok()); - EXPECT_EQ(string(1000, '0'), status_or.ValueOrDie()); + EXPECT_EQ(std::string(1000, '0'), status_or.ValueOrDie()); // Overwrite the value in status_or with another value. status_or = value1; ASSERT_TRUE(status_or.ok()); - EXPECT_EQ(string(1000, '1'), status_or.ValueOrDie()); + EXPECT_EQ(std::string(1000, '1'), status_or.ValueOrDie()); // Overwrite the value in status_or with an error. status_or = error1; @@ -205,13 +205,13 @@ TEST(StatusOr, TestCopyWithValuesAndErrors) { // Overwrite the error with a value. status_or = value2; ASSERT_TRUE(status_or.ok()); - EXPECT_EQ(string(1000, '2'), status_or.ValueOrDie()); + EXPECT_EQ(std::string(1000, '2'), status_or.ValueOrDie()); // Verify original values unchanged. - EXPECT_EQ(string(1000, '1'), value1.ValueOrDie()); + EXPECT_EQ(std::string(1000, '1'), value1.ValueOrDie()); EXPECT_EQ("error1", error1.status().error_message()); EXPECT_EQ("error2", error2.status().error_message()); - EXPECT_EQ(string(1000, '2'), value2.ValueOrDie()); + EXPECT_EQ(std::string(1000, '2'), value2.ValueOrDie()); } TEST(StatusOr, TestDefaultCtor) { diff --git a/tensorflow/stream_executor/multi_platform_manager.cc b/tensorflow/stream_executor/multi_platform_manager.cc index ad6437709c4..cbbe0654953 100644 --- a/tensorflow/stream_executor/multi_platform_manager.cc +++ b/tensorflow/stream_executor/multi_platform_manager.cc @@ -39,10 +39,10 @@ class MultiPlatformManagerImpl { LOCKS_EXCLUDED(mu_); port::StatusOr<Platform*> InitializePlatformWithName( - absl::string_view target, const std::map<string, string>& options) - LOCKS_EXCLUDED(mu_); + absl::string_view target, + const std::map<std::string, std::string>& options) LOCKS_EXCLUDED(mu_); port::StatusOr<Platform*> InitializePlatformWithId( - const Platform::Id& id, const std::map<string, string>& options) + const Platform::Id& id, const std::map<std::string, std::string>& options) LOCKS_EXCLUDED(mu_); port::StatusOr<std::vector<Platform*>> PlatformsWithFilter( @@ -66,13 +66,13 @@ class MultiPlatformManagerImpl { absl::Mutex mu_; std::vector<std::unique_ptr<Listener>> listeners_ GUARDED_BY(mu_); absl::flat_hash_map<Platform::Id, Platform*> id_map_ GUARDED_BY(mu_); - absl::flat_hash_map<string, Platform*> name_map_ GUARDED_BY(mu_); + absl::flat_hash_map<std::string, Platform*> name_map_ GUARDED_BY(mu_); }; port::Status MultiPlatformManagerImpl::RegisterPlatform( std::unique_ptr<Platform> platform) { CHECK(platform != nullptr); - string key = absl::AsciiStrToLower(platform->Name()); + std::string key = absl::AsciiStrToLower(platform->Name()); absl::MutexLock lock(&mu_); if (name_map_.find(key) != name_map_.end()) { return port::Status(port::error::INTERNAL, @@ -118,7 +118,8 @@ port::StatusOr<Platform*> MultiPlatformManagerImpl::PlatformWithId( } port::StatusOr<Platform*> MultiPlatformManagerImpl::InitializePlatformWithName( - absl::string_view target, const std::map<string, string>& options) { + absl::string_view target, + const std::map<std::string, std::string>& options) { absl::MutexLock lock(&mu_); SE_ASSIGN_OR_RETURN(Platform * platform, LookupByNameLocked(target)); @@ -134,7 +135,7 @@ port::StatusOr<Platform*> MultiPlatformManagerImpl::InitializePlatformWithName( } port::StatusOr<Platform*> MultiPlatformManagerImpl::InitializePlatformWithId( - const Platform::Id& id, const std::map<string, string>& options) { + const Platform::Id& id, const std::map<std::string, std::string>& options) { absl::MutexLock lock(&mu_); SE_ASSIGN_OR_RETURN(Platform * platform, LookupByIdLocked(id)); @@ -224,13 +225,14 @@ MultiPlatformManagerImpl& Impl() { /*static*/ port::StatusOr<Platform*> MultiPlatformManager::InitializePlatformWithName( - absl::string_view target, const std::map<string, string>& options) { + absl::string_view target, + const std::map<std::string, std::string>& options) { return Impl().InitializePlatformWithName(target, options); } /*static*/ port::StatusOr<Platform*> MultiPlatformManager::InitializePlatformWithId( - const Platform::Id& id, const std::map<string, string>& options) { + const Platform::Id& id, const std::map<std::string, std::string>& options) { return Impl().InitializePlatformWithId(id, options); } diff --git a/tensorflow/stream_executor/multi_platform_manager.h b/tensorflow/stream_executor/multi_platform_manager.h index 6e6617a6da9..556015de790 100644 --- a/tensorflow/stream_executor/multi_platform_manager.h +++ b/tensorflow/stream_executor/multi_platform_manager.h @@ -111,10 +111,12 @@ class MultiPlatformManager { // Ownership of the platform is NOT transferred to the caller -- // the MultiPlatformManager owns the platforms in a singleton-like fashion. static port::StatusOr<Platform*> InitializePlatformWithName( - absl::string_view target, const std::map<string, string>& options); + absl::string_view target, + const std::map<std::string, std::string>& options); static port::StatusOr<Platform*> InitializePlatformWithId( - const Platform::Id& id, const std::map<string, string>& options); + const Platform::Id& id, + const std::map<std::string, std::string>& options); // Retrieves the platforms satisfying the given filter, i.e. returns true. // Returned Platforms are always initialized. diff --git a/tensorflow/stream_executor/platform.cc b/tensorflow/stream_executor/platform.cc index 9c995814386..fce9a1c0cd2 100644 --- a/tensorflow/stream_executor/platform.cc +++ b/tensorflow/stream_executor/platform.cc @@ -24,7 +24,7 @@ limitations under the License. namespace stream_executor { -string PlatformKindString(PlatformKind kind) { +std::string PlatformKindString(PlatformKind kind) { switch (kind) { case PlatformKind::kCuda: return "CUDA"; @@ -41,7 +41,7 @@ string PlatformKindString(PlatformKind kind) { } } -PlatformKind PlatformKindFromString(string kind) { +PlatformKind PlatformKindFromString(std::string kind) { for (int i = 0; i < static_cast<int>(PlatformKind::kSize); ++i) { if (kind == PlatformKindString(static_cast<PlatformKind>(i))) { return static_cast<PlatformKind>(i); @@ -91,7 +91,7 @@ Platform::~Platform() {} bool Platform::Initialized() const { return true; } port::Status Platform::Initialize( - const std::map<string, string> &platform_options) { + const std::map<std::string, std::string> &platform_options) { if (!platform_options.empty()) { return port::Status(port::error::UNIMPLEMENTED, "this platform does not support custom initialization"); diff --git a/tensorflow/stream_executor/platform.h b/tensorflow/stream_executor/platform.h index aefb94f6192..8f80ee837d0 100644 --- a/tensorflow/stream_executor/platform.h +++ b/tensorflow/stream_executor/platform.h @@ -60,11 +60,11 @@ bool PlatformIsRunnable(PlatformKind kind); bool PlatformIsRunnableOnDevice(PlatformKind kind); // Returns a printable description of a PlatformKind. -string PlatformKindString(PlatformKind kind); +std::string PlatformKindString(PlatformKind kind); // Returns the PlatformKind corresponding to the input string; returns kInvalid // in the case of no match. -PlatformKind PlatformKindFromString(string platform_string); +PlatformKind PlatformKindFromString(std::string platform_string); // Checks that kind takes on a valid value. void CheckPlatformKindIsValid(PlatformKind kind); @@ -114,7 +114,7 @@ class Platform { virtual Id id() const = 0; // Name of this platform. - virtual const string& Name() const = 0; + virtual const std::string& Name() const = 0; // Returns the number of devices accessible on this platform. // @@ -133,7 +133,7 @@ class Platform { // MultiPlatformManager, this method will be called automatically by // InitializePlatformWithId/InitializePlatformWithName. virtual port::Status Initialize( - const std::map<string, string>& platform_options); + const std::map<std::string, std::string>& platform_options); // Returns a populated DeviceDescription for the device at the given ordinal. // This should not require device initialization. Note that not all platforms diff --git a/tensorflow/stream_executor/plugin_registry.cc b/tensorflow/stream_executor/plugin_registry.cc index 1e6a2d4f2a9..4f85f92ca1c 100644 --- a/tensorflow/stream_executor/plugin_registry.cc +++ b/tensorflow/stream_executor/plugin_registry.cc @@ -27,7 +27,7 @@ namespace stream_executor { const PluginId kNullPlugin = nullptr; // Returns the string representation of the specified PluginKind. -string PluginKindString(PluginKind plugin_kind) { +std::string PluginKindString(PluginKind plugin_kind) { switch (plugin_kind) { case PluginKind::kBlas: return "BLAS"; @@ -70,7 +70,7 @@ void PluginRegistry::MapPlatformKindToId(PlatformKind platform_kind, template <typename FACTORY_TYPE> port::Status PluginRegistry::RegisterFactoryInternal( - PluginId plugin_id, const string& plugin_name, FACTORY_TYPE factory, + PluginId plugin_id, const std::string& plugin_name, FACTORY_TYPE factory, std::map<PluginId, FACTORY_TYPE>* factories) { absl::MutexLock lock{&GetPluginRegistryMutex()}; @@ -110,7 +110,7 @@ bool PluginRegistry::SetDefaultFactory(Platform::Id platform_id, if (!HasFactory(platform_id, plugin_kind, plugin_id)) { port::StatusOr<Platform*> status = MultiPlatformManager::PlatformWithId(platform_id); - string platform_name = "<unregistered platform>"; + std::string platform_name = "<unregistered platform>"; if (status.ok()) { platform_name = status.ValueOrDie()->Name(); } @@ -194,7 +194,7 @@ bool PluginRegistry::HasFactory(Platform::Id platform_id, \ template <> \ port::Status PluginRegistry::RegisterFactory<PluginRegistry::FACTORY_TYPE>( \ - Platform::Id platform_id, PluginId plugin_id, const string& name, \ + Platform::Id platform_id, PluginId plugin_id, const std::string& name, \ PluginRegistry::FACTORY_TYPE factory) { \ return RegisterFactoryInternal(plugin_id, name, factory, \ &factories_[platform_id].FACTORY_VAR); \ @@ -202,7 +202,8 @@ bool PluginRegistry::HasFactory(Platform::Id platform_id, \ template <> \ port::Status PluginRegistry::RegisterFactoryForAllPlatforms< \ - PluginRegistry::FACTORY_TYPE>(PluginId plugin_id, const string& name, \ + PluginRegistry::FACTORY_TYPE>(PluginId plugin_id, \ + const std::string& name, \ PluginRegistry::FACTORY_TYPE factory) { \ return RegisterFactoryInternal(plugin_id, name, factory, \ &generic_factories_.FACTORY_VAR); \ diff --git a/tensorflow/stream_executor/plugin_registry.h b/tensorflow/stream_executor/plugin_registry.h index 7e4407d4b27..2ea39c9cafc 100644 --- a/tensorflow/stream_executor/plugin_registry.h +++ b/tensorflow/stream_executor/plugin_registry.h @@ -62,13 +62,13 @@ class PluginRegistry { // with that platform (but execution should be otherwise unaffected). template <typename FactoryT> port::Status RegisterFactory(Platform::Id platform_id, PluginId plugin_id, - const string& name, FactoryT factory); + const std::string& name, FactoryT factory); // Registers the specified factory as usable by _all_ platform types. // Reports errors just as RegisterFactory. template <typename FactoryT> port::Status RegisterFactoryForAllPlatforms(PluginId plugin_id, - const string& name, + const std::string& name, FactoryT factory); // TODO(b/22689637): Setter for temporary mapping until all users are using @@ -122,7 +122,7 @@ class PluginRegistry { // Actually performs the work of registration. template <typename FactoryT> port::Status RegisterFactoryInternal(PluginId plugin_id, - const string& plugin_name, + const std::string& plugin_name, FactoryT factory, std::map<PluginId, FactoryT>* factories); @@ -155,7 +155,7 @@ class PluginRegistry { std::map<Platform::Id, DefaultFactories> default_factories_; // Lookup table for plugin names. - std::map<PluginId, string> plugin_names_; + std::map<PluginId, std::string> plugin_names_; SE_DISALLOW_COPY_AND_ASSIGN(PluginRegistry); }; @@ -164,7 +164,7 @@ class PluginRegistry { #define DECLARE_PLUGIN_SPECIALIZATIONS(FACTORY_TYPE) \ template <> \ port::Status PluginRegistry::RegisterFactory<PluginRegistry::FACTORY_TYPE>( \ - Platform::Id platform_id, PluginId plugin_id, const string& name, \ + Platform::Id platform_id, PluginId plugin_id, const std::string& name, \ PluginRegistry::FACTORY_TYPE factory); \ template <> \ port::StatusOr<PluginRegistry::FACTORY_TYPE> PluginRegistry::GetFactory( \ diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc index 4e430b11af5..c4564a613e1 100644 --- a/tensorflow/stream_executor/stream.cc +++ b/tensorflow/stream_executor/stream.cc @@ -35,55 +35,57 @@ namespace { // will be VLOG'ed. We need overloads, instead of // e.g. BatchDescriptorToVlogString(), as the code that calls these // functions does not know what the type of the parameter is. -string ToVlogString(const dnn::BatchDescriptor &descriptor) { +std::string ToVlogString(const dnn::BatchDescriptor &descriptor) { return descriptor.ToShortString(); } -string ToVlogString(const dnn::FilterDescriptor &descriptor) { +std::string ToVlogString(const dnn::FilterDescriptor &descriptor) { return descriptor.ToShortString(); } -string ToVlogString(const dnn::ConvolutionDescriptor &descriptor) { +std::string ToVlogString(const dnn::ConvolutionDescriptor &descriptor) { return descriptor.ToShortString(); } -string ToVlogString(const dnn::PoolingDescriptor &descriptor) { +std::string ToVlogString(const dnn::PoolingDescriptor &descriptor) { return descriptor.ToShortString(); } -string ToVlogString(const dnn::NormalizeDescriptor &descriptor) { +std::string ToVlogString(const dnn::NormalizeDescriptor &descriptor) { return descriptor.ToShortString(); } -string ToVlogString(dnn::ActivationMode mode) { +std::string ToVlogString(dnn::ActivationMode mode) { return dnn::ActivationModeString(mode); } -string ToVlogString(const dnn::AlgorithmConfig &algo_config) { +std::string ToVlogString(const dnn::AlgorithmConfig &algo_config) { return algo_config.ToString(); } -string ToVlogString(dnn::ElementwiseOperation op) { +std::string ToVlogString(dnn::ElementwiseOperation op) { return dnn::ElementwiseOperationString(op); } -string ToVlogString(dnn::QuantizedActivationMode mode) { +std::string ToVlogString(dnn::QuantizedActivationMode mode) { return dnn::QuantizedActivationModeString(mode); } -string ToVlogString(blas::Transpose t) { return blas::TransposeString(t); } +std::string ToVlogString(blas::Transpose t) { return blas::TransposeString(t); } -string ToVlogString(blas::UpperLower ul) { return blas::UpperLowerString(ul); } +std::string ToVlogString(blas::UpperLower ul) { + return blas::UpperLowerString(ul); +} -string ToVlogString(blas::Diagonal d) { return blas::DiagonalString(d); } +std::string ToVlogString(blas::Diagonal d) { return blas::DiagonalString(d); } -string ToVlogString(blas::Side s) { return blas::SideString(s); } +std::string ToVlogString(blas::Side s) { return blas::SideString(s); } -string ToVlogString(blas::ComputationType ty) { +std::string ToVlogString(blas::ComputationType ty) { return blas::ComputationTypeString(ty); } -string ToVlogString(const void *ptr) { +std::string ToVlogString(const void *ptr) { if (ptr == nullptr) { return "null"; } @@ -95,7 +97,7 @@ string ToVlogString(const void *ptr) { } template <class T> -string ToVlogString(const std::complex<T> &c) { +std::string ToVlogString(const std::complex<T> &c) { // StrCat does not convert std::complex to text. std::ostringstream out; out << c; @@ -103,36 +105,36 @@ string ToVlogString(const std::complex<T> &c) { } template <class T> -string ToVlogString(const std::function<T> &f) { +std::string ToVlogString(const std::function<T> &f) { return f == nullptr ? "null" : "<non-null function>"; } -string ToVlogString(const DeviceMemoryBase &memory) { +std::string ToVlogString(const DeviceMemoryBase &memory) { return ToVlogString(memory.opaque()); } -string ToVlogString(const DeviceMemoryBase *memory) { +std::string ToVlogString(const DeviceMemoryBase *memory) { return memory == nullptr ? "null" : ToVlogString(*memory); } -string ToVlogString(const Eigen::half &h) { +std::string ToVlogString(const Eigen::half &h) { return absl::StrCat(static_cast<float>(h)); } -string ToVlogString(int i) { return absl::StrCat(i); } +std::string ToVlogString(int i) { return absl::StrCat(i); } -string ToVlogString(uint32 i) { return absl::StrCat(i); } +std::string ToVlogString(uint32 i) { return absl::StrCat(i); } -string ToVlogString(uint64 i) { return absl::StrCat(i); } +std::string ToVlogString(uint64 i) { return absl::StrCat(i); } -string ToVlogString(int64 i) { return absl::StrCat(i); } +std::string ToVlogString(int64 i) { return absl::StrCat(i); } -string ToVlogString(float f) { return absl::StrCat(f); } +std::string ToVlogString(float f) { return absl::StrCat(f); } -string ToVlogString(double d) { return absl::StrCat(d); } +std::string ToVlogString(double d) { return absl::StrCat(d); } template <typename T> -string ToVlogString(const HostOrDeviceScalar<T> &memory_or_constant) { +std::string ToVlogString(const HostOrDeviceScalar<T> &memory_or_constant) { if (memory_or_constant.is_pointer()) { return ToVlogString(memory_or_constant.pointer()); } @@ -140,8 +142,8 @@ string ToVlogString(const HostOrDeviceScalar<T> &memory_or_constant) { } template <class T> -string ToVlogString(port::ArraySlice<T> elements) { - string str = absl::StrCat( +std::string ToVlogString(port::ArraySlice<T> elements) { + std::string str = absl::StrCat( ToVlogString(reinterpret_cast<const void *>(elements.data())), "[", elements.size(), "]{"); const char *separator = ""; @@ -166,11 +168,11 @@ string ToVlogString(port::ArraySlice<T> elements) { } template <class T> -string ToVlogString(port::MutableArraySlice<T> elements) { +std::string ToVlogString(port::MutableArraySlice<T> elements) { return ToVlogString(port::ArraySlice<T>(elements)); } -string ToVlogString(dnn::DepthToSpaceLayout depth_to_space_layout) { +std::string ToVlogString(dnn::DepthToSpaceLayout depth_to_space_layout) { switch (depth_to_space_layout) { case dnn::DepthToSpaceLayout::DepthHeightWidth: return "DepthToSpaceLayout::DepthHeightWidth"; @@ -178,7 +180,7 @@ string ToVlogString(dnn::DepthToSpaceLayout depth_to_space_layout) { return "unknown DepthToSpaceLayout"; } -string ToVlogString(dnn::DataType data_type) { +std::string ToVlogString(dnn::DataType data_type) { switch (data_type) { case dnn::DataType::kFloat: return "dnn::DataType::kFloat"; @@ -205,14 +207,14 @@ string ToVlogString(dnn::DataType data_type) { // See VLOG_CALL for a short-hand for this. This way of doing it saves // a tremendous amount of boilerplate code given how many functions // there are on Stream and how many parameters they each have. -string CallStr(const char *function_name, Stream *stream, - std::vector<std::pair<const char *, string>> params) { +std::string CallStr(const char *function_name, Stream *stream, + std::vector<std::pair<const char *, std::string>> params) { // Do not call this function unless VLOG is on since just // constructing all the strings in params is expensive. CHECK(VLOG_IS_ON(1)); - string str = absl::StrCat(stream->DebugStreamPointers(), - " Called Stream::", function_name, "("); + std::string str = absl::StrCat(stream->DebugStreamPointers(), + " Called Stream::", function_name, "("); const char *separator = ""; for (const auto ¶m : params) { absl::StrAppend(&str, separator, param.first, "=", param.second); @@ -5470,7 +5472,7 @@ void Stream::RunAfterBlockHostUntilDoneCallbacks() { } } -string Stream::DebugStreamPointers() const { +std::string Stream::DebugStreamPointers() const { // Relies on the ToVlogString(const void*) overload above. return absl::StrCat("[stream=", ToVlogString(this), ",impl=", ToVlogString(implementation_.get()), "]"); diff --git a/tensorflow/stream_executor/stream.h b/tensorflow/stream_executor/stream.h index d5ff7ea206c..7e4f2e627ee 100644 --- a/tensorflow/stream_executor/stream.h +++ b/tensorflow/stream_executor/stream.h @@ -2014,7 +2014,7 @@ class Stream { internal::TemporaryMemoryManager *temporary_memory_manager(); // Returns a debugging string "[stream=0x...,impl=0x...]". - string DebugStreamPointers() const; + std::string DebugStreamPointers() const; private: friend class host::HostBlas; // for parent_. diff --git a/tensorflow/stream_executor/stream_executor_internal.h b/tensorflow/stream_executor/stream_executor_internal.h index 793fdeb6d56..408b4fc8207 100644 --- a/tensorflow/stream_executor/stream_executor_internal.h +++ b/tensorflow/stream_executor/stream_executor_internal.h @@ -286,8 +286,9 @@ class StreamExecutorInterface { // If ModuleHandle is set then we search for `symbol_name` only within the // module corresponding to `module_handle`. Otherwise all loaded modules are // searched. - virtual bool GetSymbol(const string &symbol_name, ModuleHandle module_handle, - void **mem, size_t *bytes) { + virtual bool GetSymbol(const std::string &symbol_name, + ModuleHandle module_handle, void **mem, + size_t *bytes) { return false; } diff --git a/tensorflow/stream_executor/stream_executor_pimpl.cc b/tensorflow/stream_executor/stream_executor_pimpl.cc index 3b7f2b9760e..dc777c7a2d4 100644 --- a/tensorflow/stream_executor/stream_executor_pimpl.cc +++ b/tensorflow/stream_executor/stream_executor_pimpl.cc @@ -47,7 +47,7 @@ bool FLAGS_check_device_leaks = false; namespace stream_executor { namespace { -string StackTraceIfVLOG10() { +std::string StackTraceIfVLOG10() { if (VLOG_IS_ON(10)) { return absl::StrCat(" ", port::CurrentStackTrace(), "\n"); } else { @@ -149,7 +149,7 @@ StreamExecutor::StreamExecutor( mem_alloc_bytes_(0), memory_limit_bytes_(GetMemoryLimitBytes()), allocator_(this) { - string name = absl::AsciiStrToLower(platform_->Name()); + std::string name = absl::AsciiStrToLower(platform_->Name()); if (name == "cuda") { platform_kind_ = PlatformKind::kCuda; } else if (name == "rocm") { @@ -239,7 +239,7 @@ port::Status StreamExecutor::SetDeviceSharedMemoryConfig( if (config != SharedMemoryConfig::kDefault && config != SharedMemoryConfig::kFourByte && config != SharedMemoryConfig::kEightByte) { - string error_msg = absl::StrFormat( + std::string error_msg = absl::StrFormat( "Invalid shared memory config specified: %d", static_cast<int>(config)); LOG(ERROR) << error_msg; return port::Status(port::error::INVALID_ARGUMENT, error_msg); @@ -492,7 +492,7 @@ DeviceMemoryBase StreamExecutor::Allocate(uint64 size, int64 memory_space) { } port::StatusOr<DeviceMemoryBase> StreamExecutor::GetUntypedSymbol( - const string &symbol_name, ModuleHandle module_handle) { + const std::string &symbol_name, ModuleHandle module_handle) { // If failed to get the symbol, opaque/bytes are unchanged. Initialize them to // be nullptr/0 for consistency with DeviceMemory semantics. void *opaque = nullptr; @@ -515,7 +515,7 @@ port::StatusOr<DeviceMemoryBase> StreamExecutor::GetUntypedSymbol( } } -bool StreamExecutor::GetSymbol(const string &symbol_name, +bool StreamExecutor::GetSymbol(const std::string &symbol_name, ModuleHandle module_handle, void **mem, size_t *bytes) { return implementation_->GetSymbol(symbol_name, module_handle, mem, bytes); diff --git a/tensorflow/stream_executor/stream_executor_pimpl.h b/tensorflow/stream_executor/stream_executor_pimpl.h index 9fd78c54ebb..391ae52ebfd 100644 --- a/tensorflow/stream_executor/stream_executor_pimpl.h +++ b/tensorflow/stream_executor/stream_executor_pimpl.h @@ -50,7 +50,7 @@ struct AllocRecord { // Holds a representation of the stack at the time the associated buffer was // allocated. Produced in a form described in // //util/symbolize/symbolized_stacktrace.h. - string stack_trace; + std::string stack_trace; }; // Forward declaration of private friend class. @@ -175,12 +175,12 @@ class StreamExecutor { // If `module_handle` is set then searches only within the module // corresponding to `module_handle`. template <typename T> - port::StatusOr<DeviceMemory<T>> GetSymbol(const string &symbol_name, + port::StatusOr<DeviceMemory<T>> GetSymbol(const std::string &symbol_name, ModuleHandle module_handle = {}); // An untyped version of GetSymbol. port::StatusOr<DeviceMemoryBase> GetUntypedSymbol( - const string &symbol_name, ModuleHandle module_handle = {}); + const std::string &symbol_name, ModuleHandle module_handle = {}); // Deallocate the DeviceMemory previously allocated via this interface. // Deallocation of a nullptr-representative value is permitted. @@ -554,7 +554,7 @@ class StreamExecutor { // Finds and retrieves device memory for the symbol on the underlying // platform. - bool GetSymbol(const string &symbol_name, ModuleHandle module_handle, + bool GetSymbol(const std::string &symbol_name, ModuleHandle module_handle, void **mem, size_t *bytes); // Entrains a memcpy operation onto stream, with a host destination location @@ -805,7 +805,7 @@ inline DeviceMemory<T> StreamExecutor::AllocateArray(uint64 element_count, template <typename T> inline port::StatusOr<DeviceMemory<T>> StreamExecutor::GetSymbol( - const string &symbol_name, ModuleHandle module_handle) { + const std::string &symbol_name, ModuleHandle module_handle) { port::StatusOr<DeviceMemoryBase> untyped_symbol = GetUntypedSymbol(symbol_name, module_handle); if (!untyped_symbol.ok()) { diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index 5671fdba5ac..22709278d0c 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -969,7 +969,7 @@ def tf_gen_op_wrapper_py( name = name + "_pygenrule", outs = [out], srcs = api_def_srcs + [hidden_file], - exec_tools = [tool_name] + tf_binary_additional_srcs(), + tools = [tool_name] + tf_binary_additional_srcs(), cmd = ("$(location " + tool_name + ") " + api_def_args_str + " @$(location " + hidden_file + ") > $@"), ) @@ -978,7 +978,7 @@ def tf_gen_op_wrapper_py( name = name + "_pygenrule", outs = [out], srcs = api_def_srcs, - exec_tools = [tool_name] + tf_binary_additional_srcs(), + tools = [tool_name] + tf_binary_additional_srcs(), cmd = ("$(location " + tool_name + ") " + api_def_args_str + " " + op_list_arg + " " + ("1" if op_list_is_whitelist else "0") + " > $@"), @@ -2691,6 +2691,9 @@ def if_cuda_or_rocm(if_true, if_false = []): "//conditions:default": if_false, }) +def tf_monitoring_deps(): + return [] + def tf_jit_compilation_passes_extra_deps(): return [] diff --git a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-cross-device-ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-cross-device-ops.pbtxt index a2ea2343241..0a8e0b4421a 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-cross-device-ops.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-cross-device-ops.pbtxt @@ -8,11 +8,11 @@ tf_class { } member_method { name: "batch_reduce" - argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], " } member_method { name: "batch_reduce_implementation" - argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None" } member_method { name: "broadcast" @@ -24,10 +24,10 @@ tf_class { } member_method { name: "reduce" - argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], " } member_method { name: "reduce_implementation" - argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None" } } diff --git a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-hierarchical-copy-all-reduce.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-hierarchical-copy-all-reduce.pbtxt index a38c4b21d56..b5ccb39d075 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-hierarchical-copy-all-reduce.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-hierarchical-copy-all-reduce.pbtxt @@ -10,11 +10,11 @@ tf_class { } member_method { name: "batch_reduce" - argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], " } member_method { name: "batch_reduce_implementation" - argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None" } member_method { name: "broadcast" @@ -26,10 +26,10 @@ tf_class { } member_method { name: "reduce" - argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], " } member_method { name: "reduce_implementation" - argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None" } } diff --git a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-nccl-all-reduce.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-nccl-all-reduce.pbtxt index bdc09bcd84b..1a039b10501 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-nccl-all-reduce.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-nccl-all-reduce.pbtxt @@ -10,11 +10,11 @@ tf_class { } member_method { name: "batch_reduce" - argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], " } member_method { name: "batch_reduce_implementation" - argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None" } member_method { name: "broadcast" @@ -26,10 +26,10 @@ tf_class { } member_method { name: "reduce" - argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], " } member_method { name: "reduce_implementation" - argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None" } } diff --git a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-reduction-to-one-device.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-reduction-to-one-device.pbtxt index f5ade9f86ba..7876166dc40 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-reduction-to-one-device.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-reduction-to-one-device.pbtxt @@ -9,11 +9,11 @@ tf_class { } member_method { name: "batch_reduce" - argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], " } member_method { name: "batch_reduce_implementation" - argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None" } member_method { name: "broadcast" @@ -25,10 +25,10 @@ tf_class { } member_method { name: "reduce" - argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], " } member_method { name: "reduce_implementation" - argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None" } } diff --git a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-replica-context.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-replica-context.pbtxt index c3b79911757..0101212e4cc 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-replica-context.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-replica-context.pbtxt @@ -24,7 +24,7 @@ tf_class { } member_method { name: "all_reduce" - argspec: "args=[\'self\', \'reduce_op\', \'value\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'reduce_op\', \'value\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], " } member_method { name: "merge_call" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-strategy-extended.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-strategy-extended.pbtxt index 8d5f6daff47..b2d9d4ee2cb 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-strategy-extended.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-strategy-extended.pbtxt @@ -37,7 +37,7 @@ tf_class { } member_method { name: "batch_reduce_to" - argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], " } member_method { name: "broadcast_to" @@ -69,7 +69,7 @@ tf_class { } member_method { name: "reduce_to" - argspec: "args=[\'self\', \'reduce_op\', \'value\', \'destinations\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'reduce_op\', \'value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], " } member_method { name: "update" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.distribute.experimental.-collective-hints.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distribute.experimental.-collective-hints.pbtxt new file mode 100644 index 00000000000..c010134466c --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.distribute.experimental.-collective-hints.pbtxt @@ -0,0 +1,9 @@ +path: "tensorflow.distribute.experimental.CollectiveHints" +tf_class { + is_instance: "<class \'tensorflow.python.distribute.collective_util.Hints\'>" + is_instance: "<type \'object\'>" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'bytes_per_pack\'], varargs=None, keywords=None, defaults=[\'0\'], " + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.distribute.experimental.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distribute.experimental.pbtxt index 879cceec3ac..9247db37925 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.distribute.experimental.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.distribute.experimental.pbtxt @@ -8,6 +8,10 @@ tf_module { name: "CollectiveCommunication" mtype: "<class \'enum.EnumMeta\'>" } + member { + name: "CollectiveHints" + mtype: "<type \'type\'>" + } member { name: "MultiWorkerMirroredStrategy" mtype: "<type \'type\'>" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-cross-device-ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-cross-device-ops.pbtxt index a2ea2343241..0a8e0b4421a 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-cross-device-ops.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-cross-device-ops.pbtxt @@ -8,11 +8,11 @@ tf_class { } member_method { name: "batch_reduce" - argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], " } member_method { name: "batch_reduce_implementation" - argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None" } member_method { name: "broadcast" @@ -24,10 +24,10 @@ tf_class { } member_method { name: "reduce" - argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], " } member_method { name: "reduce_implementation" - argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None" } } diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-hierarchical-copy-all-reduce.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-hierarchical-copy-all-reduce.pbtxt index a38c4b21d56..b5ccb39d075 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-hierarchical-copy-all-reduce.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-hierarchical-copy-all-reduce.pbtxt @@ -10,11 +10,11 @@ tf_class { } member_method { name: "batch_reduce" - argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], " } member_method { name: "batch_reduce_implementation" - argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None" } member_method { name: "broadcast" @@ -26,10 +26,10 @@ tf_class { } member_method { name: "reduce" - argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], " } member_method { name: "reduce_implementation" - argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None" } } diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-nccl-all-reduce.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-nccl-all-reduce.pbtxt index bdc09bcd84b..1a039b10501 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-nccl-all-reduce.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-nccl-all-reduce.pbtxt @@ -10,11 +10,11 @@ tf_class { } member_method { name: "batch_reduce" - argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], " } member_method { name: "batch_reduce_implementation" - argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None" } member_method { name: "broadcast" @@ -26,10 +26,10 @@ tf_class { } member_method { name: "reduce" - argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], " } member_method { name: "reduce_implementation" - argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None" } } diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-reduction-to-one-device.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-reduction-to-one-device.pbtxt index f5ade9f86ba..7876166dc40 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-reduction-to-one-device.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-reduction-to-one-device.pbtxt @@ -9,11 +9,11 @@ tf_class { } member_method { name: "batch_reduce" - argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], " } member_method { name: "batch_reduce_implementation" - argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None" } member_method { name: "broadcast" @@ -25,10 +25,10 @@ tf_class { } member_method { name: "reduce" - argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], " } member_method { name: "reduce_implementation" - argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None" } } diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-replica-context.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-replica-context.pbtxt index c3b79911757..0101212e4cc 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-replica-context.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-replica-context.pbtxt @@ -24,7 +24,7 @@ tf_class { } member_method { name: "all_reduce" - argspec: "args=[\'self\', \'reduce_op\', \'value\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'reduce_op\', \'value\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], " } member_method { name: "merge_call" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-strategy-extended.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-strategy-extended.pbtxt index c5e87b33070..f3fa80427a4 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-strategy-extended.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-strategy-extended.pbtxt @@ -20,7 +20,7 @@ tf_class { } member_method { name: "batch_reduce_to" - argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], " } member_method { name: "colocate_vars_with" @@ -32,7 +32,7 @@ tf_class { } member_method { name: "reduce_to" - argspec: "args=[\'self\', \'reduce_op\', \'value\', \'destinations\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'reduce_op\', \'value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], " } member_method { name: "update" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-collective-hints.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-collective-hints.pbtxt new file mode 100644 index 00000000000..c010134466c --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-collective-hints.pbtxt @@ -0,0 +1,9 @@ +path: "tensorflow.distribute.experimental.CollectiveHints" +tf_class { + is_instance: "<class \'tensorflow.python.distribute.collective_util.Hints\'>" + is_instance: "<type \'object\'>" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'bytes_per_pack\'], varargs=None, keywords=None, defaults=[\'0\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.pbtxt index 879cceec3ac..9247db37925 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.pbtxt @@ -8,6 +8,10 @@ tf_module { name: "CollectiveCommunication" mtype: "<class \'enum.EnumMeta\'>" } + member { + name: "CollectiveHints" + mtype: "<type \'type\'>" + } member { name: "MultiWorkerMirroredStrategy" mtype: "<type \'type\'>" diff --git a/tensorflow/tools/ci_build/release/common.sh b/tensorflow/tools/ci_build/release/common.sh index dcac9688bf0..9e4a59ccab3 100644 --- a/tensorflow/tools/ci_build/release/common.sh +++ b/tensorflow/tools/ci_build/release/common.sh @@ -103,25 +103,6 @@ function update_bazel_linux { # LINT.ThenChange( # //tensorflow_estimator/google/kokoro/common.sh) -# Install the given bazel version on macos -function update_bazel_macos { - if [[ -z "$1" ]]; then - BAZEL_VERSION=${LATEST_BAZEL_VERSION} - else - BAZEL_VERSION=$1 - fi - BAZEL_COMMAND="curl -L https://github.com/bazelbuild/bazel/releases/download/${BAZEL_VERSION}/bazel-${BAZEL_VERSION}-installer-darwin-x86_64.sh -O && \ - chmod +x bazel-*.sh && ./bazel-${BAZEL_VERSION}-installer-darwin-x86_64.sh --user && \ - rm -f bazel-${BAZEL_VERSION}-installer-darwin-x86_64.sh" - # If the bazel update fails retry again in 60 seconds. - run_with_retry "${BAZEL_COMMAND}" - # Add new bazel installation to path - PATH="/Users/kbuilder/bin:$PATH" - set_bazel_outdir - which bazel - bazel version -} - function install_pip2 { curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py sudo python2 get-pip.py @@ -153,6 +134,7 @@ function install_pip_deps { ${SUDO_CMD} ${PIP_CMD} install astunparse==1.6.3 ${SUDO_CMD} ${PIP_CMD} install keras_preprocessing==1.1.0 --no-deps "${PIP_CMD}" install numpy==1.16.0 --user + "${PIP_CMD}" install PyYAML==3.12 --user ${SUDO_CMD} ${PIP_CMD} install gast==0.3.3 ${SUDO_CMD} ${PIP_CMD} install h5py==2.10.0 ${SUDO_CMD} ${PIP_CMD} install six==1.12.0 @@ -193,6 +175,7 @@ function install_ubuntu_16_pip_deps { "${PIP_CMD}" install portpicker --user "${PIP_CMD}" install scipy --user "${PIP_CMD}" install scikit-learn --user + "${PIP_CMD}" install PyYAML==3.12 --user "${PIP_CMD}" install --user --upgrade tf-estimator-nightly "${PIP_CMD}" install --user --upgrade tb-nightly # LINT.ThenChange(:ubuntu_pip_installations) diff --git a/tensorflow/tools/compatibility/tf_upgrade_v2_safety.py b/tensorflow/tools/compatibility/tf_upgrade_v2_safety.py index ec2914f9638..7a53fd87169 100644 --- a/tensorflow/tools/compatibility/tf_upgrade_v2_safety.py +++ b/tensorflow/tools/compatibility/tf_upgrade_v2_safety.py @@ -47,19 +47,12 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec): # List module renames. If changed, please update max_submodule_depth. self.import_renames = { - "tensorflow.google": - ast_edits.ImportRename( - "tensorflow.google.compat.v1", - excluded_prefixes=[ - "tensorflow.google.compat.v1", - "tensorflow.google.compat.v2", - ], - ), "tensorflow": ast_edits.ImportRename( "tensorflow.compat.v1", excluded_prefixes=[ "tensorflow.contrib", "tensorflow.flags", + "tensorflow.compat", "tensorflow.compat.v1", "tensorflow.compat.v2", "tensorflow.google" ], diff --git a/tensorflow/tools/compatibility/tf_upgrade_v2_safety_test.py b/tensorflow/tools/compatibility/tf_upgrade_v2_safety_test.py index b974a6963ae..eee946b0cb2 100644 --- a/tensorflow/tools/compatibility/tf_upgrade_v2_safety_test.py +++ b/tensorflow/tools/compatibility/tf_upgrade_v2_safety_test.py @@ -70,14 +70,12 @@ class TfUpgradeV2SafetyTest(test_util.TensorFlowTestCase): def testTensorFlowGoogleImport(self): text = "import tensorflow.google as tf" - expected_text = "import tensorflow.google.compat.v1 as tf" _, _, _, new_text = self._upgrade(text) - self.assertEqual(expected_text, new_text) + self.assertEqual(text, new_text) text = "import tensorflow.google" - expected_text = "import tensorflow.google.compat.v1" _, _, _, new_text = self._upgrade(text) - self.assertEqual(expected_text, new_text) + self.assertEqual(text, new_text) text = "import tensorflow.google.compat.v1 as tf" expected_text = "import tensorflow.google.compat.v1 as tf" diff --git a/tensorflow/tools/docs/BUILD b/tensorflow/tools/docs/BUILD index d8a45098b78..5d349da84bf 100644 --- a/tensorflow/tools/docs/BUILD +++ b/tensorflow/tools/docs/BUILD @@ -34,6 +34,7 @@ py_test( deps = [ ":tf_doctest_lib", "//tensorflow:tensorflow_py", + "//tensorflow/python/keras/preprocessing", "//third_party/py/numpy", ], ) diff --git a/tensorflow/tools/docs/tf_doctest.py b/tensorflow/tools/docs/tf_doctest.py index 2dce63541f5..19624659e37 100644 --- a/tensorflow/tools/docs/tf_doctest.py +++ b/tensorflow/tools/docs/tf_doctest.py @@ -28,6 +28,7 @@ import numpy as np import tensorflow.compat.v2 as tf +from tensorflow.python.keras import preprocessing from tensorflow.tools.docs import tf_doctest_lib # We put doctest after absltest so that it picks up the unittest monkeypatch. @@ -36,6 +37,9 @@ import doctest # pylint: disable=g-bad-import-order tf.compat.v1.enable_v2_behavior() +# Inject keras.preprocessing files into `tf.keras.preprocessing` namespace. +tf.keras.preprocessing = preprocessing + FLAGS = flags.FLAGS flags.DEFINE_string('module', None, 'A specific module to run doctest on.') diff --git a/tensorflow/tools/git/BUILD b/tensorflow/tools/git/BUILD index 8a47f4c4c2d..c1f0577f33b 100644 --- a/tensorflow/tools/git/BUILD +++ b/tensorflow/tools/git/BUILD @@ -13,5 +13,4 @@ py_binary( srcs = ["gen_git_source.py"], python_version = "PY3", srcs_version = "PY2AND3", - deps = ["@six_archive//:six"], ) diff --git a/tensorflow/tools/git/gen_git_source.py b/tensorflow/tools/git/gen_git_source.py index 011406e2288..0cb1d006142 100755 --- a/tensorflow/tools/git/gen_git_source.py +++ b/tensorflow/tools/git/gen_git_source.py @@ -35,8 +35,6 @@ import os import shutil import subprocess -import six - def parse_branch_ref(filename): """Given a filename of a .git/HEAD file return ref path. @@ -169,8 +167,8 @@ def get_git_version(git_base_path, git_tag_override): subprocess.check_output([ "git", str("--git-dir=%s/.git" % git_base_path), - str("--work-tree=" + six.ensure_str(git_base_path)), "describe", - "--long", "--tags" + str("--work-tree=%s" % git_base_path), "describe", "--long", + "--tags" ]).strip()) version_separator = b"-" if git_tag_override and val: diff --git a/tensorflow/tools/lib_package/BUILD b/tensorflow/tools/lib_package/BUILD index 0a9e5d151e0..1479c0177e7 100644 --- a/tensorflow/tools/lib_package/BUILD +++ b/tensorflow/tools/lib_package/BUILD @@ -161,7 +161,6 @@ genrule( "@nasm//:LICENSE", "@nsync//:LICENSE", "@png//:LICENSE", - "@six_archive//:LICENSE", "@snappy//:COPYING", "@sobol_data//:LICENSE", "@zlib//:zlib.h", @@ -244,7 +243,6 @@ genrule( "@nasm//:LICENSE", "@nsync//:LICENSE", "@png//:LICENSE", - "@six_archive//:LICENSE", "@snappy//:COPYING", "@sobol_data//:LICENSE", "@zlib//:zlib.h", diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index 8b8248399d5..682d022d691 100755 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -201,11 +201,11 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""): name = "eigen_archive", build_file = clean_dep("//third_party:eigen.BUILD"), patch_file = clean_dep("//third_party/eigen3:gpu_packet_math.patch"), - sha256 = "71905cca5553804beee85e9ab8b254931d3cbeda8df1a40e5af3773f5b657179", # SHARED_EIGEN_SHA - strip_prefix = "eigen-3fda850c46e5e589668a85d89299433e0686eec9", + sha256 = "88e95180a7eae9acd3e79d2efeea1026eefad9f515a44418b63b189a1887108c", # SHARED_EIGEN_SHA + strip_prefix = "eigen-52a2fbbb008a47c5e3fb8ac1c65c2feecb0c511c", urls = [ - "https://storage.googleapis.com/mirror.tensorflow.org/gitlab.com/libeigen/eigen/-/archive/3fda850c46e5e589668a85d89299433e0686eec9/eigen-3fda850c46e5e589668a85d89299433e0686eec9.tar.gz", - "https://gitlab.com/libeigen/eigen/-/archive/3fda850c46e5e589668a85d89299433e0686eec9/eigen-3fda850c46e5e589668a85d89299433e0686eec9.tar.gz", + "https://storage.googleapis.com/mirror.tensorflow.org/gitlab.com/libeigen/eigen/-/archive/52a2fbbb008a47c5e3fb8ac1c65c2feecb0c511c/eigen-52a2fbbb008a47c5e3fb8ac1c65c2feecb0c511c.tar.gz", + "https://gitlab.com/libeigen/eigen/-/archive/52a2fbbb008a47c5e3fb8ac1c65c2feecb0c511c/eigen-52a2fbbb008a47c5e3fb8ac1c65c2feecb0c511c.tar.gz", ], ) @@ -305,12 +305,12 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""): tf_http_archive( name = "org_sqlite", build_file = clean_dep("//third_party:sqlite.BUILD"), - sha256 = "adf051d4c10781ea5cfabbbc4a2577b6ceca68590d23b58b8260a8e24cc5f081", - strip_prefix = "sqlite-amalgamation-3300100", + sha256 = "f3c79bc9f4162d0b06fa9fe09ee6ccd23bb99ce310b792c5145f87fbcc30efca", + strip_prefix = "sqlite-amalgamation-3310100", system_build_file = clean_dep("//third_party/systemlibs:sqlite.BUILD"), urls = [ - "https://storage.googleapis.com/mirror.tensorflow.org/www.sqlite.org/2019/sqlite-amalgamation-3300100.zip", - "https://www.sqlite.org/2019/sqlite-amalgamation-3300100.zip", + "https://storage.googleapis.com/mirror.tensorflow.org/www.sqlite.org/2020/sqlite-amalgamation-3310100.zip", + "https://www.sqlite.org/2020/sqlite-amalgamation-3310100.zip", ], ) @@ -576,12 +576,12 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""): tf_http_archive( name = "com_github_nanopb_nanopb", - sha256 = "8bbbb1e78d4ddb0a1919276924ab10d11b631df48b657d960e0c795a25515735", + sha256 = "18234d9f01b57248472a9bfa65c3379352b5d66c15b0ef1c2b4feece4b5670fe", build_file = "@com_github_grpc_grpc//third_party:nanopb.BUILD", - strip_prefix = "nanopb-f8ac463766281625ad710900479130c7fcb4d63b", + strip_prefix = "nanopb-0.4.1", urls = [ - "https://storage.googleapis.com/mirror.tensorflow.org/github.com/nanopb/nanopb/archive/f8ac463766281625ad710900479130c7fcb4d63b.tar.gz", - "https://github.com/nanopb/nanopb/archive/f8ac463766281625ad710900479130c7fcb4d63b.tar.gz", + "https://storage.googleapis.com/mirror.tensorflow.org/github.com/nanopb/nanopb/archive/0.4.1.tar.gz", + "https://github.com/nanopb/nanopb/archive/0.4.1.tar.gz", ], ) @@ -597,8 +597,8 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""): ) # Check out LLVM and MLIR from llvm-project. - LLVM_COMMIT = "78be61871704a451a5d9462d7e96ed6c982746d4" - LLVM_SHA256 = "35ce5950da1c83c91ab10ea5ac6e88c46fca160044cfd9dc6e50c83879d963ef" + LLVM_COMMIT = "fee41517fe0f7ff9f0e204dd9200ebf32ca03cb8" + LLVM_SHA256 = "dceb84396e8c30348dbd426c53eeae6657f5c67a24830c9a610a037fffcbe5cf" LLVM_URLS = [ "https://storage.googleapis.com/mirror.tensorflow.org/github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT), "https://github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT), diff --git a/third_party/hexagon/workspace.bzl b/third_party/hexagon/workspace.bzl index 79529f3fb9c..d60fcaf82ac 100644 --- a/third_party/hexagon/workspace.bzl +++ b/third_party/hexagon/workspace.bzl @@ -5,9 +5,9 @@ load("//third_party:repo.bzl", "third_party_http_archive") def repo(): third_party_http_archive( name = "hexagon_nn", - sha256 = "4cbf3c18834e24b1f64cc507f9c2f22b4fe576c6ff938d55faced5d8f1bddf62", + sha256 = "281d46b47f7191f03a8a4071c4c8d2af9409bb9d59573dc2e42f04c4fd61f1fd", urls = [ - "https://storage.googleapis.com/mirror.tensorflow.org/storage.cloud.google.com/download.tensorflow.org/tflite/hexagon_nn_headers_v1.10.3.1.2.tgz", + "https://storage.googleapis.com/mirror.tensorflow.org/storage.cloud.google.com/download.tensorflow.org/tflite/hexagon_nn_headers_v1.10.3.1.3.tgz", ], build_file = "//third_party/hexagon:BUILD", ) diff --git a/third_party/llvm/llvm.autogenerated.BUILD b/third_party/llvm/llvm.autogenerated.BUILD index f8bc4cda891..a89838ebac9 100644 --- a/third_party/llvm/llvm.autogenerated.BUILD +++ b/third_party/llvm/llvm.autogenerated.BUILD @@ -575,7 +575,8 @@ gentbl( name = "amdgpu_isel_target_gen", tbl_outs = [ ("-gen-global-isel", "lib/Target/AMDGPU/AMDGPUGenGlobalISel.inc"), - ("-gen-global-isel-combiner -combiners=AMDGPUPreLegalizerCombinerHelper", "lib/Target/AMDGPU/AMDGPUGenGICombiner.inc"), + ("-gen-global-isel-combiner -combiners=AMDGPUPreLegalizerCombinerHelper", "lib/Target/AMDGPU/AMDGPUGenPreLegalizeGICombiner.inc"), + ("-gen-global-isel-combiner -combiners=AMDGPUPostLegalizerCombinerHelper", "lib/Target/AMDGPU/AMDGPUGenPostLegalizeGICombiner.inc"), ], tblgen = ":llvm-tblgen", td_file = "lib/Target/AMDGPU/AMDGPUGISel.td", diff --git a/third_party/mlir/BUILD b/third_party/mlir/BUILD index cd213f4046a..5729afac05d 100644 --- a/third_party/mlir/BUILD +++ b/third_party/mlir/BUILD @@ -199,15 +199,17 @@ gentbl( cc_library( name = "LoopOpsTransforms", - srcs = ["lib/Dialect/LoopOps/Transforms/ParallelLoopFusion.cpp"], + srcs = glob(["lib/Dialect/LoopOps/Transforms/*.cpp"]), hdrs = ["include/mlir/Dialect/LoopOps/Passes.h"], includes = ["include"], deps = [ + ":AffineOps", ":IR", ":LoopOps", ":Pass", ":StandardOps", ":Transforms", + "@llvm-project//llvm:support", ], alwayslink = 1, ) @@ -216,7 +218,7 @@ filegroup( name = "StdOpsTdFiles", srcs = [ "include/mlir/Analysis/CallInterfaces.td", - "include/mlir/Dialect/StandardOps/Ops.td", + "include/mlir/Dialect/StandardOps/IR/Ops.td", "include/mlir/IR/OpAsmInterface.td", ":OpBaseTdFiles", ], @@ -228,23 +230,23 @@ gentbl( tbl_outs = [ ( "-gen-op-decls", - "include/mlir/Dialect/StandardOps/Ops.h.inc", + "include/mlir/Dialect/StandardOps/IR/Ops.h.inc", ), ( "-gen-op-defs", - "include/mlir/Dialect/StandardOps/Ops.cpp.inc", + "include/mlir/Dialect/StandardOps/IR/Ops.cpp.inc", ), ( "-gen-enum-decls", - "include/mlir/Dialect/StandardOps/OpsEnums.h.inc", + "include/mlir/Dialect/StandardOps/IR/OpsEnums.h.inc", ), ( "-gen-enum-defs", - "include/mlir/Dialect/StandardOps/OpsEnums.cpp.inc", + "include/mlir/Dialect/StandardOps/IR/OpsEnums.cpp.inc", ), ], tblgen = ":mlir-tblgen", - td_file = "include/mlir/Dialect/StandardOps/Ops.td", + td_file = "include/mlir/Dialect/StandardOps/IR/Ops.td", td_srcs = [ ":StdOpsTdFiles", ], @@ -384,13 +386,13 @@ cc_library( name = "StandardOps", srcs = glob( [ - "lib/Dialect/StandardOps/*.cpp", - "lib/Dialect/StandardOps/*.h", + "lib/Dialect/StandardOps/IR/*.cpp", + "lib/Dialect/StandardOps/IR/*.h", "lib/Dialect/StandardOps/EDSC/*.cpp", ], ), hdrs = glob([ - "include/mlir/Dialect/StandardOps/*.h", + "include/mlir/Dialect/StandardOps/IR/*.h", "include/mlir/Dialect/StandardOps/EDSC/*.h", ]) + [ "include/mlir/Analysis/CallInterfaces.h", @@ -1371,11 +1373,13 @@ cc_library( includes = ["include"], deps = [ ":AffineOps", + ":GPUDialect", ":LoopOps", ":LoopsToGPU", ":Pass", ":StandardOps", ":Support", + ":Transforms", "@llvm-project//llvm:support", ], ) @@ -1693,6 +1697,7 @@ cc_library( "@llvm-project//mlir/test:TestDialect", "@llvm-project//mlir/test:TestIR", "@llvm-project//mlir/test:TestPass", + "@llvm-project//mlir/test:TestSPIRV", "@llvm-project//mlir/test:TestTransforms", ], ) @@ -1820,6 +1825,7 @@ cc_binary( "@llvm-project//mlir/test:TestDialect", "@llvm-project//mlir/test:TestIR", "@llvm-project//mlir/test:TestPass", + "@llvm-project//mlir/test:TestSPIRV", "@llvm-project//mlir/test:TestTransforms", ], ) @@ -2519,7 +2525,7 @@ exports_files( "include/mlir/Analysis/CallInterfaces.h", "include/mlir/Analysis/CallInterfaces.td", "include/mlir/Dialect/LLVMIR/LLVMOpBase.td", - "include/mlir/Dialect/StandardOps/Ops.td", + "include/mlir/Dialect/StandardOps/IR/Ops.td", "include/mlir/IR/OpAsmInterface.td", "include/mlir/IR/OpBase.td", "include/mlir/Transforms/InliningUtils.h", diff --git a/third_party/mlir/test.BUILD b/third_party/mlir/test.BUILD index 1e89b553ac4..657d011254b 100644 --- a/third_party/mlir/test.BUILD +++ b/third_party/mlir/test.BUILD @@ -166,3 +166,16 @@ cc_library( "@llvm-project//mlir:VectorToLoops", ], ) + +cc_library( + name = "TestSPIRV", + srcs = glob([ + "lib/Dialect/SPIRV/*.cpp", + ]), + deps = [ + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SPIRVDialect", + "@llvm-project//mlir:SPIRVLowering", + ], +) diff --git a/third_party/snappy.BUILD b/third_party/snappy.BUILD index d93f0307690..a2ab4924f29 100644 --- a/third_party/snappy.BUILD +++ b/third_party/snappy.BUILD @@ -27,6 +27,10 @@ cc_library( "-Wno-implicit-function-declaration", ], }), + defines = select({ + "@org_tensorflow//tensorflow:windows": [], + "//conditions:default": ["HAVE_SYS_UIO_H"], + }), ) genrule( diff --git a/third_party/toolchains/remote_config/configs.bzl b/third_party/toolchains/remote_config/configs.bzl index 4945db280b6..973efb40af1 100644 --- a/third_party/toolchains/remote_config/configs.bzl +++ b/third_party/toolchains/remote_config/configs.bzl @@ -22,6 +22,18 @@ def initialize_rbe_configs(): tensorrt_version = "5.1", ) + tensorflow_rbe_config( + name = "ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0", + compiler = "/dt7/usr/bin/gcc", + compiler_prefix = "/usr/bin", + cuda_version = "10.1", + cudnn_version = "7", + os = "ubuntu16.04-manylinux2010", + python_version = "3", + tensorrt_install_path = "/usr", + tensorrt_version = "6.0", + ) + tensorflow_rbe_config( name = "ubuntu16.04-py3_opt-gcc5-rocm", compiler = "gcc", diff --git a/third_party/toolchains/remote_config/containers.bzl b/third_party/toolchains/remote_config/containers.bzl index 27e0cceee1c..26cb3ea3367 100644 --- a/third_party/toolchains/remote_config/containers.bzl +++ b/third_party/toolchains/remote_config/containers.bzl @@ -17,6 +17,13 @@ containers = { "digest": container_digests["cuda10.0-cudnn7-ubuntu16.04-manylinux2010"], }, + # Built with //tensorflow/tools/ci_build/Dockerfile.rbe.cuda10.1-cudnn7-ubuntu16.04-manylinux2010. + "cuda10.1-cudnn7-ubuntu16.04-manylinux2010": { + "registry": "gcr.io", + "repository": "tensorflow-testing/nosla-cuda10.1-cudnn7-ubuntu16.04-manylinux2010", + "digest": container_digests["cuda10.1-cudnn7-ubuntu16.04-manylinux2010"], + }, + # Built with //tensorflow/tools/ci_build/Dockerfile.rbe.rocm-ubuntu16.04 "rocm-ubuntu16.04": { "registry": "gcr.io",