Merge branch 'master' of https://github.com/tensorflow/tensorflow
This commit is contained in:
commit
8a13ac7b5e
2
.bazelrc
2
.bazelrc
@ -323,8 +323,6 @@ build:windows --copt=/experimental:preprocessor
|
||||
build:windows --host_copt=/experimental:preprocessor
|
||||
|
||||
# Misc build options we need for windows.
|
||||
build:windows --linkopt=/DEBUG
|
||||
build:windows --host_linkopt=/DEBUG
|
||||
build:windows --linkopt=/OPT:REF
|
||||
build:windows --host_linkopt=/OPT:REF
|
||||
build:windows --linkopt=/OPT:ICF
|
||||
|
@ -206,6 +206,9 @@
|
||||
`fit()`. Running multiple batches inside a single `tf.function` call can
|
||||
greatly improve performance on TPUs or small models with a large Python
|
||||
overhead.
|
||||
* Improvements to Keras preprocessing layers:
|
||||
* TextVectorization can now accept a vocabulary list or file as an
|
||||
init arg.
|
||||
* `tf.function` / AutoGraph:
|
||||
|
||||
* Added `experimental_follow_type_hints` argument for `tf.function`. When
|
||||
|
@ -3,7 +3,7 @@
|
||||
load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"if_tpu",
|
||||
"if_libtpu",
|
||||
"tf_cc_test",
|
||||
"tf_copts",
|
||||
"tf_cuda_cc_test",
|
||||
@ -289,7 +289,7 @@ cc_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
] + if_tpu(
|
||||
] + if_libtpu(
|
||||
if_false = ["//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration"],
|
||||
if_true = [],
|
||||
),
|
||||
@ -354,7 +354,7 @@ cc_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
] + if_tpu(
|
||||
] + if_libtpu(
|
||||
if_false = ["//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration"],
|
||||
if_true = [],
|
||||
),
|
||||
|
@ -39,7 +39,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/tfe_op_internal.h"
|
||||
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
||||
#include "tensorflow/c/tf_tensor_internal.h"
|
||||
#if defined(PLATFORM_GOOGLE) && !defined(LIBTFTPU)
|
||||
#if defined(PLATFORM_GOOGLE) && !defined(LIBTPU_ON_GCE)
|
||||
#include "tensorflow/core/tfrt/eager/c_api_tfrt.h"
|
||||
#endif
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
@ -729,7 +729,7 @@ void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; }
|
||||
|
||||
TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
|
||||
if (opts->use_tfrt) {
|
||||
#if defined(PLATFORM_GOOGLE) && !defined(LIBTFTPU)
|
||||
#if defined(PLATFORM_GOOGLE) && !defined(LIBTPU_ON_GCE)
|
||||
return tensorflow::wrap(new tfrt::tf::ContextInterface(opts->async));
|
||||
#else
|
||||
status->status = tensorflow::errors::Unimplemented("TFRT is not supported");
|
||||
|
@ -42,13 +42,15 @@ cc_library(
|
||||
name = "reader",
|
||||
srcs = ["reader.cc"],
|
||||
hdrs = ["reader.h"],
|
||||
deps = [":constants"] + if_not_mobile([
|
||||
deps = [
|
||||
":constants",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
] + if_not_mobile([
|
||||
# TODO(b/111634734): :lib and :protos_all contain dependencies that
|
||||
# cannot be built on mobile platforms. Instead, include the appropriate
|
||||
# tf_lib depending on the build platform.
|
||||
"@com_google_absl//absl/memory:memory",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
]),
|
||||
)
|
||||
|
||||
|
@ -4,7 +4,7 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
|
||||
load("//tensorflow:tensorflow.bzl", "cc_header_only_library", "if_mlir", "tf_cc_test")
|
||||
|
||||
# buildifier: disable=same-origin-load
|
||||
load("//tensorflow:tensorflow.bzl", "if_tpu", "tf_copts")
|
||||
load("//tensorflow:tensorflow.bzl", "if_libtpu", "tf_copts")
|
||||
load("//tensorflow/stream_executor:build_defs.bzl", "if_cuda_or_rocm")
|
||||
|
||||
# buildifier: disable=same-origin-load
|
||||
@ -77,7 +77,7 @@ cc_library(
|
||||
"//tensorflow/compiler/jit/kernels:xla_ops",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||
] + if_tpu(
|
||||
] + if_libtpu(
|
||||
if_false = ["//tensorflow/compiler/xla/service:cpu_plugin"],
|
||||
if_true = [],
|
||||
),
|
||||
@ -114,7 +114,7 @@ cc_library(
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:lib",
|
||||
] + if_tpu(
|
||||
] + if_libtpu(
|
||||
if_false = [
|
||||
"//tensorflow/compiler/xla/service:cpu_plugin", # buildcleaner: keep
|
||||
],
|
||||
@ -141,7 +141,7 @@ cc_library(
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/common_runtime/gpu:gpu_init",
|
||||
] + if_tpu(
|
||||
] + if_libtpu(
|
||||
if_false = [
|
||||
"//tensorflow/compiler/xla/service:gpu_plugin", # buildcleaner: keep
|
||||
],
|
||||
@ -375,7 +375,7 @@ cc_library(
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/platform:logging",
|
||||
] + if_tpu(
|
||||
] + if_libtpu(
|
||||
if_false = [
|
||||
"//tensorflow/compiler/mlir:array_container_utils",
|
||||
"//tensorflow/compiler/mlir/tensorflow:compile_mlir_util_no_tf_dialect_passes",
|
||||
|
@ -47,7 +47,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/public/version.h"
|
||||
#include "tensorflow/core/util/dump_graph.h"
|
||||
|
||||
#if !defined(LIBTFTPU)
|
||||
#if !defined(LIBTPU_ON_GCE)
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h"
|
||||
#include "tensorflow/compiler/mlir/utils/array_container_utils.h"
|
||||
#endif
|
||||
@ -289,7 +289,7 @@ Status XlaCompilationCache::CompileSingleOp(
|
||||
});
|
||||
const ConfigProto* config = ctx->function_library()->config_proto();
|
||||
bool use_mlir = config && config->experimental().enable_mlir_bridge();
|
||||
#ifdef LIBTFTPU
|
||||
#ifdef LIBTPU_ON_GCE
|
||||
if (use_mlir && has_tensor_list_arg) {
|
||||
LOG(WARNING) << "MLIR is not supported in this environment.";
|
||||
}
|
||||
|
@ -157,6 +157,9 @@ def HLO_AbsOp: HLO_UnaryElementwiseOp<"abs",
|
||||
>];
|
||||
}
|
||||
|
||||
def HLO_CbrtOp: HLO_UnaryElementwiseOp<"cbrt",
|
||||
[NoSideEffect, SameOperandsAndResultType], HLO_FpTensor>, BASE_HLO_CbrtOp;
|
||||
|
||||
def HLO_CeilOp: HLO_UnaryElementwiseOp<"ceil",
|
||||
[NoSideEffect, SameOperandsAndResultType], HLO_FpTensor>, BASE_HLO_CeilOp;
|
||||
|
||||
@ -1423,4 +1426,21 @@ def HLO_FusionOp : HLO_Op<"fusion", []> {
|
||||
let hasCustomHLOConverter = 1;
|
||||
}
|
||||
|
||||
// This is an op for purposes internal to XLA/GPU.
|
||||
def HLO_BitcastOp : HLO_Op<"bitcast", [NoSideEffect]>, BASE_HLO_BitcastOp {
|
||||
let arguments = (ins HLO_Tensor:$operand);
|
||||
let results = (outs HLO_Tensor);
|
||||
let hasCustomHLOConverter = 1;
|
||||
}
|
||||
|
||||
def HLO_ReducePrecisionOp: HLO_Op<"reduce_precision", [SameOperandsAndResultShape]>,
|
||||
BASE_HLO_ReducePrecisionOp {
|
||||
let arguments = (ins
|
||||
HLO_FpTensor:$operand,
|
||||
I32Attr:$exponent_bits,
|
||||
I32Attr:$mantissa_bits
|
||||
);
|
||||
let results = (outs HLO_FpTensor:$output);
|
||||
}
|
||||
|
||||
#endif // HLO_OPS
|
||||
|
@ -127,6 +127,17 @@ class BASE_HLO_AbsOp {
|
||||
}];
|
||||
}
|
||||
|
||||
class BASE_HLO_CbrtOp {
|
||||
string summary = "Cubic root operator";
|
||||
|
||||
string description = [{
|
||||
Returns element-wise cubic root of the operand.
|
||||
|
||||
See
|
||||
https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
|
||||
}];
|
||||
}
|
||||
|
||||
class BASE_HLO_CeilOp {
|
||||
string summary = "Ceil operator";
|
||||
|
||||
@ -1336,4 +1347,17 @@ class BASE_HLO_WhileOp {
|
||||
}];
|
||||
}
|
||||
|
||||
class BASE_HLO_BitcastOp {
|
||||
string summary = "Bitcast operator";
|
||||
|
||||
string description = [{
|
||||
This op changes the shape of the input in the way that the physical
|
||||
arranggment of elements are unchanged.
|
||||
|
||||
However, the op needs layout information to make sense of "physical
|
||||
arrangement of elements". Layout support in MHLO is currently under
|
||||
exploration.
|
||||
}];
|
||||
}
|
||||
|
||||
#endif // HLO_OPS_BASE
|
||||
|
@ -1193,3 +1193,24 @@ func @incompatible_shapes(%arg0: tensor<?xf32>, %shape: tensor<2xindex>) -> tens
|
||||
%0 = "mhlo.dynamic_reshape"(%arg0, %shape) : (tensor<?xf32>, tensor<2xindex>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @cbrt(%arg: tensor<2x4xf32>) -> tensor<2x4xf32> {
|
||||
%0 = "mhlo.cbrt"(%arg) : (tensor<2x4xf32>) -> tensor<2x4xf32>
|
||||
return %0 : tensor<2x4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @bitcast(%arg: tensor<2x4xf32>) -> tensor<2x4xf32> {
|
||||
%0 = "mhlo.bitcast"(%arg) : (tensor<2x4xf32>) -> tensor<2x4xf32>
|
||||
return %0 : tensor<2x4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @bitcast(%arg: tensor<2x4xf32>) -> tensor<2x4xf32> {
|
||||
%0 = "mhlo.reduce_precision"(%arg) {exponent_bits=2 : i32, mantissa_bits=3 : i32} : (tensor<2x4xf32>) -> tensor<2x4xf32>
|
||||
return %0 : tensor<2x4xf32>
|
||||
}
|
||||
|
@ -74,8 +74,8 @@ tool_names = [
|
||||
'tf_tfjs_translate', 'flatbuffer_to_string', 'flatbuffer_translate',
|
||||
'tf-mlir-translate', 'mlir-tflite-runner', 'tfcompile',
|
||||
'json_to_flatbuffer', 'xla-gpu-opt', 'xla-mlir-gpu-opt', 'xla-opt',
|
||||
'hlo_to_llvm_ir', 'kernel-gen-opt', 'tf_to_kernel', 'tf_to_gpu_binary',
|
||||
'xla-thunks-opt', 'tfjs-opt'
|
||||
'hlo_to_llvm_ir', 'kernel-gen-opt', 'tf_to_gpu_binary', 'xla-thunks-opt',
|
||||
'tfjs-opt'
|
||||
]
|
||||
tools = [ToolSubst(s, unresolved='ignore') for s in tool_names]
|
||||
llvm_config.add_tool_substitutions(tools, tool_dirs)
|
||||
|
@ -1,4 +1,5 @@
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_python_pybind_extension")
|
||||
load(
|
||||
"//third_party/mlir:tblgen.bzl",
|
||||
"gentbl",
|
||||
@ -226,3 +227,23 @@ cc_library(
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
tf_python_pybind_extension(
|
||||
name = "tfr_wrapper",
|
||||
srcs = ["python/tfr_wrapper.cc"],
|
||||
module_name = "tfr_wrapper",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/tfr",
|
||||
"//tensorflow/python:pybind11_lib",
|
||||
"//tensorflow/python:pybind11_status",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Parser",
|
||||
"@llvm-project//mlir:SCFDialect",
|
||||
"@llvm-project//mlir:Shape",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@pybind11",
|
||||
],
|
||||
)
|
||||
|
58
tensorflow/compiler/mlir/tfr/python/tfr_wrapper.cc
Normal file
58
tensorflow/compiler/mlir/tfr/python/tfr_wrapper.cc
Normal file
@ -0,0 +1,58 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/AsmState.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Verifier.h" // from @llvm-project
|
||||
#include "mlir/Parser.h" // from @llvm-project
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/stl.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.h"
|
||||
#include "tensorflow/python/lib/core/pybind11_lib.h"
|
||||
#include "tensorflow/python/lib/core/pybind11_status.h"
|
||||
|
||||
PYBIND11_MODULE(tfr_wrapper, m) {
|
||||
m.def("verify", [](std::string input) {
|
||||
mlir::MLIRContext ctx(/*loadAllDialects=*/true);
|
||||
auto& registry = ctx.getDialectRegistry();
|
||||
registry.insert<mlir::scf::SCFDialect, mlir::TF::TensorFlowDialect,
|
||||
mlir::StandardOpsDialect, mlir::shape::ShapeDialect,
|
||||
mlir::TFR::TFRDialect>();
|
||||
ctx.getDialectRegistry().loadAll(&ctx);
|
||||
|
||||
llvm::SourceMgr source_mgr = llvm::SourceMgr();
|
||||
source_mgr.AddNewSourceBuffer(llvm::MemoryBuffer::getMemBuffer(input),
|
||||
llvm::SMLoc());
|
||||
auto module = mlir::parseSourceFile(source_mgr, &ctx);
|
||||
if (!module) {
|
||||
return false;
|
||||
}
|
||||
|
||||
mlir::SourceMgrDiagnosticHandler sourceMgrHandler(source_mgr, &ctx);
|
||||
if (failed(mlir::verify(*module))) {
|
||||
module->emitError("Invalid MLIR module: failed verification.");
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
});
|
||||
}
|
@ -105,10 +105,7 @@ tf_cc_binary(
|
||||
tf_cc_binary(
|
||||
name = "tf_to_kernel",
|
||||
srcs = ["tf_to_kernel.cc"],
|
||||
visibility = [
|
||||
"//tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_kernel:__pkg__",
|
||||
"//tensorflow/core/kernels/mlir_generated:__pkg__",
|
||||
],
|
||||
visibility = ["//tensorflow/core/kernels/mlir_generated:__pkg__"],
|
||||
deps = [
|
||||
":kernel_creator",
|
||||
"//tensorflow/compiler/mlir:init_mlir",
|
||||
@ -162,7 +159,7 @@ cc_library(
|
||||
|
||||
cc_library(
|
||||
name = "tf_cuda_runtime_wrappers",
|
||||
srcs = ["tf_cuda_runtime_wrappers.cpp"],
|
||||
srcs = ["tf_cuda_runtime_wrappers.cc"],
|
||||
compatible_with = get_compatible_with_cloud(),
|
||||
deps = [
|
||||
"//tensorflow/core/platform/default/build_config:stream_executor_cuda",
|
||||
|
@ -174,8 +174,7 @@ Status LowerTFtoGPU(mlir::ModuleOp module, bool gpu_binary_only,
|
||||
Status LowerGPUToLLVM(mlir::ModuleOp module, bool gpu_binary_only,
|
||||
llvm::ArrayRef<uint32_t> same_shape,
|
||||
llvm::StringRef gpu_binary_attr_name,
|
||||
llvm::ArrayRef<uint32_t> architectures,
|
||||
bool generate_fatbin) {
|
||||
int32_t architecture) {
|
||||
mlir::PassManager pm(module.getContext());
|
||||
applyTensorflowAndCLOptions(pm);
|
||||
|
||||
@ -188,7 +187,7 @@ Status LowerGPUToLLVM(mlir::ModuleOp module, bool gpu_binary_only,
|
||||
}
|
||||
kernel_pm.addPass(mlir::createStripDebugInfoPass());
|
||||
kernel_pm.addPass(mlir::kernel_gen::transforms::CreateGpuKernelToBlobPass(
|
||||
gpu_binary_attr_name, architectures, generate_fatbin));
|
||||
gpu_binary_attr_name, architecture));
|
||||
|
||||
if (!gpu_binary_only) {
|
||||
pm.addPass(mlir::kernel_gen::transforms::CreateTFKernelToLLVMPass());
|
||||
@ -203,9 +202,9 @@ Status LowerGPUToLLVM(mlir::ModuleOp module, bool gpu_binary_only,
|
||||
|
||||
StatusOr<mlir::OwningModuleRef> GenerateKernelForTfCode(
|
||||
mlir::MLIRContext& context, llvm::StringRef tf_code, bool gpu_binary_only,
|
||||
llvm::ArrayRef<uint32_t> architectures, llvm::ArrayRef<uint32_t> tile_sizes,
|
||||
int32_t architecture, llvm::ArrayRef<uint32_t> tile_sizes,
|
||||
llvm::ArrayRef<uint32_t> same_shape,
|
||||
llvm::ArrayRef<uint32_t> unroll_factors, bool generate_fatbin) {
|
||||
llvm::ArrayRef<uint32_t> unroll_factors) {
|
||||
mlir::RegisterAllTensorFlowDialects(context.getDialectRegistry());
|
||||
mlir::OwningModuleRef module = mlir::parseSourceString(tf_code, &context);
|
||||
TF_RETURN_IF_ERROR(
|
||||
@ -222,8 +221,7 @@ StatusOr<mlir::OwningModuleRef> GenerateKernelForTfCode(
|
||||
TF_RETURN_IF_ERROR(xla::mlir_gpu::LowerKernelBodiesToNVVM(module.get()));
|
||||
#endif
|
||||
TF_RETURN_IF_ERROR(LowerGPUToLLVM(module.get(), gpu_binary_only, same_shape,
|
||||
kGpuBinaryAttrName, architectures,
|
||||
generate_fatbin));
|
||||
kGpuBinaryAttrName, architecture));
|
||||
return module;
|
||||
}
|
||||
|
||||
|
@ -38,10 +38,9 @@ namespace kernel_gen {
|
||||
// false, lowers the host side to LLVM Dialect.
|
||||
xla::StatusOr<mlir::OwningModuleRef> GenerateKernelForTfCode(
|
||||
mlir::MLIRContext& context, llvm::StringRef tf_code, bool gpu_binary_only,
|
||||
llvm::ArrayRef<uint32_t> architectures = {75},
|
||||
llvm::ArrayRef<uint32_t> tile_sizes = {16, 64},
|
||||
int32_t architecture = 75, llvm::ArrayRef<uint32_t> tile_sizes = {16, 64},
|
||||
llvm::ArrayRef<uint32_t> same_shape = {},
|
||||
llvm::ArrayRef<uint32_t> unroll_factors = {}, bool generate_fatbin = true);
|
||||
llvm::ArrayRef<uint32_t> unroll_factors = {});
|
||||
|
||||
// Extracts gpu_binary from the converted module.
|
||||
xla::StatusOr<std::string> ExtractGpuBinary(mlir::ModuleOp module);
|
||||
|
@ -1,5 +1,6 @@
|
||||
// RUN: tf_to_gpu_binary --input=%s --output=%t --same_shape=0,1 --unroll_factors=4 --tile_sizes=256 --arch=70
|
||||
func @tanh(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||
%0 = "tf.Tanh"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = "tf.Tanh"(%arg0) { }
|
||||
: (tensor<?xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
@ -1,17 +0,0 @@
|
||||
load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests")
|
||||
|
||||
package(licenses = ["notice"])
|
||||
|
||||
glob_lit_tests(
|
||||
data = [
|
||||
"//tensorflow/compiler/mlir/tools/kernel_gen:tf_to_kernel",
|
||||
"@llvm-project//mlir:run_lit.sh",
|
||||
],
|
||||
default_tags = [
|
||||
# We need access to the CUDA SDK.
|
||||
"gpu",
|
||||
"no_rocm",
|
||||
],
|
||||
driver = "//tensorflow/compiler/mlir:run_lit.sh",
|
||||
test_file_exts = ["mlir"],
|
||||
)
|
@ -1,6 +0,0 @@
|
||||
// RUN: tf_to_kernel --input=%s --output=%t --same_shape=0,1 --unroll_factors=4 --tile_sizes=256 --arch=70,75
|
||||
|
||||
func @tanh(%arg: tensor<*xf32>) -> tensor<*xf32> {
|
||||
%0 = "tf.Tanh"(%arg) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
@ -20,9 +20,9 @@ limitations under the License.
|
||||
#include <cassert>
|
||||
#include <numeric>
|
||||
|
||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/ArrayRef.h"
|
||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/Support/raw_ostream.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/ExecutionEngine/CRunnerUtils.h" // from @llvm-project
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#include "third_party/gpus/cuda/include/cuda.h"
|
@ -48,7 +48,7 @@ xla::Status Run(llvm::StringRef input_file, llvm::StringRef output_file,
|
||||
mlir::OwningModuleRef module,
|
||||
GenerateKernelForTfCode(context, tf_code, /*gpu_binary_only=*/true,
|
||||
architecture, tile_sizes, same_shape,
|
||||
unroll_factors, /*generate_fatbin=*/false));
|
||||
unroll_factors));
|
||||
// Extract gpu_binary.
|
||||
TF_ASSIGN_OR_RETURN(std::string gpu_binary, ExtractGpuBinary(*module));
|
||||
|
||||
|
@ -95,8 +95,7 @@ xla::StatusOr<std::string> EmitToBinary(mlir::ModuleOp module) {
|
||||
}
|
||||
|
||||
xla::Status Run(llvm::StringRef input_file, llvm::StringRef output_file,
|
||||
llvm::ArrayRef<uint32_t> architectures,
|
||||
llvm::ArrayRef<uint32_t> tile_sizes,
|
||||
int32_t architecture, llvm::ArrayRef<uint32_t> tile_sizes,
|
||||
llvm::ArrayRef<uint32_t> same_shape,
|
||||
llvm::ArrayRef<uint32_t> unroll_factors) {
|
||||
// Read TF code.
|
||||
@ -108,7 +107,7 @@ xla::Status Run(llvm::StringRef input_file, llvm::StringRef output_file,
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
mlir::OwningModuleRef module,
|
||||
GenerateKernelForTfCode(context, tf_code, /*gpu_binary_only=*/false,
|
||||
architectures, tile_sizes, same_shape,
|
||||
architecture, tile_sizes, same_shape,
|
||||
unroll_factors));
|
||||
// Get binary.
|
||||
TF_ASSIGN_OR_RETURN(std::string binary, EmitToBinary(*module));
|
||||
@ -130,8 +129,8 @@ int main(int argc, char** argv) {
|
||||
llvm::cl::opt<std::string> output_file(
|
||||
"output", llvm::cl::desc("output file"), llvm::cl::value_desc("filename"),
|
||||
llvm::cl::init("foo.bin"));
|
||||
llvm::cl::list<uint32_t> architectures(
|
||||
"arch", llvm::cl::desc("target architectures (e.g. 50 for sm_50)"),
|
||||
llvm::cl::list<int32_t> architecture(
|
||||
"arch", llvm::cl::desc("target architecture (e.g. 50 for sm_50)"),
|
||||
llvm::cl::OneOrMore, llvm::cl::CommaSeparated);
|
||||
llvm::cl::list<uint32_t> tile_sizes(
|
||||
"tile_sizes", llvm::cl::desc("tile sizes to use"), llvm::cl::ZeroOrMore,
|
||||
@ -152,7 +151,7 @@ int main(int argc, char** argv) {
|
||||
llvm::cl::ParseCommandLineOptions(argc, argv, "TF op GPU kernel generator\n");
|
||||
|
||||
auto status =
|
||||
tensorflow::kernel_gen::Run(input_file, output_file, architectures,
|
||||
tensorflow::kernel_gen::Run(input_file, output_file, architecture.front(),
|
||||
tile_sizes, same_shape, unroll_factors);
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << status;
|
||||
|
@ -117,7 +117,6 @@ cc_library(
|
||||
"@llvm-project//mlir:AllPassesAndDialects",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:Transforms",
|
||||
"@llvm-project//llvm:TransformUtils",
|
||||
"//tensorflow/compiler/mlir/hlo",
|
||||
"//tensorflow/compiler/mlir/hlo:hlo_legalize_to_lhlo",
|
||||
"//tensorflow/compiler/mlir/hlo:lhlo",
|
||||
|
@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "llvm/Transforms/Utils/Cloning.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/Target/NVVMIR.h" // from @llvm-project
|
||||
#include "mlir/Target/ROCDLIR.h" // from @llvm-project
|
||||
@ -50,12 +49,9 @@ using xla::InternalError;
|
||||
class GpuKernelToBlobPass
|
||||
: public GpuKernelToBlobPassBase<GpuKernelToBlobPass> {
|
||||
public:
|
||||
GpuKernelToBlobPass(mlir::StringRef blob_annotation,
|
||||
llvm::ArrayRef<uint32_t> architectures,
|
||||
bool generate_fatbin) {
|
||||
GpuKernelToBlobPass(mlir::StringRef blob_annotation, int32_t arch) {
|
||||
blob_annotation_ = blob_annotation.str();
|
||||
architectures_ = architectures;
|
||||
generate_fatbin_ = generate_fatbin;
|
||||
arch_ = arch;
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
@ -73,17 +69,7 @@ class GpuKernelToBlobPass
|
||||
|
||||
xla::StatusOr<std::vector<uint8_t>> GetGpuBinaryBlob(
|
||||
mlir::gpu::GPUModuleOp gpu_module) {
|
||||
if (architectures_.empty()) {
|
||||
return InternalError("Expected at least one GPU architecture.");
|
||||
}
|
||||
if (!generate_fatbin_ && architectures_.size() > 1) {
|
||||
return InternalError(
|
||||
"Can only generate machine code for more than one architecture as a "
|
||||
"fatbin.");
|
||||
}
|
||||
|
||||
llvm::LLVMContext llvmContext;
|
||||
|
||||
#if TENSORFLOW_USE_ROCM
|
||||
auto llvmModule = mlir::translateModuleToROCDLIR(gpu_module, llvmContext);
|
||||
if (!llvmModule) {
|
||||
@ -95,14 +81,9 @@ class GpuKernelToBlobPass
|
||||
xla::HloModuleConfig config;
|
||||
config.set_debug_options(xla::GetDebugOptionsFromFlags());
|
||||
|
||||
// TODO(b/169066682): Support fatbin on ROCm.
|
||||
if (generate_fatbin_) {
|
||||
return InternalError("Fatbins are not yet supported for ROCm.");
|
||||
}
|
||||
|
||||
uint32_t arch = architectures_.front();
|
||||
std::string libdevice_dir = tensorflow::RocdlRoot();
|
||||
return xla::gpu::amdgpu::CompileToHsaco(llvmModule.get(), arch, config,
|
||||
|
||||
return xla::gpu::amdgpu::CompileToHsaco(llvmModule.get(), arch_, config,
|
||||
libdevice_dir);
|
||||
|
||||
#elif GOOGLE_CUDA
|
||||
@ -121,42 +102,19 @@ class GpuKernelToBlobPass
|
||||
target->Options.AllowFPOpFusion = llvm::FPOpFusion::FPOpFusionMode::Fast;
|
||||
};
|
||||
|
||||
// Compile and collect requested cubin and PTX images.
|
||||
std::vector<tensorflow::se::CubinOrPTXImage> images;
|
||||
int32_t cc_major = arch_ / 10;
|
||||
int32_t cc_minor = arch_ % 10;
|
||||
TF_ASSIGN_OR_RETURN(std::string libdevice_dir, GetLibdeviceDir(config));
|
||||
auto gpu_asm_opts = xla::gpu::PtxOptsFromConfig(config);
|
||||
for (uint32_t arch : architectures_) {
|
||||
int32_t cc_major = arch / 10;
|
||||
int32_t cc_minor = arch % 10;
|
||||
// Module may be changed by CompileToPtx.
|
||||
auto llvmModuleCopy = llvm::CloneModule(*llvmModule);
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::string ptx,
|
||||
xla::gpu::nvptx::CompileToPtx(llvmModuleCopy.get(),
|
||||
std::make_pair(cc_major, cc_minor),
|
||||
config, libdevice_dir, enable_fusion));
|
||||
// TODO(b/169066682): If compute_XX profile, collect PTX image here.
|
||||
VLOG(1) << ptx;
|
||||
TF_ASSIGN_OR_RETURN(std::vector<uint8_t> gpu_asm,
|
||||
tensorflow::se::CompileGpuAsm(
|
||||
cc_major, cc_minor, ptx.c_str(), gpu_asm_opts));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::string ptx,
|
||||
xla::gpu::nvptx::CompileToPtx(llvmModule.get(),
|
||||
std::make_pair(cc_major, cc_minor),
|
||||
config, libdevice_dir, enable_fusion));
|
||||
VLOG(1) << ptx;
|
||||
|
||||
if (!generate_fatbin_) {
|
||||
// Skip fatbin generation and return the first and only GPU machine
|
||||
// code.
|
||||
return gpu_asm;
|
||||
}
|
||||
|
||||
// Collect cubin image.
|
||||
images.push_back({absl::StrCat("sm_", arch), std::move(gpu_asm)});
|
||||
}
|
||||
|
||||
// TODO(b/169870789): Revisit the use of fatbins.
|
||||
// Bundle cubin and PTX images into a single fatbin.
|
||||
return tensorflow::se::BundleGpuAsm(images,
|
||||
gpu_asm_opts.preferred_cuda_dir);
|
||||
return tensorflow::se::CompileGpuAsm(cc_major, cc_minor, ptx.c_str(),
|
||||
xla::gpu::PtxOptsFromConfig(config));
|
||||
#endif
|
||||
|
||||
return InternalError(
|
||||
"Neither TENSORFLOW_USE_ROCM nor GOOGLE_CUDA are defined."
|
||||
" Did you specify either --config=rocm or --config=cuda ?");
|
||||
@ -183,10 +141,8 @@ class GpuKernelToBlobPass
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<gpu::GPUModuleOp>> CreateGpuKernelToBlobPass(
|
||||
mlir::StringRef blob_annotation, ArrayRef<uint32_t> architectures,
|
||||
bool generate_fatbin) {
|
||||
return std::make_unique<GpuKernelToBlobPass>(blob_annotation, architectures,
|
||||
generate_fatbin);
|
||||
mlir::StringRef blob_annotation, int32_t architecture) {
|
||||
return std::make_unique<GpuKernelToBlobPass>(blob_annotation, architecture);
|
||||
}
|
||||
|
||||
} // namespace transforms
|
||||
|
@ -61,8 +61,7 @@ CreatePropagateTensorFlowABIKnowledgePass(
|
||||
|
||||
// Pass to annotate GPU Module with its PTX.
|
||||
std::unique_ptr<OperationPass<gpu::GPUModuleOp>> CreateGpuKernelToBlobPass(
|
||||
mlir::StringRef blob_annotation = "", ArrayRef<uint32_t> architectures = {},
|
||||
bool generate_fatbin = true);
|
||||
mlir::StringRef blob_annotation = "", int32_t architecture = 0);
|
||||
|
||||
// Pass to unfuse batch norm.
|
||||
std::unique_ptr<FunctionPass> CreateUnfuseBatchNormPass();
|
||||
|
@ -53,10 +53,7 @@ def GpuKernelToBlobPass : Pass<"gpu-kernel-to-blob", "gpu::GPUModuleOp"> {
|
||||
let options = [
|
||||
Option<"blob_annotation_", "blob-annotation", "std::string",
|
||||
/*default=*/"", "Blob attribute name">,
|
||||
ListOption<"architectures_", "arch", "uint32_t", "GPU architectures">,
|
||||
Option<"generate_fatbin_", "generate-fatbin", "bool", /*default=*/"true",
|
||||
"Bundle machine code for the different architectures in one "
|
||||
"fatbin.">,
|
||||
Option<"arch_", "arch", "int32_t", /*default=*/"0", "GPU architecture">,
|
||||
];
|
||||
let constructor = "transforms::CreateGpuKernelToBlobPass()";
|
||||
}
|
||||
|
@ -681,6 +681,7 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstructionImpl(
|
||||
NoAttributeCase(kAnd, AndOp);
|
||||
NoAttributeCase(kAtan2, Atan2Op);
|
||||
NoAttributeCase(kBitcastConvert, BitcastConvertOp);
|
||||
NoAttributeCase(kCbrt, CbrtOp);
|
||||
NoAttributeCase(kConvert, ConvertOp);
|
||||
NoAttributeCase(kCeil, CeilOp);
|
||||
NoAttributeCase(kClamp, ClampOp);
|
||||
@ -738,6 +739,20 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstructionImpl(
|
||||
&fusion.fused_computation()));
|
||||
return fusion.getOperation();
|
||||
}
|
||||
case HloOpcode::kBitcast:
|
||||
return func_builder
|
||||
->create<mlir::mhlo::BitcastOp>(loc, result_type, operands,
|
||||
attributes)
|
||||
.getOperation();
|
||||
case HloOpcode::kReducePrecision: {
|
||||
auto op = func_builder->create<mlir::mhlo::ReducePrecisionOp>(
|
||||
loc, result_type, operands[0], attributes);
|
||||
op.exponent_bitsAttr(func_builder->getIntegerAttr(
|
||||
func_builder->getI32Type(), instruction->exponent_bits()));
|
||||
op.mantissa_bitsAttr(func_builder->getIntegerAttr(
|
||||
func_builder->getI32Type(), instruction->mantissa_bits()));
|
||||
return op.getOperation();
|
||||
}
|
||||
case HloOpcode::kAddDependency:
|
||||
// Arbitrary op code that I suspect we will not implement for quite a
|
||||
// while and allows testing handling of unknown ops. Selected because it
|
||||
@ -762,17 +777,10 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
|
||||
ImportInstructionImpl(instruction, func_builder));
|
||||
if (op == nullptr) return op;
|
||||
|
||||
// Best-effort propagation of the layouts. These layouts serve as performance
|
||||
// hints to the backend.
|
||||
// See MlirToHloConversionOptions for more about layouts.
|
||||
//
|
||||
// Minor-to-major is a permutation of [0, rank), presenting tensor dimensions
|
||||
// in physical minor-to-major order.
|
||||
//
|
||||
// Note that non-array shapes are not carrying layouts, and users have to
|
||||
// figure out the proper layouts of them through context. This is one of the
|
||||
// reasons why the attribute-based solution is temporary.
|
||||
//
|
||||
// TODO(timshen): Investigate the necessity of having layouts in MHLO.
|
||||
if (instruction->shape().IsArray() &&
|
||||
instruction->shape().layout() !=
|
||||
LayoutUtil::MakeDescendingLayout(
|
||||
|
@ -499,12 +499,14 @@ class ConvertToHloModule {
|
||||
// single value.
|
||||
explicit ConvertToHloModule(
|
||||
mlir::ModuleOp module, bool use_tuple_args, bool return_tuple,
|
||||
tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn)
|
||||
tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn,
|
||||
MlirToHloConversionOptions options)
|
||||
: module_(module),
|
||||
module_builder_("main"),
|
||||
use_tuple_args_(use_tuple_args),
|
||||
return_tuple_(return_tuple),
|
||||
shape_representation_fn_(shape_representation_fn) {
|
||||
shape_representation_fn_(shape_representation_fn),
|
||||
options_(options) {
|
||||
if (!shape_representation_fn_)
|
||||
shape_representation_fn_ = tensorflow::IdentityShapeRepresentationFn();
|
||||
}
|
||||
@ -585,6 +587,8 @@ class ConvertToHloModule {
|
||||
|
||||
// Unique suffix to give to the name of the next lowered region.
|
||||
size_t region_id_ = 0;
|
||||
|
||||
MlirToHloConversionOptions options_;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
@ -1078,6 +1082,15 @@ LogicalResult ExportXlaOp(FusionOp op, OpLoweringContext ctx) {
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult ExportXlaOp(BitcastOp op, OpLoweringContext ctx) {
|
||||
auto& value_map = *ctx.values;
|
||||
xla::XlaOp operand;
|
||||
if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
|
||||
value_map[op] = xla::internal::XlaBuilderFriend::BuildBitcast(
|
||||
ctx.builder, operand, xla::TypeToShape(op.getType()));
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mhlo
|
||||
} // namespace mlir
|
||||
@ -1087,18 +1100,19 @@ LogicalResult ExportXlaOp(FusionOp op, OpLoweringContext ctx) {
|
||||
namespace mlir {
|
||||
namespace {
|
||||
|
||||
StatusOr<xla::Literal> CreateLiteralFromAttr(ElementsAttr attr) {
|
||||
StatusOr<xla::Literal> CreateArrayLiteralFromAttr(ElementsAttr attr,
|
||||
xla::Layout layout) {
|
||||
if (attr.isa<OpaqueElementsAttr>())
|
||||
return tensorflow::errors::Unimplemented(
|
||||
"Opaque elements attr not supported");
|
||||
|
||||
xla::Shape shape = xla::TypeToShape(attr.getType());
|
||||
|
||||
#define ELEMENTS_ATTR_TO_LITERAL(xla_type, cpp_type) \
|
||||
case xla_type: { \
|
||||
xla::Array<cpp_type> source_data(shape.dimensions()); \
|
||||
source_data.SetValues(attr.getValues<cpp_type>()); \
|
||||
return xla::LiteralUtil::CreateFromArray(source_data); \
|
||||
#define ELEMENTS_ATTR_TO_LITERAL(xla_type, cpp_type) \
|
||||
case xla_type: { \
|
||||
xla::Array<cpp_type> source_data(shape.dimensions()); \
|
||||
source_data.SetValues(attr.getValues<cpp_type>()); \
|
||||
return xla::LiteralUtil::CreateFromArrayWithLayout(source_data, layout); \
|
||||
}
|
||||
|
||||
switch (shape.element_type()) {
|
||||
@ -1128,7 +1142,7 @@ StatusOr<xla::Literal> CreateLiteralFromAttr(ElementsAttr attr) {
|
||||
}
|
||||
xla::Array<xla::half> source_data(shape.dimensions());
|
||||
source_data.SetValues(values);
|
||||
return xla::LiteralUtil::CreateFromArray(source_data);
|
||||
return xla::LiteralUtil::CreateFromArrayWithLayout(source_data, layout);
|
||||
}
|
||||
case xla::PrimitiveType::BF16: {
|
||||
xla::Array<double> source_data(shape.dimensions());
|
||||
@ -1145,7 +1159,7 @@ StatusOr<xla::Literal> CreateLiteralFromAttr(ElementsAttr attr) {
|
||||
}
|
||||
source_data.SetValues(values_double);
|
||||
return xla::LiteralUtil::ConvertF64ToBF16(
|
||||
xla::LiteralUtil::CreateFromArray(source_data));
|
||||
xla::LiteralUtil::CreateFromArrayWithLayout(source_data, layout));
|
||||
}
|
||||
default:
|
||||
return tensorflow::errors::Internal(absl::StrCat(
|
||||
@ -1154,25 +1168,33 @@ StatusOr<xla::Literal> CreateLiteralFromAttr(ElementsAttr attr) {
|
||||
#undef ELEMENTS_ATTR_TO_LITERAL
|
||||
}
|
||||
|
||||
xla::Layout ExtractLayout(mlir::Operation* op, int rank) {
|
||||
if (auto attr =
|
||||
op->getAttrOfType<mlir::DenseIntElementsAttr>("minor_to_major")) {
|
||||
llvm::SmallVector<int64, 4> minor_to_major;
|
||||
minor_to_major.reserve(attr.size());
|
||||
for (const llvm::APInt& i : attr) {
|
||||
minor_to_major.push_back(i.getZExtValue());
|
||||
}
|
||||
return xla::LayoutUtil::MakeLayout(minor_to_major);
|
||||
}
|
||||
return xla::LayoutUtil::MakeDescendingLayout(rank);
|
||||
}
|
||||
|
||||
LogicalResult ConvertToHloModule::Lower(
|
||||
mlir::Operation* inst, bool is_entry_function,
|
||||
llvm::ArrayRef<absl::optional<xla::OpSharding>> ret_shardings,
|
||||
xla::XlaBuilder* builder,
|
||||
ConvertToHloModule::ValueLoweringMap* value_lowering,
|
||||
xla::XlaComputation* result) {
|
||||
// See hlo_function_importer.cc for documentation about layouts in MHLO.
|
||||
auto propagate_layouts = [](mlir::Operation* inst, xla::XlaOp xla_op) {
|
||||
auto attr =
|
||||
inst->getAttrOfType<mlir::DenseIntElementsAttr>("minor_to_major");
|
||||
if (!attr) return;
|
||||
|
||||
auto* v = xla::internal::XlaBuilderFriend::GetInstruction(xla_op)
|
||||
->mutable_shape()
|
||||
->mutable_layout()
|
||||
->mutable_minor_to_major();
|
||||
v->Clear();
|
||||
for (const llvm::APInt& i : attr) {
|
||||
*v->Add() = i.getZExtValue();
|
||||
// See MlirToHloConversionOptions for more about layouts.
|
||||
auto propagate_layouts = [this](mlir::Operation* inst, xla::XlaOp xla_op) {
|
||||
if (options_.propagate_layouts) {
|
||||
auto* shape = xla::internal::XlaBuilderFriend::GetInstruction(xla_op)
|
||||
->mutable_shape();
|
||||
if (shape->tuple_shapes().empty())
|
||||
*shape->mutable_layout() =
|
||||
ExtractLayout(inst, shape->dimensions().size()).ToProto();
|
||||
}
|
||||
};
|
||||
|
||||
@ -1216,12 +1238,14 @@ LogicalResult ConvertToHloModule::Lower(
|
||||
}
|
||||
|
||||
if (matchPattern(inst, m_Constant(&const_attr))) {
|
||||
auto literal_or = CreateLiteralFromAttr(const_attr);
|
||||
xla::Layout layout;
|
||||
layout = ExtractLayout(inst, const_attr.getType().getRank());
|
||||
auto literal_or = CreateArrayLiteralFromAttr(const_attr, layout);
|
||||
if (!literal_or.ok())
|
||||
return inst->emitError(literal_or.status().ToString());
|
||||
auto constant = xla::ConstantLiteral(builder, literal_or.ValueOrDie());
|
||||
value_map[inst->getResult(0)] = constant;
|
||||
propagate_layouts(inst, constant);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
@ -1674,22 +1698,24 @@ LogicalResult AddDynamicParameterBindings(mlir::ModuleOp module,
|
||||
} // namespace
|
||||
|
||||
Status ConvertRegionToComputation(mlir::Region* region,
|
||||
xla::XlaComputation* func) {
|
||||
xla::XlaComputation* func,
|
||||
MlirToHloConversionOptions options) {
|
||||
mlir::ModuleOp module;
|
||||
ConvertToHloModule converter(module, true, true, {});
|
||||
ConvertToHloModule converter(module, true, true, {}, options);
|
||||
if (failed(converter.LowerRegionAsComputation(region, func)))
|
||||
return tensorflow::errors::Internal(
|
||||
"failed to convert region to computation");
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ConvertMlirHloToHlo(mlir::ModuleOp module, xla::HloProto* hlo_proto,
|
||||
bool use_tuple_args, bool return_tuple,
|
||||
const tensorflow::XlaHelpers::ShapeRepresentationFn
|
||||
shape_representation_fn) {
|
||||
Status ConvertMlirHloToHlo(
|
||||
mlir::ModuleOp module, xla::HloProto* hlo_proto, bool use_tuple_args,
|
||||
bool return_tuple,
|
||||
const tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn,
|
||||
MlirToHloConversionOptions options) {
|
||||
mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext());
|
||||
ConvertToHloModule converter(module, use_tuple_args, return_tuple,
|
||||
shape_representation_fn);
|
||||
shape_representation_fn, options);
|
||||
if (failed(converter.Run())) return diag_handler.ConsumeStatus();
|
||||
auto hlo_module = converter.ConsumeMainProto();
|
||||
hlo_proto->mutable_hlo_module()->Swap(&hlo_module);
|
||||
|
@ -25,6 +25,18 @@ limitations under the License.
|
||||
|
||||
namespace mlir {
|
||||
|
||||
struct MlirToHloConversionOptions {
|
||||
// Best-effort propagation of the layouts. These layouts serve as performance
|
||||
// hints to the backend.
|
||||
//
|
||||
// Note that non-array shapes are not carrying layouts, and users have to
|
||||
// figure out the proper layouts of them through context. This is one of the
|
||||
// reasons why the attribute-based solution is temporary.
|
||||
//
|
||||
// TODO(timshen): Investigate the necessity of having layouts in MHLO.
|
||||
bool propagate_layouts = false;
|
||||
};
|
||||
|
||||
// Converts a MLIR module in HLO dialect into a HloModuleProto. If
|
||||
// use_tuple_args is set, then the entry computations's arguments are converted
|
||||
// to a tuple and passed as a single parameter.
|
||||
@ -32,15 +44,19 @@ namespace mlir {
|
||||
// are converted to a tuple even when there is only a single return value.
|
||||
// Multiple return values are always converted to a tuple and returned as a
|
||||
// single value.
|
||||
//
|
||||
// TODO(timshen): move other options into `options`.
|
||||
Status ConvertMlirHloToHlo(mlir::ModuleOp module, ::xla::HloProto* hlo_proto,
|
||||
bool use_tuple_args, bool return_tuple,
|
||||
const tensorflow::XlaHelpers::ShapeRepresentationFn
|
||||
shape_representation_fn = nullptr);
|
||||
shape_representation_fn = nullptr,
|
||||
MlirToHloConversionOptions options = {});
|
||||
|
||||
// Converts a region to a computation. It returns a standalone module that
|
||||
// contains the converted region as the entry computation.
|
||||
Status ConvertRegionToComputation(mlir::Region* region,
|
||||
::xla::XlaComputation* func);
|
||||
::xla::XlaComputation* func,
|
||||
MlirToHloConversionOptions options = {});
|
||||
|
||||
// Creates XlaOp equivalent of a given MLIR operation using the operand info
|
||||
// from `value_lowering` map.
|
||||
|
@ -1102,3 +1102,33 @@ func @main(%arg: tensor<3xui64>) -> tuple<tensor<3xui64>, tensor<2x2xui32>> {
|
||||
%0 = "mhlo.rng_bit_generator"(%arg) {rng_algorithm = 2 : i32} : (tensor<3xui64>) -> tuple<tensor<3xui64>, tensor<2x2xui32>>
|
||||
return %0 : tuple<tensor<3xui64>, tensor<2x2xui32>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: HloModule
|
||||
func @main(%arg: tensor<3x4xf32>) -> tensor<3x4xf32> {
|
||||
// CHECK: %[[ARG0:.*]] = f32[3,4] parameter(0)
|
||||
// CHECK: ROOT %[[RESULT:.*]] = f32[3,4] cbrt(f32[3,4] %[[ARG0]])
|
||||
%0 = "mhlo.cbrt"(%arg) : (tensor<3x4xf32>) -> tensor<3x4xf32>
|
||||
return %0 : tensor<3x4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: HloModule
|
||||
func @main(%arg: tensor<3x4xf32>) -> tensor<3x4xf32> {
|
||||
// CHECK: %[[ARG0:.*]] = f32[3,4] parameter(0)
|
||||
// CHECK: ROOT %[[RESULT:.*]] = f32[3,4] reduce-precision(f32[3,4] %[[ARG0]]), exponent_bits=8, mantissa_bits=10
|
||||
%0 = "mhlo.reduce_precision"(%arg) {exponent_bits = 8 : i32, mantissa_bits = 10 : i32} : (tensor<3x4xf32>) -> tensor<3x4xf32>
|
||||
return %0 : tensor<3x4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: HloModule
|
||||
func @main(%arg: tensor<3x4xf32>) -> tensor<3x4x1xf32> {
|
||||
// CHECK: %[[ARG0:.*]] = f32[3,4] parameter(0)
|
||||
// CHECK: ROOT %[[RESULT:.*]] = f32[3,4,1] bitcast(f32[3,4] %[[ARG0]])
|
||||
%0 = "mhlo.bitcast"(%arg) : (tensor<3x4xf32>) -> tensor<3x4x1xf32>
|
||||
return %0 : tensor<3x4x1xf32>
|
||||
}
|
||||
|
@ -1014,3 +1014,26 @@ add {
|
||||
ROOT %rng-bit-generator.2 = (u64[3], u32[2,2]) rng-bit-generator(u64[3] %Arg_0.1), algorithm=rng_philox
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @cbrt
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: tensor<3x4xf32>)
|
||||
%cbrt (Arg_0.1: f32[3,4]) -> f32[3,4] {
|
||||
%Arg_0.1 = f32[3,4] parameter(0)
|
||||
// CHECK: "mhlo.cbrt"(%[[ARG0]]) : (tensor<3x4xf32>) -> tensor<3x4xf32>
|
||||
ROOT %cbrt = f32[3,4] cbrt(f32[3,4] %Arg_0.1)
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @bitcast
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: tensor<3x4xf32>) -> tensor<3x4x1xf32>
|
||||
%bitcast (Arg_0.1: f32[3,4]) -> f32[3,4,1] {
|
||||
%Arg_0.1 = f32[3,4] parameter(0)
|
||||
// CHECK: "mhlo.bitcast"(%[[ARG0]]) : (tensor<3x4xf32>) -> tensor<3x4x1xf32>
|
||||
ROOT %bitcast = f32[3,4,1] bitcast(f32[3,4] %Arg_0.1)
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @reduce_precision
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: tensor<3x4xf32>)
|
||||
%reduce_precision (Arg_0.1: f32[3,4]) -> f32[3,4] {
|
||||
%Arg_0.1 = f32[3,4] parameter(0)
|
||||
// CHECK: "mhlo.reduce_precision"(%[[ARG0]]) {exponent_bits = 8 : i32, mantissa_bits = 10 : i32} : (tensor<3x4xf32>) -> tensor<3x4xf32>
|
||||
ROOT %reduce_precision = f32[3,4] reduce-precision(f32[3,4] %Arg_0.1), exponent_bits=8, mantissa_bits=10
|
||||
}
|
||||
|
@ -26,5 +26,9 @@ func @main(%arg0: tensor<128x224x224x4xf16>, %arg1: tensor<64x7x7x4xf16>) -> ten
|
||||
rhs_dilations = dense<1> : tensor<2xi64>,
|
||||
window_strides = dense<2> : tensor<2xi64>
|
||||
} : (tensor<128x224x224x4xf16>, tensor<64x7x7x4xf16>)-> tensor<128x64x112x112xf16> loc("root.42")
|
||||
|
||||
// CHECK: s32[1,1]{0,1} constant({ {42} })
|
||||
%cst_1 = "std.constant"() {value = dense<[[42]]> : tensor<1x1xi32>, minor_to_major = dense<[0, 1]> : tensor<2xindex>} : () -> tensor<1x1xi32>
|
||||
|
||||
return %0 : tensor<128x64x112x112xf16>
|
||||
}
|
||||
|
@ -129,8 +129,11 @@ static mlir::LogicalResult MlirHloToHloTextTranslateFunctionImpl(
|
||||
if (!module) return mlir::failure();
|
||||
|
||||
HloProto hloProto;
|
||||
mlir::MlirToHloConversionOptions options;
|
||||
options.propagate_layouts = with_layouts;
|
||||
Status status = mlir::ConvertMlirHloToHlo(
|
||||
module, &hloProto, emit_use_tuple_arg, emit_return_tuple);
|
||||
module, &hloProto, emit_use_tuple_arg, emit_return_tuple,
|
||||
/*shape_representation_fn=*/nullptr, options);
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << "Module conversion failed: " << status;
|
||||
return mlir::failure();
|
||||
|
@ -1,5 +1,5 @@
|
||||
load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
|
||||
load("//tensorflow:tensorflow.bzl", "if_tpu", "tf_cc_binary", "tf_cc_test", "tf_copts", "tf_cuda_cc_test", "tf_openmp_copts")
|
||||
load("//tensorflow:tensorflow.bzl", "if_libtpu", "tf_cc_binary", "tf_cc_test", "tf_copts", "tf_cuda_cc_test", "tf_openmp_copts")
|
||||
load(
|
||||
"//tensorflow/core/platform/default:cuda_build_defs.bzl",
|
||||
"if_cuda_is_configured",
|
||||
@ -298,7 +298,7 @@ cc_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/stream_executor:platform",
|
||||
] + if_tpu(
|
||||
] + if_libtpu(
|
||||
if_false = [
|
||||
"//tensorflow/compiler/xla/service:cpu_plugin",
|
||||
"//tensorflow/compiler/xla/service/cpu:buffer_info_util",
|
||||
@ -369,7 +369,7 @@ cc_library(
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:ops",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
] + if_tpu(
|
||||
] + if_libtpu(
|
||||
if_false = [
|
||||
"//tensorflow/compiler/mlir:array_container_utils",
|
||||
"//tensorflow/compiler/mlir/tensorflow:compile_mlir_util_no_tf_dialect_passes",
|
||||
@ -877,13 +877,13 @@ cc_library(
|
||||
|
||||
cc_library(
|
||||
name = "mlir_bridge_pass_registration",
|
||||
srcs = if_tpu(
|
||||
srcs = if_libtpu(
|
||||
if_false = [
|
||||
"mlir_bridge_pass_registration.cc",
|
||||
],
|
||||
if_true = [],
|
||||
),
|
||||
deps = if_tpu(
|
||||
deps = if_libtpu(
|
||||
if_false = [
|
||||
":mlir_bridge_pass",
|
||||
"//tensorflow/compiler/mlir:mlir_graph_optimization_pass_registration",
|
||||
|
@ -56,7 +56,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/protobuf/graph_debug_info.pb.h"
|
||||
#include "tensorflow/core/util/dump_graph.h"
|
||||
|
||||
#ifndef LIBTFTPU
|
||||
#ifndef LIBTPU_ON_GCE
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h"
|
||||
#include "tensorflow/compiler/mlir/utils/array_container_utils.h"
|
||||
#endif
|
||||
@ -733,7 +733,7 @@ Status XlaCompiler::CompileFunction(
|
||||
}
|
||||
|
||||
VLOG(1) << "====================================================";
|
||||
#ifdef LIBTFTPU
|
||||
#ifdef LIBTPU_ON_GCE
|
||||
if (GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge) {
|
||||
VLOG(1) << "MLIR is not supported in this environment.";
|
||||
}
|
||||
|
@ -149,6 +149,16 @@ XlaOp XlaBuilderFriend::BuildFusion(XlaBuilder* builder,
|
||||
});
|
||||
}
|
||||
|
||||
XlaOp XlaBuilderFriend::BuildBitcast(XlaBuilder* builder, XlaOp operand,
|
||||
const Shape& shape) {
|
||||
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
HloInstructionProto instr;
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
return builder->AddInstruction(std::move(instr), HloOpcode::kBitcast,
|
||||
{operand});
|
||||
});
|
||||
}
|
||||
|
||||
HloInstructionProto* XlaBuilderFriend::GetInstruction(XlaOp op) {
|
||||
return &op.builder()
|
||||
->instructions_[op.builder()->handle_to_index_[op.handle_]];
|
||||
|
@ -57,6 +57,9 @@ struct XlaBuilderFriend {
|
||||
absl::string_view fusion_kind,
|
||||
const XlaComputation& fused_computation);
|
||||
|
||||
static XlaOp BuildBitcast(XlaBuilder* builder, XlaOp operand,
|
||||
const Shape& shape);
|
||||
|
||||
static HloInstructionProto* GetInstruction(XlaOp op);
|
||||
};
|
||||
|
||||
|
@ -2,7 +2,7 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_grpc_cc_dependency")
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"if_tpu",
|
||||
"if_libtpu",
|
||||
"tf_cc_binary",
|
||||
"tf_cc_test",
|
||||
)
|
||||
@ -57,7 +57,7 @@ cc_library(
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
tf_grpc_cc_dependency(),
|
||||
] + if_tpu(
|
||||
] + if_libtpu(
|
||||
if_false = ["//tensorflow/compiler/xla/service:cpu_plugin"],
|
||||
if_true = [],
|
||||
),
|
||||
|
@ -1708,7 +1708,6 @@ cc_library(
|
||||
srcs = ["hlo_creation_utils.cc"],
|
||||
hdrs = [
|
||||
"hlo_creation_utils.h",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
],
|
||||
deps = [
|
||||
":hlo",
|
||||
|
@ -217,6 +217,7 @@ cc_library(
|
||||
":backend_configs_cc",
|
||||
":buffer_allocations",
|
||||
":gpu_constants",
|
||||
":gpu_conv_runner",
|
||||
":gpu_executable",
|
||||
":ir_emission_utils",
|
||||
":nccl_all_reduce_thunk",
|
||||
|
@ -45,10 +45,7 @@ CholeskyThunk::CholeskyThunk(ThunkInfo thunk_info,
|
||||
info_buffer_(info_buffer),
|
||||
type_(type),
|
||||
batch_size_(batch_size),
|
||||
a_batch_stride_(
|
||||
n * n *
|
||||
ShapeUtil::ByteSizeOfPrimitiveType(
|
||||
thunk_info.hlo_instruction->operand(0)->shape().element_type())),
|
||||
a_batch_stride_(n * n * ShapeUtil::ByteSizeOfPrimitiveType(type)),
|
||||
n_(n) {}
|
||||
|
||||
Status CholeskyThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
|
@ -31,7 +31,8 @@ namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
ConvolutionThunk::ConvolutionThunk(
|
||||
ThunkInfo thunk_info, std::vector<BufferAllocation::Slice> operand_slices,
|
||||
ThunkInfo thunk_info, GpuConvConfig&& config,
|
||||
std::vector<BufferAllocation::Slice> operand_slices,
|
||||
BufferAllocation::Slice result_slice, BufferAllocation::Slice scratch_slice,
|
||||
BufferAllocation::Slice tuple_result_slice)
|
||||
: Thunk(Kind::kConvolution, thunk_info),
|
||||
@ -39,9 +40,7 @@ ConvolutionThunk::ConvolutionThunk(
|
||||
result_buffer_(result_slice),
|
||||
scratch_buffer_(scratch_slice),
|
||||
tuple_result_buffer_(tuple_result_slice),
|
||||
config_(GetGpuConvConfig(
|
||||
Cast<HloCustomCallInstruction>(thunk_info.hlo_instruction))
|
||||
.ValueOrDie()) {}
|
||||
config_(std::move(config)) {}
|
||||
|
||||
Status ConvolutionThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
const auto& buffer_allocations = *params.buffer_allocations;
|
||||
|
@ -43,7 +43,7 @@ class ConvolutionThunk : public Thunk {
|
||||
// write a tuple (result, scratch_memory) into `tuple_result_buffer`.
|
||||
//
|
||||
// operand_slices should be in the same order as cudnn_call->operands().
|
||||
ConvolutionThunk(ThunkInfo thunk_info,
|
||||
ConvolutionThunk(ThunkInfo thunk_info, GpuConvConfig&& config,
|
||||
std::vector<BufferAllocation::Slice> operand_slices,
|
||||
BufferAllocation::Slice result_slice,
|
||||
BufferAllocation::Slice scratch_slice,
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/fft_thunk.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/gemm_thunk.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/infeed_thunk.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/outfeed_thunk.h"
|
||||
@ -238,9 +239,13 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) {
|
||||
auto conv_result_slice = GetAllocationSlice(*custom_call, {0});
|
||||
auto scratch_slice = GetAllocationSlice(*custom_call, {1});
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
GpuConvConfig config,
|
||||
GetGpuConvConfig(Cast<HloCustomCallInstruction>(custom_call)));
|
||||
AddThunkToThunkSequence(absl::make_unique<ConvolutionThunk>(
|
||||
context_->GetThunkInfo(custom_call), std::move(operand_slices),
|
||||
conv_result_slice, scratch_slice, tuple_result_slice));
|
||||
context_->GetThunkInfo(custom_call), std::move(config),
|
||||
std::move(operand_slices), conv_result_slice, scratch_slice,
|
||||
tuple_result_slice));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -1524,9 +1524,11 @@ StatusOr<int64> CompressInstruction(MemoryUsageTracker* memory_tracker,
|
||||
|
||||
HloInstruction* compressed = computation->AddInstruction(
|
||||
HloInstruction::CreateUnary(compact_shape, HloOpcode::kCopy, best));
|
||||
compressed->SetAndSanitizeName(best->name() + ".remat_compressed");
|
||||
|
||||
HloInstruction* uncompressed = computation->AddInstruction(
|
||||
HloInstruction::CreateUnary(best->shape(), HloOpcode::kCopy, compressed));
|
||||
uncompressed->SetAndSanitizeName(best->name() + ".remat_uncompressed");
|
||||
|
||||
Item* compressed_item = instruction_list->CreateItem(compressed);
|
||||
compressed_item->placed = true;
|
||||
|
@ -68,9 +68,9 @@ load(
|
||||
"if_chromiumos",
|
||||
"if_cuda_or_rocm",
|
||||
"if_ios",
|
||||
"if_libtpu",
|
||||
"if_mobile",
|
||||
"if_not_windows",
|
||||
"if_tpu",
|
||||
"tf_android_core_proto_headers",
|
||||
"tf_cc_test",
|
||||
"tf_cc_test_mkl",
|
||||
@ -894,8 +894,7 @@ cc_library(
|
||||
"//tensorflow/c/kernels:summary_op_lib",
|
||||
] + if_chromiumos(
|
||||
[],
|
||||
# Non-tpu platforms don't need tpu dependency. It would be best to guard
|
||||
# them by if_tpu. But there is no such flag yet.
|
||||
# Non-tpu platforms don't need tpu dependency.
|
||||
[
|
||||
":tpu_configuration_ops_op_lib",
|
||||
":tpu_cross_replica_ops_op_lib",
|
||||
@ -916,7 +915,7 @@ cc_library(
|
||||
]) + if_tensorrt([
|
||||
"//tensorflow/compiler/tf2tensorrt:trt_engine_resource_ops_op_lib",
|
||||
"//tensorflow/compiler/tf2tensorrt:trt_op_libs",
|
||||
]) + if_tpu(
|
||||
]) + if_libtpu(
|
||||
if_false = ["//tensorflow/compiler/mlir/tensorflow:mlir_passthrough_op"],
|
||||
if_true = [],
|
||||
),
|
||||
|
@ -1,6 +1,6 @@
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"if_tpu",
|
||||
"if_libtpu",
|
||||
"tf_cc_test",
|
||||
"tf_cc_test_mkl",
|
||||
"tf_cc_tests",
|
||||
@ -93,7 +93,7 @@ cc_library(
|
||||
deps = [
|
||||
":core_cpu",
|
||||
"//tensorflow/core/common_runtime/gpu:gpu_runtime",
|
||||
] + if_tpu(["//tensorflow/core/tpu:tpu_runtime"]),
|
||||
] + if_libtpu(["//tensorflow/core/tpu:tpu_runtime"]),
|
||||
)
|
||||
|
||||
filegroup(
|
||||
|
@ -151,7 +151,7 @@ void IntraProcessRecvAsyncImpl(const DeviceMgr* device_mgr,
|
||||
|
||||
RefCountedIntraProcessRendezvous::RefCountedIntraProcessRendezvous(
|
||||
const DeviceMgr* device_mgr)
|
||||
: device_mgr_(device_mgr) {}
|
||||
: device_mgr_(device_mgr), local_(this) {}
|
||||
|
||||
RefCountedIntraProcessRendezvous::~RefCountedIntraProcessRendezvous() {}
|
||||
|
||||
@ -176,7 +176,7 @@ void RefCountedIntraProcessRendezvous::StartAbort(const Status& s) {
|
||||
|
||||
PrivateIntraProcessRendezvous::PrivateIntraProcessRendezvous(
|
||||
const DeviceMgr* device_mgr)
|
||||
: device_mgr_(device_mgr) {}
|
||||
: device_mgr_(device_mgr), local_(nullptr) {}
|
||||
|
||||
PrivateIntraProcessRendezvous::~PrivateIntraProcessRendezvous() {}
|
||||
|
||||
|
@ -1121,8 +1121,17 @@ Status AvgPoolShape(shape_inference::InferenceContext* c) {
|
||||
}
|
||||
|
||||
Status FusedBatchNormShape(shape_inference::InferenceContext* c) {
|
||||
string data_format_str;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str));
|
||||
TensorFormat data_format;
|
||||
if (!FormatFromString(data_format_str, &data_format)) {
|
||||
return errors::InvalidArgument("Invalid data format string: ",
|
||||
data_format_str);
|
||||
}
|
||||
const int rank =
|
||||
(data_format_str == "NDHWC" or data_format_str == "NCDHW") ? 5 : 4;
|
||||
ShapeHandle x;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &x));
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &x));
|
||||
|
||||
bool is_training;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("is_training", &is_training));
|
||||
@ -1131,14 +1140,8 @@ Status FusedBatchNormShape(shape_inference::InferenceContext* c) {
|
||||
exponential_avg_factor = 1.0f; // default value
|
||||
}
|
||||
int number_inputs = (is_training && exponential_avg_factor == 1.0f) ? 3 : 5;
|
||||
string data_format_str;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str));
|
||||
TensorFormat data_format;
|
||||
if (!FormatFromString(data_format_str, &data_format)) {
|
||||
return errors::InvalidArgument("Invalid data format string: ",
|
||||
data_format_str);
|
||||
}
|
||||
int channel_dim_index = GetTensorFeatureDimIndex(4, data_format);
|
||||
|
||||
int channel_dim_index = GetTensorFeatureDimIndex(rank, data_format);
|
||||
DimensionHandle channel_dim = c->Dim(x, channel_dim_index);
|
||||
|
||||
// covers scale, offset, and if is_training is false, mean, variance
|
||||
@ -1191,13 +1194,6 @@ Status FusedBatchNormExShape(shape_inference::InferenceContext* c) {
|
||||
}
|
||||
|
||||
Status FusedBatchNormGradShape(shape_inference::InferenceContext* c) {
|
||||
ShapeHandle y_backprop;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &y_backprop));
|
||||
ShapeHandle x;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 4, &x));
|
||||
|
||||
bool is_training;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("is_training", &is_training));
|
||||
string data_format_str;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str));
|
||||
TensorFormat data_format;
|
||||
@ -1205,7 +1201,17 @@ Status FusedBatchNormGradShape(shape_inference::InferenceContext* c) {
|
||||
return errors::InvalidArgument("Invalid data format string: ",
|
||||
data_format_str);
|
||||
}
|
||||
int channel_dim_index = GetTensorFeatureDimIndex(4, data_format);
|
||||
const int rank =
|
||||
(data_format_str == "NDHWC" or data_format_str == "NCDHW") ? 5 : 4;
|
||||
ShapeHandle y_backprop;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &y_backprop));
|
||||
ShapeHandle x;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), rank, &x));
|
||||
|
||||
bool is_training;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("is_training", &is_training));
|
||||
|
||||
int channel_dim_index = GetTensorFeatureDimIndex(rank, data_format);
|
||||
DimensionHandle channel_dim = c->Dim(y_backprop, channel_dim_index);
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->Merge(channel_dim, c->Dim(x, channel_dim_index), &channel_dim));
|
||||
|
@ -187,6 +187,20 @@ void LocalRendezvous::RecvAsync(const Rendezvous::ParsedKey& key,
|
||||
CancellationToken token = CancellationManager::kInvalidToken;
|
||||
bool already_cancelled = false;
|
||||
if (cm != nullptr) {
|
||||
// Increment the refcount when cancellation manager is present, to make
|
||||
// sure the rendezvous outlives the recv and its cancel callbacks.
|
||||
// This refcount is dropped in exactly one of the following cases:
|
||||
// (1) Recv registers cancellation callback to cm, and then cm is
|
||||
// cancelled, unref in the cancellation callback;
|
||||
// (2) Recv registers cancellation callback to cm, but cm is already
|
||||
// cancelled, unref in the already_cancelled check;
|
||||
// (3) Recv is successful, and item done callback finishes deregistering
|
||||
// the cancellation callback, unref in the item done callback;
|
||||
// (4) Recv is successful, but the item done callback fails to deregister
|
||||
// the cancellation callback because cm already StartCancel, in this
|
||||
// case the cancellation callback will be invoked by the cm anyway,
|
||||
// unref in the cancellation callback.
|
||||
if (rc_owner_) rc_owner_->Ref();
|
||||
token = cm->get_cancellation_token();
|
||||
already_cancelled = !cm->RegisterCallback(token, [this, token, key_hash] {
|
||||
Item* item = nullptr;
|
||||
@ -230,10 +244,14 @@ void LocalRendezvous::RecvAsync(const Rendezvous::ParsedKey& key,
|
||||
Rendezvous::Args(), item->args, Tensor(), /*is_dead=*/false);
|
||||
delete item;
|
||||
}
|
||||
// Unref case (1) and (4)
|
||||
if (rc_owner_) rc_owner_->Unref();
|
||||
});
|
||||
}
|
||||
if (already_cancelled) {
|
||||
mu_.unlock();
|
||||
// Unref case (2)
|
||||
if (rc_owner_) rc_owner_->Unref();
|
||||
done(StatusGroup::MakeDerived(
|
||||
errors::Cancelled("RecvAsync is cancelled.")),
|
||||
Rendezvous::Args(), recv_args, Tensor(), /*is_dead=*/false);
|
||||
@ -250,10 +268,17 @@ void LocalRendezvous::RecvAsync(const Rendezvous::ParsedKey& key,
|
||||
// cancellation manager may no longer be live after `done` is called.
|
||||
queue->push_back(new Item(
|
||||
recv_args,
|
||||
[cm, token, done = std::move(done)](
|
||||
[this, cm, token, done = std::move(done)](
|
||||
const Status& s, const Rendezvous::Args& send_args,
|
||||
const Rendezvous::Args& recv_args, const Tensor& v, bool dead) {
|
||||
cm->TryDeregisterCallback(token);
|
||||
// TryDeregisterCallback returns true when the cancellation callback
|
||||
// is successfully deregistered. If it fails because the CM already
|
||||
// StartAbort, Unref will happen inside the cancellation callback
|
||||
// when called by the CM.
|
||||
if (cm->TryDeregisterCallback(token)) {
|
||||
// Unref case (3)
|
||||
if (this->rc_owner_) this->rc_owner_->Unref();
|
||||
}
|
||||
done(s, send_args, recv_args, v, dead);
|
||||
},
|
||||
token));
|
||||
|
@ -35,7 +35,11 @@ namespace tensorflow {
|
||||
// is not expected to be needed.
|
||||
class LocalRendezvous {
|
||||
public:
|
||||
LocalRendezvous() = default;
|
||||
// If the class wrapping LocalRendezvous is refcounted (i.e., extending
|
||||
// Rendezvous), pass in its pointer in constructor so the LocalRendezvous
|
||||
// can make sure it outlives the async recv requests.
|
||||
// Pass in nullptr if the wrapping class is not refcounted.
|
||||
explicit LocalRendezvous(Rendezvous* owner) : rc_owner_(owner) {}
|
||||
~LocalRendezvous();
|
||||
|
||||
Status Send(const Rendezvous::ParsedKey& key,
|
||||
@ -62,6 +66,9 @@ class LocalRendezvous {
|
||||
|
||||
typedef gtl::FlatMap<uint64, ItemQueue> Table;
|
||||
|
||||
// Pointer to the owner class of this LocalRendezvous if it is refcounted.
|
||||
const Rendezvous* rc_owner_;
|
||||
|
||||
// TODO(zhifengc): shard table_.
|
||||
mutex mu_;
|
||||
Table table_ TF_GUARDED_BY(mu_);
|
||||
|
@ -1152,22 +1152,17 @@ TEST(RegisteredKernels, GetRegisteredKernelsForOp) {
|
||||
EXPECT_EQ(kernel_list.kernel(0).device_type(), "CPU");
|
||||
}
|
||||
|
||||
#define EXTRACT_KERNEL_NAME_AND_BUILDER_IMPL(kernel_name, kernel_builder, ...) \
|
||||
constexpr char const* kKernelName = kernel_name; \
|
||||
auto builder = []() { \
|
||||
return std::unique_ptr<KernelDef const>(kernel_builder.Build()); \
|
||||
};
|
||||
#define EXTRACT_KERNEL_NAME_AND_BUILDER(kernel_builder) \
|
||||
TF_EXTRACT_KERNEL_NAME(EXTRACT_KERNEL_NAME_AND_BUILDER_IMPL, kernel_builder)
|
||||
// EXTRACT_KERNEL_NAME_TO_STRING wraps TF_EXTRACT_KERNEL_NAME for testing
|
||||
// (it involves quite a bit of macro-magic).
|
||||
#define EXTRACT_KERNEL_NAME_TO_STRING_IMPL(name, kernel_builder, ...) name
|
||||
#define EXTRACT_KERNEL_NAME_TO_STRING(kernel_builder) \
|
||||
TF_EXTRACT_KERNEL_NAME(EXTRACT_KERNEL_NAME_TO_STRING_IMPL, kernel_builder)
|
||||
|
||||
TEST(RegisterKernelMacro, ExtractName) {
|
||||
constexpr char const* kName = "Foo";
|
||||
constexpr char const* kLabel = "Label";
|
||||
EXTRACT_KERNEL_NAME_AND_BUILDER(Name(kName).Label(kLabel));
|
||||
EXPECT_THAT(kKernelName, ::testing::StrEq(kName));
|
||||
std::unique_ptr<KernelDef const> kernel_def = builder();
|
||||
EXPECT_THAT(kernel_def->op(), ::testing::StrEq(kName));
|
||||
EXPECT_THAT(kernel_def->label(), ::testing::StrEq(kLabel));
|
||||
static constexpr char const* kName = "Foo";
|
||||
static constexpr char const* kExtractedName =
|
||||
EXTRACT_KERNEL_NAME_TO_STRING(Name(kName).Label("Label"));
|
||||
EXPECT_THAT(kExtractedName, ::testing::StrEq(kName));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -151,7 +151,7 @@ Status RendezvousInterface::Recv(const ParsedKey& key, const Args& args,
|
||||
namespace {
|
||||
class LocalRendezvousWrapper : public Rendezvous {
|
||||
public:
|
||||
LocalRendezvousWrapper() = default;
|
||||
LocalRendezvousWrapper() : impl_(this) {}
|
||||
|
||||
Status Send(const ParsedKey& key, const Args& send_args, const Tensor& val,
|
||||
const bool is_dead) override {
|
||||
|
@ -670,7 +670,25 @@ Status LayoutSensitiveOpTransposer::UpdateNode(TransposeContext* context,
|
||||
Status DefaultLayoutSensitiveOpTransposer::TransposeNode(
|
||||
TransposeContext* context, utils::MutableNodeView* node) {
|
||||
DCHECK(IsDefaultLayoutSensitiveOp(*node->node()));
|
||||
if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4)) {
|
||||
const auto* output_shape_attr = node->GetAttr(kAttrOutputShape);
|
||||
const auto& shape = output_shape_attr->list().shape(0);
|
||||
const int rank = shape.dim_size();
|
||||
std::string src_format = context->src_format;
|
||||
std::string dst_format = context->dst_format;
|
||||
// Update the format from 4D to 5D layout if necessary.
|
||||
bool allow_5d = rank == 5 && (src_format == "NHWC" || src_format == "NCHW");
|
||||
if (allow_5d) {
|
||||
std::string src_format_3d = src_format == "NHWC" ? "NDHWC" : "NCDHW";
|
||||
std::string dst_format_3d = dst_format == "NHWC" ? "NDHWC" : "NCDHW";
|
||||
context->AssignDeviceAndDataFormats(context->target_device, src_format_3d,
|
||||
dst_format_3d);
|
||||
}
|
||||
if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, rank)) {
|
||||
// Change back to the original layout due to early exit.
|
||||
if (allow_5d) {
|
||||
context->AssignDeviceAndDataFormats(context->target_device, src_format,
|
||||
dst_format);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
|
||||
@ -679,6 +697,11 @@ Status DefaultLayoutSensitiveOpTransposer::TransposeNode(
|
||||
TF_RETURN_IF_ERROR(UpdateNode(context, node));
|
||||
TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
|
||||
TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
|
||||
// Change back the format from 5D to 4D layout.
|
||||
if (allow_5d) {
|
||||
context->AssignDeviceAndDataFormats(context->target_device, src_format,
|
||||
dst_format);
|
||||
}
|
||||
return context->graph_view->GetMutationBuilder()->Apply();
|
||||
}
|
||||
|
||||
@ -881,8 +904,26 @@ bool FusedBatchNormGradTransposer::IsTraining(
|
||||
Status FusedBatchNormGradTransposer::TransposeNode(
|
||||
TransposeContext* context, utils::MutableNodeView* node) {
|
||||
DCHECK(IsFusedBatchNormGrad(*node->node()));
|
||||
if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4) ||
|
||||
const auto* output_shape_attr = node->GetAttr(kAttrOutputShape);
|
||||
const auto& shape = output_shape_attr->list().shape(0);
|
||||
const int rank = shape.dim_size();
|
||||
std::string src_format = context->src_format;
|
||||
std::string dst_format = context->dst_format;
|
||||
// Update the format from 4D to 5D layout if necessary.
|
||||
bool allow_5d = rank == 5 && (src_format == "NHWC" || src_format == "NCHW");
|
||||
if (allow_5d) {
|
||||
std::string src_format_3d = src_format == "NHWC" ? "NDHWC" : "NCDHW";
|
||||
std::string dst_format_3d = dst_format == "NHWC" ? "NDHWC" : "NCDHW";
|
||||
context->AssignDeviceAndDataFormats(context->target_device, src_format_3d,
|
||||
dst_format_3d);
|
||||
}
|
||||
if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, rank) ||
|
||||
!IsTraining(*node)) {
|
||||
// Change back to the original layout due to early exit.
|
||||
if (allow_5d) {
|
||||
context->AssignDeviceAndDataFormats(context->target_device, src_format,
|
||||
dst_format);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
|
||||
@ -892,6 +933,11 @@ Status FusedBatchNormGradTransposer::TransposeNode(
|
||||
TF_RETURN_IF_ERROR(
|
||||
UpdateFaninEdgesWithOp(context, {0, 1}, node, kOpTranspose));
|
||||
TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
|
||||
// Change back the format from 5D to 4D layout.
|
||||
if (allow_5d) {
|
||||
context->AssignDeviceAndDataFormats(context->target_device, src_format,
|
||||
dst_format);
|
||||
}
|
||||
return context->graph_view->GetMutationBuilder()->Apply();
|
||||
}
|
||||
|
||||
|
@ -1438,29 +1438,41 @@ Status AddBatchNormNodes(RemapperContext* ctx, const FusedBatchNorm& matched) {
|
||||
utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder();
|
||||
Status status;
|
||||
|
||||
if (fused_node.attr().at(kDataFormat).s() == "NCHW") {
|
||||
string x_format = fused_node.attr().at(kDataFormat).s();
|
||||
if (x_format == "NCHW" or x_format == "NCDHW") {
|
||||
// Need to reshape the last 4 inputs
|
||||
NodeDef new_shape;
|
||||
const string new_shape_name =
|
||||
AddPrefixToNodeName("NCHWShape", fused_node.name());
|
||||
AddPrefixToNodeName(x_format + "Shape", fused_node.name());
|
||||
new_shape.set_name(new_shape_name);
|
||||
new_shape.set_op("Const");
|
||||
new_shape.set_device(fused_node.device());
|
||||
*new_shape.add_input() = AsControlDependency(scale);
|
||||
(*new_shape.mutable_attr())["dtype"].set_type(DT_INT32);
|
||||
Tensor t(DT_INT32, {4});
|
||||
t.flat<int32>()(0) = 1;
|
||||
t.flat<int32>()(1) = -1;
|
||||
t.flat<int32>()(2) = 1;
|
||||
t.flat<int32>()(3) = 1;
|
||||
t.AsProtoTensorContent(
|
||||
(*new_shape.mutable_attr())["value"].mutable_tensor());
|
||||
if (x_format == "NCHW") {
|
||||
Tensor t(DT_INT32, {4});
|
||||
t.flat<int32>()(0) = 1;
|
||||
t.flat<int32>()(1) = -1;
|
||||
t.flat<int32>()(2) = 1;
|
||||
t.flat<int32>()(3) = 1;
|
||||
t.AsProtoTensorContent(
|
||||
(*new_shape.mutable_attr())["value"].mutable_tensor());
|
||||
} else {
|
||||
Tensor t(DT_INT32, {5});
|
||||
t.flat<int32>()(0) = 1;
|
||||
t.flat<int32>()(1) = -1;
|
||||
t.flat<int32>()(2) = 1;
|
||||
t.flat<int32>()(3) = 1;
|
||||
t.flat<int32>()(4) = 1;
|
||||
t.AsProtoTensorContent(
|
||||
(*new_shape.mutable_attr())["value"].mutable_tensor());
|
||||
}
|
||||
mutation->AddNode(std::move(new_shape), &status);
|
||||
TF_RETURN_IF_ERROR(status);
|
||||
|
||||
NodeDef reshaped_scale;
|
||||
reshaped_scale.set_name(
|
||||
AddPrefixToNodeName("NCHWShapedScale", fused_node.name()));
|
||||
AddPrefixToNodeName(x_format + "ShapedScale", fused_node.name()));
|
||||
reshaped_scale.set_op("Reshape");
|
||||
reshaped_scale.set_device(fused_node.device());
|
||||
*reshaped_scale.add_input() = scale;
|
||||
@ -1473,7 +1485,7 @@ Status AddBatchNormNodes(RemapperContext* ctx, const FusedBatchNorm& matched) {
|
||||
|
||||
NodeDef reshaped_offset;
|
||||
reshaped_offset.set_name(
|
||||
AddPrefixToNodeName("NCHWShapedOffset", fused_node.name()));
|
||||
AddPrefixToNodeName(x_format + "ShapedOffset", fused_node.name()));
|
||||
reshaped_offset.set_op("Reshape");
|
||||
reshaped_offset.set_device(fused_node.device());
|
||||
*reshaped_offset.add_input() = offset;
|
||||
@ -1486,7 +1498,7 @@ Status AddBatchNormNodes(RemapperContext* ctx, const FusedBatchNorm& matched) {
|
||||
|
||||
NodeDef reshaped_mean;
|
||||
reshaped_mean.set_name(
|
||||
AddPrefixToNodeName("NCHWShapedMean", fused_node.name()));
|
||||
AddPrefixToNodeName(x_format + "ShapedMean", fused_node.name()));
|
||||
reshaped_mean.set_op("Reshape");
|
||||
reshaped_mean.set_device(fused_node.device());
|
||||
*reshaped_mean.add_input() = mean;
|
||||
@ -1499,7 +1511,7 @@ Status AddBatchNormNodes(RemapperContext* ctx, const FusedBatchNorm& matched) {
|
||||
|
||||
NodeDef reshaped_variance;
|
||||
reshaped_variance.set_name(
|
||||
AddPrefixToNodeName("NCHWShapedVariance", fused_node.name()));
|
||||
AddPrefixToNodeName(x_format + "ShapedVariance", fused_node.name()));
|
||||
reshaped_variance.set_op("Reshape");
|
||||
reshaped_variance.set_device(fused_node.device());
|
||||
*reshaped_variance.add_input() = variance;
|
||||
|
@ -104,6 +104,37 @@ TF_CALL_GPU_ALL_TYPES(REGISTER);
|
||||
|
||||
#undef REGISTER
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
// Required by MSVC non-release build to ensure the compiler sees all the
|
||||
// template expansions that are needed.
|
||||
#define FORCE_CONCAT(TYPE) \
|
||||
template <> \
|
||||
void ConcatGPU<TYPE>( \
|
||||
OpKernelContext * c, \
|
||||
const std::vector< \
|
||||
std::unique_ptr<typename TTypes<TYPE, 2>::ConstMatrix>>& \
|
||||
inputs_flat, \
|
||||
Tensor* output, typename TTypes<TYPE, 2>::Tensor* output_flat) { \
|
||||
LOG(FATAL) << "Should not be called"; \
|
||||
}
|
||||
|
||||
FORCE_CONCAT(tensorflow::Variant)
|
||||
FORCE_CONCAT(tensorflow::ResourceHandle)
|
||||
FORCE_CONCAT(unsigned short)
|
||||
FORCE_CONCAT(signed char)
|
||||
FORCE_CONCAT(tensorflow::tstring)
|
||||
FORCE_CONCAT(Eigen::QUInt8)
|
||||
FORCE_CONCAT(Eigen::QInt8)
|
||||
FORCE_CONCAT(Eigen::QUInt16)
|
||||
FORCE_CONCAT(Eigen::QInt16)
|
||||
FORCE_CONCAT(Eigen::QInt32)
|
||||
FORCE_CONCAT(unsigned int)
|
||||
FORCE_CONCAT(unsigned __int64)
|
||||
|
||||
#undef FORCE_CONCAT
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
@ -21,6 +21,13 @@ namespace tensorflow {
|
||||
namespace functor {
|
||||
DEFINE_UNARY1(conj, complex64);
|
||||
DEFINE_UNARY1(conj, complex128);
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
// Non-release build with MSVC needs these symbols.
|
||||
DEFINE_UNARY1(conj, float);
|
||||
DEFINE_UNARY1(conj, double);
|
||||
#endif
|
||||
|
||||
} // namespace functor
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -155,14 +155,17 @@ class WindowDatasetOp::Dataset : public DatasetBase {
|
||||
std::vector<std::vector<Tensor>> window_elements;
|
||||
Status status = Status::OK();
|
||||
{
|
||||
const size_t target_size = TargetBufferSize(window_size, window_stride);
|
||||
|
||||
mutex_lock l(mu_);
|
||||
if (!input_impl_ && buffer_.empty()) {
|
||||
if (!input_impl_ &&
|
||||
(buffer_.empty() ||
|
||||
(dataset()->drop_remainder_ && buffer_.size() < target_size))) {
|
||||
*end_of_sequence = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Add elements to the buffer.
|
||||
size_t target_size = TargetBufferSize(window_size, window_stride);
|
||||
if (input_impl_) {
|
||||
*end_of_sequence = false;
|
||||
for (size_t i = buffer_.size(); i < target_size && !*end_of_sequence;
|
||||
|
@ -71,6 +71,27 @@ TF_CALL_int8(DEFINE_GPU_KERNELS);
|
||||
TF_CALL_uint32(DEFINE_GPU_KERNELS);
|
||||
#undef DEFINE_GPU_KERNELS
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
|
||||
template <>
|
||||
struct functor::DenseUpdate<GPUDevice, tensorflow::Variant, ASSIGN> {
|
||||
void operator()(const GPUDevice& d,
|
||||
typename TTypes<tensorflow::Variant>::Flat params,
|
||||
typename TTypes<tensorflow::Variant>::ConstFlat update) {
|
||||
LOG(FATAL) << "Not handling type tensorflow::Variant";
|
||||
}
|
||||
};
|
||||
|
||||
// The function is required to force above template specialization. Without it
|
||||
// msvc compiler doesn't include the functor in the object file
|
||||
void _force_instantiation(
|
||||
const GPUDevice& d, typename TTypes<tensorflow::Variant>::Flat params,
|
||||
typename TTypes<tensorflow::Variant>::ConstFlat update) {
|
||||
functor::DenseUpdate<GPUDevice, tensorflow::Variant, ASSIGN> x;
|
||||
x(d, params, update);
|
||||
}
|
||||
#endif // _MSC_VER
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
@ -22,6 +22,10 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/gpu_kernel_helper.h"
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#endif
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
@ -251,6 +255,62 @@ template struct functor::DepthToSpaceOpFunctor<GPUDevice, Eigen::half,
|
||||
// NCHW_VECT_C with 4 x qint8 can be treated as NCHW int32.
|
||||
template struct functor::DepthToSpaceOpFunctor<GPUDevice, int32, FORMAT_NCHW>;
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
#define FORCE_DEPTH(TYPE, NAME, NUM, DEVICE) \
|
||||
template <> \
|
||||
struct functor::DepthToSpaceOpFunctor<DEVICE, TYPE, NUM> { \
|
||||
void operator()(const DEVICE& d, \
|
||||
typename TTypes<TYPE, 4>::ConstTensor input, \
|
||||
int block_size, typename TTypes<TYPE, 4>::Tensor output) { \
|
||||
LOG(FATAL) << "Should not be called."; \
|
||||
} \
|
||||
void operator()(const DEVICE& d, \
|
||||
typename TTypes<TYPE, 5>::ConstTensor input, \
|
||||
int block_size, typename TTypes<TYPE, 5>::Tensor output) { \
|
||||
LOG(FATAL) << "Should not be called."; \
|
||||
} \
|
||||
}; \
|
||||
void _force_DepthToSpaceOpFunctor##NAME( \
|
||||
const DEVICE& d, typename TTypes<TYPE, 4>::ConstTensor input, \
|
||||
int block_size, typename TTypes<TYPE, 4>::Tensor output) { \
|
||||
functor::DepthToSpaceOpFunctor<DEVICE, TYPE, NUM> op; \
|
||||
op(d, input, block_size, output); \
|
||||
} \
|
||||
void _force_DepthToSpaceOpFunctor##NAME##_2( \
|
||||
const DEVICE& d, typename TTypes<TYPE, 5>::ConstTensor input, \
|
||||
int block_size, typename TTypes<TYPE, 5>::Tensor output) { \
|
||||
functor::DepthToSpaceOpFunctor<DEVICE, TYPE, NUM> op; \
|
||||
op(d, input, block_size, output); \
|
||||
}
|
||||
|
||||
FORCE_DEPTH(__int64, int64, FORMAT_NCHW, Eigen::ThreadPoolDevice)
|
||||
FORCE_DEPTH(unsigned __int64, uint64, FORMAT_NCHW, Eigen::ThreadPoolDevice)
|
||||
FORCE_DEPTH(unsigned int, uint, FORMAT_NCHW, Eigen::ThreadPoolDevice)
|
||||
FORCE_DEPTH(int, int, FORMAT_NCHW, Eigen::ThreadPoolDevice)
|
||||
FORCE_DEPTH(unsigned short, ushort, FORMAT_NCHW, Eigen::ThreadPoolDevice)
|
||||
FORCE_DEPTH(short, short, FORMAT_NCHW, Eigen::ThreadPoolDevice)
|
||||
FORCE_DEPTH(unsigned char, uchar, FORMAT_NCHW, Eigen::ThreadPoolDevice)
|
||||
FORCE_DEPTH(signed char, char, FORMAT_NCHW, Eigen::ThreadPoolDevice)
|
||||
FORCE_DEPTH(bfloat16, bfloat16, FORMAT_NCHW, Eigen::ThreadPoolDevice)
|
||||
FORCE_DEPTH(double, double, FORMAT_NCHW, Eigen::ThreadPoolDevice)
|
||||
FORCE_DEPTH(complex64, complex64, FORMAT_NCHW, Eigen::ThreadPoolDevice)
|
||||
FORCE_DEPTH(complex128, complex128, FORMAT_NCHW, Eigen::ThreadPoolDevice)
|
||||
FORCE_DEPTH(bool, bool, FORMAT_NCHW, Eigen::ThreadPoolDevice)
|
||||
FORCE_DEPTH(tensorflow::tstring, tstring, FORMAT_NCHW, Eigen::ThreadPoolDevice)
|
||||
FORCE_DEPTH(tensorflow::ResourceHandle, ResourceHandle, FORMAT_NCHW,
|
||||
Eigen::ThreadPoolDevice)
|
||||
FORCE_DEPTH(tensorflow::Variant, variant, FORMAT_NCHW, Eigen::ThreadPoolDevice)
|
||||
FORCE_DEPTH(Eigen::QInt8, qint8, FORMAT_NCHW, Eigen::ThreadPoolDevice)
|
||||
FORCE_DEPTH(Eigen::QInt8, qint8_2, FORMAT_NHWC, Eigen::ThreadPoolDevice)
|
||||
FORCE_DEPTH(Eigen::half, half, FORMAT_NCHW, Eigen::ThreadPoolDevice)
|
||||
FORCE_DEPTH(float, float, FORMAT_NCHW, Eigen::ThreadPoolDevice)
|
||||
FORCE_DEPTH(Eigen::QInt8, qint8, FORMAT_NCHW, GPUDevice)
|
||||
FORCE_DEPTH(Eigen::QInt8, qint8_2, FORMAT_NHWC, GPUDevice)
|
||||
|
||||
#undef FORCE_DEPTH
|
||||
|
||||
#endif
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
@ -131,8 +131,8 @@ SpatialMaxPooling(const Input& input, DenseIndex patchRows,
|
||||
.extract_image_patches(
|
||||
patchRows, patchCols, strideRows, strideCols, in_strideRows,
|
||||
in_strideCols, padding_type,
|
||||
-Eigen::NumTraits<typename internal::remove_const<
|
||||
typename internal::traits<Input>::Scalar>::type>::highest())
|
||||
Eigen::NumTraits<typename internal::remove_const<
|
||||
typename internal::traits<Input>::Scalar>::type>::lowest())
|
||||
.maximum(reduction_dims)
|
||||
.reshape(post_reduce_dims);
|
||||
}
|
||||
|
@ -1241,15 +1241,15 @@ class FusedBatchNormOpBase : public OpKernel {
|
||||
// If use_reserved_space is false, we don't have 5th output.
|
||||
virtual void ComputeWithReservedSpace(OpKernelContext* context,
|
||||
bool use_reserved_space) {
|
||||
const Tensor& x = context->input(0);
|
||||
Tensor x = context->input(0);
|
||||
const Tensor& scale = context->input(1);
|
||||
const Tensor& offset = context->input(2);
|
||||
const Tensor& estimated_mean = context->input(3);
|
||||
const Tensor& estimated_variance = context->input(4);
|
||||
const Tensor* side_input = has_side_input_ ? &context->input(5) : nullptr;
|
||||
|
||||
OP_REQUIRES(context, x.dims() == 4,
|
||||
errors::InvalidArgument("input must be 4-dimensional",
|
||||
OP_REQUIRES(context, x.dims() == 4 or x.dims() == 5,
|
||||
errors::InvalidArgument("input must be 4 or 5-dimensional",
|
||||
x.shape().DebugString()));
|
||||
OP_REQUIRES(context, scale.dims() == 1,
|
||||
errors::InvalidArgument("scale must be 1-dimensional",
|
||||
@ -1264,6 +1264,21 @@ class FusedBatchNormOpBase : public OpKernel {
|
||||
context, estimated_variance.dims() == 1,
|
||||
errors::InvalidArgument("estimated_variance must be 1-dimensional",
|
||||
estimated_variance.shape().DebugString()));
|
||||
bool use_reshape = (x.dims() == 5);
|
||||
auto x_shape = x.shape();
|
||||
TensorShape dest_shape;
|
||||
if (use_reshape) {
|
||||
const int64 in_batch = GetTensorDim(x, tensor_format_, 'N');
|
||||
int64 in_planes = GetTensorDim(x, tensor_format_, '0');
|
||||
int64 in_rows = GetTensorDim(x, tensor_format_, '1');
|
||||
int64 in_cols = GetTensorDim(x, tensor_format_, '2');
|
||||
const int64 in_depth = GetTensorDim(x, tensor_format_, 'C');
|
||||
dest_shape = ShapeFromFormat(tensor_format_, in_batch,
|
||||
{{in_planes, in_rows * in_cols}}, in_depth);
|
||||
OP_REQUIRES(context, x.CopyFrom(x, dest_shape),
|
||||
errors::InvalidArgument("Error during tensor copy."));
|
||||
}
|
||||
|
||||
if (has_side_input_) {
|
||||
OP_REQUIRES(context, side_input->shape() == x.shape(),
|
||||
errors::InvalidArgument(
|
||||
@ -1282,8 +1297,10 @@ class FusedBatchNormOpBase : public OpKernel {
|
||||
}
|
||||
|
||||
Tensor* y = nullptr;
|
||||
auto alloc_shape = use_reshape ? dest_shape : x_shape;
|
||||
OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
|
||||
{0}, 0, x.shape(), &y));
|
||||
{0}, 0, alloc_shape, &y));
|
||||
|
||||
Tensor* batch_mean = nullptr;
|
||||
OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
|
||||
{3}, 1, scale.shape(), &batch_mean));
|
||||
@ -1310,6 +1327,10 @@ class FusedBatchNormOpBase : public OpKernel {
|
||||
batch_mean, batch_var, saved_mean, saved_maybe_inv_var,
|
||||
tensor_format_, use_reserved_space);
|
||||
}
|
||||
if (use_reshape) {
|
||||
OP_REQUIRES(context, y->CopyFrom(*y, x_shape),
|
||||
errors::InvalidArgument("Error during tensor copy."));
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
@ -1375,8 +1396,8 @@ class FusedBatchNormGradOpBase : public OpKernel {
|
||||
|
||||
virtual void ComputeWithReservedSpace(OpKernelContext* context,
|
||||
bool use_reserved_space) {
|
||||
const Tensor& y_backprop = context->input(0);
|
||||
const Tensor& x = context->input(1);
|
||||
Tensor y_backprop = context->input(0);
|
||||
Tensor x = context->input(1);
|
||||
const Tensor& scale = context->input(2);
|
||||
// When is_training=True, batch mean and variance/inverted variance are
|
||||
// saved in the forward pass to be reused here. When is_training=False,
|
||||
@ -1387,11 +1408,11 @@ class FusedBatchNormGradOpBase : public OpKernel {
|
||||
// saves inverted variance.
|
||||
const Tensor& saved_maybe_inv_var_or_pop_var = context->input(4);
|
||||
|
||||
OP_REQUIRES(context, y_backprop.dims() == 4,
|
||||
errors::InvalidArgument("input must be 4-dimensional",
|
||||
OP_REQUIRES(context, y_backprop.dims() == 4 or y_backprop.dims() == 5,
|
||||
errors::InvalidArgument("input must be 4 or 5-dimensional",
|
||||
y_backprop.shape().DebugString()));
|
||||
OP_REQUIRES(context, x.dims() == 4,
|
||||
errors::InvalidArgument("input must be 4-dimensional",
|
||||
OP_REQUIRES(context, x.dims() == 4 or x.dims() == 5,
|
||||
errors::InvalidArgument("input must be 4 or 5-dimensional",
|
||||
x.shape().DebugString()));
|
||||
OP_REQUIRES(context, scale.dims() == 1,
|
||||
errors::InvalidArgument("scale must be 1-dimensional",
|
||||
@ -1404,10 +1425,27 @@ class FusedBatchNormGradOpBase : public OpKernel {
|
||||
errors::InvalidArgument(
|
||||
"saved variance must be 1-dimensional",
|
||||
saved_maybe_inv_var_or_pop_var.shape().DebugString()));
|
||||
bool use_reshape = (x.dims() == 5);
|
||||
auto x_shape = x.shape();
|
||||
TensorShape dest_shape;
|
||||
if (use_reshape) {
|
||||
const int64 in_batch = GetTensorDim(x, tensor_format_, 'N');
|
||||
int64 in_planes = GetTensorDim(x, tensor_format_, '0');
|
||||
int64 in_rows = GetTensorDim(x, tensor_format_, '1');
|
||||
int64 in_cols = GetTensorDim(x, tensor_format_, '2');
|
||||
const int64 in_depth = GetTensorDim(x, tensor_format_, 'C');
|
||||
dest_shape = ShapeFromFormat(tensor_format_, in_batch,
|
||||
{{in_planes, in_rows * in_cols}}, in_depth);
|
||||
OP_REQUIRES(context, x.CopyFrom(x, dest_shape),
|
||||
errors::InvalidArgument("Error during tensor copy."));
|
||||
OP_REQUIRES(context, y_backprop.CopyFrom(y_backprop, dest_shape),
|
||||
errors::InvalidArgument("Error during tensor copy."));
|
||||
}
|
||||
|
||||
Tensor* x_backprop = nullptr;
|
||||
auto alloc_shape = use_reshape ? dest_shape : x_shape;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output(0, x.shape(), &x_backprop));
|
||||
context->allocate_output(0, alloc_shape, &x_backprop));
|
||||
|
||||
const TensorShape& scale_offset_shape = scale.shape();
|
||||
Tensor* scale_backprop = nullptr;
|
||||
@ -1441,15 +1479,20 @@ class FusedBatchNormGradOpBase : public OpKernel {
|
||||
offset_backprop, use_reserved_space, tensor_format_);
|
||||
} else {
|
||||
// Necessary layout conversion is currently done in python.
|
||||
CHECK(tensor_format_ == FORMAT_NHWC)
|
||||
<< "The implementation of FusedBatchNormGrad with is_training=False "
|
||||
"only support "
|
||||
<< "NHWC tensor format for now.";
|
||||
OP_REQUIRES(context, tensor_format_ == FORMAT_NHWC,
|
||||
errors::InvalidArgument(
|
||||
"The implementation of "
|
||||
"FusedBatchNormGrad with is_training=False only support "
|
||||
"NHWC tensor format for now."));
|
||||
functor::FusedBatchNormFreezeGrad<Device, T, U>()(
|
||||
context, y_backprop, x, scale, saved_mean_or_pop_mean,
|
||||
saved_maybe_inv_var_or_pop_var, epsilon_, x_backprop, scale_backprop,
|
||||
offset_backprop);
|
||||
}
|
||||
if (use_reshape) {
|
||||
OP_REQUIRES(context, x_backprop->CopyFrom(*x_backprop, x_shape),
|
||||
errors::InvalidArgument("Error during tensor copy."));
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -530,6 +530,11 @@ TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPEC);
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES_NO_HALF(DEFINE_GRAD_GPU_SPEC);
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
// Required for MSVC debug build
|
||||
TF_CALL_half(DEFINE_GRAD_GPU_SPEC)
|
||||
#endif
|
||||
|
||||
#undef DEFINE_GPU_SPEC
|
||||
#undef DEFINE_GRAD_GPU_SPEC
|
||||
|
||||
|
@ -296,6 +296,9 @@ def _gen_unranked_kernel_fatbin_impl(ctx):
|
||||
archs_trimmed.append(arch[3:])
|
||||
arch_flag = ",".join(archs_trimmed)
|
||||
|
||||
# TODO(b/169066682): Generate Fatbin when lowering GPU module.
|
||||
arch_flag = "75"
|
||||
|
||||
filename = "%s.a" % (name)
|
||||
gpu_bin = ctx.outputs.output
|
||||
ctx.actions.run(
|
||||
|
@ -43,7 +43,8 @@ namespace tensorflow {
|
||||
// We have to be able to detect and handle overflows in int32, so this function
|
||||
// uses doubles and int64's to make sure we have enough room.
|
||||
template <class T>
|
||||
int64 FloatToQuantizedUnclamped(float input, float range_min, float range_max) {
|
||||
inline int64 FloatToQuantizedUnclamped(float input, float range_min,
|
||||
float range_max) {
|
||||
const int64 lowest_quantized =
|
||||
static_cast<double>(Eigen::NumTraits<T>::lowest());
|
||||
if (range_min == range_max) {
|
||||
@ -60,6 +61,12 @@ int64 FloatToQuantizedUnclamped(float input, float range_min, float range_max) {
|
||||
return quantized;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline int64 FloatToQuantizedUnclamped<float>(float input, float range_min,
|
||||
float range_max) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
// This converts the float into the final quantized type, clamping/saturating
|
||||
// any over or underflows.
|
||||
template <class T>
|
||||
|
@ -22,6 +22,10 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/gpu_kernel_helper.h"
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#endif
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
@ -252,6 +256,70 @@ template struct functor::SpaceToDepthOpFunctor<GPUDevice, uint8, FORMAT_NHWC>;
|
||||
// NCHW_VECT_C with 4 x qint8 can be treated as NCHW int32.
|
||||
template struct functor::SpaceToDepthOpFunctor<GPUDevice, int32, FORMAT_NCHW>;
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
#define FORCE_DEPTH(TYPE, NAME, NUM, DEVICE) \
|
||||
template <> \
|
||||
struct functor::SpaceToDepthOpFunctor<DEVICE, TYPE, NUM> { \
|
||||
void operator()(const DEVICE& d, \
|
||||
typename TTypes<TYPE, 4>::ConstTensor input, \
|
||||
int block_size, typename TTypes<TYPE, 4>::Tensor output) { \
|
||||
LOG(FATAL) << "Should not be called."; \
|
||||
} \
|
||||
}; \
|
||||
void _force_SpaceToDepthOpFunctor##NAME( \
|
||||
const DEVICE& d, typename TTypes<TYPE, 4>::ConstTensor input, \
|
||||
int block_size, typename TTypes<TYPE, 4>::Tensor output) { \
|
||||
functor::SpaceToDepthOpFunctor<DEVICE, TYPE, NUM> op; \
|
||||
op(d, input, block_size, output); \
|
||||
}
|
||||
|
||||
#define FORCE_DEPTH2(TYPE, NAME, DEVICE) \
|
||||
FORCE_DEPTH(TYPE, NAME, FORMAT_NCHW, DEVICE) \
|
||||
FORCE_DEPTH(TYPE, NAME##_2, FORMAT_NHWC, DEVICE)
|
||||
|
||||
FORCE_DEPTH2(__int64, int64, Eigen::ThreadPoolDevice)
|
||||
FORCE_DEPTH2(unsigned __int64, uint64, Eigen::ThreadPoolDevice)
|
||||
FORCE_DEPTH2(unsigned int, uint, Eigen::ThreadPoolDevice)
|
||||
FORCE_DEPTH2(unsigned short, ushort, Eigen::ThreadPoolDevice)
|
||||
FORCE_DEPTH2(short, short, Eigen::ThreadPoolDevice)
|
||||
FORCE_DEPTH2(signed char, char, Eigen::ThreadPoolDevice)
|
||||
FORCE_DEPTH2(unsigned char, char, Eigen::ThreadPoolDevice)
|
||||
FORCE_DEPTH2(bfloat16, bfloat16, Eigen::ThreadPoolDevice)
|
||||
FORCE_DEPTH2(double, double, Eigen::ThreadPoolDevice)
|
||||
FORCE_DEPTH2(complex64, complex64, Eigen::ThreadPoolDevice)
|
||||
FORCE_DEPTH2(complex128, complex128, Eigen::ThreadPoolDevice)
|
||||
FORCE_DEPTH2(bool, bool, Eigen::ThreadPoolDevice)
|
||||
FORCE_DEPTH2(tensorflow::tstring, tstring, Eigen::ThreadPoolDevice)
|
||||
FORCE_DEPTH2(tensorflow::ResourceHandle, ResourceHandle,
|
||||
Eigen::ThreadPoolDevice)
|
||||
FORCE_DEPTH2(tensorflow::Variant, variant, Eigen::ThreadPoolDevice)
|
||||
FORCE_DEPTH2(Eigen::QInt8, qint8, Eigen::ThreadPoolDevice)
|
||||
FORCE_DEPTH2(Eigen::half, half, Eigen::ThreadPoolDevice)
|
||||
FORCE_DEPTH2(float, float, Eigen::ThreadPoolDevice)
|
||||
FORCE_DEPTH2(int, int, Eigen::ThreadPoolDevice)
|
||||
FORCE_DEPTH2(Eigen::QInt8, qint8gpu, GPUDevice)
|
||||
|
||||
// Special case for int, FORMAT_NHWC
|
||||
template <>
|
||||
struct functor::SpaceToDepthOpFunctor<GPUDevice, int, FORMAT_NHWC> {
|
||||
void operator()(const GPUDevice& d,
|
||||
typename TTypes<int, 4>::ConstTensor input, int block_size,
|
||||
typename TTypes<int, 4>::Tensor output) {
|
||||
LOG(FATAL) << "Should not be called.";
|
||||
}
|
||||
};
|
||||
void _force_SpaceToDepthOpFunctor_int(
|
||||
const GPUDevice& d, typename TTypes<int, 4>::ConstTensor input,
|
||||
int block_size, typename TTypes<int, 4>::Tensor output) {
|
||||
functor::SpaceToDepthOpFunctor<GPUDevice, int, FORMAT_NHWC> op;
|
||||
op(d, input, block_size, output);
|
||||
}
|
||||
|
||||
#undef FORCE_DEPTH
|
||||
#undef FORCE_DEPTH2
|
||||
|
||||
#endif
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
@ -58,3 +58,77 @@ op {
|
||||
has_minimum: true
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "SnapshotDatasetV2"
|
||||
input_arg {
|
||||
name: "input_dataset"
|
||||
type: DT_VARIANT
|
||||
}
|
||||
input_arg {
|
||||
name: "path"
|
||||
type: DT_STRING
|
||||
}
|
||||
input_arg {
|
||||
name: "reader_func_other_args"
|
||||
type_list_attr: "Treader_func_args"
|
||||
}
|
||||
input_arg {
|
||||
name: "shard_func_other_args"
|
||||
type_list_attr: "Tshard_func_args"
|
||||
}
|
||||
output_arg {
|
||||
name: "handle"
|
||||
type: DT_VARIANT
|
||||
}
|
||||
attr {
|
||||
name: "output_types"
|
||||
type: "list(type)"
|
||||
has_minimum: true
|
||||
minimum: 1
|
||||
}
|
||||
attr {
|
||||
name: "output_shapes"
|
||||
type: "list(shape)"
|
||||
has_minimum: true
|
||||
minimum: 1
|
||||
}
|
||||
attr {
|
||||
name: "compression"
|
||||
type: "string"
|
||||
default_value {
|
||||
s: ""
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "reader_prefix"
|
||||
type: "string"
|
||||
default_value {
|
||||
s: ""
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "writer_prefix"
|
||||
type: "string"
|
||||
default_value {
|
||||
s: ""
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "reader_func"
|
||||
type: "func"
|
||||
}
|
||||
attr {
|
||||
name: "shard_func"
|
||||
type: "func"
|
||||
}
|
||||
attr {
|
||||
name: "Treader_func_args"
|
||||
type: "list(type)"
|
||||
has_minimum: true
|
||||
}
|
||||
attr {
|
||||
name: "Tshard_func_args"
|
||||
type: "list(type)"
|
||||
has_minimum: true
|
||||
}
|
||||
}
|
||||
|
@ -221,7 +221,7 @@ REGISTER_OP("FusedBatchNormV3")
|
||||
.Attr("U: {float}")
|
||||
.Attr("epsilon: float = 0.0001")
|
||||
.Attr("exponential_avg_factor: float = 1.0")
|
||||
.Attr(GetConvnetDataFormatAttrString())
|
||||
.Attr(GetConvnetDataFormat2D3DAttrString())
|
||||
.Attr("is_training: bool = true")
|
||||
.SetShapeFn(shape_inference::FusedBatchNormV3Shape);
|
||||
|
||||
@ -308,7 +308,7 @@ REGISTER_OP("FusedBatchNormGradV3")
|
||||
.Attr("T: {half, bfloat16, float}")
|
||||
.Attr("U: {float}")
|
||||
.Attr("epsilon: float = 0.0001")
|
||||
.Attr(GetConvnetDataFormatAttrString())
|
||||
.Attr(GetConvnetDataFormat2D3DAttrString())
|
||||
.Attr("is_training: bool = true")
|
||||
.SetShapeFn(shape_inference::FusedBatchNormGradShape);
|
||||
// --------------------------------------------------------------------------
|
||||
|
@ -44435,6 +44435,20 @@ op {
|
||||
s: ""
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "reader_prefix"
|
||||
type: "string"
|
||||
default_value {
|
||||
s: ""
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "writer_prefix"
|
||||
type: "string"
|
||||
default_value {
|
||||
s: ""
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "reader_func"
|
||||
type: "func"
|
||||
|
@ -1,7 +1,7 @@
|
||||
# Platform-specific build configurations.
|
||||
|
||||
load("@com_google_protobuf//:protobuf.bzl", "proto_gen")
|
||||
load("//tensorflow:tensorflow.bzl", "clean_dep", "if_not_windows", "if_tpu")
|
||||
load("//tensorflow:tensorflow.bzl", "clean_dep", "if_libtpu", "if_not_windows")
|
||||
load("//tensorflow/core/platform:build_config_root.bzl", "if_static")
|
||||
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
|
||||
load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm")
|
||||
@ -814,4 +814,4 @@ def if_llvm_system_z_available(then, otherwise = []):
|
||||
})
|
||||
|
||||
def tf_tpu_dependencies():
|
||||
return if_tpu(["//tensorflow/core/tpu/kernels"])
|
||||
return if_libtpu(["//tensorflow/core/tpu/kernels"])
|
||||
|
@ -406,7 +406,6 @@ tf_cc_test(
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_googletest//:gtest",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -62,7 +62,6 @@ tf_cc_test(
|
||||
"//tensorflow/core/profiler/utils:xplane_visitor",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@com_google_googletest//:gtest",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -16,7 +16,6 @@ limitations under the License.
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
|
||||
#include <gmock/gmock.h>
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/core/framework/step_stats.pb.h"
|
||||
|
@ -21,7 +21,6 @@ limitations under the License.
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include <gmock/gmock.h>
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/env_time.h"
|
||||
|
@ -108,7 +108,7 @@ limitations under the License.
|
||||
|
||||
#define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0
|
||||
#define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0
|
||||
#define TF_GRAPH_DEF_VERSION 542 // Updated: 2020/10/2
|
||||
#define TF_GRAPH_DEF_VERSION 543 // Updated: 2020/10/3
|
||||
|
||||
// Checkpoint compatibility versions (the versions field in SavedSliceMeta).
|
||||
//
|
||||
|
@ -5,13 +5,11 @@ load(
|
||||
"//tensorflow/core/platform:build_config.bzl",
|
||||
"tf_proto_library",
|
||||
)
|
||||
load("//tensorflow:tensorflow.bzl", "if_libtpu", "tf_copts")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_grpc_cc_dependency") # buildifier: disable=same-origin-load
|
||||
load("//tensorflow:tensorflow.bzl", "tf_kernel_library") # buildifier: disable=same-origin-load
|
||||
|
||||
# Config setting to enable go/libtpu support.
|
||||
WITH_TPU_SUPPORT = "//tensorflow:with_tpu_support"
|
||||
|
||||
DEFAULT = "//conditions:default"
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
@ -44,10 +42,10 @@ cc_library(
|
||||
name = "tpu_compile_op_common",
|
||||
srcs = ["tpu_compile_op_common.cc"],
|
||||
hdrs = ["tpu_compile_op_common.h"],
|
||||
deps = select({
|
||||
WITH_TPU_SUPPORT: [":tpu_compilation_metrics"],
|
||||
DEFAULT: ["//tensorflow/core/tpu/kernels:tpu_compilation_metrics"],
|
||||
}) + [
|
||||
deps = if_libtpu(
|
||||
[":tpu_compilation_metrics"],
|
||||
["//tensorflow/core/tpu/kernels:tpu_compilation_metrics"],
|
||||
) + [
|
||||
":tpu_compilation_cache_entry_unloader",
|
||||
":tpu_compilation_cache_interface",
|
||||
":tpu_compilation_metrics_hdrs",
|
||||
@ -97,14 +95,10 @@ tf_kernel_library(
|
||||
name = "tpu_configuration_ops",
|
||||
srcs = ["tpu_configuration_ops.cc"],
|
||||
hdrs = ["tpu_configuration_ops.h"],
|
||||
copts = select({
|
||||
WITH_TPU_SUPPORT: ["-DLIBTFTPU"],
|
||||
DEFAULT: [],
|
||||
}),
|
||||
deps = select({
|
||||
WITH_TPU_SUPPORT: [":tpu_util"],
|
||||
DEFAULT: ["//tensorflow/core/tpu/kernels:tpu_util"],
|
||||
}) + [
|
||||
deps = if_libtpu(
|
||||
[":tpu_util"],
|
||||
["//tensorflow/core/tpu/kernels:tpu_util"],
|
||||
) + [
|
||||
":tpu_compilation_cache_factory",
|
||||
":tpu_compilation_cache_interface",
|
||||
":tpu_compilation_cache_local_lookup",
|
||||
@ -346,10 +340,10 @@ cc_library(
|
||||
name = "tpu_compilation_cache_interface",
|
||||
srcs = ["tpu_compilation_cache_interface.cc"],
|
||||
hdrs = ["tpu_compilation_cache_interface.h"],
|
||||
deps = select({
|
||||
WITH_TPU_SUPPORT: [":tpu_compilation_metrics"],
|
||||
DEFAULT: ["//tensorflow/core/tpu/kernels:tpu_compilation_metrics"],
|
||||
}) + [
|
||||
deps = if_libtpu(
|
||||
[":tpu_compilation_metrics"],
|
||||
["//tensorflow/core/tpu/kernels:tpu_compilation_metrics"],
|
||||
) + [
|
||||
":compiled_subgraph",
|
||||
":tpu_compilation_cache_common_proto_cc",
|
||||
":tpu_compilation_cache_entry",
|
||||
@ -424,10 +418,7 @@ cc_library(
|
||||
cc_library(
|
||||
name = "tpu_compilation_metrics",
|
||||
srcs = ["tpu_compilation_metrics.cc"],
|
||||
copts = select({
|
||||
WITH_TPU_SUPPORT: ["-DLIBTFTPU"],
|
||||
DEFAULT: [],
|
||||
}),
|
||||
copts = tf_copts(),
|
||||
deps = [
|
||||
":tpu_compilation_metrics_hdrs",
|
||||
],
|
||||
@ -529,14 +520,11 @@ cc_library(
|
||||
cc_library(
|
||||
name = "tpu_compilation_cache_rpc_support_hdrs",
|
||||
hdrs = ["tpu_compilation_cache_rpc_support.h"],
|
||||
copts = select({
|
||||
WITH_TPU_SUPPORT: ["-DLIBTFTPU"],
|
||||
DEFAULT: [],
|
||||
}),
|
||||
deps = select({
|
||||
WITH_TPU_SUPPORT: [":tpu_compilation_cache_proto_cc"], # build_cleaner: keep
|
||||
DEFAULT: ["//tensorflow/core/tpu/kernels:tpu_compilation_cache_cc_proto"], # build_cleaner: keep
|
||||
}) + [
|
||||
copts = tf_copts(),
|
||||
deps = if_libtpu(
|
||||
[":tpu_compilation_cache_proto_cc"],
|
||||
["//tensorflow/core/tpu/kernels:tpu_compilation_cache_cc_proto"],
|
||||
) + [
|
||||
":tpu_compilation_cache_entry",
|
||||
":tpu_compilation_cache_interface",
|
||||
":tpu_compilation_cache_lookup",
|
||||
@ -550,10 +538,7 @@ cc_library(
|
||||
cc_library(
|
||||
name = "tpu_compilation_cache_rpc_support",
|
||||
srcs = ["tpu_compilation_cache_rpc_support.cc"],
|
||||
copts = select({
|
||||
WITH_TPU_SUPPORT: ["-DLIBTFTPU"],
|
||||
DEFAULT: [],
|
||||
}),
|
||||
copts = tf_copts(),
|
||||
deps = [
|
||||
":tpu_compilation_cache_common_proto_cc",
|
||||
":tpu_compilation_cache_proto_cc",
|
||||
@ -572,14 +557,11 @@ cc_library(
|
||||
name = "tpu_compilation_cache_rpc_lookup",
|
||||
srcs = ["tpu_compilation_cache_rpc_lookup.cc"],
|
||||
hdrs = ["tpu_compilation_cache_rpc_lookup.h"],
|
||||
copts = select({
|
||||
WITH_TPU_SUPPORT: ["-DLIBTFTPU"],
|
||||
DEFAULT: [],
|
||||
}),
|
||||
deps = select({
|
||||
WITH_TPU_SUPPORT: [":tpu_compilation_cache_rpc_support"],
|
||||
DEFAULT: ["//tensorflow/core/tpu/kernels:tpu_compilation_cache_rpc_support"],
|
||||
}) + [
|
||||
copts = tf_copts(),
|
||||
deps = if_libtpu(
|
||||
[":tpu_compilation_cache_rpc_support"],
|
||||
["//tensorflow/core/tpu/kernels:tpu_compilation_cache_rpc_support"],
|
||||
) + [
|
||||
":tpu_compilation_cache_grpc",
|
||||
":tpu_compilation_cache_interface",
|
||||
":tpu_compilation_cache_lookup",
|
||||
@ -617,14 +599,11 @@ cc_library(
|
||||
name = "tpu_compilation_cache_grpc",
|
||||
srcs = ["tpu_compilation_cache_grpc.cc"],
|
||||
hdrs = ["tpu_compilation_cache_grpc.h"],
|
||||
copts = select({
|
||||
WITH_TPU_SUPPORT: ["-DLIBTFTPU"],
|
||||
DEFAULT: [],
|
||||
}),
|
||||
deps = select({
|
||||
WITH_TPU_SUPPORT: [":tpu_compilation_cache_proto_cc"],
|
||||
DEFAULT: ["//tensorflow/core/tpu/kernels:tpu_compilation_cache_cc_proto"],
|
||||
}) + [
|
||||
copts = tf_copts(),
|
||||
deps = if_libtpu(
|
||||
[":tpu_compilation_cache_proto_cc"],
|
||||
["//tensorflow/core/tpu/kernels:tpu_compilation_cache_cc_proto"],
|
||||
) + [
|
||||
":tpu_compilation_cache_common_proto_cc",
|
||||
tf_grpc_cc_dependency(),
|
||||
],
|
||||
@ -634,20 +613,17 @@ cc_library(
|
||||
name = "tpu_compilation_cache_service",
|
||||
srcs = ["tpu_compilation_cache_service.cc"],
|
||||
hdrs = ["tpu_compilation_cache_service.h"],
|
||||
copts = select({
|
||||
WITH_TPU_SUPPORT: ["-DLIBTFTPU"],
|
||||
DEFAULT: [],
|
||||
}),
|
||||
deps = select({
|
||||
WITH_TPU_SUPPORT: [
|
||||
":tpu_compilation_cache_rpc_support", # build_cleaner: keep
|
||||
":tpu_compilation_cache_proto_cc", # build_cleaner: keep
|
||||
copts = tf_copts(),
|
||||
deps = if_libtpu(
|
||||
[
|
||||
":tpu_compilation_cache_rpc_support",
|
||||
":tpu_compilation_cache_proto_cc",
|
||||
],
|
||||
DEFAULT: [
|
||||
"//tensorflow/core/tpu/kernels:tpu_compilation_cache_rpc_support", # build_cleaner: keep
|
||||
"//tensorflow/core/tpu/kernels:tpu_compilation_cache_cc_proto", # build_cleaner: keep
|
||||
[
|
||||
"//tensorflow/core/tpu/kernels:tpu_compilation_cache_rpc_support",
|
||||
"//tensorflow/core/tpu/kernels:tpu_compilation_cache_cc_proto",
|
||||
],
|
||||
}) + [
|
||||
) + [
|
||||
":tpu_compilation_cache_common_proto_cc",
|
||||
":tpu_compilation_cache_grpc",
|
||||
":tpu_compilation_cache_interface",
|
||||
@ -704,10 +680,7 @@ cc_library(
|
||||
name = "tpu_compile_op_impl",
|
||||
srcs = ["tpu_compile_op_impl.cc"],
|
||||
hdrs = ["tpu_compile_op_impl.h"],
|
||||
copts = select({
|
||||
WITH_TPU_SUPPORT: ["-DLIBTFTPU"],
|
||||
DEFAULT: [],
|
||||
}),
|
||||
copts = tf_copts(),
|
||||
deps = [
|
||||
":tpu_compilation_cache_key",
|
||||
":tpu_compile_c_api_hdrs",
|
||||
@ -952,14 +925,11 @@ cc_library(
|
||||
name = "tpu_pod_state",
|
||||
srcs = ["tpu_pod_state.cc"],
|
||||
hdrs = ["tpu_pod_state.h"],
|
||||
copts = select({
|
||||
WITH_TPU_SUPPORT: ["-DLIBTFTPU"],
|
||||
DEFAULT: [],
|
||||
}),
|
||||
deps = select({
|
||||
WITH_TPU_SUPPORT: [":tpu_util"],
|
||||
DEFAULT: ["//tensorflow/core/tpu/kernels:tpu_util"],
|
||||
}) + [
|
||||
copts = tf_copts(),
|
||||
deps = if_libtpu(
|
||||
[":tpu_util"],
|
||||
["//tensorflow/core/tpu/kernels:tpu_util"],
|
||||
) + [
|
||||
":tpu_compilation_cache_service",
|
||||
"//tensorflow/c:tf_status",
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
|
@ -30,11 +30,11 @@ namespace tensorflow {
|
||||
namespace tpu {
|
||||
|
||||
static const char* grpcTpuCompilationCacheService_method_names[] = {
|
||||
#if defined(LIBTFTPU)
|
||||
#if defined(LIBTPU_ON_GCE)
|
||||
"/tensorflow.tpu.TpuCompilationCacheServiceExternal/GetTpuProgram",
|
||||
#else // LIBTFTPU
|
||||
#else // LIBTPU_ON_GCE
|
||||
"/tensorflow.tpu.TpuCompilationCacheService/GetTpuProgram",
|
||||
#endif // LIBTFTPU
|
||||
#endif // LIBTPU_ON_GCE
|
||||
};
|
||||
|
||||
std::unique_ptr<grpc::TpuCompilationCacheService::Stub>
|
||||
|
@ -35,7 +35,7 @@ limitations under the License.
|
||||
|
||||
#include <functional>
|
||||
|
||||
#if defined(LIBTFTPU)
|
||||
#if defined(LIBTPU_ON_GCE)
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache.pb.h"
|
||||
#else
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache.pb.h" // copybara"
|
||||
@ -48,7 +48,7 @@ namespace grpc {
|
||||
class TpuCompilationCacheService final {
|
||||
public:
|
||||
using RequestType = ::tensorflow::tpu::GetTpuProgramRequest;
|
||||
#if defined(LIBTFTPU)
|
||||
#if defined(LIBTPU_ON_GCE)
|
||||
using ResponseType = ::tensorflow::tpu::GetTpuProgramResponseExternal;
|
||||
#else
|
||||
using ResponseType = ::tensorflow::tpu::GetTpuProgramResponse;
|
||||
@ -59,7 +59,7 @@ class TpuCompilationCacheService final {
|
||||
enum class MethodId { kGetTpuProgram = 0 };
|
||||
|
||||
static constexpr char const* service_full_name() {
|
||||
#if defined(LIBTFTPU)
|
||||
#if defined(LIBTPU_ON_GCE)
|
||||
return "tensorflow.tpu.TpuCompilationCacheServiceExternal";
|
||||
#else
|
||||
return "tensorflow.tpu.TpuCompilationCacheService";
|
||||
|
@ -25,7 +25,7 @@ namespace tensorflow {
|
||||
namespace tpu {
|
||||
namespace {
|
||||
|
||||
#if defined(LIBTFTPU)
|
||||
#if defined(LIBTPU_ON_GCE)
|
||||
using ResponseType = GetTpuProgramResponseExternal;
|
||||
#else
|
||||
using ResponseType = GetTpuProgramResponse;
|
||||
|
@ -17,7 +17,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#if defined(LIBTFTPU)
|
||||
#if defined(LIBTPU_ON_GCE)
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache.pb.h"
|
||||
#endif
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_common.pb.h"
|
||||
@ -30,7 +30,7 @@ std::shared_ptr<::grpc::ChannelCredentials> CreateChannelCredentials() {
|
||||
return ::grpc::InsecureChannelCredentials(); // NOLINT
|
||||
}
|
||||
|
||||
#if defined(LIBTFTPU)
|
||||
#if defined(LIBTPU_ON_GCE)
|
||||
template <>
|
||||
Status DeserializeRpcResponseToCacheEntry<GetTpuProgramResponseExternal>(
|
||||
absl::string_view local_proto_key, GetTpuProgramResponseExternal* response,
|
||||
@ -156,6 +156,6 @@ xla::StatusOr<std::vector<::grpc::Slice>> SerializeCacheEntryToBufferSlices(
|
||||
|
||||
return std::vector<::grpc::Slice>{::grpc::Slice(encoded_header)};
|
||||
}
|
||||
#endif // LIBTFTPU
|
||||
#endif // LIBTPU_ON_GCE
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
||||
|
@ -19,7 +19,7 @@ namespace tpu {
|
||||
|
||||
// TODO(henrytan): remove this once `TpuCompilationCache` migration to OSS is
|
||||
// completed.
|
||||
#if defined(LIBTFTPU)
|
||||
#if defined(LIBTPU_ON_GCE)
|
||||
/* static */
|
||||
void TpuCompilationMetrics::IncrementCacheLookupCount(
|
||||
bool is_cache_hit, absl::string_view session_name) {
|
||||
@ -36,7 +36,7 @@ void TpuCompilationMetrics::IncrementCompilationCount(
|
||||
absl::string_view session_name) {
|
||||
// A placeholder for tracking metrics.
|
||||
}
|
||||
#endif // LIBTFTPU
|
||||
#endif // LIBTPU_ON_GCE
|
||||
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
||||
|
@ -68,11 +68,11 @@ class TpuCompileOpImplFactory : public CompileOpImplFactory {
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(LIBTFTPU)
|
||||
#if defined(LIBTPU_ON_GCE)
|
||||
REGISTER_MODULE_INITIALIZER(tpu_compile_op_impl_factory, {
|
||||
VLOG(1) << "register TpuCompileOpImplFactory()";
|
||||
CompileOpImplFactory::Register(new TpuCompileOpImplFactory());
|
||||
});
|
||||
#endif // LIBTFTPU
|
||||
#endif // LIBTPU_ON_GCE
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
||||
|
@ -18,7 +18,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/core/tpu/tpu_api.h"
|
||||
|
||||
#if defined(LIBTFTPU)
|
||||
#if defined(LIBTPU_ON_GCE)
|
||||
#include "tensorflow/core/tpu/kernels/tpu_util.h"
|
||||
#else
|
||||
#include "tensorflow/core/tpu/kernels/tpu_util.h" // copybara"
|
||||
@ -54,7 +54,7 @@ xla::StatusOr<std::unique_ptr<TpuCompilationCacheService>>
|
||||
ConstructCacheService(ResourceMgr* rmgr, int serving_port,
|
||||
tpu::TpuCompilationCacheInterface* compilation_cache) {
|
||||
xla::StatusOr<std::unique_ptr<::grpc::ServerBuilder>> server_builder;
|
||||
#if defined(LIBTFTPU)
|
||||
#if defined(LIBTPU_ON_GCE)
|
||||
server_builder = tpu::CreateServerBuilder(serving_port);
|
||||
#else
|
||||
server_builder = tpu::CreateServerBuilderGoogle(serving_port);
|
||||
|
@ -286,10 +286,8 @@ cc_library(
|
||||
":cl_command_queue",
|
||||
":cl_context",
|
||||
":cl_device",
|
||||
":cl_kernel",
|
||||
":precision",
|
||||
":program_cache",
|
||||
":tensor",
|
||||
":tensor_type",
|
||||
":util",
|
||||
"//tensorflow/lite/delegates/gpu/common:data_type",
|
||||
|
@ -18,7 +18,6 @@ limitations under the License.
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/cl/cl_kernel.h"
|
||||
#include "tensorflow/lite/delegates/gpu/cl/util.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/shape.h"
|
||||
|
||||
@ -26,59 +25,6 @@ namespace tflite {
|
||||
namespace gpu {
|
||||
namespace cl {
|
||||
namespace {
|
||||
|
||||
std::string GetKernelOneLayerTextureArray() {
|
||||
return R"(
|
||||
|
||||
__kernel void main_function(__write_only image2d_array_t dst) {
|
||||
int X = (int)(get_global_id(0));
|
||||
int Y = (int)(get_global_id(1));
|
||||
|
||||
write_imagef(dst, (int4)(X, Y, 0, 0), (float4)(2.0, 2.0, 2.0, 2.0));
|
||||
}
|
||||
)";
|
||||
}
|
||||
|
||||
// Some Adreno < 600 have bug with one layer texture array. b/131099086
|
||||
// If we have one layer texture array and will write smt from kernel to this
|
||||
// texture, we will get zeroes instead of actual values.
|
||||
// The same kernel will work, if we use texture array with more than one layer.
|
||||
// With help of this code we can detect this bug.
|
||||
absl::Status CheckKernelSupportOfOneLayerTextureArray(Environment* env,
|
||||
bool* result) {
|
||||
// No bug on Adreno 6xx
|
||||
if (env->device().info_.adreno_info.gpu_version >= 600) {
|
||||
*result = true;
|
||||
return absl::OkStatus();
|
||||
}
|
||||
CLKernel kernel;
|
||||
RETURN_IF_ERROR(env->program_cache()->GetOrCreateCLKernel(
|
||||
GetKernelOneLayerTextureArray(), "main_function", env->context(),
|
||||
env->device(), &kernel));
|
||||
|
||||
Tensor tensor;
|
||||
const BHWC shape(1, 4, 4, 4);
|
||||
RETURN_IF_ERROR(CreateTensor(
|
||||
env->context(), shape,
|
||||
{DataType::FLOAT32, TensorStorageType::TEXTURE_ARRAY, Layout::HWC},
|
||||
&tensor));
|
||||
RETURN_IF_ERROR(kernel.SetMemory(0, tensor.GetMemoryPtr()));
|
||||
RETURN_IF_ERROR(env->queue()->DispatchImplicit(kernel, {4, 4, 1}, {4, 4, 1}));
|
||||
TensorFloat32 tensor_gpu;
|
||||
tensor_gpu.shape = shape;
|
||||
tensor_gpu.data.resize(shape.DimensionsProduct());
|
||||
RETURN_IF_ERROR(tensor.ReadData(env->queue(), &tensor_gpu));
|
||||
|
||||
*result = true;
|
||||
for (int i = 0; i < 64; ++i) {
|
||||
if (tensor_gpu.data[i] != 2.0) {
|
||||
*result = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status CreateEnvironment(Environment* result, bool shared,
|
||||
cl_context_properties egl_context,
|
||||
cl_context_properties egl_display) {
|
||||
@ -99,16 +45,7 @@ absl::Status CreateEnvironment(Environment* result, bool shared,
|
||||
*result = Environment(std::move(gpu), std::move(context), std::move(queue),
|
||||
std::move(profiling_queue));
|
||||
|
||||
if (result->device().IsAdreno() && result->device().SupportsTextureArray()) {
|
||||
bool supports_one_layer;
|
||||
RETURN_IF_ERROR(
|
||||
CheckKernelSupportOfOneLayerTextureArray(result, &supports_one_layer));
|
||||
if (!supports_one_layer) {
|
||||
result->GetDevicePtr()->DisableOneLayerTextureArray();
|
||||
}
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
return result->Init();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@ -141,10 +78,12 @@ Environment& Environment::operator=(Environment&& environment) {
|
||||
|
||||
absl::Status Environment::Init() {
|
||||
if (device().IsAdreno() && device().SupportsTextureArray()) {
|
||||
bool supports_one_layer;
|
||||
RETURN_IF_ERROR(
|
||||
CheckKernelSupportOfOneLayerTextureArray(this, &supports_one_layer));
|
||||
if (!supports_one_layer) {
|
||||
// Some Adreno < 600 have bug with one layer texture array. b/131099086
|
||||
// If we have one layer texture array and will write smt from kernel to this
|
||||
// texture, we will get zeroes instead of actual values.
|
||||
// The same kernel will work, if we use texture array with more than one
|
||||
// layer.
|
||||
if (device().info_.adreno_info.gpu_version < 600) {
|
||||
GetDevicePtr()->DisableOneLayerTextureArray();
|
||||
}
|
||||
}
|
||||
|
@ -21,7 +21,6 @@ limitations under the License.
|
||||
#include "tensorflow/lite/delegates/gpu/cl/cl_device.h"
|
||||
#include "tensorflow/lite/delegates/gpu/cl/precision.h"
|
||||
#include "tensorflow/lite/delegates/gpu/cl/program_cache.h"
|
||||
#include "tensorflow/lite/delegates/gpu/cl/tensor.h"
|
||||
#include "tensorflow/lite/delegates/gpu/cl/tensor_type.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/data_type.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||
|
12
tensorflow/lite/micro/cortex_m_gcc_generic/README.md
Normal file
12
tensorflow/lite/micro/cortex_m_gcc_generic/README.md
Normal file
@ -0,0 +1,12 @@
|
||||
# Generic Cortex-Mx customizations
|
||||
|
||||
The customization requires a definition where the debug log goes to. The purpose
|
||||
of the generic Cortex-Mx target is to generate a TFLM library file for use in
|
||||
application projects outside of this repo. As the chip HAL and the board
|
||||
specific layer are only defined in the application project, the TFLM library
|
||||
cannot write the debug log anywhere. Instead, we allow the application layer to
|
||||
register a callback function for writing the TFLM kernel debug log.
|
||||
|
||||
# Usage
|
||||
|
||||
See debug_log_callback.h
|
@ -13,14 +13,31 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Implementation for the DebugLog() function that prints to the debug logger on
|
||||
// an generic cortex-m device.
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif // __cplusplus
|
||||
|
||||
#include "tensorflow/lite/micro/debug_log.h"
|
||||
|
||||
#ifndef TF_LITE_STRIP_ERROR_STRINGS
|
||||
#include <cstdio>
|
||||
#endif
|
||||
#include "tensorflow/lite/micro/cortex_m_gcc_generic/debug_log_callback.h"
|
||||
|
||||
extern "C" void DebugLog(const char* s) {
|
||||
static DebugLogCallback debug_log_callback = nullptr;
|
||||
|
||||
void RegisterDebugLogCallback(void (*cb)(const char* s)) {
|
||||
debug_log_callback = cb;
|
||||
}
|
||||
|
||||
void DebugLog(const char* s) {
|
||||
#ifndef TF_LITE_STRIP_ERROR_STRINGS
|
||||
fprintf(stderr, "%s", s);
|
||||
if (debug_log_callback != nullptr) {
|
||||
debug_log_callback(s);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern "C"
|
||||
#endif // __cplusplus
|
@ -0,0 +1,49 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_LITE_MICRO_CORTEX_M_GCC_GENERIC_DEBUG_LOG_CALLBACK_H_
|
||||
#define TENSORFLOW_LITE_MICRO_CORTEX_M_GCC_GENERIC_DEBUG_LOG_CALLBACK_H_
|
||||
|
||||
// The application layer must implement and register a callback before calling
|
||||
// the network in a way similar to
|
||||
//
|
||||
// void debug_log_printf(const char* s)
|
||||
// {
|
||||
// printf(s);
|
||||
// }
|
||||
//
|
||||
// int main(void)
|
||||
// {
|
||||
// // Register callback for printing debug log
|
||||
// RegisterDebugLogCallback(debug_log_printf);
|
||||
//
|
||||
// // now call the network
|
||||
// TfLiteStatus invoke_status = interpreter->Invoke();
|
||||
// }
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif // __cplusplus
|
||||
|
||||
typedef void (*DebugLogCallback)(const char* s);
|
||||
|
||||
// Registers and application-specific callback for debug logging. It must be
|
||||
// called before the first call to DebugLog().
|
||||
void RegisterDebugLogCallback(DebugLogCallback callback);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern "C"
|
||||
#endif // __cplusplus
|
||||
|
||||
#endif // TENSORFLOW_LITE_MICRO_CORTEX_M_GCC_GENERIC_DEBUG_LOG_CALLBACK_H_
|
@ -15,9 +15,17 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_LITE_MICRO_DEBUG_LOG_H_
|
||||
#define TENSORFLOW_LITE_MICRO_DEBUG_LOG_H_
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif // __cplusplus
|
||||
|
||||
// This function should be implemented by each target platform, and provide a
|
||||
// way for strings to be output to some text stream. For more information, see
|
||||
// tensorflow/lite/micro/debug_log.cc.
|
||||
extern "C" void DebugLog(const char* s);
|
||||
void DebugLog(const char* s);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern "C"
|
||||
#endif // __cplusplus
|
||||
|
||||
#endif // TENSORFLOW_LITE_MICRO_DEBUG_LOG_H_
|
||||
|
@ -52,4 +52,7 @@ tensorflow/lite/micro/tools/ci_build/test_stm32f4.sh PRESUBMIT
|
||||
echo "Running Arduino tests at `date`"
|
||||
tensorflow/lite/micro/tools/ci_build/test_arduino.sh
|
||||
|
||||
echo "Running cortex_m_gcc_generic tests at `date`"
|
||||
tensorflow/lite/micro/tools/ci_build/test_cortex_m_gcc_generic.sh
|
||||
|
||||
echo "Finished all micro tests at `date`"
|
||||
|
46
tensorflow/lite/micro/tools/ci_build/test_cortex_m_gcc_generic.sh
Executable file
46
tensorflow/lite/micro/tools/ci_build/test_cortex_m_gcc_generic.sh
Executable file
@ -0,0 +1,46 @@
|
||||
#!/usr/bin/env bash
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
#
|
||||
# Tests the microcontroller code using a Cortex-M4/M4F platform.
|
||||
|
||||
set -e
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
ROOT_DIR=${SCRIPT_DIR}/../../../../..
|
||||
cd "${ROOT_DIR}"
|
||||
|
||||
source tensorflow/lite/micro/tools/ci_build/helper_functions.sh
|
||||
|
||||
TARGET=cortex_m_gcc_generic
|
||||
|
||||
# TODO(b/143715361): downloading first to allow for parallel builds.
|
||||
readable_run make -f tensorflow/lite/micro/tools/make/Makefile TAGS=cmsis-nn TARGET=${TARGET} CORTEX_M_CORE=M4F third_party_downloads
|
||||
|
||||
# Build for Cortex-M4 (no FPU) without CMSIS
|
||||
readable_run make -f tensorflow/lite/micro/tools/make/Makefile clean
|
||||
readable_run make -j8 -f tensorflow/lite/micro/tools/make/Makefile TARGET=${TARGET} CORTEX_M_CORE=M4 microlite
|
||||
|
||||
# Build for Cortex-M4F (FPU present) without CMSIS
|
||||
readable_run make -f tensorflow/lite/micro/tools/make/Makefile clean
|
||||
readable_run make -j8 -f tensorflow/lite/micro/tools/make/Makefile TARGET=${TARGET} CORTEX_M_CORE=M4F microlite
|
||||
|
||||
# Build for Cortex-M4 (no FPU) with CMSIS
|
||||
readable_run make -f tensorflow/lite/micro/tools/make/Makefile clean
|
||||
readable_run make -j8 -f tensorflow/lite/micro/tools/make/Makefile TAGS=cmsis-nn TARGET=${TARGET} CORTEX_M_CORE=M4 microlite
|
||||
|
||||
# Build for Cortex-M4 (FPU present) with CMSIS
|
||||
readable_run make -f tensorflow/lite/micro/tools/make/Makefile clean
|
||||
readable_run make -j8 -f tensorflow/lite/micro/tools/make/Makefile TAGS=cmsis-nn TARGET=${TARGET} CORTEX_M_CORE=M4F microlite
|
@ -118,4 +118,11 @@ ifneq ($(filter cmsis-nn,$(ALL_TAGS)),)
|
||||
$(CMSIS_PATH)CMSIS/DSP/Include/dsp/matrix_functions.h
|
||||
|
||||
|
||||
# Need to add the CMSIS Core includes path.
|
||||
# All other CMSIS header files are included with their relative path
|
||||
# in the CMSIS-NN micro kernel source files in
|
||||
# tensorflow/lite/micro/kernels/cmsis-nn
|
||||
INCLUDES += \
|
||||
-I$(CMSIS_PATH)/CMSIS/Core/Include
|
||||
|
||||
endif
|
||||
|
@ -1,51 +0,0 @@
|
||||
# Generic Makefile target for ARM Cortex M4 builds.
|
||||
# REQUIRED:
|
||||
# - TOOLCHAIN_PATH: The path to the ARM GCC toolchain to use.
|
||||
|
||||
ifeq ($(TARGET), cortex_m4_generic)
|
||||
TARGET_ARCH := arm
|
||||
TARGET_TOOLCHAIN_PREFIX := arm-none-eabi-
|
||||
export PATH := $(TOOLCHAIN_PATH):$(PATH)
|
||||
|
||||
PLATFORM_FLAGS = \
|
||||
-DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK \
|
||||
-DTF_LITE_STATIC_MEMORY \
|
||||
-DNDEBUG \
|
||||
-DTF_LITE_MCU_DEBUG_LOG \
|
||||
-D __FPU_PRESENT=1 \
|
||||
-DARM_MATH_CM4 \
|
||||
-fno-rtti \
|
||||
-fmessage-length=0 \
|
||||
-fno-exceptions \
|
||||
-fno-unwind-tables \
|
||||
-ffunction-sections \
|
||||
-fdata-sections \
|
||||
-funsigned-char \
|
||||
-MMD \
|
||||
-mcpu=cortex-m4 \
|
||||
-mthumb \
|
||||
-mfpu=fpv4-sp-d16 \
|
||||
-mfloat-abi=softfp \
|
||||
-std=gnu++11 \
|
||||
-Wvla \
|
||||
-Wall \
|
||||
-Wextra \
|
||||
-Wno-shadow \
|
||||
-Wno-missing-field-initializers \
|
||||
-Wno-strict-aliasing \
|
||||
-Wno-type-limits \
|
||||
-Wno-unused-function \
|
||||
-Wno-unused-parameter \
|
||||
-fno-delete-null-pointer-checks \
|
||||
-fno-threadsafe-statics \
|
||||
-fomit-frame-pointer \
|
||||
-fno-use-cxa-atexit \
|
||||
-O3
|
||||
|
||||
CXXFLAGS += $(PLATFORM_FLAGS)
|
||||
CCFLAGS += $(PLATFORM_FLAGS)
|
||||
|
||||
LDFLAGS += -Wl,--gc-sections
|
||||
|
||||
endif
|
||||
|
@ -0,0 +1,36 @@
|
||||
# Generic Makefile target for ARM Cortex Mx gcc builds.
|
||||
ifeq ($(TARGET), cortex_m_gcc_generic)
|
||||
TARGET_ARCH := arm
|
||||
TARGET_TOOLCHAIN_PREFIX := arm-none-eabi-
|
||||
export PATH := $(MAKEFILE_DIR)/downloads/gcc_embedded/bin/:$(PATH)
|
||||
|
||||
$(eval $(call add_third_party_download,$(GCC_EMBEDDED_URL),$(GCC_EMBEDDED_MD5),gcc_embedded,))
|
||||
|
||||
PLATFORM_FLAGS = \
|
||||
-DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK \
|
||||
-DTF_LITE_MCU_DEBUG_LOG \
|
||||
-fmessage-length=0 \
|
||||
-fno-exceptions \
|
||||
-fno-unwind-tables \
|
||||
-ffunction-sections \
|
||||
-fdata-sections \
|
||||
-funsigned-char \
|
||||
-mcpu=cortex-m4 \
|
||||
-mfpu=fpv4-sp-d16 \
|
||||
-mthumb \
|
||||
-fomit-frame-pointer
|
||||
|
||||
ifeq ($(CORTEX_M_CORE), M4F)
|
||||
PLATFORM_FLAGS += -mfloat-abi=hard
|
||||
else ifeq ($(CORTEX_M_CORE), M4)
|
||||
PLATFORM_FLAGS += -mfloat-abi=softfp
|
||||
else ifeq ($(CORTEX_M_CORE), )
|
||||
$(error CORTEX_M_CORE=[M4|M4F] not defined on the command line)
|
||||
else
|
||||
$(error invalid target defined in command line option CORTEX_M_CORE=[M4|M4F])
|
||||
endif
|
||||
|
||||
CXXFLAGS += $(PLATFORM_FLAGS)
|
||||
CCFLAGS += $(PLATFORM_FLAGS)
|
||||
|
||||
endif
|
@ -2825,6 +2825,7 @@ tf_py_test(
|
||||
":framework_combinations",
|
||||
":framework_for_generated_wrappers",
|
||||
":framework_test_lib",
|
||||
":lookup_ops",
|
||||
":platform_test",
|
||||
":random_ops",
|
||||
":resource_variable_ops",
|
||||
|
@ -116,6 +116,33 @@ def _is_none_or_undef(value):
|
||||
or isinstance(value, variables.Undefined))
|
||||
|
||||
|
||||
def _verify_tf_condition(cond, tag):
|
||||
"""Ensures that the condition can be used in a TF control flow."""
|
||||
extra_hint = 'to check for None, use `is not None`'
|
||||
cond = ops.convert_to_tensor_v2(cond)
|
||||
|
||||
if cond.dtype != dtypes.bool:
|
||||
raise ValueError(
|
||||
'condition of {} expected to be `tf.bool` scalar, got {}'
|
||||
'; to use as boolean Tensor, use `tf.cast`'
|
||||
'; {}'.format(tag, cond, extra_hint))
|
||||
|
||||
if cond.shape is None or cond.shape.ndims is None:
|
||||
# TODO(mdan): Consider a explicit size check, if not too slow.
|
||||
cond = array_ops.reshape(cond, ())
|
||||
|
||||
elif cond.shape.ndims > 0:
|
||||
known_dims = [d for d in cond.shape.as_list() if d is not None]
|
||||
if np.prod(known_dims) > 1:
|
||||
raise ValueError(
|
||||
'condition of {} expected to be `tf.bool` scalar, got {}'
|
||||
'; {}'.format(tag, cond, extra_hint))
|
||||
else:
|
||||
cond = array_ops.reshape(cond, ())
|
||||
|
||||
return cond
|
||||
|
||||
|
||||
def _verify_loop_init_vars(init_vars, symbol_names, first_iter_vars=None):
|
||||
"""Ensures that all values in the state are valid to use in a TF loop.
|
||||
|
||||
@ -1038,7 +1065,7 @@ def _tf_while_stmt(test, body, get_state, set_state, symbol_names, opts):
|
||||
loop_vars = loop_vars[1:]
|
||||
|
||||
set_state(loop_vars)
|
||||
return test()
|
||||
return _verify_tf_condition(test(), 'while loop')
|
||||
|
||||
def aug_body(*loop_vars):
|
||||
if require_one_iteration:
|
||||
@ -1141,6 +1168,8 @@ def if_stmt(cond, body, orelse, get_state, set_state, symbol_names, nouts):
|
||||
def _tf_if_stmt(
|
||||
cond, body, orelse, get_state, set_state, symbol_names, nouts):
|
||||
"""Overload of if_stmt that stages a TF cond."""
|
||||
cond = _verify_tf_condition(cond, 'if statement')
|
||||
|
||||
if not nouts:
|
||||
prev_get_state, prev_set_state = get_state, set_state
|
||||
# Control flow V1 wants at least one output.
|
||||
|
@ -35,6 +35,7 @@ from tensorflow.python.autograph.utils import testing
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors_impl
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -46,6 +47,20 @@ from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
def _unranked_item(value):
|
||||
rand_rank = random_ops.random_uniform(
|
||||
shape=(), minval=3, maxval=4, dtype=dtypes.int32)
|
||||
rand_shape = array_ops.ones([rand_rank], dtype=dtypes.int32)
|
||||
return array_ops.fill(rand_shape, value)
|
||||
|
||||
|
||||
def _partial_shaped_bools():
|
||||
rand_vect = math_ops.range(
|
||||
random_ops.random_uniform(
|
||||
shape=(), minval=2, maxval=3, dtype=dtypes.int32))
|
||||
return array_ops.expand_dims_v2(rand_vect, 0) < 0
|
||||
|
||||
|
||||
class ForLoopTest(testing.AutoGraphTestCase):
|
||||
|
||||
def test_tensor(self):
|
||||
@ -871,6 +886,60 @@ class WhileLoopTest(testing.AutoGraphTestCase):
|
||||
with self.assertRaisesRegex(ValueError, r"'s'.* shape \(1,\) after"):
|
||||
self._basic_loop(0, lambda i, s: np.array([1], dtype=np.int32))
|
||||
|
||||
def _fixed_while_loop(self, cond_fn):
|
||||
def test_():
|
||||
return cond_fn(s)
|
||||
|
||||
def body():
|
||||
nonlocal s
|
||||
s += 1
|
||||
|
||||
def set_state(loop_vars):
|
||||
nonlocal s
|
||||
s, = loop_vars
|
||||
|
||||
s = constant_op.constant(0)
|
||||
control_flow.while_stmt(
|
||||
test=test_,
|
||||
body=body,
|
||||
get_state=lambda: (s,),
|
||||
set_state=set_state,
|
||||
symbol_names=('s',),
|
||||
opts={})
|
||||
return s
|
||||
|
||||
def _assertFixedLoopResult(self, cond, expected):
|
||||
def test_fn():
|
||||
return self._fixed_while_loop(cond)
|
||||
self.assertEqual(test_fn(), expected)
|
||||
|
||||
def test_tensor_legal_cond_scalar(self):
|
||||
self._assertFixedLoopResult(lambda s: constant_op.constant(False), 0)
|
||||
self._assertFixedLoopResult(lambda s: s < 2, 2)
|
||||
|
||||
def test_tensor_legal_cond_single_element_nd(self):
|
||||
self._assertFixedLoopResult(lambda s: constant_op.constant([[False]]), 0)
|
||||
self._assertFixedLoopResult(lambda s: _unranked_item(False), 0)
|
||||
|
||||
def _assertCondCheckFails(self, cond):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, 'condition of while loop expected to be `tf.bool`'):
|
||||
self._fixed_while_loop(cond)
|
||||
|
||||
def test_tensor_illegal_cond_not_bool(self):
|
||||
self._assertCondCheckFails(lambda s: constant_op.constant(1))
|
||||
self._assertCondCheckFails(lambda s: s)
|
||||
|
||||
def test_tensor_illegal_cond_not_single_element(self):
|
||||
self._assertCondCheckFails(lambda s: constant_op.constant([1, 2, 3]))
|
||||
self._assertCondCheckFails(lambda s: constant_op.constant([True, False]))
|
||||
|
||||
def test_tensor_illegal_cond_not_single_element_dynamic_shape(self):
|
||||
self._fixed_while_loop(lambda s: _partial_shaped_bools())
|
||||
# TODO(mdan): This error is quite bad. Measure the cost of an assertion.
|
||||
self.assertRaisesRuntime(
|
||||
errors_impl.InvalidArgumentError, 'requested shape has 1')
|
||||
|
||||
|
||||
class IfStmtTest(testing.AutoGraphTestCase):
|
||||
|
||||
@ -1065,6 +1134,62 @@ class IfStmtTest(testing.AutoGraphTestCase):
|
||||
TypeError, "'x' has dtype int32.*but.*float32"):
|
||||
self._basic_cond(lambda: 1, lambda: 1.0)
|
||||
|
||||
def _fixed_cond(self, cond_val):
|
||||
def body():
|
||||
nonlocal x
|
||||
x = 1
|
||||
|
||||
def orelse():
|
||||
nonlocal x
|
||||
x = -1
|
||||
|
||||
def set_state(cond_vars):
|
||||
nonlocal x
|
||||
x, = cond_vars
|
||||
|
||||
x = 0
|
||||
control_flow.if_stmt(
|
||||
cond=cond_val,
|
||||
body=body,
|
||||
orelse=orelse,
|
||||
get_state=lambda: (x,),
|
||||
set_state=set_state,
|
||||
symbol_names=('x',),
|
||||
nouts=1)
|
||||
return x
|
||||
|
||||
def _assertFixedCondResult(self, cond, expected):
|
||||
def test_fn():
|
||||
return self._fixed_cond(cond)
|
||||
self.assertEqual(test_fn(), expected)
|
||||
|
||||
def test_tensor_legal_cond_scalar(self):
|
||||
self._assertFixedCondResult(constant_op.constant(True), 1)
|
||||
self._assertFixedCondResult(constant_op.constant(False), -1)
|
||||
|
||||
def test_tensor_legal_cond_single_element_nd(self):
|
||||
self._assertFixedCondResult(constant_op.constant([[True]]), 1)
|
||||
self._assertFixedCondResult(constant_op.constant([[False]]), -1)
|
||||
self._assertFixedCondResult(_unranked_item(True), 1)
|
||||
self._assertFixedCondResult(_unranked_item(False), -1)
|
||||
|
||||
def _assertCondCheckFails(self, cond):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, 'condition of if statement expected to be `tf.bool`'):
|
||||
self._fixed_cond(cond)
|
||||
|
||||
def test_tensor_illegal_cond_not_bool(self):
|
||||
self._assertCondCheckFails(constant_op.constant(1))
|
||||
|
||||
def test_tensor_illegal_cond_not_single_element(self):
|
||||
self._assertCondCheckFails(constant_op.constant([1, 2, 3]))
|
||||
self._assertCondCheckFails(constant_op.constant([True, False]))
|
||||
|
||||
def test_tensor_illegal_cond_not_single_element_dynamic_shape(self):
|
||||
self._fixed_cond(_partial_shaped_bools())
|
||||
# TODO(mdan): This error is quite bad. Measure the cost of an assertion.
|
||||
self.assertRaisesRuntime(
|
||||
errors_impl.InvalidArgumentError, 'requested shape has 1')
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -19,6 +19,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import re
|
||||
import sys
|
||||
import types
|
||||
import unittest
|
||||
|
||||
@ -81,18 +82,29 @@ class AutoGraphTestCase(test.TestCase):
|
||||
@def_function.function(autograph=False) # Testing autograph itself.
|
||||
def fn_wrapper():
|
||||
self.assertions = []
|
||||
self.raises_cm = None
|
||||
self.graph_assertions = []
|
||||
self.trace_log = []
|
||||
fn()
|
||||
targets = [args for _, args in self.assertions]
|
||||
return targets
|
||||
|
||||
tensors = fn_wrapper()
|
||||
try:
|
||||
tensors = fn_wrapper()
|
||||
|
||||
for assertion in self.graph_assertions:
|
||||
assertion(fn_wrapper.get_concrete_function().graph)
|
||||
for assertion in self.graph_assertions:
|
||||
assertion(fn_wrapper.get_concrete_function().graph)
|
||||
|
||||
actuals = self.evaluate(tensors)
|
||||
|
||||
except: # pylint:disable=bare-except
|
||||
if self.raises_cm is not None:
|
||||
# Note: Yes, the Raises and function contexts cross.
|
||||
self.raises_cm.__exit__(*sys.exc_info())
|
||||
return
|
||||
else:
|
||||
raise
|
||||
|
||||
actuals = self.evaluate(tensors)
|
||||
for (assertion, _), values in zip(self.assertions, actuals):
|
||||
assertion(*values)
|
||||
|
||||
@ -109,6 +121,7 @@ class AutoGraphTestCase(test.TestCase):
|
||||
super().setUp()
|
||||
self.variables = {}
|
||||
self.trace_log = []
|
||||
self.raises_cm = None
|
||||
op_callbacks.add_op_callback(self._op_callback)
|
||||
|
||||
def tearDown(self):
|
||||
@ -145,3 +158,9 @@ class AutoGraphTestCase(test.TestCase):
|
||||
|
||||
def assertDictEqual(self, *args):
|
||||
self.assertions.append((super().assertDictEqual, list(args)))
|
||||
|
||||
def assertRaisesRuntime(self, *args):
|
||||
if self.raises_cm is not None:
|
||||
raise ValueError('cannot use more than one assertRaisesRuntime in a test')
|
||||
self.raises_cm = self.assertRaisesRegex(*args)
|
||||
self.raises_cm.__enter__()
|
||||
|
@ -33,7 +33,7 @@ from tensorflow.python.util.tf_export import tf_export
|
||||
# This value changes every day with an automatic CL. It can be modified in code
|
||||
# via `forward_compatibility_horizon()` or with the environment variable
|
||||
# TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date.
|
||||
_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2020, 10, 2)
|
||||
_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2020, 10, 3)
|
||||
_FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS"
|
||||
_FORWARD_COMPATIBILITY_DATE_NUMBER = None
|
||||
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user