Merge branch 'master' of https://github.com/tensorflow/tensorflow
This commit is contained in:
commit
8a13ac7b5e
2
.bazelrc
2
.bazelrc
@ -323,8 +323,6 @@ build:windows --copt=/experimental:preprocessor
|
|||||||
build:windows --host_copt=/experimental:preprocessor
|
build:windows --host_copt=/experimental:preprocessor
|
||||||
|
|
||||||
# Misc build options we need for windows.
|
# Misc build options we need for windows.
|
||||||
build:windows --linkopt=/DEBUG
|
|
||||||
build:windows --host_linkopt=/DEBUG
|
|
||||||
build:windows --linkopt=/OPT:REF
|
build:windows --linkopt=/OPT:REF
|
||||||
build:windows --host_linkopt=/OPT:REF
|
build:windows --host_linkopt=/OPT:REF
|
||||||
build:windows --linkopt=/OPT:ICF
|
build:windows --linkopt=/OPT:ICF
|
||||||
|
@ -206,6 +206,9 @@
|
|||||||
`fit()`. Running multiple batches inside a single `tf.function` call can
|
`fit()`. Running multiple batches inside a single `tf.function` call can
|
||||||
greatly improve performance on TPUs or small models with a large Python
|
greatly improve performance on TPUs or small models with a large Python
|
||||||
overhead.
|
overhead.
|
||||||
|
* Improvements to Keras preprocessing layers:
|
||||||
|
* TextVectorization can now accept a vocabulary list or file as an
|
||||||
|
init arg.
|
||||||
* `tf.function` / AutoGraph:
|
* `tf.function` / AutoGraph:
|
||||||
|
|
||||||
* Added `experimental_follow_type_hints` argument for `tf.function`. When
|
* Added `experimental_follow_type_hints` argument for `tf.function`. When
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
|
load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
|
||||||
load(
|
load(
|
||||||
"//tensorflow:tensorflow.bzl",
|
"//tensorflow:tensorflow.bzl",
|
||||||
"if_tpu",
|
"if_libtpu",
|
||||||
"tf_cc_test",
|
"tf_cc_test",
|
||||||
"tf_copts",
|
"tf_copts",
|
||||||
"tf_cuda_cc_test",
|
"tf_cuda_cc_test",
|
||||||
@ -289,7 +289,7 @@ cc_library(
|
|||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core/lib/llvm_rtti",
|
"//tensorflow/core/lib/llvm_rtti",
|
||||||
] + if_tpu(
|
] + if_libtpu(
|
||||||
if_false = ["//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration"],
|
if_false = ["//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration"],
|
||||||
if_true = [],
|
if_true = [],
|
||||||
),
|
),
|
||||||
@ -354,7 +354,7 @@ cc_library(
|
|||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core/lib/llvm_rtti",
|
"//tensorflow/core/lib/llvm_rtti",
|
||||||
] + if_tpu(
|
] + if_libtpu(
|
||||||
if_false = ["//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration"],
|
if_false = ["//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration"],
|
||||||
if_true = [],
|
if_true = [],
|
||||||
),
|
),
|
||||||
|
@ -39,7 +39,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/c/eager/tfe_op_internal.h"
|
#include "tensorflow/c/eager/tfe_op_internal.h"
|
||||||
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
||||||
#include "tensorflow/c/tf_tensor_internal.h"
|
#include "tensorflow/c/tf_tensor_internal.h"
|
||||||
#if defined(PLATFORM_GOOGLE) && !defined(LIBTFTPU)
|
#if defined(PLATFORM_GOOGLE) && !defined(LIBTPU_ON_GCE)
|
||||||
#include "tensorflow/core/tfrt/eager/c_api_tfrt.h"
|
#include "tensorflow/core/tfrt/eager/c_api_tfrt.h"
|
||||||
#endif
|
#endif
|
||||||
#include "tensorflow/core/common_runtime/device.h"
|
#include "tensorflow/core/common_runtime/device.h"
|
||||||
@ -729,7 +729,7 @@ void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; }
|
|||||||
|
|
||||||
TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
|
TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
|
||||||
if (opts->use_tfrt) {
|
if (opts->use_tfrt) {
|
||||||
#if defined(PLATFORM_GOOGLE) && !defined(LIBTFTPU)
|
#if defined(PLATFORM_GOOGLE) && !defined(LIBTPU_ON_GCE)
|
||||||
return tensorflow::wrap(new tfrt::tf::ContextInterface(opts->async));
|
return tensorflow::wrap(new tfrt::tf::ContextInterface(opts->async));
|
||||||
#else
|
#else
|
||||||
status->status = tensorflow::errors::Unimplemented("TFRT is not supported");
|
status->status = tensorflow::errors::Unimplemented("TFRT is not supported");
|
||||||
|
@ -42,13 +42,15 @@ cc_library(
|
|||||||
name = "reader",
|
name = "reader",
|
||||||
srcs = ["reader.cc"],
|
srcs = ["reader.cc"],
|
||||||
hdrs = ["reader.h"],
|
hdrs = ["reader.h"],
|
||||||
deps = [":constants"] + if_not_mobile([
|
deps = [
|
||||||
|
":constants",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
] + if_not_mobile([
|
||||||
# TODO(b/111634734): :lib and :protos_all contain dependencies that
|
# TODO(b/111634734): :lib and :protos_all contain dependencies that
|
||||||
# cannot be built on mobile platforms. Instead, include the appropriate
|
# cannot be built on mobile platforms. Instead, include the appropriate
|
||||||
# tf_lib depending on the build platform.
|
# tf_lib depending on the build platform.
|
||||||
"@com_google_absl//absl/memory:memory",
|
"@com_google_absl//absl/memory:memory",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:protos_all_cc",
|
|
||||||
]),
|
]),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -4,7 +4,7 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
|
|||||||
load("//tensorflow:tensorflow.bzl", "cc_header_only_library", "if_mlir", "tf_cc_test")
|
load("//tensorflow:tensorflow.bzl", "cc_header_only_library", "if_mlir", "tf_cc_test")
|
||||||
|
|
||||||
# buildifier: disable=same-origin-load
|
# buildifier: disable=same-origin-load
|
||||||
load("//tensorflow:tensorflow.bzl", "if_tpu", "tf_copts")
|
load("//tensorflow:tensorflow.bzl", "if_libtpu", "tf_copts")
|
||||||
load("//tensorflow/stream_executor:build_defs.bzl", "if_cuda_or_rocm")
|
load("//tensorflow/stream_executor:build_defs.bzl", "if_cuda_or_rocm")
|
||||||
|
|
||||||
# buildifier: disable=same-origin-load
|
# buildifier: disable=same-origin-load
|
||||||
@ -77,7 +77,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/jit/kernels:xla_ops",
|
"//tensorflow/compiler/jit/kernels:xla_ops",
|
||||||
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
|
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
|
||||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||||
] + if_tpu(
|
] + if_libtpu(
|
||||||
if_false = ["//tensorflow/compiler/xla/service:cpu_plugin"],
|
if_false = ["//tensorflow/compiler/xla/service:cpu_plugin"],
|
||||||
if_true = [],
|
if_true = [],
|
||||||
),
|
),
|
||||||
@ -114,7 +114,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||||
"//tensorflow/core:core_cpu_internal",
|
"//tensorflow/core:core_cpu_internal",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
] + if_tpu(
|
] + if_libtpu(
|
||||||
if_false = [
|
if_false = [
|
||||||
"//tensorflow/compiler/xla/service:cpu_plugin", # buildcleaner: keep
|
"//tensorflow/compiler/xla/service:cpu_plugin", # buildcleaner: keep
|
||||||
],
|
],
|
||||||
@ -141,7 +141,7 @@ cc_library(
|
|||||||
"//tensorflow/core:core_cpu_internal",
|
"//tensorflow/core:core_cpu_internal",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core/common_runtime/gpu:gpu_init",
|
"//tensorflow/core/common_runtime/gpu:gpu_init",
|
||||||
] + if_tpu(
|
] + if_libtpu(
|
||||||
if_false = [
|
if_false = [
|
||||||
"//tensorflow/compiler/xla/service:gpu_plugin", # buildcleaner: keep
|
"//tensorflow/compiler/xla/service:gpu_plugin", # buildcleaner: keep
|
||||||
],
|
],
|
||||||
@ -375,7 +375,7 @@ cc_library(
|
|||||||
"//tensorflow/core:lib_internal",
|
"//tensorflow/core:lib_internal",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core/platform:logging",
|
"//tensorflow/core/platform:logging",
|
||||||
] + if_tpu(
|
] + if_libtpu(
|
||||||
if_false = [
|
if_false = [
|
||||||
"//tensorflow/compiler/mlir:array_container_utils",
|
"//tensorflow/compiler/mlir:array_container_utils",
|
||||||
"//tensorflow/compiler/mlir/tensorflow:compile_mlir_util_no_tf_dialect_passes",
|
"//tensorflow/compiler/mlir/tensorflow:compile_mlir_util_no_tf_dialect_passes",
|
||||||
|
@ -47,7 +47,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/public/version.h"
|
#include "tensorflow/core/public/version.h"
|
||||||
#include "tensorflow/core/util/dump_graph.h"
|
#include "tensorflow/core/util/dump_graph.h"
|
||||||
|
|
||||||
#if !defined(LIBTFTPU)
|
#if !defined(LIBTPU_ON_GCE)
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h"
|
#include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h"
|
||||||
#include "tensorflow/compiler/mlir/utils/array_container_utils.h"
|
#include "tensorflow/compiler/mlir/utils/array_container_utils.h"
|
||||||
#endif
|
#endif
|
||||||
@ -289,7 +289,7 @@ Status XlaCompilationCache::CompileSingleOp(
|
|||||||
});
|
});
|
||||||
const ConfigProto* config = ctx->function_library()->config_proto();
|
const ConfigProto* config = ctx->function_library()->config_proto();
|
||||||
bool use_mlir = config && config->experimental().enable_mlir_bridge();
|
bool use_mlir = config && config->experimental().enable_mlir_bridge();
|
||||||
#ifdef LIBTFTPU
|
#ifdef LIBTPU_ON_GCE
|
||||||
if (use_mlir && has_tensor_list_arg) {
|
if (use_mlir && has_tensor_list_arg) {
|
||||||
LOG(WARNING) << "MLIR is not supported in this environment.";
|
LOG(WARNING) << "MLIR is not supported in this environment.";
|
||||||
}
|
}
|
||||||
|
@ -157,6 +157,9 @@ def HLO_AbsOp: HLO_UnaryElementwiseOp<"abs",
|
|||||||
>];
|
>];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def HLO_CbrtOp: HLO_UnaryElementwiseOp<"cbrt",
|
||||||
|
[NoSideEffect, SameOperandsAndResultType], HLO_FpTensor>, BASE_HLO_CbrtOp;
|
||||||
|
|
||||||
def HLO_CeilOp: HLO_UnaryElementwiseOp<"ceil",
|
def HLO_CeilOp: HLO_UnaryElementwiseOp<"ceil",
|
||||||
[NoSideEffect, SameOperandsAndResultType], HLO_FpTensor>, BASE_HLO_CeilOp;
|
[NoSideEffect, SameOperandsAndResultType], HLO_FpTensor>, BASE_HLO_CeilOp;
|
||||||
|
|
||||||
@ -1423,4 +1426,21 @@ def HLO_FusionOp : HLO_Op<"fusion", []> {
|
|||||||
let hasCustomHLOConverter = 1;
|
let hasCustomHLOConverter = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// This is an op for purposes internal to XLA/GPU.
|
||||||
|
def HLO_BitcastOp : HLO_Op<"bitcast", [NoSideEffect]>, BASE_HLO_BitcastOp {
|
||||||
|
let arguments = (ins HLO_Tensor:$operand);
|
||||||
|
let results = (outs HLO_Tensor);
|
||||||
|
let hasCustomHLOConverter = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
def HLO_ReducePrecisionOp: HLO_Op<"reduce_precision", [SameOperandsAndResultShape]>,
|
||||||
|
BASE_HLO_ReducePrecisionOp {
|
||||||
|
let arguments = (ins
|
||||||
|
HLO_FpTensor:$operand,
|
||||||
|
I32Attr:$exponent_bits,
|
||||||
|
I32Attr:$mantissa_bits
|
||||||
|
);
|
||||||
|
let results = (outs HLO_FpTensor:$output);
|
||||||
|
}
|
||||||
|
|
||||||
#endif // HLO_OPS
|
#endif // HLO_OPS
|
||||||
|
@ -127,6 +127,17 @@ class BASE_HLO_AbsOp {
|
|||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class BASE_HLO_CbrtOp {
|
||||||
|
string summary = "Cubic root operator";
|
||||||
|
|
||||||
|
string description = [{
|
||||||
|
Returns element-wise cubic root of the operand.
|
||||||
|
|
||||||
|
See
|
||||||
|
https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
class BASE_HLO_CeilOp {
|
class BASE_HLO_CeilOp {
|
||||||
string summary = "Ceil operator";
|
string summary = "Ceil operator";
|
||||||
|
|
||||||
@ -1336,4 +1347,17 @@ class BASE_HLO_WhileOp {
|
|||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class BASE_HLO_BitcastOp {
|
||||||
|
string summary = "Bitcast operator";
|
||||||
|
|
||||||
|
string description = [{
|
||||||
|
This op changes the shape of the input in the way that the physical
|
||||||
|
arranggment of elements are unchanged.
|
||||||
|
|
||||||
|
However, the op needs layout information to make sense of "physical
|
||||||
|
arrangement of elements". Layout support in MHLO is currently under
|
||||||
|
exploration.
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
#endif // HLO_OPS_BASE
|
#endif // HLO_OPS_BASE
|
||||||
|
@ -1193,3 +1193,24 @@ func @incompatible_shapes(%arg0: tensor<?xf32>, %shape: tensor<2xindex>) -> tens
|
|||||||
%0 = "mhlo.dynamic_reshape"(%arg0, %shape) : (tensor<?xf32>, tensor<2xindex>) -> tensor<?xf32>
|
%0 = "mhlo.dynamic_reshape"(%arg0, %shape) : (tensor<?xf32>, tensor<2xindex>) -> tensor<?xf32>
|
||||||
return %0 : tensor<?xf32>
|
return %0 : tensor<?xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @cbrt(%arg: tensor<2x4xf32>) -> tensor<2x4xf32> {
|
||||||
|
%0 = "mhlo.cbrt"(%arg) : (tensor<2x4xf32>) -> tensor<2x4xf32>
|
||||||
|
return %0 : tensor<2x4xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @bitcast(%arg: tensor<2x4xf32>) -> tensor<2x4xf32> {
|
||||||
|
%0 = "mhlo.bitcast"(%arg) : (tensor<2x4xf32>) -> tensor<2x4xf32>
|
||||||
|
return %0 : tensor<2x4xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @bitcast(%arg: tensor<2x4xf32>) -> tensor<2x4xf32> {
|
||||||
|
%0 = "mhlo.reduce_precision"(%arg) {exponent_bits=2 : i32, mantissa_bits=3 : i32} : (tensor<2x4xf32>) -> tensor<2x4xf32>
|
||||||
|
return %0 : tensor<2x4xf32>
|
||||||
|
}
|
||||||
|
@ -74,8 +74,8 @@ tool_names = [
|
|||||||
'tf_tfjs_translate', 'flatbuffer_to_string', 'flatbuffer_translate',
|
'tf_tfjs_translate', 'flatbuffer_to_string', 'flatbuffer_translate',
|
||||||
'tf-mlir-translate', 'mlir-tflite-runner', 'tfcompile',
|
'tf-mlir-translate', 'mlir-tflite-runner', 'tfcompile',
|
||||||
'json_to_flatbuffer', 'xla-gpu-opt', 'xla-mlir-gpu-opt', 'xla-opt',
|
'json_to_flatbuffer', 'xla-gpu-opt', 'xla-mlir-gpu-opt', 'xla-opt',
|
||||||
'hlo_to_llvm_ir', 'kernel-gen-opt', 'tf_to_kernel', 'tf_to_gpu_binary',
|
'hlo_to_llvm_ir', 'kernel-gen-opt', 'tf_to_gpu_binary', 'xla-thunks-opt',
|
||||||
'xla-thunks-opt', 'tfjs-opt'
|
'tfjs-opt'
|
||||||
]
|
]
|
||||||
tools = [ToolSubst(s, unresolved='ignore') for s in tool_names]
|
tools = [ToolSubst(s, unresolved='ignore') for s in tool_names]
|
||||||
llvm_config.add_tool_substitutions(tools, tool_dirs)
|
llvm_config.add_tool_substitutions(tools, tool_dirs)
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test")
|
load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test")
|
||||||
|
load("//tensorflow:tensorflow.bzl", "tf_python_pybind_extension")
|
||||||
load(
|
load(
|
||||||
"//third_party/mlir:tblgen.bzl",
|
"//third_party/mlir:tblgen.bzl",
|
||||||
"gentbl",
|
"gentbl",
|
||||||
@ -226,3 +227,23 @@ cc_library(
|
|||||||
],
|
],
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_python_pybind_extension(
|
||||||
|
name = "tfr_wrapper",
|
||||||
|
srcs = ["python/tfr_wrapper.cc"],
|
||||||
|
module_name = "tfr_wrapper",
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/compiler/mlir/tensorflow",
|
||||||
|
"//tensorflow/compiler/mlir/tfr",
|
||||||
|
"//tensorflow/python:pybind11_lib",
|
||||||
|
"//tensorflow/python:pybind11_status",
|
||||||
|
"@llvm-project//llvm:Support",
|
||||||
|
"@llvm-project//mlir:IR",
|
||||||
|
"@llvm-project//mlir:Parser",
|
||||||
|
"@llvm-project//mlir:SCFDialect",
|
||||||
|
"@llvm-project//mlir:Shape",
|
||||||
|
"@llvm-project//mlir:StandardOps",
|
||||||
|
"@pybind11",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
58
tensorflow/compiler/mlir/tfr/python/tfr_wrapper.cc
Normal file
58
tensorflow/compiler/mlir/tfr/python/tfr_wrapper.cc
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "llvm/Support/SourceMgr.h"
|
||||||
|
#include "llvm/Support/raw_ostream.h"
|
||||||
|
#include "mlir/Dialect/SCF/SCF.h" // from @llvm-project
|
||||||
|
#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project
|
||||||
|
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||||
|
#include "mlir/IR/AsmState.h" // from @llvm-project
|
||||||
|
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||||
|
#include "mlir/IR/Verifier.h" // from @llvm-project
|
||||||
|
#include "mlir/Parser.h" // from @llvm-project
|
||||||
|
#include "pybind11/pybind11.h"
|
||||||
|
#include "pybind11/stl.h"
|
||||||
|
#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
|
||||||
|
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
|
||||||
|
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||||
|
#include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.h"
|
||||||
|
#include "tensorflow/python/lib/core/pybind11_lib.h"
|
||||||
|
#include "tensorflow/python/lib/core/pybind11_status.h"
|
||||||
|
|
||||||
|
PYBIND11_MODULE(tfr_wrapper, m) {
|
||||||
|
m.def("verify", [](std::string input) {
|
||||||
|
mlir::MLIRContext ctx(/*loadAllDialects=*/true);
|
||||||
|
auto& registry = ctx.getDialectRegistry();
|
||||||
|
registry.insert<mlir::scf::SCFDialect, mlir::TF::TensorFlowDialect,
|
||||||
|
mlir::StandardOpsDialect, mlir::shape::ShapeDialect,
|
||||||
|
mlir::TFR::TFRDialect>();
|
||||||
|
ctx.getDialectRegistry().loadAll(&ctx);
|
||||||
|
|
||||||
|
llvm::SourceMgr source_mgr = llvm::SourceMgr();
|
||||||
|
source_mgr.AddNewSourceBuffer(llvm::MemoryBuffer::getMemBuffer(input),
|
||||||
|
llvm::SMLoc());
|
||||||
|
auto module = mlir::parseSourceFile(source_mgr, &ctx);
|
||||||
|
if (!module) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
mlir::SourceMgrDiagnosticHandler sourceMgrHandler(source_mgr, &ctx);
|
||||||
|
if (failed(mlir::verify(*module))) {
|
||||||
|
module->emitError("Invalid MLIR module: failed verification.");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
});
|
||||||
|
}
|
@ -105,10 +105,7 @@ tf_cc_binary(
|
|||||||
tf_cc_binary(
|
tf_cc_binary(
|
||||||
name = "tf_to_kernel",
|
name = "tf_to_kernel",
|
||||||
srcs = ["tf_to_kernel.cc"],
|
srcs = ["tf_to_kernel.cc"],
|
||||||
visibility = [
|
visibility = ["//tensorflow/core/kernels/mlir_generated:__pkg__"],
|
||||||
"//tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_kernel:__pkg__",
|
|
||||||
"//tensorflow/core/kernels/mlir_generated:__pkg__",
|
|
||||||
],
|
|
||||||
deps = [
|
deps = [
|
||||||
":kernel_creator",
|
":kernel_creator",
|
||||||
"//tensorflow/compiler/mlir:init_mlir",
|
"//tensorflow/compiler/mlir:init_mlir",
|
||||||
@ -162,7 +159,7 @@ cc_library(
|
|||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "tf_cuda_runtime_wrappers",
|
name = "tf_cuda_runtime_wrappers",
|
||||||
srcs = ["tf_cuda_runtime_wrappers.cpp"],
|
srcs = ["tf_cuda_runtime_wrappers.cc"],
|
||||||
compatible_with = get_compatible_with_cloud(),
|
compatible_with = get_compatible_with_cloud(),
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/core/platform/default/build_config:stream_executor_cuda",
|
"//tensorflow/core/platform/default/build_config:stream_executor_cuda",
|
||||||
|
@ -174,8 +174,7 @@ Status LowerTFtoGPU(mlir::ModuleOp module, bool gpu_binary_only,
|
|||||||
Status LowerGPUToLLVM(mlir::ModuleOp module, bool gpu_binary_only,
|
Status LowerGPUToLLVM(mlir::ModuleOp module, bool gpu_binary_only,
|
||||||
llvm::ArrayRef<uint32_t> same_shape,
|
llvm::ArrayRef<uint32_t> same_shape,
|
||||||
llvm::StringRef gpu_binary_attr_name,
|
llvm::StringRef gpu_binary_attr_name,
|
||||||
llvm::ArrayRef<uint32_t> architectures,
|
int32_t architecture) {
|
||||||
bool generate_fatbin) {
|
|
||||||
mlir::PassManager pm(module.getContext());
|
mlir::PassManager pm(module.getContext());
|
||||||
applyTensorflowAndCLOptions(pm);
|
applyTensorflowAndCLOptions(pm);
|
||||||
|
|
||||||
@ -188,7 +187,7 @@ Status LowerGPUToLLVM(mlir::ModuleOp module, bool gpu_binary_only,
|
|||||||
}
|
}
|
||||||
kernel_pm.addPass(mlir::createStripDebugInfoPass());
|
kernel_pm.addPass(mlir::createStripDebugInfoPass());
|
||||||
kernel_pm.addPass(mlir::kernel_gen::transforms::CreateGpuKernelToBlobPass(
|
kernel_pm.addPass(mlir::kernel_gen::transforms::CreateGpuKernelToBlobPass(
|
||||||
gpu_binary_attr_name, architectures, generate_fatbin));
|
gpu_binary_attr_name, architecture));
|
||||||
|
|
||||||
if (!gpu_binary_only) {
|
if (!gpu_binary_only) {
|
||||||
pm.addPass(mlir::kernel_gen::transforms::CreateTFKernelToLLVMPass());
|
pm.addPass(mlir::kernel_gen::transforms::CreateTFKernelToLLVMPass());
|
||||||
@ -203,9 +202,9 @@ Status LowerGPUToLLVM(mlir::ModuleOp module, bool gpu_binary_only,
|
|||||||
|
|
||||||
StatusOr<mlir::OwningModuleRef> GenerateKernelForTfCode(
|
StatusOr<mlir::OwningModuleRef> GenerateKernelForTfCode(
|
||||||
mlir::MLIRContext& context, llvm::StringRef tf_code, bool gpu_binary_only,
|
mlir::MLIRContext& context, llvm::StringRef tf_code, bool gpu_binary_only,
|
||||||
llvm::ArrayRef<uint32_t> architectures, llvm::ArrayRef<uint32_t> tile_sizes,
|
int32_t architecture, llvm::ArrayRef<uint32_t> tile_sizes,
|
||||||
llvm::ArrayRef<uint32_t> same_shape,
|
llvm::ArrayRef<uint32_t> same_shape,
|
||||||
llvm::ArrayRef<uint32_t> unroll_factors, bool generate_fatbin) {
|
llvm::ArrayRef<uint32_t> unroll_factors) {
|
||||||
mlir::RegisterAllTensorFlowDialects(context.getDialectRegistry());
|
mlir::RegisterAllTensorFlowDialects(context.getDialectRegistry());
|
||||||
mlir::OwningModuleRef module = mlir::parseSourceString(tf_code, &context);
|
mlir::OwningModuleRef module = mlir::parseSourceString(tf_code, &context);
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
@ -222,8 +221,7 @@ StatusOr<mlir::OwningModuleRef> GenerateKernelForTfCode(
|
|||||||
TF_RETURN_IF_ERROR(xla::mlir_gpu::LowerKernelBodiesToNVVM(module.get()));
|
TF_RETURN_IF_ERROR(xla::mlir_gpu::LowerKernelBodiesToNVVM(module.get()));
|
||||||
#endif
|
#endif
|
||||||
TF_RETURN_IF_ERROR(LowerGPUToLLVM(module.get(), gpu_binary_only, same_shape,
|
TF_RETURN_IF_ERROR(LowerGPUToLLVM(module.get(), gpu_binary_only, same_shape,
|
||||||
kGpuBinaryAttrName, architectures,
|
kGpuBinaryAttrName, architecture));
|
||||||
generate_fatbin));
|
|
||||||
return module;
|
return module;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -38,10 +38,9 @@ namespace kernel_gen {
|
|||||||
// false, lowers the host side to LLVM Dialect.
|
// false, lowers the host side to LLVM Dialect.
|
||||||
xla::StatusOr<mlir::OwningModuleRef> GenerateKernelForTfCode(
|
xla::StatusOr<mlir::OwningModuleRef> GenerateKernelForTfCode(
|
||||||
mlir::MLIRContext& context, llvm::StringRef tf_code, bool gpu_binary_only,
|
mlir::MLIRContext& context, llvm::StringRef tf_code, bool gpu_binary_only,
|
||||||
llvm::ArrayRef<uint32_t> architectures = {75},
|
int32_t architecture = 75, llvm::ArrayRef<uint32_t> tile_sizes = {16, 64},
|
||||||
llvm::ArrayRef<uint32_t> tile_sizes = {16, 64},
|
|
||||||
llvm::ArrayRef<uint32_t> same_shape = {},
|
llvm::ArrayRef<uint32_t> same_shape = {},
|
||||||
llvm::ArrayRef<uint32_t> unroll_factors = {}, bool generate_fatbin = true);
|
llvm::ArrayRef<uint32_t> unroll_factors = {});
|
||||||
|
|
||||||
// Extracts gpu_binary from the converted module.
|
// Extracts gpu_binary from the converted module.
|
||||||
xla::StatusOr<std::string> ExtractGpuBinary(mlir::ModuleOp module);
|
xla::StatusOr<std::string> ExtractGpuBinary(mlir::ModuleOp module);
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
// RUN: tf_to_gpu_binary --input=%s --output=%t --same_shape=0,1 --unroll_factors=4 --tile_sizes=256 --arch=70
|
// RUN: tf_to_gpu_binary --input=%s --output=%t --same_shape=0,1 --unroll_factors=4 --tile_sizes=256 --arch=70
|
||||||
func @tanh(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
func @tanh(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||||
%0 = "tf.Tanh"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
%0 = "tf.Tanh"(%arg0) { }
|
||||||
|
: (tensor<?xf32>) -> tensor<?xf32>
|
||||||
return %0 : tensor<?xf32>
|
return %0 : tensor<?xf32>
|
||||||
}
|
}
|
||||||
|
@ -1,17 +0,0 @@
|
|||||||
load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests")
|
|
||||||
|
|
||||||
package(licenses = ["notice"])
|
|
||||||
|
|
||||||
glob_lit_tests(
|
|
||||||
data = [
|
|
||||||
"//tensorflow/compiler/mlir/tools/kernel_gen:tf_to_kernel",
|
|
||||||
"@llvm-project//mlir:run_lit.sh",
|
|
||||||
],
|
|
||||||
default_tags = [
|
|
||||||
# We need access to the CUDA SDK.
|
|
||||||
"gpu",
|
|
||||||
"no_rocm",
|
|
||||||
],
|
|
||||||
driver = "//tensorflow/compiler/mlir:run_lit.sh",
|
|
||||||
test_file_exts = ["mlir"],
|
|
||||||
)
|
|
@ -1,6 +0,0 @@
|
|||||||
// RUN: tf_to_kernel --input=%s --output=%t --same_shape=0,1 --unroll_factors=4 --tile_sizes=256 --arch=70,75
|
|
||||||
|
|
||||||
func @tanh(%arg: tensor<*xf32>) -> tensor<*xf32> {
|
|
||||||
%0 = "tf.Tanh"(%arg) : (tensor<*xf32>) -> tensor<*xf32>
|
|
||||||
return %0 : tensor<*xf32>
|
|
||||||
}
|
|
@ -20,9 +20,9 @@ limitations under the License.
|
|||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
|
|
||||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/ArrayRef.h"
|
#include "llvm/ADT/ArrayRef.h"
|
||||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/Support/raw_ostream.h"
|
#include "llvm/Support/raw_ostream.h"
|
||||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h"
|
#include "mlir/ExecutionEngine/CRunnerUtils.h" // from @llvm-project
|
||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA
|
||||||
#include "third_party/gpus/cuda/include/cuda.h"
|
#include "third_party/gpus/cuda/include/cuda.h"
|
@ -48,7 +48,7 @@ xla::Status Run(llvm::StringRef input_file, llvm::StringRef output_file,
|
|||||||
mlir::OwningModuleRef module,
|
mlir::OwningModuleRef module,
|
||||||
GenerateKernelForTfCode(context, tf_code, /*gpu_binary_only=*/true,
|
GenerateKernelForTfCode(context, tf_code, /*gpu_binary_only=*/true,
|
||||||
architecture, tile_sizes, same_shape,
|
architecture, tile_sizes, same_shape,
|
||||||
unroll_factors, /*generate_fatbin=*/false));
|
unroll_factors));
|
||||||
// Extract gpu_binary.
|
// Extract gpu_binary.
|
||||||
TF_ASSIGN_OR_RETURN(std::string gpu_binary, ExtractGpuBinary(*module));
|
TF_ASSIGN_OR_RETURN(std::string gpu_binary, ExtractGpuBinary(*module));
|
||||||
|
|
||||||
|
@ -95,8 +95,7 @@ xla::StatusOr<std::string> EmitToBinary(mlir::ModuleOp module) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
xla::Status Run(llvm::StringRef input_file, llvm::StringRef output_file,
|
xla::Status Run(llvm::StringRef input_file, llvm::StringRef output_file,
|
||||||
llvm::ArrayRef<uint32_t> architectures,
|
int32_t architecture, llvm::ArrayRef<uint32_t> tile_sizes,
|
||||||
llvm::ArrayRef<uint32_t> tile_sizes,
|
|
||||||
llvm::ArrayRef<uint32_t> same_shape,
|
llvm::ArrayRef<uint32_t> same_shape,
|
||||||
llvm::ArrayRef<uint32_t> unroll_factors) {
|
llvm::ArrayRef<uint32_t> unroll_factors) {
|
||||||
// Read TF code.
|
// Read TF code.
|
||||||
@ -108,7 +107,7 @@ xla::Status Run(llvm::StringRef input_file, llvm::StringRef output_file,
|
|||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
mlir::OwningModuleRef module,
|
mlir::OwningModuleRef module,
|
||||||
GenerateKernelForTfCode(context, tf_code, /*gpu_binary_only=*/false,
|
GenerateKernelForTfCode(context, tf_code, /*gpu_binary_only=*/false,
|
||||||
architectures, tile_sizes, same_shape,
|
architecture, tile_sizes, same_shape,
|
||||||
unroll_factors));
|
unroll_factors));
|
||||||
// Get binary.
|
// Get binary.
|
||||||
TF_ASSIGN_OR_RETURN(std::string binary, EmitToBinary(*module));
|
TF_ASSIGN_OR_RETURN(std::string binary, EmitToBinary(*module));
|
||||||
@ -130,8 +129,8 @@ int main(int argc, char** argv) {
|
|||||||
llvm::cl::opt<std::string> output_file(
|
llvm::cl::opt<std::string> output_file(
|
||||||
"output", llvm::cl::desc("output file"), llvm::cl::value_desc("filename"),
|
"output", llvm::cl::desc("output file"), llvm::cl::value_desc("filename"),
|
||||||
llvm::cl::init("foo.bin"));
|
llvm::cl::init("foo.bin"));
|
||||||
llvm::cl::list<uint32_t> architectures(
|
llvm::cl::list<int32_t> architecture(
|
||||||
"arch", llvm::cl::desc("target architectures (e.g. 50 for sm_50)"),
|
"arch", llvm::cl::desc("target architecture (e.g. 50 for sm_50)"),
|
||||||
llvm::cl::OneOrMore, llvm::cl::CommaSeparated);
|
llvm::cl::OneOrMore, llvm::cl::CommaSeparated);
|
||||||
llvm::cl::list<uint32_t> tile_sizes(
|
llvm::cl::list<uint32_t> tile_sizes(
|
||||||
"tile_sizes", llvm::cl::desc("tile sizes to use"), llvm::cl::ZeroOrMore,
|
"tile_sizes", llvm::cl::desc("tile sizes to use"), llvm::cl::ZeroOrMore,
|
||||||
@ -152,7 +151,7 @@ int main(int argc, char** argv) {
|
|||||||
llvm::cl::ParseCommandLineOptions(argc, argv, "TF op GPU kernel generator\n");
|
llvm::cl::ParseCommandLineOptions(argc, argv, "TF op GPU kernel generator\n");
|
||||||
|
|
||||||
auto status =
|
auto status =
|
||||||
tensorflow::kernel_gen::Run(input_file, output_file, architectures,
|
tensorflow::kernel_gen::Run(input_file, output_file, architecture.front(),
|
||||||
tile_sizes, same_shape, unroll_factors);
|
tile_sizes, same_shape, unroll_factors);
|
||||||
if (!status.ok()) {
|
if (!status.ok()) {
|
||||||
LOG(ERROR) << status;
|
LOG(ERROR) << status;
|
||||||
|
@ -117,7 +117,6 @@ cc_library(
|
|||||||
"@llvm-project//mlir:AllPassesAndDialects",
|
"@llvm-project//mlir:AllPassesAndDialects",
|
||||||
"@llvm-project//mlir:Support",
|
"@llvm-project//mlir:Support",
|
||||||
"@llvm-project//mlir:Transforms",
|
"@llvm-project//mlir:Transforms",
|
||||||
"@llvm-project//llvm:TransformUtils",
|
|
||||||
"//tensorflow/compiler/mlir/hlo",
|
"//tensorflow/compiler/mlir/hlo",
|
||||||
"//tensorflow/compiler/mlir/hlo:hlo_legalize_to_lhlo",
|
"//tensorflow/compiler/mlir/hlo:hlo_legalize_to_lhlo",
|
||||||
"//tensorflow/compiler/mlir/hlo:lhlo",
|
"//tensorflow/compiler/mlir/hlo:lhlo",
|
||||||
|
@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "llvm/Transforms/Utils/Cloning.h"
|
|
||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||||
#include "mlir/Target/NVVMIR.h" // from @llvm-project
|
#include "mlir/Target/NVVMIR.h" // from @llvm-project
|
||||||
#include "mlir/Target/ROCDLIR.h" // from @llvm-project
|
#include "mlir/Target/ROCDLIR.h" // from @llvm-project
|
||||||
@ -50,12 +49,9 @@ using xla::InternalError;
|
|||||||
class GpuKernelToBlobPass
|
class GpuKernelToBlobPass
|
||||||
: public GpuKernelToBlobPassBase<GpuKernelToBlobPass> {
|
: public GpuKernelToBlobPassBase<GpuKernelToBlobPass> {
|
||||||
public:
|
public:
|
||||||
GpuKernelToBlobPass(mlir::StringRef blob_annotation,
|
GpuKernelToBlobPass(mlir::StringRef blob_annotation, int32_t arch) {
|
||||||
llvm::ArrayRef<uint32_t> architectures,
|
|
||||||
bool generate_fatbin) {
|
|
||||||
blob_annotation_ = blob_annotation.str();
|
blob_annotation_ = blob_annotation.str();
|
||||||
architectures_ = architectures;
|
arch_ = arch;
|
||||||
generate_fatbin_ = generate_fatbin;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void runOnOperation() override {
|
void runOnOperation() override {
|
||||||
@ -73,17 +69,7 @@ class GpuKernelToBlobPass
|
|||||||
|
|
||||||
xla::StatusOr<std::vector<uint8_t>> GetGpuBinaryBlob(
|
xla::StatusOr<std::vector<uint8_t>> GetGpuBinaryBlob(
|
||||||
mlir::gpu::GPUModuleOp gpu_module) {
|
mlir::gpu::GPUModuleOp gpu_module) {
|
||||||
if (architectures_.empty()) {
|
|
||||||
return InternalError("Expected at least one GPU architecture.");
|
|
||||||
}
|
|
||||||
if (!generate_fatbin_ && architectures_.size() > 1) {
|
|
||||||
return InternalError(
|
|
||||||
"Can only generate machine code for more than one architecture as a "
|
|
||||||
"fatbin.");
|
|
||||||
}
|
|
||||||
|
|
||||||
llvm::LLVMContext llvmContext;
|
llvm::LLVMContext llvmContext;
|
||||||
|
|
||||||
#if TENSORFLOW_USE_ROCM
|
#if TENSORFLOW_USE_ROCM
|
||||||
auto llvmModule = mlir::translateModuleToROCDLIR(gpu_module, llvmContext);
|
auto llvmModule = mlir::translateModuleToROCDLIR(gpu_module, llvmContext);
|
||||||
if (!llvmModule) {
|
if (!llvmModule) {
|
||||||
@ -95,14 +81,9 @@ class GpuKernelToBlobPass
|
|||||||
xla::HloModuleConfig config;
|
xla::HloModuleConfig config;
|
||||||
config.set_debug_options(xla::GetDebugOptionsFromFlags());
|
config.set_debug_options(xla::GetDebugOptionsFromFlags());
|
||||||
|
|
||||||
// TODO(b/169066682): Support fatbin on ROCm.
|
|
||||||
if (generate_fatbin_) {
|
|
||||||
return InternalError("Fatbins are not yet supported for ROCm.");
|
|
||||||
}
|
|
||||||
|
|
||||||
uint32_t arch = architectures_.front();
|
|
||||||
std::string libdevice_dir = tensorflow::RocdlRoot();
|
std::string libdevice_dir = tensorflow::RocdlRoot();
|
||||||
return xla::gpu::amdgpu::CompileToHsaco(llvmModule.get(), arch, config,
|
|
||||||
|
return xla::gpu::amdgpu::CompileToHsaco(llvmModule.get(), arch_, config,
|
||||||
libdevice_dir);
|
libdevice_dir);
|
||||||
|
|
||||||
#elif GOOGLE_CUDA
|
#elif GOOGLE_CUDA
|
||||||
@ -121,42 +102,19 @@ class GpuKernelToBlobPass
|
|||||||
target->Options.AllowFPOpFusion = llvm::FPOpFusion::FPOpFusionMode::Fast;
|
target->Options.AllowFPOpFusion = llvm::FPOpFusion::FPOpFusionMode::Fast;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Compile and collect requested cubin and PTX images.
|
int32_t cc_major = arch_ / 10;
|
||||||
std::vector<tensorflow::se::CubinOrPTXImage> images;
|
int32_t cc_minor = arch_ % 10;
|
||||||
TF_ASSIGN_OR_RETURN(std::string libdevice_dir, GetLibdeviceDir(config));
|
TF_ASSIGN_OR_RETURN(std::string libdevice_dir, GetLibdeviceDir(config));
|
||||||
auto gpu_asm_opts = xla::gpu::PtxOptsFromConfig(config);
|
TF_ASSIGN_OR_RETURN(
|
||||||
for (uint32_t arch : architectures_) {
|
std::string ptx,
|
||||||
int32_t cc_major = arch / 10;
|
xla::gpu::nvptx::CompileToPtx(llvmModule.get(),
|
||||||
int32_t cc_minor = arch % 10;
|
std::make_pair(cc_major, cc_minor),
|
||||||
// Module may be changed by CompileToPtx.
|
config, libdevice_dir, enable_fusion));
|
||||||
auto llvmModuleCopy = llvm::CloneModule(*llvmModule);
|
VLOG(1) << ptx;
|
||||||
TF_ASSIGN_OR_RETURN(
|
|
||||||
std::string ptx,
|
|
||||||
xla::gpu::nvptx::CompileToPtx(llvmModuleCopy.get(),
|
|
||||||
std::make_pair(cc_major, cc_minor),
|
|
||||||
config, libdevice_dir, enable_fusion));
|
|
||||||
// TODO(b/169066682): If compute_XX profile, collect PTX image here.
|
|
||||||
VLOG(1) << ptx;
|
|
||||||
TF_ASSIGN_OR_RETURN(std::vector<uint8_t> gpu_asm,
|
|
||||||
tensorflow::se::CompileGpuAsm(
|
|
||||||
cc_major, cc_minor, ptx.c_str(), gpu_asm_opts));
|
|
||||||
|
|
||||||
if (!generate_fatbin_) {
|
return tensorflow::se::CompileGpuAsm(cc_major, cc_minor, ptx.c_str(),
|
||||||
// Skip fatbin generation and return the first and only GPU machine
|
xla::gpu::PtxOptsFromConfig(config));
|
||||||
// code.
|
|
||||||
return gpu_asm;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Collect cubin image.
|
|
||||||
images.push_back({absl::StrCat("sm_", arch), std::move(gpu_asm)});
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO(b/169870789): Revisit the use of fatbins.
|
|
||||||
// Bundle cubin and PTX images into a single fatbin.
|
|
||||||
return tensorflow::se::BundleGpuAsm(images,
|
|
||||||
gpu_asm_opts.preferred_cuda_dir);
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
return InternalError(
|
return InternalError(
|
||||||
"Neither TENSORFLOW_USE_ROCM nor GOOGLE_CUDA are defined."
|
"Neither TENSORFLOW_USE_ROCM nor GOOGLE_CUDA are defined."
|
||||||
" Did you specify either --config=rocm or --config=cuda ?");
|
" Did you specify either --config=rocm or --config=cuda ?");
|
||||||
@ -183,10 +141,8 @@ class GpuKernelToBlobPass
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<gpu::GPUModuleOp>> CreateGpuKernelToBlobPass(
|
std::unique_ptr<OperationPass<gpu::GPUModuleOp>> CreateGpuKernelToBlobPass(
|
||||||
mlir::StringRef blob_annotation, ArrayRef<uint32_t> architectures,
|
mlir::StringRef blob_annotation, int32_t architecture) {
|
||||||
bool generate_fatbin) {
|
return std::make_unique<GpuKernelToBlobPass>(blob_annotation, architecture);
|
||||||
return std::make_unique<GpuKernelToBlobPass>(blob_annotation, architectures,
|
|
||||||
generate_fatbin);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace transforms
|
} // namespace transforms
|
||||||
|
@ -61,8 +61,7 @@ CreatePropagateTensorFlowABIKnowledgePass(
|
|||||||
|
|
||||||
// Pass to annotate GPU Module with its PTX.
|
// Pass to annotate GPU Module with its PTX.
|
||||||
std::unique_ptr<OperationPass<gpu::GPUModuleOp>> CreateGpuKernelToBlobPass(
|
std::unique_ptr<OperationPass<gpu::GPUModuleOp>> CreateGpuKernelToBlobPass(
|
||||||
mlir::StringRef blob_annotation = "", ArrayRef<uint32_t> architectures = {},
|
mlir::StringRef blob_annotation = "", int32_t architecture = 0);
|
||||||
bool generate_fatbin = true);
|
|
||||||
|
|
||||||
// Pass to unfuse batch norm.
|
// Pass to unfuse batch norm.
|
||||||
std::unique_ptr<FunctionPass> CreateUnfuseBatchNormPass();
|
std::unique_ptr<FunctionPass> CreateUnfuseBatchNormPass();
|
||||||
|
@ -53,10 +53,7 @@ def GpuKernelToBlobPass : Pass<"gpu-kernel-to-blob", "gpu::GPUModuleOp"> {
|
|||||||
let options = [
|
let options = [
|
||||||
Option<"blob_annotation_", "blob-annotation", "std::string",
|
Option<"blob_annotation_", "blob-annotation", "std::string",
|
||||||
/*default=*/"", "Blob attribute name">,
|
/*default=*/"", "Blob attribute name">,
|
||||||
ListOption<"architectures_", "arch", "uint32_t", "GPU architectures">,
|
Option<"arch_", "arch", "int32_t", /*default=*/"0", "GPU architecture">,
|
||||||
Option<"generate_fatbin_", "generate-fatbin", "bool", /*default=*/"true",
|
|
||||||
"Bundle machine code for the different architectures in one "
|
|
||||||
"fatbin.">,
|
|
||||||
];
|
];
|
||||||
let constructor = "transforms::CreateGpuKernelToBlobPass()";
|
let constructor = "transforms::CreateGpuKernelToBlobPass()";
|
||||||
}
|
}
|
||||||
|
@ -681,6 +681,7 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstructionImpl(
|
|||||||
NoAttributeCase(kAnd, AndOp);
|
NoAttributeCase(kAnd, AndOp);
|
||||||
NoAttributeCase(kAtan2, Atan2Op);
|
NoAttributeCase(kAtan2, Atan2Op);
|
||||||
NoAttributeCase(kBitcastConvert, BitcastConvertOp);
|
NoAttributeCase(kBitcastConvert, BitcastConvertOp);
|
||||||
|
NoAttributeCase(kCbrt, CbrtOp);
|
||||||
NoAttributeCase(kConvert, ConvertOp);
|
NoAttributeCase(kConvert, ConvertOp);
|
||||||
NoAttributeCase(kCeil, CeilOp);
|
NoAttributeCase(kCeil, CeilOp);
|
||||||
NoAttributeCase(kClamp, ClampOp);
|
NoAttributeCase(kClamp, ClampOp);
|
||||||
@ -738,6 +739,20 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstructionImpl(
|
|||||||
&fusion.fused_computation()));
|
&fusion.fused_computation()));
|
||||||
return fusion.getOperation();
|
return fusion.getOperation();
|
||||||
}
|
}
|
||||||
|
case HloOpcode::kBitcast:
|
||||||
|
return func_builder
|
||||||
|
->create<mlir::mhlo::BitcastOp>(loc, result_type, operands,
|
||||||
|
attributes)
|
||||||
|
.getOperation();
|
||||||
|
case HloOpcode::kReducePrecision: {
|
||||||
|
auto op = func_builder->create<mlir::mhlo::ReducePrecisionOp>(
|
||||||
|
loc, result_type, operands[0], attributes);
|
||||||
|
op.exponent_bitsAttr(func_builder->getIntegerAttr(
|
||||||
|
func_builder->getI32Type(), instruction->exponent_bits()));
|
||||||
|
op.mantissa_bitsAttr(func_builder->getIntegerAttr(
|
||||||
|
func_builder->getI32Type(), instruction->mantissa_bits()));
|
||||||
|
return op.getOperation();
|
||||||
|
}
|
||||||
case HloOpcode::kAddDependency:
|
case HloOpcode::kAddDependency:
|
||||||
// Arbitrary op code that I suspect we will not implement for quite a
|
// Arbitrary op code that I suspect we will not implement for quite a
|
||||||
// while and allows testing handling of unknown ops. Selected because it
|
// while and allows testing handling of unknown ops. Selected because it
|
||||||
@ -762,17 +777,10 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
|
|||||||
ImportInstructionImpl(instruction, func_builder));
|
ImportInstructionImpl(instruction, func_builder));
|
||||||
if (op == nullptr) return op;
|
if (op == nullptr) return op;
|
||||||
|
|
||||||
// Best-effort propagation of the layouts. These layouts serve as performance
|
// See MlirToHloConversionOptions for more about layouts.
|
||||||
// hints to the backend.
|
|
||||||
//
|
//
|
||||||
// Minor-to-major is a permutation of [0, rank), presenting tensor dimensions
|
// Minor-to-major is a permutation of [0, rank), presenting tensor dimensions
|
||||||
// in physical minor-to-major order.
|
// in physical minor-to-major order.
|
||||||
//
|
|
||||||
// Note that non-array shapes are not carrying layouts, and users have to
|
|
||||||
// figure out the proper layouts of them through context. This is one of the
|
|
||||||
// reasons why the attribute-based solution is temporary.
|
|
||||||
//
|
|
||||||
// TODO(timshen): Investigate the necessity of having layouts in MHLO.
|
|
||||||
if (instruction->shape().IsArray() &&
|
if (instruction->shape().IsArray() &&
|
||||||
instruction->shape().layout() !=
|
instruction->shape().layout() !=
|
||||||
LayoutUtil::MakeDescendingLayout(
|
LayoutUtil::MakeDescendingLayout(
|
||||||
|
@ -499,12 +499,14 @@ class ConvertToHloModule {
|
|||||||
// single value.
|
// single value.
|
||||||
explicit ConvertToHloModule(
|
explicit ConvertToHloModule(
|
||||||
mlir::ModuleOp module, bool use_tuple_args, bool return_tuple,
|
mlir::ModuleOp module, bool use_tuple_args, bool return_tuple,
|
||||||
tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn)
|
tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn,
|
||||||
|
MlirToHloConversionOptions options)
|
||||||
: module_(module),
|
: module_(module),
|
||||||
module_builder_("main"),
|
module_builder_("main"),
|
||||||
use_tuple_args_(use_tuple_args),
|
use_tuple_args_(use_tuple_args),
|
||||||
return_tuple_(return_tuple),
|
return_tuple_(return_tuple),
|
||||||
shape_representation_fn_(shape_representation_fn) {
|
shape_representation_fn_(shape_representation_fn),
|
||||||
|
options_(options) {
|
||||||
if (!shape_representation_fn_)
|
if (!shape_representation_fn_)
|
||||||
shape_representation_fn_ = tensorflow::IdentityShapeRepresentationFn();
|
shape_representation_fn_ = tensorflow::IdentityShapeRepresentationFn();
|
||||||
}
|
}
|
||||||
@ -585,6 +587,8 @@ class ConvertToHloModule {
|
|||||||
|
|
||||||
// Unique suffix to give to the name of the next lowered region.
|
// Unique suffix to give to the name of the next lowered region.
|
||||||
size_t region_id_ = 0;
|
size_t region_id_ = 0;
|
||||||
|
|
||||||
|
MlirToHloConversionOptions options_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
@ -1078,6 +1082,15 @@ LogicalResult ExportXlaOp(FusionOp op, OpLoweringContext ctx) {
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LogicalResult ExportXlaOp(BitcastOp op, OpLoweringContext ctx) {
|
||||||
|
auto& value_map = *ctx.values;
|
||||||
|
xla::XlaOp operand;
|
||||||
|
if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
|
||||||
|
value_map[op] = xla::internal::XlaBuilderFriend::BuildBitcast(
|
||||||
|
ctx.builder, operand, xla::TypeToShape(op.getType()));
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace mhlo
|
} // namespace mhlo
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
@ -1087,18 +1100,19 @@ LogicalResult ExportXlaOp(FusionOp op, OpLoweringContext ctx) {
|
|||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
StatusOr<xla::Literal> CreateLiteralFromAttr(ElementsAttr attr) {
|
StatusOr<xla::Literal> CreateArrayLiteralFromAttr(ElementsAttr attr,
|
||||||
|
xla::Layout layout) {
|
||||||
if (attr.isa<OpaqueElementsAttr>())
|
if (attr.isa<OpaqueElementsAttr>())
|
||||||
return tensorflow::errors::Unimplemented(
|
return tensorflow::errors::Unimplemented(
|
||||||
"Opaque elements attr not supported");
|
"Opaque elements attr not supported");
|
||||||
|
|
||||||
xla::Shape shape = xla::TypeToShape(attr.getType());
|
xla::Shape shape = xla::TypeToShape(attr.getType());
|
||||||
|
|
||||||
#define ELEMENTS_ATTR_TO_LITERAL(xla_type, cpp_type) \
|
#define ELEMENTS_ATTR_TO_LITERAL(xla_type, cpp_type) \
|
||||||
case xla_type: { \
|
case xla_type: { \
|
||||||
xla::Array<cpp_type> source_data(shape.dimensions()); \
|
xla::Array<cpp_type> source_data(shape.dimensions()); \
|
||||||
source_data.SetValues(attr.getValues<cpp_type>()); \
|
source_data.SetValues(attr.getValues<cpp_type>()); \
|
||||||
return xla::LiteralUtil::CreateFromArray(source_data); \
|
return xla::LiteralUtil::CreateFromArrayWithLayout(source_data, layout); \
|
||||||
}
|
}
|
||||||
|
|
||||||
switch (shape.element_type()) {
|
switch (shape.element_type()) {
|
||||||
@ -1128,7 +1142,7 @@ StatusOr<xla::Literal> CreateLiteralFromAttr(ElementsAttr attr) {
|
|||||||
}
|
}
|
||||||
xla::Array<xla::half> source_data(shape.dimensions());
|
xla::Array<xla::half> source_data(shape.dimensions());
|
||||||
source_data.SetValues(values);
|
source_data.SetValues(values);
|
||||||
return xla::LiteralUtil::CreateFromArray(source_data);
|
return xla::LiteralUtil::CreateFromArrayWithLayout(source_data, layout);
|
||||||
}
|
}
|
||||||
case xla::PrimitiveType::BF16: {
|
case xla::PrimitiveType::BF16: {
|
||||||
xla::Array<double> source_data(shape.dimensions());
|
xla::Array<double> source_data(shape.dimensions());
|
||||||
@ -1145,7 +1159,7 @@ StatusOr<xla::Literal> CreateLiteralFromAttr(ElementsAttr attr) {
|
|||||||
}
|
}
|
||||||
source_data.SetValues(values_double);
|
source_data.SetValues(values_double);
|
||||||
return xla::LiteralUtil::ConvertF64ToBF16(
|
return xla::LiteralUtil::ConvertF64ToBF16(
|
||||||
xla::LiteralUtil::CreateFromArray(source_data));
|
xla::LiteralUtil::CreateFromArrayWithLayout(source_data, layout));
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
return tensorflow::errors::Internal(absl::StrCat(
|
return tensorflow::errors::Internal(absl::StrCat(
|
||||||
@ -1154,25 +1168,33 @@ StatusOr<xla::Literal> CreateLiteralFromAttr(ElementsAttr attr) {
|
|||||||
#undef ELEMENTS_ATTR_TO_LITERAL
|
#undef ELEMENTS_ATTR_TO_LITERAL
|
||||||
}
|
}
|
||||||
|
|
||||||
|
xla::Layout ExtractLayout(mlir::Operation* op, int rank) {
|
||||||
|
if (auto attr =
|
||||||
|
op->getAttrOfType<mlir::DenseIntElementsAttr>("minor_to_major")) {
|
||||||
|
llvm::SmallVector<int64, 4> minor_to_major;
|
||||||
|
minor_to_major.reserve(attr.size());
|
||||||
|
for (const llvm::APInt& i : attr) {
|
||||||
|
minor_to_major.push_back(i.getZExtValue());
|
||||||
|
}
|
||||||
|
return xla::LayoutUtil::MakeLayout(minor_to_major);
|
||||||
|
}
|
||||||
|
return xla::LayoutUtil::MakeDescendingLayout(rank);
|
||||||
|
}
|
||||||
|
|
||||||
LogicalResult ConvertToHloModule::Lower(
|
LogicalResult ConvertToHloModule::Lower(
|
||||||
mlir::Operation* inst, bool is_entry_function,
|
mlir::Operation* inst, bool is_entry_function,
|
||||||
llvm::ArrayRef<absl::optional<xla::OpSharding>> ret_shardings,
|
llvm::ArrayRef<absl::optional<xla::OpSharding>> ret_shardings,
|
||||||
xla::XlaBuilder* builder,
|
xla::XlaBuilder* builder,
|
||||||
ConvertToHloModule::ValueLoweringMap* value_lowering,
|
ConvertToHloModule::ValueLoweringMap* value_lowering,
|
||||||
xla::XlaComputation* result) {
|
xla::XlaComputation* result) {
|
||||||
// See hlo_function_importer.cc for documentation about layouts in MHLO.
|
// See MlirToHloConversionOptions for more about layouts.
|
||||||
auto propagate_layouts = [](mlir::Operation* inst, xla::XlaOp xla_op) {
|
auto propagate_layouts = [this](mlir::Operation* inst, xla::XlaOp xla_op) {
|
||||||
auto attr =
|
if (options_.propagate_layouts) {
|
||||||
inst->getAttrOfType<mlir::DenseIntElementsAttr>("minor_to_major");
|
auto* shape = xla::internal::XlaBuilderFriend::GetInstruction(xla_op)
|
||||||
if (!attr) return;
|
->mutable_shape();
|
||||||
|
if (shape->tuple_shapes().empty())
|
||||||
auto* v = xla::internal::XlaBuilderFriend::GetInstruction(xla_op)
|
*shape->mutable_layout() =
|
||||||
->mutable_shape()
|
ExtractLayout(inst, shape->dimensions().size()).ToProto();
|
||||||
->mutable_layout()
|
|
||||||
->mutable_minor_to_major();
|
|
||||||
v->Clear();
|
|
||||||
for (const llvm::APInt& i : attr) {
|
|
||||||
*v->Add() = i.getZExtValue();
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -1216,12 +1238,14 @@ LogicalResult ConvertToHloModule::Lower(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (matchPattern(inst, m_Constant(&const_attr))) {
|
if (matchPattern(inst, m_Constant(&const_attr))) {
|
||||||
auto literal_or = CreateLiteralFromAttr(const_attr);
|
xla::Layout layout;
|
||||||
|
layout = ExtractLayout(inst, const_attr.getType().getRank());
|
||||||
|
auto literal_or = CreateArrayLiteralFromAttr(const_attr, layout);
|
||||||
if (!literal_or.ok())
|
if (!literal_or.ok())
|
||||||
return inst->emitError(literal_or.status().ToString());
|
return inst->emitError(literal_or.status().ToString());
|
||||||
auto constant = xla::ConstantLiteral(builder, literal_or.ValueOrDie());
|
auto constant = xla::ConstantLiteral(builder, literal_or.ValueOrDie());
|
||||||
value_map[inst->getResult(0)] = constant;
|
value_map[inst->getResult(0)] = constant;
|
||||||
propagate_layouts(inst, constant);
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1674,22 +1698,24 @@ LogicalResult AddDynamicParameterBindings(mlir::ModuleOp module,
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
Status ConvertRegionToComputation(mlir::Region* region,
|
Status ConvertRegionToComputation(mlir::Region* region,
|
||||||
xla::XlaComputation* func) {
|
xla::XlaComputation* func,
|
||||||
|
MlirToHloConversionOptions options) {
|
||||||
mlir::ModuleOp module;
|
mlir::ModuleOp module;
|
||||||
ConvertToHloModule converter(module, true, true, {});
|
ConvertToHloModule converter(module, true, true, {}, options);
|
||||||
if (failed(converter.LowerRegionAsComputation(region, func)))
|
if (failed(converter.LowerRegionAsComputation(region, func)))
|
||||||
return tensorflow::errors::Internal(
|
return tensorflow::errors::Internal(
|
||||||
"failed to convert region to computation");
|
"failed to convert region to computation");
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ConvertMlirHloToHlo(mlir::ModuleOp module, xla::HloProto* hlo_proto,
|
Status ConvertMlirHloToHlo(
|
||||||
bool use_tuple_args, bool return_tuple,
|
mlir::ModuleOp module, xla::HloProto* hlo_proto, bool use_tuple_args,
|
||||||
const tensorflow::XlaHelpers::ShapeRepresentationFn
|
bool return_tuple,
|
||||||
shape_representation_fn) {
|
const tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn,
|
||||||
|
MlirToHloConversionOptions options) {
|
||||||
mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext());
|
mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext());
|
||||||
ConvertToHloModule converter(module, use_tuple_args, return_tuple,
|
ConvertToHloModule converter(module, use_tuple_args, return_tuple,
|
||||||
shape_representation_fn);
|
shape_representation_fn, options);
|
||||||
if (failed(converter.Run())) return diag_handler.ConsumeStatus();
|
if (failed(converter.Run())) return diag_handler.ConsumeStatus();
|
||||||
auto hlo_module = converter.ConsumeMainProto();
|
auto hlo_module = converter.ConsumeMainProto();
|
||||||
hlo_proto->mutable_hlo_module()->Swap(&hlo_module);
|
hlo_proto->mutable_hlo_module()->Swap(&hlo_module);
|
||||||
|
@ -25,6 +25,18 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
|
|
||||||
|
struct MlirToHloConversionOptions {
|
||||||
|
// Best-effort propagation of the layouts. These layouts serve as performance
|
||||||
|
// hints to the backend.
|
||||||
|
//
|
||||||
|
// Note that non-array shapes are not carrying layouts, and users have to
|
||||||
|
// figure out the proper layouts of them through context. This is one of the
|
||||||
|
// reasons why the attribute-based solution is temporary.
|
||||||
|
//
|
||||||
|
// TODO(timshen): Investigate the necessity of having layouts in MHLO.
|
||||||
|
bool propagate_layouts = false;
|
||||||
|
};
|
||||||
|
|
||||||
// Converts a MLIR module in HLO dialect into a HloModuleProto. If
|
// Converts a MLIR module in HLO dialect into a HloModuleProto. If
|
||||||
// use_tuple_args is set, then the entry computations's arguments are converted
|
// use_tuple_args is set, then the entry computations's arguments are converted
|
||||||
// to a tuple and passed as a single parameter.
|
// to a tuple and passed as a single parameter.
|
||||||
@ -32,15 +44,19 @@ namespace mlir {
|
|||||||
// are converted to a tuple even when there is only a single return value.
|
// are converted to a tuple even when there is only a single return value.
|
||||||
// Multiple return values are always converted to a tuple and returned as a
|
// Multiple return values are always converted to a tuple and returned as a
|
||||||
// single value.
|
// single value.
|
||||||
|
//
|
||||||
|
// TODO(timshen): move other options into `options`.
|
||||||
Status ConvertMlirHloToHlo(mlir::ModuleOp module, ::xla::HloProto* hlo_proto,
|
Status ConvertMlirHloToHlo(mlir::ModuleOp module, ::xla::HloProto* hlo_proto,
|
||||||
bool use_tuple_args, bool return_tuple,
|
bool use_tuple_args, bool return_tuple,
|
||||||
const tensorflow::XlaHelpers::ShapeRepresentationFn
|
const tensorflow::XlaHelpers::ShapeRepresentationFn
|
||||||
shape_representation_fn = nullptr);
|
shape_representation_fn = nullptr,
|
||||||
|
MlirToHloConversionOptions options = {});
|
||||||
|
|
||||||
// Converts a region to a computation. It returns a standalone module that
|
// Converts a region to a computation. It returns a standalone module that
|
||||||
// contains the converted region as the entry computation.
|
// contains the converted region as the entry computation.
|
||||||
Status ConvertRegionToComputation(mlir::Region* region,
|
Status ConvertRegionToComputation(mlir::Region* region,
|
||||||
::xla::XlaComputation* func);
|
::xla::XlaComputation* func,
|
||||||
|
MlirToHloConversionOptions options = {});
|
||||||
|
|
||||||
// Creates XlaOp equivalent of a given MLIR operation using the operand info
|
// Creates XlaOp equivalent of a given MLIR operation using the operand info
|
||||||
// from `value_lowering` map.
|
// from `value_lowering` map.
|
||||||
|
@ -1102,3 +1102,33 @@ func @main(%arg: tensor<3xui64>) -> tuple<tensor<3xui64>, tensor<2x2xui32>> {
|
|||||||
%0 = "mhlo.rng_bit_generator"(%arg) {rng_algorithm = 2 : i32} : (tensor<3xui64>) -> tuple<tensor<3xui64>, tensor<2x2xui32>>
|
%0 = "mhlo.rng_bit_generator"(%arg) {rng_algorithm = 2 : i32} : (tensor<3xui64>) -> tuple<tensor<3xui64>, tensor<2x2xui32>>
|
||||||
return %0 : tuple<tensor<3xui64>, tensor<2x2xui32>>
|
return %0 : tuple<tensor<3xui64>, tensor<2x2xui32>>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK: HloModule
|
||||||
|
func @main(%arg: tensor<3x4xf32>) -> tensor<3x4xf32> {
|
||||||
|
// CHECK: %[[ARG0:.*]] = f32[3,4] parameter(0)
|
||||||
|
// CHECK: ROOT %[[RESULT:.*]] = f32[3,4] cbrt(f32[3,4] %[[ARG0]])
|
||||||
|
%0 = "mhlo.cbrt"(%arg) : (tensor<3x4xf32>) -> tensor<3x4xf32>
|
||||||
|
return %0 : tensor<3x4xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK: HloModule
|
||||||
|
func @main(%arg: tensor<3x4xf32>) -> tensor<3x4xf32> {
|
||||||
|
// CHECK: %[[ARG0:.*]] = f32[3,4] parameter(0)
|
||||||
|
// CHECK: ROOT %[[RESULT:.*]] = f32[3,4] reduce-precision(f32[3,4] %[[ARG0]]), exponent_bits=8, mantissa_bits=10
|
||||||
|
%0 = "mhlo.reduce_precision"(%arg) {exponent_bits = 8 : i32, mantissa_bits = 10 : i32} : (tensor<3x4xf32>) -> tensor<3x4xf32>
|
||||||
|
return %0 : tensor<3x4xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK: HloModule
|
||||||
|
func @main(%arg: tensor<3x4xf32>) -> tensor<3x4x1xf32> {
|
||||||
|
// CHECK: %[[ARG0:.*]] = f32[3,4] parameter(0)
|
||||||
|
// CHECK: ROOT %[[RESULT:.*]] = f32[3,4,1] bitcast(f32[3,4] %[[ARG0]])
|
||||||
|
%0 = "mhlo.bitcast"(%arg) : (tensor<3x4xf32>) -> tensor<3x4x1xf32>
|
||||||
|
return %0 : tensor<3x4x1xf32>
|
||||||
|
}
|
||||||
|
@ -1014,3 +1014,26 @@ add {
|
|||||||
ROOT %rng-bit-generator.2 = (u64[3], u32[2,2]) rng-bit-generator(u64[3] %Arg_0.1), algorithm=rng_philox
|
ROOT %rng-bit-generator.2 = (u64[3], u32[2,2]) rng-bit-generator(u64[3] %Arg_0.1), algorithm=rng_philox
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @cbrt
|
||||||
|
// CHECK-SAME: (%[[ARG0:.*]]: tensor<3x4xf32>)
|
||||||
|
%cbrt (Arg_0.1: f32[3,4]) -> f32[3,4] {
|
||||||
|
%Arg_0.1 = f32[3,4] parameter(0)
|
||||||
|
// CHECK: "mhlo.cbrt"(%[[ARG0]]) : (tensor<3x4xf32>) -> tensor<3x4xf32>
|
||||||
|
ROOT %cbrt = f32[3,4] cbrt(f32[3,4] %Arg_0.1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @bitcast
|
||||||
|
// CHECK-SAME: (%[[ARG0:.*]]: tensor<3x4xf32>) -> tensor<3x4x1xf32>
|
||||||
|
%bitcast (Arg_0.1: f32[3,4]) -> f32[3,4,1] {
|
||||||
|
%Arg_0.1 = f32[3,4] parameter(0)
|
||||||
|
// CHECK: "mhlo.bitcast"(%[[ARG0]]) : (tensor<3x4xf32>) -> tensor<3x4x1xf32>
|
||||||
|
ROOT %bitcast = f32[3,4,1] bitcast(f32[3,4] %Arg_0.1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @reduce_precision
|
||||||
|
// CHECK-SAME: (%[[ARG0:.*]]: tensor<3x4xf32>)
|
||||||
|
%reduce_precision (Arg_0.1: f32[3,4]) -> f32[3,4] {
|
||||||
|
%Arg_0.1 = f32[3,4] parameter(0)
|
||||||
|
// CHECK: "mhlo.reduce_precision"(%[[ARG0]]) {exponent_bits = 8 : i32, mantissa_bits = 10 : i32} : (tensor<3x4xf32>) -> tensor<3x4xf32>
|
||||||
|
ROOT %reduce_precision = f32[3,4] reduce-precision(f32[3,4] %Arg_0.1), exponent_bits=8, mantissa_bits=10
|
||||||
|
}
|
||||||
|
@ -26,5 +26,9 @@ func @main(%arg0: tensor<128x224x224x4xf16>, %arg1: tensor<64x7x7x4xf16>) -> ten
|
|||||||
rhs_dilations = dense<1> : tensor<2xi64>,
|
rhs_dilations = dense<1> : tensor<2xi64>,
|
||||||
window_strides = dense<2> : tensor<2xi64>
|
window_strides = dense<2> : tensor<2xi64>
|
||||||
} : (tensor<128x224x224x4xf16>, tensor<64x7x7x4xf16>)-> tensor<128x64x112x112xf16> loc("root.42")
|
} : (tensor<128x224x224x4xf16>, tensor<64x7x7x4xf16>)-> tensor<128x64x112x112xf16> loc("root.42")
|
||||||
|
|
||||||
|
// CHECK: s32[1,1]{0,1} constant({ {42} })
|
||||||
|
%cst_1 = "std.constant"() {value = dense<[[42]]> : tensor<1x1xi32>, minor_to_major = dense<[0, 1]> : tensor<2xindex>} : () -> tensor<1x1xi32>
|
||||||
|
|
||||||
return %0 : tensor<128x64x112x112xf16>
|
return %0 : tensor<128x64x112x112xf16>
|
||||||
}
|
}
|
||||||
|
@ -129,8 +129,11 @@ static mlir::LogicalResult MlirHloToHloTextTranslateFunctionImpl(
|
|||||||
if (!module) return mlir::failure();
|
if (!module) return mlir::failure();
|
||||||
|
|
||||||
HloProto hloProto;
|
HloProto hloProto;
|
||||||
|
mlir::MlirToHloConversionOptions options;
|
||||||
|
options.propagate_layouts = with_layouts;
|
||||||
Status status = mlir::ConvertMlirHloToHlo(
|
Status status = mlir::ConvertMlirHloToHlo(
|
||||||
module, &hloProto, emit_use_tuple_arg, emit_return_tuple);
|
module, &hloProto, emit_use_tuple_arg, emit_return_tuple,
|
||||||
|
/*shape_representation_fn=*/nullptr, options);
|
||||||
if (!status.ok()) {
|
if (!status.ok()) {
|
||||||
LOG(ERROR) << "Module conversion failed: " << status;
|
LOG(ERROR) << "Module conversion failed: " << status;
|
||||||
return mlir::failure();
|
return mlir::failure();
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
|
load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
|
||||||
load("//tensorflow:tensorflow.bzl", "if_tpu", "tf_cc_binary", "tf_cc_test", "tf_copts", "tf_cuda_cc_test", "tf_openmp_copts")
|
load("//tensorflow:tensorflow.bzl", "if_libtpu", "tf_cc_binary", "tf_cc_test", "tf_copts", "tf_cuda_cc_test", "tf_openmp_copts")
|
||||||
load(
|
load(
|
||||||
"//tensorflow/core/platform/default:cuda_build_defs.bzl",
|
"//tensorflow/core/platform/default:cuda_build_defs.bzl",
|
||||||
"if_cuda_is_configured",
|
"if_cuda_is_configured",
|
||||||
@ -298,7 +298,7 @@ cc_library(
|
|||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/stream_executor:platform",
|
"//tensorflow/stream_executor:platform",
|
||||||
] + if_tpu(
|
] + if_libtpu(
|
||||||
if_false = [
|
if_false = [
|
||||||
"//tensorflow/compiler/xla/service:cpu_plugin",
|
"//tensorflow/compiler/xla/service:cpu_plugin",
|
||||||
"//tensorflow/compiler/xla/service/cpu:buffer_info_util",
|
"//tensorflow/compiler/xla/service/cpu:buffer_info_util",
|
||||||
@ -369,7 +369,7 @@ cc_library(
|
|||||||
"//tensorflow/core:lib_internal",
|
"//tensorflow/core:lib_internal",
|
||||||
"//tensorflow/core:ops",
|
"//tensorflow/core:ops",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
] + if_tpu(
|
] + if_libtpu(
|
||||||
if_false = [
|
if_false = [
|
||||||
"//tensorflow/compiler/mlir:array_container_utils",
|
"//tensorflow/compiler/mlir:array_container_utils",
|
||||||
"//tensorflow/compiler/mlir/tensorflow:compile_mlir_util_no_tf_dialect_passes",
|
"//tensorflow/compiler/mlir/tensorflow:compile_mlir_util_no_tf_dialect_passes",
|
||||||
@ -877,13 +877,13 @@ cc_library(
|
|||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "mlir_bridge_pass_registration",
|
name = "mlir_bridge_pass_registration",
|
||||||
srcs = if_tpu(
|
srcs = if_libtpu(
|
||||||
if_false = [
|
if_false = [
|
||||||
"mlir_bridge_pass_registration.cc",
|
"mlir_bridge_pass_registration.cc",
|
||||||
],
|
],
|
||||||
if_true = [],
|
if_true = [],
|
||||||
),
|
),
|
||||||
deps = if_tpu(
|
deps = if_libtpu(
|
||||||
if_false = [
|
if_false = [
|
||||||
":mlir_bridge_pass",
|
":mlir_bridge_pass",
|
||||||
"//tensorflow/compiler/mlir:mlir_graph_optimization_pass_registration",
|
"//tensorflow/compiler/mlir:mlir_graph_optimization_pass_registration",
|
||||||
|
@ -56,7 +56,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/protobuf/graph_debug_info.pb.h"
|
#include "tensorflow/core/protobuf/graph_debug_info.pb.h"
|
||||||
#include "tensorflow/core/util/dump_graph.h"
|
#include "tensorflow/core/util/dump_graph.h"
|
||||||
|
|
||||||
#ifndef LIBTFTPU
|
#ifndef LIBTPU_ON_GCE
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h"
|
#include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h"
|
||||||
#include "tensorflow/compiler/mlir/utils/array_container_utils.h"
|
#include "tensorflow/compiler/mlir/utils/array_container_utils.h"
|
||||||
#endif
|
#endif
|
||||||
@ -733,7 +733,7 @@ Status XlaCompiler::CompileFunction(
|
|||||||
}
|
}
|
||||||
|
|
||||||
VLOG(1) << "====================================================";
|
VLOG(1) << "====================================================";
|
||||||
#ifdef LIBTFTPU
|
#ifdef LIBTPU_ON_GCE
|
||||||
if (GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge) {
|
if (GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge) {
|
||||||
VLOG(1) << "MLIR is not supported in this environment.";
|
VLOG(1) << "MLIR is not supported in this environment.";
|
||||||
}
|
}
|
||||||
|
@ -149,6 +149,16 @@ XlaOp XlaBuilderFriend::BuildFusion(XlaBuilder* builder,
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
XlaOp XlaBuilderFriend::BuildBitcast(XlaBuilder* builder, XlaOp operand,
|
||||||
|
const Shape& shape) {
|
||||||
|
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||||
|
HloInstructionProto instr;
|
||||||
|
*instr.mutable_shape() = shape.ToProto();
|
||||||
|
return builder->AddInstruction(std::move(instr), HloOpcode::kBitcast,
|
||||||
|
{operand});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
HloInstructionProto* XlaBuilderFriend::GetInstruction(XlaOp op) {
|
HloInstructionProto* XlaBuilderFriend::GetInstruction(XlaOp op) {
|
||||||
return &op.builder()
|
return &op.builder()
|
||||||
->instructions_[op.builder()->handle_to_index_[op.handle_]];
|
->instructions_[op.builder()->handle_to_index_[op.handle_]];
|
||||||
|
@ -57,6 +57,9 @@ struct XlaBuilderFriend {
|
|||||||
absl::string_view fusion_kind,
|
absl::string_view fusion_kind,
|
||||||
const XlaComputation& fused_computation);
|
const XlaComputation& fused_computation);
|
||||||
|
|
||||||
|
static XlaOp BuildBitcast(XlaBuilder* builder, XlaOp operand,
|
||||||
|
const Shape& shape);
|
||||||
|
|
||||||
static HloInstructionProto* GetInstruction(XlaOp op);
|
static HloInstructionProto* GetInstruction(XlaOp op);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -2,7 +2,7 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
|
|||||||
load("//tensorflow:tensorflow.bzl", "tf_grpc_cc_dependency")
|
load("//tensorflow:tensorflow.bzl", "tf_grpc_cc_dependency")
|
||||||
load(
|
load(
|
||||||
"//tensorflow:tensorflow.bzl",
|
"//tensorflow:tensorflow.bzl",
|
||||||
"if_tpu",
|
"if_libtpu",
|
||||||
"tf_cc_binary",
|
"tf_cc_binary",
|
||||||
"tf_cc_test",
|
"tf_cc_test",
|
||||||
)
|
)
|
||||||
@ -57,7 +57,7 @@ cc_library(
|
|||||||
"//tensorflow/core:framework_internal",
|
"//tensorflow/core:framework_internal",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
tf_grpc_cc_dependency(),
|
tf_grpc_cc_dependency(),
|
||||||
] + if_tpu(
|
] + if_libtpu(
|
||||||
if_false = ["//tensorflow/compiler/xla/service:cpu_plugin"],
|
if_false = ["//tensorflow/compiler/xla/service:cpu_plugin"],
|
||||||
if_true = [],
|
if_true = [],
|
||||||
),
|
),
|
||||||
|
@ -1708,7 +1708,6 @@ cc_library(
|
|||||||
srcs = ["hlo_creation_utils.cc"],
|
srcs = ["hlo_creation_utils.cc"],
|
||||||
hdrs = [
|
hdrs = [
|
||||||
"hlo_creation_utils.h",
|
"hlo_creation_utils.h",
|
||||||
"//tensorflow/compiler/xla:literal_util",
|
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":hlo",
|
":hlo",
|
||||||
|
@ -217,6 +217,7 @@ cc_library(
|
|||||||
":backend_configs_cc",
|
":backend_configs_cc",
|
||||||
":buffer_allocations",
|
":buffer_allocations",
|
||||||
":gpu_constants",
|
":gpu_constants",
|
||||||
|
":gpu_conv_runner",
|
||||||
":gpu_executable",
|
":gpu_executable",
|
||||||
":ir_emission_utils",
|
":ir_emission_utils",
|
||||||
":nccl_all_reduce_thunk",
|
":nccl_all_reduce_thunk",
|
||||||
|
@ -45,10 +45,7 @@ CholeskyThunk::CholeskyThunk(ThunkInfo thunk_info,
|
|||||||
info_buffer_(info_buffer),
|
info_buffer_(info_buffer),
|
||||||
type_(type),
|
type_(type),
|
||||||
batch_size_(batch_size),
|
batch_size_(batch_size),
|
||||||
a_batch_stride_(
|
a_batch_stride_(n * n * ShapeUtil::ByteSizeOfPrimitiveType(type)),
|
||||||
n * n *
|
|
||||||
ShapeUtil::ByteSizeOfPrimitiveType(
|
|
||||||
thunk_info.hlo_instruction->operand(0)->shape().element_type())),
|
|
||||||
n_(n) {}
|
n_(n) {}
|
||||||
|
|
||||||
Status CholeskyThunk::ExecuteOnStream(const ExecuteParams& params) {
|
Status CholeskyThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||||
|
@ -31,7 +31,8 @@ namespace xla {
|
|||||||
namespace gpu {
|
namespace gpu {
|
||||||
|
|
||||||
ConvolutionThunk::ConvolutionThunk(
|
ConvolutionThunk::ConvolutionThunk(
|
||||||
ThunkInfo thunk_info, std::vector<BufferAllocation::Slice> operand_slices,
|
ThunkInfo thunk_info, GpuConvConfig&& config,
|
||||||
|
std::vector<BufferAllocation::Slice> operand_slices,
|
||||||
BufferAllocation::Slice result_slice, BufferAllocation::Slice scratch_slice,
|
BufferAllocation::Slice result_slice, BufferAllocation::Slice scratch_slice,
|
||||||
BufferAllocation::Slice tuple_result_slice)
|
BufferAllocation::Slice tuple_result_slice)
|
||||||
: Thunk(Kind::kConvolution, thunk_info),
|
: Thunk(Kind::kConvolution, thunk_info),
|
||||||
@ -39,9 +40,7 @@ ConvolutionThunk::ConvolutionThunk(
|
|||||||
result_buffer_(result_slice),
|
result_buffer_(result_slice),
|
||||||
scratch_buffer_(scratch_slice),
|
scratch_buffer_(scratch_slice),
|
||||||
tuple_result_buffer_(tuple_result_slice),
|
tuple_result_buffer_(tuple_result_slice),
|
||||||
config_(GetGpuConvConfig(
|
config_(std::move(config)) {}
|
||||||
Cast<HloCustomCallInstruction>(thunk_info.hlo_instruction))
|
|
||||||
.ValueOrDie()) {}
|
|
||||||
|
|
||||||
Status ConvolutionThunk::ExecuteOnStream(const ExecuteParams& params) {
|
Status ConvolutionThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||||
const auto& buffer_allocations = *params.buffer_allocations;
|
const auto& buffer_allocations = *params.buffer_allocations;
|
||||||
|
@ -43,7 +43,7 @@ class ConvolutionThunk : public Thunk {
|
|||||||
// write a tuple (result, scratch_memory) into `tuple_result_buffer`.
|
// write a tuple (result, scratch_memory) into `tuple_result_buffer`.
|
||||||
//
|
//
|
||||||
// operand_slices should be in the same order as cudnn_call->operands().
|
// operand_slices should be in the same order as cudnn_call->operands().
|
||||||
ConvolutionThunk(ThunkInfo thunk_info,
|
ConvolutionThunk(ThunkInfo thunk_info, GpuConvConfig&& config,
|
||||||
std::vector<BufferAllocation::Slice> operand_slices,
|
std::vector<BufferAllocation::Slice> operand_slices,
|
||||||
BufferAllocation::Slice result_slice,
|
BufferAllocation::Slice result_slice,
|
||||||
BufferAllocation::Slice scratch_slice,
|
BufferAllocation::Slice scratch_slice,
|
||||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h"
|
#include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h"
|
||||||
#include "tensorflow/compiler/xla/service/gpu/fft_thunk.h"
|
#include "tensorflow/compiler/xla/service/gpu/fft_thunk.h"
|
||||||
#include "tensorflow/compiler/xla/service/gpu/gemm_thunk.h"
|
#include "tensorflow/compiler/xla/service/gpu/gemm_thunk.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h"
|
||||||
#include "tensorflow/compiler/xla/service/gpu/infeed_thunk.h"
|
#include "tensorflow/compiler/xla/service/gpu/infeed_thunk.h"
|
||||||
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
|
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
|
||||||
#include "tensorflow/compiler/xla/service/gpu/outfeed_thunk.h"
|
#include "tensorflow/compiler/xla/service/gpu/outfeed_thunk.h"
|
||||||
@ -238,9 +239,13 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) {
|
|||||||
auto conv_result_slice = GetAllocationSlice(*custom_call, {0});
|
auto conv_result_slice = GetAllocationSlice(*custom_call, {0});
|
||||||
auto scratch_slice = GetAllocationSlice(*custom_call, {1});
|
auto scratch_slice = GetAllocationSlice(*custom_call, {1});
|
||||||
|
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
GpuConvConfig config,
|
||||||
|
GetGpuConvConfig(Cast<HloCustomCallInstruction>(custom_call)));
|
||||||
AddThunkToThunkSequence(absl::make_unique<ConvolutionThunk>(
|
AddThunkToThunkSequence(absl::make_unique<ConvolutionThunk>(
|
||||||
context_->GetThunkInfo(custom_call), std::move(operand_slices),
|
context_->GetThunkInfo(custom_call), std::move(config),
|
||||||
conv_result_slice, scratch_slice, tuple_result_slice));
|
std::move(operand_slices), conv_result_slice, scratch_slice,
|
||||||
|
tuple_result_slice));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1524,9 +1524,11 @@ StatusOr<int64> CompressInstruction(MemoryUsageTracker* memory_tracker,
|
|||||||
|
|
||||||
HloInstruction* compressed = computation->AddInstruction(
|
HloInstruction* compressed = computation->AddInstruction(
|
||||||
HloInstruction::CreateUnary(compact_shape, HloOpcode::kCopy, best));
|
HloInstruction::CreateUnary(compact_shape, HloOpcode::kCopy, best));
|
||||||
|
compressed->SetAndSanitizeName(best->name() + ".remat_compressed");
|
||||||
|
|
||||||
HloInstruction* uncompressed = computation->AddInstruction(
|
HloInstruction* uncompressed = computation->AddInstruction(
|
||||||
HloInstruction::CreateUnary(best->shape(), HloOpcode::kCopy, compressed));
|
HloInstruction::CreateUnary(best->shape(), HloOpcode::kCopy, compressed));
|
||||||
|
uncompressed->SetAndSanitizeName(best->name() + ".remat_uncompressed");
|
||||||
|
|
||||||
Item* compressed_item = instruction_list->CreateItem(compressed);
|
Item* compressed_item = instruction_list->CreateItem(compressed);
|
||||||
compressed_item->placed = true;
|
compressed_item->placed = true;
|
||||||
|
@ -68,9 +68,9 @@ load(
|
|||||||
"if_chromiumos",
|
"if_chromiumos",
|
||||||
"if_cuda_or_rocm",
|
"if_cuda_or_rocm",
|
||||||
"if_ios",
|
"if_ios",
|
||||||
|
"if_libtpu",
|
||||||
"if_mobile",
|
"if_mobile",
|
||||||
"if_not_windows",
|
"if_not_windows",
|
||||||
"if_tpu",
|
|
||||||
"tf_android_core_proto_headers",
|
"tf_android_core_proto_headers",
|
||||||
"tf_cc_test",
|
"tf_cc_test",
|
||||||
"tf_cc_test_mkl",
|
"tf_cc_test_mkl",
|
||||||
@ -894,8 +894,7 @@ cc_library(
|
|||||||
"//tensorflow/c/kernels:summary_op_lib",
|
"//tensorflow/c/kernels:summary_op_lib",
|
||||||
] + if_chromiumos(
|
] + if_chromiumos(
|
||||||
[],
|
[],
|
||||||
# Non-tpu platforms don't need tpu dependency. It would be best to guard
|
# Non-tpu platforms don't need tpu dependency.
|
||||||
# them by if_tpu. But there is no such flag yet.
|
|
||||||
[
|
[
|
||||||
":tpu_configuration_ops_op_lib",
|
":tpu_configuration_ops_op_lib",
|
||||||
":tpu_cross_replica_ops_op_lib",
|
":tpu_cross_replica_ops_op_lib",
|
||||||
@ -916,7 +915,7 @@ cc_library(
|
|||||||
]) + if_tensorrt([
|
]) + if_tensorrt([
|
||||||
"//tensorflow/compiler/tf2tensorrt:trt_engine_resource_ops_op_lib",
|
"//tensorflow/compiler/tf2tensorrt:trt_engine_resource_ops_op_lib",
|
||||||
"//tensorflow/compiler/tf2tensorrt:trt_op_libs",
|
"//tensorflow/compiler/tf2tensorrt:trt_op_libs",
|
||||||
]) + if_tpu(
|
]) + if_libtpu(
|
||||||
if_false = ["//tensorflow/compiler/mlir/tensorflow:mlir_passthrough_op"],
|
if_false = ["//tensorflow/compiler/mlir/tensorflow:mlir_passthrough_op"],
|
||||||
if_true = [],
|
if_true = [],
|
||||||
),
|
),
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
load(
|
load(
|
||||||
"//tensorflow:tensorflow.bzl",
|
"//tensorflow:tensorflow.bzl",
|
||||||
"if_tpu",
|
"if_libtpu",
|
||||||
"tf_cc_test",
|
"tf_cc_test",
|
||||||
"tf_cc_test_mkl",
|
"tf_cc_test_mkl",
|
||||||
"tf_cc_tests",
|
"tf_cc_tests",
|
||||||
@ -93,7 +93,7 @@ cc_library(
|
|||||||
deps = [
|
deps = [
|
||||||
":core_cpu",
|
":core_cpu",
|
||||||
"//tensorflow/core/common_runtime/gpu:gpu_runtime",
|
"//tensorflow/core/common_runtime/gpu:gpu_runtime",
|
||||||
] + if_tpu(["//tensorflow/core/tpu:tpu_runtime"]),
|
] + if_libtpu(["//tensorflow/core/tpu:tpu_runtime"]),
|
||||||
)
|
)
|
||||||
|
|
||||||
filegroup(
|
filegroup(
|
||||||
|
@ -151,7 +151,7 @@ void IntraProcessRecvAsyncImpl(const DeviceMgr* device_mgr,
|
|||||||
|
|
||||||
RefCountedIntraProcessRendezvous::RefCountedIntraProcessRendezvous(
|
RefCountedIntraProcessRendezvous::RefCountedIntraProcessRendezvous(
|
||||||
const DeviceMgr* device_mgr)
|
const DeviceMgr* device_mgr)
|
||||||
: device_mgr_(device_mgr) {}
|
: device_mgr_(device_mgr), local_(this) {}
|
||||||
|
|
||||||
RefCountedIntraProcessRendezvous::~RefCountedIntraProcessRendezvous() {}
|
RefCountedIntraProcessRendezvous::~RefCountedIntraProcessRendezvous() {}
|
||||||
|
|
||||||
@ -176,7 +176,7 @@ void RefCountedIntraProcessRendezvous::StartAbort(const Status& s) {
|
|||||||
|
|
||||||
PrivateIntraProcessRendezvous::PrivateIntraProcessRendezvous(
|
PrivateIntraProcessRendezvous::PrivateIntraProcessRendezvous(
|
||||||
const DeviceMgr* device_mgr)
|
const DeviceMgr* device_mgr)
|
||||||
: device_mgr_(device_mgr) {}
|
: device_mgr_(device_mgr), local_(nullptr) {}
|
||||||
|
|
||||||
PrivateIntraProcessRendezvous::~PrivateIntraProcessRendezvous() {}
|
PrivateIntraProcessRendezvous::~PrivateIntraProcessRendezvous() {}
|
||||||
|
|
||||||
|
@ -1121,8 +1121,17 @@ Status AvgPoolShape(shape_inference::InferenceContext* c) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Status FusedBatchNormShape(shape_inference::InferenceContext* c) {
|
Status FusedBatchNormShape(shape_inference::InferenceContext* c) {
|
||||||
|
string data_format_str;
|
||||||
|
TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str));
|
||||||
|
TensorFormat data_format;
|
||||||
|
if (!FormatFromString(data_format_str, &data_format)) {
|
||||||
|
return errors::InvalidArgument("Invalid data format string: ",
|
||||||
|
data_format_str);
|
||||||
|
}
|
||||||
|
const int rank =
|
||||||
|
(data_format_str == "NDHWC" or data_format_str == "NCDHW") ? 5 : 4;
|
||||||
ShapeHandle x;
|
ShapeHandle x;
|
||||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &x));
|
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &x));
|
||||||
|
|
||||||
bool is_training;
|
bool is_training;
|
||||||
TF_RETURN_IF_ERROR(c->GetAttr("is_training", &is_training));
|
TF_RETURN_IF_ERROR(c->GetAttr("is_training", &is_training));
|
||||||
@ -1131,14 +1140,8 @@ Status FusedBatchNormShape(shape_inference::InferenceContext* c) {
|
|||||||
exponential_avg_factor = 1.0f; // default value
|
exponential_avg_factor = 1.0f; // default value
|
||||||
}
|
}
|
||||||
int number_inputs = (is_training && exponential_avg_factor == 1.0f) ? 3 : 5;
|
int number_inputs = (is_training && exponential_avg_factor == 1.0f) ? 3 : 5;
|
||||||
string data_format_str;
|
|
||||||
TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str));
|
int channel_dim_index = GetTensorFeatureDimIndex(rank, data_format);
|
||||||
TensorFormat data_format;
|
|
||||||
if (!FormatFromString(data_format_str, &data_format)) {
|
|
||||||
return errors::InvalidArgument("Invalid data format string: ",
|
|
||||||
data_format_str);
|
|
||||||
}
|
|
||||||
int channel_dim_index = GetTensorFeatureDimIndex(4, data_format);
|
|
||||||
DimensionHandle channel_dim = c->Dim(x, channel_dim_index);
|
DimensionHandle channel_dim = c->Dim(x, channel_dim_index);
|
||||||
|
|
||||||
// covers scale, offset, and if is_training is false, mean, variance
|
// covers scale, offset, and if is_training is false, mean, variance
|
||||||
@ -1191,13 +1194,6 @@ Status FusedBatchNormExShape(shape_inference::InferenceContext* c) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Status FusedBatchNormGradShape(shape_inference::InferenceContext* c) {
|
Status FusedBatchNormGradShape(shape_inference::InferenceContext* c) {
|
||||||
ShapeHandle y_backprop;
|
|
||||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &y_backprop));
|
|
||||||
ShapeHandle x;
|
|
||||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 4, &x));
|
|
||||||
|
|
||||||
bool is_training;
|
|
||||||
TF_RETURN_IF_ERROR(c->GetAttr("is_training", &is_training));
|
|
||||||
string data_format_str;
|
string data_format_str;
|
||||||
TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str));
|
TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str));
|
||||||
TensorFormat data_format;
|
TensorFormat data_format;
|
||||||
@ -1205,7 +1201,17 @@ Status FusedBatchNormGradShape(shape_inference::InferenceContext* c) {
|
|||||||
return errors::InvalidArgument("Invalid data format string: ",
|
return errors::InvalidArgument("Invalid data format string: ",
|
||||||
data_format_str);
|
data_format_str);
|
||||||
}
|
}
|
||||||
int channel_dim_index = GetTensorFeatureDimIndex(4, data_format);
|
const int rank =
|
||||||
|
(data_format_str == "NDHWC" or data_format_str == "NCDHW") ? 5 : 4;
|
||||||
|
ShapeHandle y_backprop;
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &y_backprop));
|
||||||
|
ShapeHandle x;
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), rank, &x));
|
||||||
|
|
||||||
|
bool is_training;
|
||||||
|
TF_RETURN_IF_ERROR(c->GetAttr("is_training", &is_training));
|
||||||
|
|
||||||
|
int channel_dim_index = GetTensorFeatureDimIndex(rank, data_format);
|
||||||
DimensionHandle channel_dim = c->Dim(y_backprop, channel_dim_index);
|
DimensionHandle channel_dim = c->Dim(y_backprop, channel_dim_index);
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
c->Merge(channel_dim, c->Dim(x, channel_dim_index), &channel_dim));
|
c->Merge(channel_dim, c->Dim(x, channel_dim_index), &channel_dim));
|
||||||
|
@ -187,6 +187,20 @@ void LocalRendezvous::RecvAsync(const Rendezvous::ParsedKey& key,
|
|||||||
CancellationToken token = CancellationManager::kInvalidToken;
|
CancellationToken token = CancellationManager::kInvalidToken;
|
||||||
bool already_cancelled = false;
|
bool already_cancelled = false;
|
||||||
if (cm != nullptr) {
|
if (cm != nullptr) {
|
||||||
|
// Increment the refcount when cancellation manager is present, to make
|
||||||
|
// sure the rendezvous outlives the recv and its cancel callbacks.
|
||||||
|
// This refcount is dropped in exactly one of the following cases:
|
||||||
|
// (1) Recv registers cancellation callback to cm, and then cm is
|
||||||
|
// cancelled, unref in the cancellation callback;
|
||||||
|
// (2) Recv registers cancellation callback to cm, but cm is already
|
||||||
|
// cancelled, unref in the already_cancelled check;
|
||||||
|
// (3) Recv is successful, and item done callback finishes deregistering
|
||||||
|
// the cancellation callback, unref in the item done callback;
|
||||||
|
// (4) Recv is successful, but the item done callback fails to deregister
|
||||||
|
// the cancellation callback because cm already StartCancel, in this
|
||||||
|
// case the cancellation callback will be invoked by the cm anyway,
|
||||||
|
// unref in the cancellation callback.
|
||||||
|
if (rc_owner_) rc_owner_->Ref();
|
||||||
token = cm->get_cancellation_token();
|
token = cm->get_cancellation_token();
|
||||||
already_cancelled = !cm->RegisterCallback(token, [this, token, key_hash] {
|
already_cancelled = !cm->RegisterCallback(token, [this, token, key_hash] {
|
||||||
Item* item = nullptr;
|
Item* item = nullptr;
|
||||||
@ -230,10 +244,14 @@ void LocalRendezvous::RecvAsync(const Rendezvous::ParsedKey& key,
|
|||||||
Rendezvous::Args(), item->args, Tensor(), /*is_dead=*/false);
|
Rendezvous::Args(), item->args, Tensor(), /*is_dead=*/false);
|
||||||
delete item;
|
delete item;
|
||||||
}
|
}
|
||||||
|
// Unref case (1) and (4)
|
||||||
|
if (rc_owner_) rc_owner_->Unref();
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
if (already_cancelled) {
|
if (already_cancelled) {
|
||||||
mu_.unlock();
|
mu_.unlock();
|
||||||
|
// Unref case (2)
|
||||||
|
if (rc_owner_) rc_owner_->Unref();
|
||||||
done(StatusGroup::MakeDerived(
|
done(StatusGroup::MakeDerived(
|
||||||
errors::Cancelled("RecvAsync is cancelled.")),
|
errors::Cancelled("RecvAsync is cancelled.")),
|
||||||
Rendezvous::Args(), recv_args, Tensor(), /*is_dead=*/false);
|
Rendezvous::Args(), recv_args, Tensor(), /*is_dead=*/false);
|
||||||
@ -250,10 +268,17 @@ void LocalRendezvous::RecvAsync(const Rendezvous::ParsedKey& key,
|
|||||||
// cancellation manager may no longer be live after `done` is called.
|
// cancellation manager may no longer be live after `done` is called.
|
||||||
queue->push_back(new Item(
|
queue->push_back(new Item(
|
||||||
recv_args,
|
recv_args,
|
||||||
[cm, token, done = std::move(done)](
|
[this, cm, token, done = std::move(done)](
|
||||||
const Status& s, const Rendezvous::Args& send_args,
|
const Status& s, const Rendezvous::Args& send_args,
|
||||||
const Rendezvous::Args& recv_args, const Tensor& v, bool dead) {
|
const Rendezvous::Args& recv_args, const Tensor& v, bool dead) {
|
||||||
cm->TryDeregisterCallback(token);
|
// TryDeregisterCallback returns true when the cancellation callback
|
||||||
|
// is successfully deregistered. If it fails because the CM already
|
||||||
|
// StartAbort, Unref will happen inside the cancellation callback
|
||||||
|
// when called by the CM.
|
||||||
|
if (cm->TryDeregisterCallback(token)) {
|
||||||
|
// Unref case (3)
|
||||||
|
if (this->rc_owner_) this->rc_owner_->Unref();
|
||||||
|
}
|
||||||
done(s, send_args, recv_args, v, dead);
|
done(s, send_args, recv_args, v, dead);
|
||||||
},
|
},
|
||||||
token));
|
token));
|
||||||
|
@ -35,7 +35,11 @@ namespace tensorflow {
|
|||||||
// is not expected to be needed.
|
// is not expected to be needed.
|
||||||
class LocalRendezvous {
|
class LocalRendezvous {
|
||||||
public:
|
public:
|
||||||
LocalRendezvous() = default;
|
// If the class wrapping LocalRendezvous is refcounted (i.e., extending
|
||||||
|
// Rendezvous), pass in its pointer in constructor so the LocalRendezvous
|
||||||
|
// can make sure it outlives the async recv requests.
|
||||||
|
// Pass in nullptr if the wrapping class is not refcounted.
|
||||||
|
explicit LocalRendezvous(Rendezvous* owner) : rc_owner_(owner) {}
|
||||||
~LocalRendezvous();
|
~LocalRendezvous();
|
||||||
|
|
||||||
Status Send(const Rendezvous::ParsedKey& key,
|
Status Send(const Rendezvous::ParsedKey& key,
|
||||||
@ -62,6 +66,9 @@ class LocalRendezvous {
|
|||||||
|
|
||||||
typedef gtl::FlatMap<uint64, ItemQueue> Table;
|
typedef gtl::FlatMap<uint64, ItemQueue> Table;
|
||||||
|
|
||||||
|
// Pointer to the owner class of this LocalRendezvous if it is refcounted.
|
||||||
|
const Rendezvous* rc_owner_;
|
||||||
|
|
||||||
// TODO(zhifengc): shard table_.
|
// TODO(zhifengc): shard table_.
|
||||||
mutex mu_;
|
mutex mu_;
|
||||||
Table table_ TF_GUARDED_BY(mu_);
|
Table table_ TF_GUARDED_BY(mu_);
|
||||||
|
@ -1152,22 +1152,17 @@ TEST(RegisteredKernels, GetRegisteredKernelsForOp) {
|
|||||||
EXPECT_EQ(kernel_list.kernel(0).device_type(), "CPU");
|
EXPECT_EQ(kernel_list.kernel(0).device_type(), "CPU");
|
||||||
}
|
}
|
||||||
|
|
||||||
#define EXTRACT_KERNEL_NAME_AND_BUILDER_IMPL(kernel_name, kernel_builder, ...) \
|
// EXTRACT_KERNEL_NAME_TO_STRING wraps TF_EXTRACT_KERNEL_NAME for testing
|
||||||
constexpr char const* kKernelName = kernel_name; \
|
// (it involves quite a bit of macro-magic).
|
||||||
auto builder = []() { \
|
#define EXTRACT_KERNEL_NAME_TO_STRING_IMPL(name, kernel_builder, ...) name
|
||||||
return std::unique_ptr<KernelDef const>(kernel_builder.Build()); \
|
#define EXTRACT_KERNEL_NAME_TO_STRING(kernel_builder) \
|
||||||
};
|
TF_EXTRACT_KERNEL_NAME(EXTRACT_KERNEL_NAME_TO_STRING_IMPL, kernel_builder)
|
||||||
#define EXTRACT_KERNEL_NAME_AND_BUILDER(kernel_builder) \
|
|
||||||
TF_EXTRACT_KERNEL_NAME(EXTRACT_KERNEL_NAME_AND_BUILDER_IMPL, kernel_builder)
|
|
||||||
|
|
||||||
TEST(RegisterKernelMacro, ExtractName) {
|
TEST(RegisterKernelMacro, ExtractName) {
|
||||||
constexpr char const* kName = "Foo";
|
static constexpr char const* kName = "Foo";
|
||||||
constexpr char const* kLabel = "Label";
|
static constexpr char const* kExtractedName =
|
||||||
EXTRACT_KERNEL_NAME_AND_BUILDER(Name(kName).Label(kLabel));
|
EXTRACT_KERNEL_NAME_TO_STRING(Name(kName).Label("Label"));
|
||||||
EXPECT_THAT(kKernelName, ::testing::StrEq(kName));
|
EXPECT_THAT(kExtractedName, ::testing::StrEq(kName));
|
||||||
std::unique_ptr<KernelDef const> kernel_def = builder();
|
|
||||||
EXPECT_THAT(kernel_def->op(), ::testing::StrEq(kName));
|
|
||||||
EXPECT_THAT(kernel_def->label(), ::testing::StrEq(kLabel));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
@ -151,7 +151,7 @@ Status RendezvousInterface::Recv(const ParsedKey& key, const Args& args,
|
|||||||
namespace {
|
namespace {
|
||||||
class LocalRendezvousWrapper : public Rendezvous {
|
class LocalRendezvousWrapper : public Rendezvous {
|
||||||
public:
|
public:
|
||||||
LocalRendezvousWrapper() = default;
|
LocalRendezvousWrapper() : impl_(this) {}
|
||||||
|
|
||||||
Status Send(const ParsedKey& key, const Args& send_args, const Tensor& val,
|
Status Send(const ParsedKey& key, const Args& send_args, const Tensor& val,
|
||||||
const bool is_dead) override {
|
const bool is_dead) override {
|
||||||
|
@ -670,7 +670,25 @@ Status LayoutSensitiveOpTransposer::UpdateNode(TransposeContext* context,
|
|||||||
Status DefaultLayoutSensitiveOpTransposer::TransposeNode(
|
Status DefaultLayoutSensitiveOpTransposer::TransposeNode(
|
||||||
TransposeContext* context, utils::MutableNodeView* node) {
|
TransposeContext* context, utils::MutableNodeView* node) {
|
||||||
DCHECK(IsDefaultLayoutSensitiveOp(*node->node()));
|
DCHECK(IsDefaultLayoutSensitiveOp(*node->node()));
|
||||||
if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4)) {
|
const auto* output_shape_attr = node->GetAttr(kAttrOutputShape);
|
||||||
|
const auto& shape = output_shape_attr->list().shape(0);
|
||||||
|
const int rank = shape.dim_size();
|
||||||
|
std::string src_format = context->src_format;
|
||||||
|
std::string dst_format = context->dst_format;
|
||||||
|
// Update the format from 4D to 5D layout if necessary.
|
||||||
|
bool allow_5d = rank == 5 && (src_format == "NHWC" || src_format == "NCHW");
|
||||||
|
if (allow_5d) {
|
||||||
|
std::string src_format_3d = src_format == "NHWC" ? "NDHWC" : "NCDHW";
|
||||||
|
std::string dst_format_3d = dst_format == "NHWC" ? "NDHWC" : "NCDHW";
|
||||||
|
context->AssignDeviceAndDataFormats(context->target_device, src_format_3d,
|
||||||
|
dst_format_3d);
|
||||||
|
}
|
||||||
|
if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, rank)) {
|
||||||
|
// Change back to the original layout due to early exit.
|
||||||
|
if (allow_5d) {
|
||||||
|
context->AssignDeviceAndDataFormats(context->target_device, src_format,
|
||||||
|
dst_format);
|
||||||
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
|
VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
|
||||||
@ -679,6 +697,11 @@ Status DefaultLayoutSensitiveOpTransposer::TransposeNode(
|
|||||||
TF_RETURN_IF_ERROR(UpdateNode(context, node));
|
TF_RETURN_IF_ERROR(UpdateNode(context, node));
|
||||||
TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
|
TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
|
||||||
TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
|
TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
|
||||||
|
// Change back the format from 5D to 4D layout.
|
||||||
|
if (allow_5d) {
|
||||||
|
context->AssignDeviceAndDataFormats(context->target_device, src_format,
|
||||||
|
dst_format);
|
||||||
|
}
|
||||||
return context->graph_view->GetMutationBuilder()->Apply();
|
return context->graph_view->GetMutationBuilder()->Apply();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -881,8 +904,26 @@ bool FusedBatchNormGradTransposer::IsTraining(
|
|||||||
Status FusedBatchNormGradTransposer::TransposeNode(
|
Status FusedBatchNormGradTransposer::TransposeNode(
|
||||||
TransposeContext* context, utils::MutableNodeView* node) {
|
TransposeContext* context, utils::MutableNodeView* node) {
|
||||||
DCHECK(IsFusedBatchNormGrad(*node->node()));
|
DCHECK(IsFusedBatchNormGrad(*node->node()));
|
||||||
if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4) ||
|
const auto* output_shape_attr = node->GetAttr(kAttrOutputShape);
|
||||||
|
const auto& shape = output_shape_attr->list().shape(0);
|
||||||
|
const int rank = shape.dim_size();
|
||||||
|
std::string src_format = context->src_format;
|
||||||
|
std::string dst_format = context->dst_format;
|
||||||
|
// Update the format from 4D to 5D layout if necessary.
|
||||||
|
bool allow_5d = rank == 5 && (src_format == "NHWC" || src_format == "NCHW");
|
||||||
|
if (allow_5d) {
|
||||||
|
std::string src_format_3d = src_format == "NHWC" ? "NDHWC" : "NCDHW";
|
||||||
|
std::string dst_format_3d = dst_format == "NHWC" ? "NDHWC" : "NCDHW";
|
||||||
|
context->AssignDeviceAndDataFormats(context->target_device, src_format_3d,
|
||||||
|
dst_format_3d);
|
||||||
|
}
|
||||||
|
if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, rank) ||
|
||||||
!IsTraining(*node)) {
|
!IsTraining(*node)) {
|
||||||
|
// Change back to the original layout due to early exit.
|
||||||
|
if (allow_5d) {
|
||||||
|
context->AssignDeviceAndDataFormats(context->target_device, src_format,
|
||||||
|
dst_format);
|
||||||
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
|
VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
|
||||||
@ -892,6 +933,11 @@ Status FusedBatchNormGradTransposer::TransposeNode(
|
|||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
UpdateFaninEdgesWithOp(context, {0, 1}, node, kOpTranspose));
|
UpdateFaninEdgesWithOp(context, {0, 1}, node, kOpTranspose));
|
||||||
TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
|
TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
|
||||||
|
// Change back the format from 5D to 4D layout.
|
||||||
|
if (allow_5d) {
|
||||||
|
context->AssignDeviceAndDataFormats(context->target_device, src_format,
|
||||||
|
dst_format);
|
||||||
|
}
|
||||||
return context->graph_view->GetMutationBuilder()->Apply();
|
return context->graph_view->GetMutationBuilder()->Apply();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1438,29 +1438,41 @@ Status AddBatchNormNodes(RemapperContext* ctx, const FusedBatchNorm& matched) {
|
|||||||
utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder();
|
utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder();
|
||||||
Status status;
|
Status status;
|
||||||
|
|
||||||
if (fused_node.attr().at(kDataFormat).s() == "NCHW") {
|
string x_format = fused_node.attr().at(kDataFormat).s();
|
||||||
|
if (x_format == "NCHW" or x_format == "NCDHW") {
|
||||||
// Need to reshape the last 4 inputs
|
// Need to reshape the last 4 inputs
|
||||||
NodeDef new_shape;
|
NodeDef new_shape;
|
||||||
const string new_shape_name =
|
const string new_shape_name =
|
||||||
AddPrefixToNodeName("NCHWShape", fused_node.name());
|
AddPrefixToNodeName(x_format + "Shape", fused_node.name());
|
||||||
new_shape.set_name(new_shape_name);
|
new_shape.set_name(new_shape_name);
|
||||||
new_shape.set_op("Const");
|
new_shape.set_op("Const");
|
||||||
new_shape.set_device(fused_node.device());
|
new_shape.set_device(fused_node.device());
|
||||||
*new_shape.add_input() = AsControlDependency(scale);
|
*new_shape.add_input() = AsControlDependency(scale);
|
||||||
(*new_shape.mutable_attr())["dtype"].set_type(DT_INT32);
|
(*new_shape.mutable_attr())["dtype"].set_type(DT_INT32);
|
||||||
Tensor t(DT_INT32, {4});
|
if (x_format == "NCHW") {
|
||||||
t.flat<int32>()(0) = 1;
|
Tensor t(DT_INT32, {4});
|
||||||
t.flat<int32>()(1) = -1;
|
t.flat<int32>()(0) = 1;
|
||||||
t.flat<int32>()(2) = 1;
|
t.flat<int32>()(1) = -1;
|
||||||
t.flat<int32>()(3) = 1;
|
t.flat<int32>()(2) = 1;
|
||||||
t.AsProtoTensorContent(
|
t.flat<int32>()(3) = 1;
|
||||||
(*new_shape.mutable_attr())["value"].mutable_tensor());
|
t.AsProtoTensorContent(
|
||||||
|
(*new_shape.mutable_attr())["value"].mutable_tensor());
|
||||||
|
} else {
|
||||||
|
Tensor t(DT_INT32, {5});
|
||||||
|
t.flat<int32>()(0) = 1;
|
||||||
|
t.flat<int32>()(1) = -1;
|
||||||
|
t.flat<int32>()(2) = 1;
|
||||||
|
t.flat<int32>()(3) = 1;
|
||||||
|
t.flat<int32>()(4) = 1;
|
||||||
|
t.AsProtoTensorContent(
|
||||||
|
(*new_shape.mutable_attr())["value"].mutable_tensor());
|
||||||
|
}
|
||||||
mutation->AddNode(std::move(new_shape), &status);
|
mutation->AddNode(std::move(new_shape), &status);
|
||||||
TF_RETURN_IF_ERROR(status);
|
TF_RETURN_IF_ERROR(status);
|
||||||
|
|
||||||
NodeDef reshaped_scale;
|
NodeDef reshaped_scale;
|
||||||
reshaped_scale.set_name(
|
reshaped_scale.set_name(
|
||||||
AddPrefixToNodeName("NCHWShapedScale", fused_node.name()));
|
AddPrefixToNodeName(x_format + "ShapedScale", fused_node.name()));
|
||||||
reshaped_scale.set_op("Reshape");
|
reshaped_scale.set_op("Reshape");
|
||||||
reshaped_scale.set_device(fused_node.device());
|
reshaped_scale.set_device(fused_node.device());
|
||||||
*reshaped_scale.add_input() = scale;
|
*reshaped_scale.add_input() = scale;
|
||||||
@ -1473,7 +1485,7 @@ Status AddBatchNormNodes(RemapperContext* ctx, const FusedBatchNorm& matched) {
|
|||||||
|
|
||||||
NodeDef reshaped_offset;
|
NodeDef reshaped_offset;
|
||||||
reshaped_offset.set_name(
|
reshaped_offset.set_name(
|
||||||
AddPrefixToNodeName("NCHWShapedOffset", fused_node.name()));
|
AddPrefixToNodeName(x_format + "ShapedOffset", fused_node.name()));
|
||||||
reshaped_offset.set_op("Reshape");
|
reshaped_offset.set_op("Reshape");
|
||||||
reshaped_offset.set_device(fused_node.device());
|
reshaped_offset.set_device(fused_node.device());
|
||||||
*reshaped_offset.add_input() = offset;
|
*reshaped_offset.add_input() = offset;
|
||||||
@ -1486,7 +1498,7 @@ Status AddBatchNormNodes(RemapperContext* ctx, const FusedBatchNorm& matched) {
|
|||||||
|
|
||||||
NodeDef reshaped_mean;
|
NodeDef reshaped_mean;
|
||||||
reshaped_mean.set_name(
|
reshaped_mean.set_name(
|
||||||
AddPrefixToNodeName("NCHWShapedMean", fused_node.name()));
|
AddPrefixToNodeName(x_format + "ShapedMean", fused_node.name()));
|
||||||
reshaped_mean.set_op("Reshape");
|
reshaped_mean.set_op("Reshape");
|
||||||
reshaped_mean.set_device(fused_node.device());
|
reshaped_mean.set_device(fused_node.device());
|
||||||
*reshaped_mean.add_input() = mean;
|
*reshaped_mean.add_input() = mean;
|
||||||
@ -1499,7 +1511,7 @@ Status AddBatchNormNodes(RemapperContext* ctx, const FusedBatchNorm& matched) {
|
|||||||
|
|
||||||
NodeDef reshaped_variance;
|
NodeDef reshaped_variance;
|
||||||
reshaped_variance.set_name(
|
reshaped_variance.set_name(
|
||||||
AddPrefixToNodeName("NCHWShapedVariance", fused_node.name()));
|
AddPrefixToNodeName(x_format + "ShapedVariance", fused_node.name()));
|
||||||
reshaped_variance.set_op("Reshape");
|
reshaped_variance.set_op("Reshape");
|
||||||
reshaped_variance.set_device(fused_node.device());
|
reshaped_variance.set_device(fused_node.device());
|
||||||
*reshaped_variance.add_input() = variance;
|
*reshaped_variance.add_input() = variance;
|
||||||
|
@ -104,6 +104,37 @@ TF_CALL_GPU_ALL_TYPES(REGISTER);
|
|||||||
|
|
||||||
#undef REGISTER
|
#undef REGISTER
|
||||||
|
|
||||||
|
#if defined(_MSC_VER)
|
||||||
|
// Required by MSVC non-release build to ensure the compiler sees all the
|
||||||
|
// template expansions that are needed.
|
||||||
|
#define FORCE_CONCAT(TYPE) \
|
||||||
|
template <> \
|
||||||
|
void ConcatGPU<TYPE>( \
|
||||||
|
OpKernelContext * c, \
|
||||||
|
const std::vector< \
|
||||||
|
std::unique_ptr<typename TTypes<TYPE, 2>::ConstMatrix>>& \
|
||||||
|
inputs_flat, \
|
||||||
|
Tensor* output, typename TTypes<TYPE, 2>::Tensor* output_flat) { \
|
||||||
|
LOG(FATAL) << "Should not be called"; \
|
||||||
|
}
|
||||||
|
|
||||||
|
FORCE_CONCAT(tensorflow::Variant)
|
||||||
|
FORCE_CONCAT(tensorflow::ResourceHandle)
|
||||||
|
FORCE_CONCAT(unsigned short)
|
||||||
|
FORCE_CONCAT(signed char)
|
||||||
|
FORCE_CONCAT(tensorflow::tstring)
|
||||||
|
FORCE_CONCAT(Eigen::QUInt8)
|
||||||
|
FORCE_CONCAT(Eigen::QInt8)
|
||||||
|
FORCE_CONCAT(Eigen::QUInt16)
|
||||||
|
FORCE_CONCAT(Eigen::QInt16)
|
||||||
|
FORCE_CONCAT(Eigen::QInt32)
|
||||||
|
FORCE_CONCAT(unsigned int)
|
||||||
|
FORCE_CONCAT(unsigned __int64)
|
||||||
|
|
||||||
|
#undef FORCE_CONCAT
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
@ -21,6 +21,13 @@ namespace tensorflow {
|
|||||||
namespace functor {
|
namespace functor {
|
||||||
DEFINE_UNARY1(conj, complex64);
|
DEFINE_UNARY1(conj, complex64);
|
||||||
DEFINE_UNARY1(conj, complex128);
|
DEFINE_UNARY1(conj, complex128);
|
||||||
|
|
||||||
|
#if defined(_MSC_VER)
|
||||||
|
// Non-release build with MSVC needs these symbols.
|
||||||
|
DEFINE_UNARY1(conj, float);
|
||||||
|
DEFINE_UNARY1(conj, double);
|
||||||
|
#endif
|
||||||
|
|
||||||
} // namespace functor
|
} // namespace functor
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
@ -155,14 +155,17 @@ class WindowDatasetOp::Dataset : public DatasetBase {
|
|||||||
std::vector<std::vector<Tensor>> window_elements;
|
std::vector<std::vector<Tensor>> window_elements;
|
||||||
Status status = Status::OK();
|
Status status = Status::OK();
|
||||||
{
|
{
|
||||||
|
const size_t target_size = TargetBufferSize(window_size, window_stride);
|
||||||
|
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
if (!input_impl_ && buffer_.empty()) {
|
if (!input_impl_ &&
|
||||||
|
(buffer_.empty() ||
|
||||||
|
(dataset()->drop_remainder_ && buffer_.size() < target_size))) {
|
||||||
*end_of_sequence = true;
|
*end_of_sequence = true;
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add elements to the buffer.
|
// Add elements to the buffer.
|
||||||
size_t target_size = TargetBufferSize(window_size, window_stride);
|
|
||||||
if (input_impl_) {
|
if (input_impl_) {
|
||||||
*end_of_sequence = false;
|
*end_of_sequence = false;
|
||||||
for (size_t i = buffer_.size(); i < target_size && !*end_of_sequence;
|
for (size_t i = buffer_.size(); i < target_size && !*end_of_sequence;
|
||||||
|
@ -71,6 +71,27 @@ TF_CALL_int8(DEFINE_GPU_KERNELS);
|
|||||||
TF_CALL_uint32(DEFINE_GPU_KERNELS);
|
TF_CALL_uint32(DEFINE_GPU_KERNELS);
|
||||||
#undef DEFINE_GPU_KERNELS
|
#undef DEFINE_GPU_KERNELS
|
||||||
|
|
||||||
|
#if defined(_MSC_VER)
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct functor::DenseUpdate<GPUDevice, tensorflow::Variant, ASSIGN> {
|
||||||
|
void operator()(const GPUDevice& d,
|
||||||
|
typename TTypes<tensorflow::Variant>::Flat params,
|
||||||
|
typename TTypes<tensorflow::Variant>::ConstFlat update) {
|
||||||
|
LOG(FATAL) << "Not handling type tensorflow::Variant";
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// The function is required to force above template specialization. Without it
|
||||||
|
// msvc compiler doesn't include the functor in the object file
|
||||||
|
void _force_instantiation(
|
||||||
|
const GPUDevice& d, typename TTypes<tensorflow::Variant>::Flat params,
|
||||||
|
typename TTypes<tensorflow::Variant>::ConstFlat update) {
|
||||||
|
functor::DenseUpdate<GPUDevice, tensorflow::Variant, ASSIGN> x;
|
||||||
|
x(d, params, update);
|
||||||
|
}
|
||||||
|
#endif // _MSC_VER
|
||||||
|
|
||||||
} // end namespace tensorflow
|
} // end namespace tensorflow
|
||||||
|
|
||||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
@ -22,6 +22,10 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
#include "tensorflow/core/util/gpu_kernel_helper.h"
|
#include "tensorflow/core/util/gpu_kernel_helper.h"
|
||||||
|
|
||||||
|
#if defined(_MSC_VER)
|
||||||
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
@ -251,6 +255,62 @@ template struct functor::DepthToSpaceOpFunctor<GPUDevice, Eigen::half,
|
|||||||
// NCHW_VECT_C with 4 x qint8 can be treated as NCHW int32.
|
// NCHW_VECT_C with 4 x qint8 can be treated as NCHW int32.
|
||||||
template struct functor::DepthToSpaceOpFunctor<GPUDevice, int32, FORMAT_NCHW>;
|
template struct functor::DepthToSpaceOpFunctor<GPUDevice, int32, FORMAT_NCHW>;
|
||||||
|
|
||||||
|
#if defined(_MSC_VER)
|
||||||
|
#define FORCE_DEPTH(TYPE, NAME, NUM, DEVICE) \
|
||||||
|
template <> \
|
||||||
|
struct functor::DepthToSpaceOpFunctor<DEVICE, TYPE, NUM> { \
|
||||||
|
void operator()(const DEVICE& d, \
|
||||||
|
typename TTypes<TYPE, 4>::ConstTensor input, \
|
||||||
|
int block_size, typename TTypes<TYPE, 4>::Tensor output) { \
|
||||||
|
LOG(FATAL) << "Should not be called."; \
|
||||||
|
} \
|
||||||
|
void operator()(const DEVICE& d, \
|
||||||
|
typename TTypes<TYPE, 5>::ConstTensor input, \
|
||||||
|
int block_size, typename TTypes<TYPE, 5>::Tensor output) { \
|
||||||
|
LOG(FATAL) << "Should not be called."; \
|
||||||
|
} \
|
||||||
|
}; \
|
||||||
|
void _force_DepthToSpaceOpFunctor##NAME( \
|
||||||
|
const DEVICE& d, typename TTypes<TYPE, 4>::ConstTensor input, \
|
||||||
|
int block_size, typename TTypes<TYPE, 4>::Tensor output) { \
|
||||||
|
functor::DepthToSpaceOpFunctor<DEVICE, TYPE, NUM> op; \
|
||||||
|
op(d, input, block_size, output); \
|
||||||
|
} \
|
||||||
|
void _force_DepthToSpaceOpFunctor##NAME##_2( \
|
||||||
|
const DEVICE& d, typename TTypes<TYPE, 5>::ConstTensor input, \
|
||||||
|
int block_size, typename TTypes<TYPE, 5>::Tensor output) { \
|
||||||
|
functor::DepthToSpaceOpFunctor<DEVICE, TYPE, NUM> op; \
|
||||||
|
op(d, input, block_size, output); \
|
||||||
|
}
|
||||||
|
|
||||||
|
FORCE_DEPTH(__int64, int64, FORMAT_NCHW, Eigen::ThreadPoolDevice)
|
||||||
|
FORCE_DEPTH(unsigned __int64, uint64, FORMAT_NCHW, Eigen::ThreadPoolDevice)
|
||||||
|
FORCE_DEPTH(unsigned int, uint, FORMAT_NCHW, Eigen::ThreadPoolDevice)
|
||||||
|
FORCE_DEPTH(int, int, FORMAT_NCHW, Eigen::ThreadPoolDevice)
|
||||||
|
FORCE_DEPTH(unsigned short, ushort, FORMAT_NCHW, Eigen::ThreadPoolDevice)
|
||||||
|
FORCE_DEPTH(short, short, FORMAT_NCHW, Eigen::ThreadPoolDevice)
|
||||||
|
FORCE_DEPTH(unsigned char, uchar, FORMAT_NCHW, Eigen::ThreadPoolDevice)
|
||||||
|
FORCE_DEPTH(signed char, char, FORMAT_NCHW, Eigen::ThreadPoolDevice)
|
||||||
|
FORCE_DEPTH(bfloat16, bfloat16, FORMAT_NCHW, Eigen::ThreadPoolDevice)
|
||||||
|
FORCE_DEPTH(double, double, FORMAT_NCHW, Eigen::ThreadPoolDevice)
|
||||||
|
FORCE_DEPTH(complex64, complex64, FORMAT_NCHW, Eigen::ThreadPoolDevice)
|
||||||
|
FORCE_DEPTH(complex128, complex128, FORMAT_NCHW, Eigen::ThreadPoolDevice)
|
||||||
|
FORCE_DEPTH(bool, bool, FORMAT_NCHW, Eigen::ThreadPoolDevice)
|
||||||
|
FORCE_DEPTH(tensorflow::tstring, tstring, FORMAT_NCHW, Eigen::ThreadPoolDevice)
|
||||||
|
FORCE_DEPTH(tensorflow::ResourceHandle, ResourceHandle, FORMAT_NCHW,
|
||||||
|
Eigen::ThreadPoolDevice)
|
||||||
|
FORCE_DEPTH(tensorflow::Variant, variant, FORMAT_NCHW, Eigen::ThreadPoolDevice)
|
||||||
|
FORCE_DEPTH(Eigen::QInt8, qint8, FORMAT_NCHW, Eigen::ThreadPoolDevice)
|
||||||
|
FORCE_DEPTH(Eigen::QInt8, qint8_2, FORMAT_NHWC, Eigen::ThreadPoolDevice)
|
||||||
|
FORCE_DEPTH(Eigen::half, half, FORMAT_NCHW, Eigen::ThreadPoolDevice)
|
||||||
|
FORCE_DEPTH(float, float, FORMAT_NCHW, Eigen::ThreadPoolDevice)
|
||||||
|
FORCE_DEPTH(Eigen::QInt8, qint8, FORMAT_NCHW, GPUDevice)
|
||||||
|
FORCE_DEPTH(Eigen::QInt8, qint8_2, FORMAT_NHWC, GPUDevice)
|
||||||
|
|
||||||
|
#undef FORCE_DEPTH
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
} // end namespace tensorflow
|
} // end namespace tensorflow
|
||||||
|
|
||||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
@ -131,8 +131,8 @@ SpatialMaxPooling(const Input& input, DenseIndex patchRows,
|
|||||||
.extract_image_patches(
|
.extract_image_patches(
|
||||||
patchRows, patchCols, strideRows, strideCols, in_strideRows,
|
patchRows, patchCols, strideRows, strideCols, in_strideRows,
|
||||||
in_strideCols, padding_type,
|
in_strideCols, padding_type,
|
||||||
-Eigen::NumTraits<typename internal::remove_const<
|
Eigen::NumTraits<typename internal::remove_const<
|
||||||
typename internal::traits<Input>::Scalar>::type>::highest())
|
typename internal::traits<Input>::Scalar>::type>::lowest())
|
||||||
.maximum(reduction_dims)
|
.maximum(reduction_dims)
|
||||||
.reshape(post_reduce_dims);
|
.reshape(post_reduce_dims);
|
||||||
}
|
}
|
||||||
|
@ -1241,15 +1241,15 @@ class FusedBatchNormOpBase : public OpKernel {
|
|||||||
// If use_reserved_space is false, we don't have 5th output.
|
// If use_reserved_space is false, we don't have 5th output.
|
||||||
virtual void ComputeWithReservedSpace(OpKernelContext* context,
|
virtual void ComputeWithReservedSpace(OpKernelContext* context,
|
||||||
bool use_reserved_space) {
|
bool use_reserved_space) {
|
||||||
const Tensor& x = context->input(0);
|
Tensor x = context->input(0);
|
||||||
const Tensor& scale = context->input(1);
|
const Tensor& scale = context->input(1);
|
||||||
const Tensor& offset = context->input(2);
|
const Tensor& offset = context->input(2);
|
||||||
const Tensor& estimated_mean = context->input(3);
|
const Tensor& estimated_mean = context->input(3);
|
||||||
const Tensor& estimated_variance = context->input(4);
|
const Tensor& estimated_variance = context->input(4);
|
||||||
const Tensor* side_input = has_side_input_ ? &context->input(5) : nullptr;
|
const Tensor* side_input = has_side_input_ ? &context->input(5) : nullptr;
|
||||||
|
|
||||||
OP_REQUIRES(context, x.dims() == 4,
|
OP_REQUIRES(context, x.dims() == 4 or x.dims() == 5,
|
||||||
errors::InvalidArgument("input must be 4-dimensional",
|
errors::InvalidArgument("input must be 4 or 5-dimensional",
|
||||||
x.shape().DebugString()));
|
x.shape().DebugString()));
|
||||||
OP_REQUIRES(context, scale.dims() == 1,
|
OP_REQUIRES(context, scale.dims() == 1,
|
||||||
errors::InvalidArgument("scale must be 1-dimensional",
|
errors::InvalidArgument("scale must be 1-dimensional",
|
||||||
@ -1264,6 +1264,21 @@ class FusedBatchNormOpBase : public OpKernel {
|
|||||||
context, estimated_variance.dims() == 1,
|
context, estimated_variance.dims() == 1,
|
||||||
errors::InvalidArgument("estimated_variance must be 1-dimensional",
|
errors::InvalidArgument("estimated_variance must be 1-dimensional",
|
||||||
estimated_variance.shape().DebugString()));
|
estimated_variance.shape().DebugString()));
|
||||||
|
bool use_reshape = (x.dims() == 5);
|
||||||
|
auto x_shape = x.shape();
|
||||||
|
TensorShape dest_shape;
|
||||||
|
if (use_reshape) {
|
||||||
|
const int64 in_batch = GetTensorDim(x, tensor_format_, 'N');
|
||||||
|
int64 in_planes = GetTensorDim(x, tensor_format_, '0');
|
||||||
|
int64 in_rows = GetTensorDim(x, tensor_format_, '1');
|
||||||
|
int64 in_cols = GetTensorDim(x, tensor_format_, '2');
|
||||||
|
const int64 in_depth = GetTensorDim(x, tensor_format_, 'C');
|
||||||
|
dest_shape = ShapeFromFormat(tensor_format_, in_batch,
|
||||||
|
{{in_planes, in_rows * in_cols}}, in_depth);
|
||||||
|
OP_REQUIRES(context, x.CopyFrom(x, dest_shape),
|
||||||
|
errors::InvalidArgument("Error during tensor copy."));
|
||||||
|
}
|
||||||
|
|
||||||
if (has_side_input_) {
|
if (has_side_input_) {
|
||||||
OP_REQUIRES(context, side_input->shape() == x.shape(),
|
OP_REQUIRES(context, side_input->shape() == x.shape(),
|
||||||
errors::InvalidArgument(
|
errors::InvalidArgument(
|
||||||
@ -1282,8 +1297,10 @@ class FusedBatchNormOpBase : public OpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Tensor* y = nullptr;
|
Tensor* y = nullptr;
|
||||||
|
auto alloc_shape = use_reshape ? dest_shape : x_shape;
|
||||||
OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
|
OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
|
||||||
{0}, 0, x.shape(), &y));
|
{0}, 0, alloc_shape, &y));
|
||||||
|
|
||||||
Tensor* batch_mean = nullptr;
|
Tensor* batch_mean = nullptr;
|
||||||
OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
|
OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
|
||||||
{3}, 1, scale.shape(), &batch_mean));
|
{3}, 1, scale.shape(), &batch_mean));
|
||||||
@ -1310,6 +1327,10 @@ class FusedBatchNormOpBase : public OpKernel {
|
|||||||
batch_mean, batch_var, saved_mean, saved_maybe_inv_var,
|
batch_mean, batch_var, saved_mean, saved_maybe_inv_var,
|
||||||
tensor_format_, use_reserved_space);
|
tensor_format_, use_reserved_space);
|
||||||
}
|
}
|
||||||
|
if (use_reshape) {
|
||||||
|
OP_REQUIRES(context, y->CopyFrom(*y, x_shape),
|
||||||
|
errors::InvalidArgument("Error during tensor copy."));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -1375,8 +1396,8 @@ class FusedBatchNormGradOpBase : public OpKernel {
|
|||||||
|
|
||||||
virtual void ComputeWithReservedSpace(OpKernelContext* context,
|
virtual void ComputeWithReservedSpace(OpKernelContext* context,
|
||||||
bool use_reserved_space) {
|
bool use_reserved_space) {
|
||||||
const Tensor& y_backprop = context->input(0);
|
Tensor y_backprop = context->input(0);
|
||||||
const Tensor& x = context->input(1);
|
Tensor x = context->input(1);
|
||||||
const Tensor& scale = context->input(2);
|
const Tensor& scale = context->input(2);
|
||||||
// When is_training=True, batch mean and variance/inverted variance are
|
// When is_training=True, batch mean and variance/inverted variance are
|
||||||
// saved in the forward pass to be reused here. When is_training=False,
|
// saved in the forward pass to be reused here. When is_training=False,
|
||||||
@ -1387,11 +1408,11 @@ class FusedBatchNormGradOpBase : public OpKernel {
|
|||||||
// saves inverted variance.
|
// saves inverted variance.
|
||||||
const Tensor& saved_maybe_inv_var_or_pop_var = context->input(4);
|
const Tensor& saved_maybe_inv_var_or_pop_var = context->input(4);
|
||||||
|
|
||||||
OP_REQUIRES(context, y_backprop.dims() == 4,
|
OP_REQUIRES(context, y_backprop.dims() == 4 or y_backprop.dims() == 5,
|
||||||
errors::InvalidArgument("input must be 4-dimensional",
|
errors::InvalidArgument("input must be 4 or 5-dimensional",
|
||||||
y_backprop.shape().DebugString()));
|
y_backprop.shape().DebugString()));
|
||||||
OP_REQUIRES(context, x.dims() == 4,
|
OP_REQUIRES(context, x.dims() == 4 or x.dims() == 5,
|
||||||
errors::InvalidArgument("input must be 4-dimensional",
|
errors::InvalidArgument("input must be 4 or 5-dimensional",
|
||||||
x.shape().DebugString()));
|
x.shape().DebugString()));
|
||||||
OP_REQUIRES(context, scale.dims() == 1,
|
OP_REQUIRES(context, scale.dims() == 1,
|
||||||
errors::InvalidArgument("scale must be 1-dimensional",
|
errors::InvalidArgument("scale must be 1-dimensional",
|
||||||
@ -1404,10 +1425,27 @@ class FusedBatchNormGradOpBase : public OpKernel {
|
|||||||
errors::InvalidArgument(
|
errors::InvalidArgument(
|
||||||
"saved variance must be 1-dimensional",
|
"saved variance must be 1-dimensional",
|
||||||
saved_maybe_inv_var_or_pop_var.shape().DebugString()));
|
saved_maybe_inv_var_or_pop_var.shape().DebugString()));
|
||||||
|
bool use_reshape = (x.dims() == 5);
|
||||||
|
auto x_shape = x.shape();
|
||||||
|
TensorShape dest_shape;
|
||||||
|
if (use_reshape) {
|
||||||
|
const int64 in_batch = GetTensorDim(x, tensor_format_, 'N');
|
||||||
|
int64 in_planes = GetTensorDim(x, tensor_format_, '0');
|
||||||
|
int64 in_rows = GetTensorDim(x, tensor_format_, '1');
|
||||||
|
int64 in_cols = GetTensorDim(x, tensor_format_, '2');
|
||||||
|
const int64 in_depth = GetTensorDim(x, tensor_format_, 'C');
|
||||||
|
dest_shape = ShapeFromFormat(tensor_format_, in_batch,
|
||||||
|
{{in_planes, in_rows * in_cols}}, in_depth);
|
||||||
|
OP_REQUIRES(context, x.CopyFrom(x, dest_shape),
|
||||||
|
errors::InvalidArgument("Error during tensor copy."));
|
||||||
|
OP_REQUIRES(context, y_backprop.CopyFrom(y_backprop, dest_shape),
|
||||||
|
errors::InvalidArgument("Error during tensor copy."));
|
||||||
|
}
|
||||||
|
|
||||||
Tensor* x_backprop = nullptr;
|
Tensor* x_backprop = nullptr;
|
||||||
|
auto alloc_shape = use_reshape ? dest_shape : x_shape;
|
||||||
OP_REQUIRES_OK(context,
|
OP_REQUIRES_OK(context,
|
||||||
context->allocate_output(0, x.shape(), &x_backprop));
|
context->allocate_output(0, alloc_shape, &x_backprop));
|
||||||
|
|
||||||
const TensorShape& scale_offset_shape = scale.shape();
|
const TensorShape& scale_offset_shape = scale.shape();
|
||||||
Tensor* scale_backprop = nullptr;
|
Tensor* scale_backprop = nullptr;
|
||||||
@ -1441,15 +1479,20 @@ class FusedBatchNormGradOpBase : public OpKernel {
|
|||||||
offset_backprop, use_reserved_space, tensor_format_);
|
offset_backprop, use_reserved_space, tensor_format_);
|
||||||
} else {
|
} else {
|
||||||
// Necessary layout conversion is currently done in python.
|
// Necessary layout conversion is currently done in python.
|
||||||
CHECK(tensor_format_ == FORMAT_NHWC)
|
OP_REQUIRES(context, tensor_format_ == FORMAT_NHWC,
|
||||||
<< "The implementation of FusedBatchNormGrad with is_training=False "
|
errors::InvalidArgument(
|
||||||
"only support "
|
"The implementation of "
|
||||||
<< "NHWC tensor format for now.";
|
"FusedBatchNormGrad with is_training=False only support "
|
||||||
|
"NHWC tensor format for now."));
|
||||||
functor::FusedBatchNormFreezeGrad<Device, T, U>()(
|
functor::FusedBatchNormFreezeGrad<Device, T, U>()(
|
||||||
context, y_backprop, x, scale, saved_mean_or_pop_mean,
|
context, y_backprop, x, scale, saved_mean_or_pop_mean,
|
||||||
saved_maybe_inv_var_or_pop_var, epsilon_, x_backprop, scale_backprop,
|
saved_maybe_inv_var_or_pop_var, epsilon_, x_backprop, scale_backprop,
|
||||||
offset_backprop);
|
offset_backprop);
|
||||||
}
|
}
|
||||||
|
if (use_reshape) {
|
||||||
|
OP_REQUIRES(context, x_backprop->CopyFrom(*x_backprop, x_shape),
|
||||||
|
errors::InvalidArgument("Error during tensor copy."));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -530,6 +530,11 @@ TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPEC);
|
|||||||
|
|
||||||
TF_CALL_GPU_NUMBER_TYPES_NO_HALF(DEFINE_GRAD_GPU_SPEC);
|
TF_CALL_GPU_NUMBER_TYPES_NO_HALF(DEFINE_GRAD_GPU_SPEC);
|
||||||
|
|
||||||
|
#if defined(_MSC_VER)
|
||||||
|
// Required for MSVC debug build
|
||||||
|
TF_CALL_half(DEFINE_GRAD_GPU_SPEC)
|
||||||
|
#endif
|
||||||
|
|
||||||
#undef DEFINE_GPU_SPEC
|
#undef DEFINE_GPU_SPEC
|
||||||
#undef DEFINE_GRAD_GPU_SPEC
|
#undef DEFINE_GRAD_GPU_SPEC
|
||||||
|
|
||||||
|
@ -296,6 +296,9 @@ def _gen_unranked_kernel_fatbin_impl(ctx):
|
|||||||
archs_trimmed.append(arch[3:])
|
archs_trimmed.append(arch[3:])
|
||||||
arch_flag = ",".join(archs_trimmed)
|
arch_flag = ",".join(archs_trimmed)
|
||||||
|
|
||||||
|
# TODO(b/169066682): Generate Fatbin when lowering GPU module.
|
||||||
|
arch_flag = "75"
|
||||||
|
|
||||||
filename = "%s.a" % (name)
|
filename = "%s.a" % (name)
|
||||||
gpu_bin = ctx.outputs.output
|
gpu_bin = ctx.outputs.output
|
||||||
ctx.actions.run(
|
ctx.actions.run(
|
||||||
|
@ -43,7 +43,8 @@ namespace tensorflow {
|
|||||||
// We have to be able to detect and handle overflows in int32, so this function
|
// We have to be able to detect and handle overflows in int32, so this function
|
||||||
// uses doubles and int64's to make sure we have enough room.
|
// uses doubles and int64's to make sure we have enough room.
|
||||||
template <class T>
|
template <class T>
|
||||||
int64 FloatToQuantizedUnclamped(float input, float range_min, float range_max) {
|
inline int64 FloatToQuantizedUnclamped(float input, float range_min,
|
||||||
|
float range_max) {
|
||||||
const int64 lowest_quantized =
|
const int64 lowest_quantized =
|
||||||
static_cast<double>(Eigen::NumTraits<T>::lowest());
|
static_cast<double>(Eigen::NumTraits<T>::lowest());
|
||||||
if (range_min == range_max) {
|
if (range_min == range_max) {
|
||||||
@ -60,6 +61,12 @@ int64 FloatToQuantizedUnclamped(float input, float range_min, float range_max) {
|
|||||||
return quantized;
|
return quantized;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline int64 FloatToQuantizedUnclamped<float>(float input, float range_min,
|
||||||
|
float range_max) {
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
// This converts the float into the final quantized type, clamping/saturating
|
// This converts the float into the final quantized type, clamping/saturating
|
||||||
// any over or underflows.
|
// any over or underflows.
|
||||||
template <class T>
|
template <class T>
|
||||||
|
@ -22,6 +22,10 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
#include "tensorflow/core/util/gpu_kernel_helper.h"
|
#include "tensorflow/core/util/gpu_kernel_helper.h"
|
||||||
|
|
||||||
|
#if defined(_MSC_VER)
|
||||||
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
typedef Eigen::GpuDevice GPUDevice;
|
typedef Eigen::GpuDevice GPUDevice;
|
||||||
@ -252,6 +256,70 @@ template struct functor::SpaceToDepthOpFunctor<GPUDevice, uint8, FORMAT_NHWC>;
|
|||||||
// NCHW_VECT_C with 4 x qint8 can be treated as NCHW int32.
|
// NCHW_VECT_C with 4 x qint8 can be treated as NCHW int32.
|
||||||
template struct functor::SpaceToDepthOpFunctor<GPUDevice, int32, FORMAT_NCHW>;
|
template struct functor::SpaceToDepthOpFunctor<GPUDevice, int32, FORMAT_NCHW>;
|
||||||
|
|
||||||
|
#if defined(_MSC_VER)
|
||||||
|
#define FORCE_DEPTH(TYPE, NAME, NUM, DEVICE) \
|
||||||
|
template <> \
|
||||||
|
struct functor::SpaceToDepthOpFunctor<DEVICE, TYPE, NUM> { \
|
||||||
|
void operator()(const DEVICE& d, \
|
||||||
|
typename TTypes<TYPE, 4>::ConstTensor input, \
|
||||||
|
int block_size, typename TTypes<TYPE, 4>::Tensor output) { \
|
||||||
|
LOG(FATAL) << "Should not be called."; \
|
||||||
|
} \
|
||||||
|
}; \
|
||||||
|
void _force_SpaceToDepthOpFunctor##NAME( \
|
||||||
|
const DEVICE& d, typename TTypes<TYPE, 4>::ConstTensor input, \
|
||||||
|
int block_size, typename TTypes<TYPE, 4>::Tensor output) { \
|
||||||
|
functor::SpaceToDepthOpFunctor<DEVICE, TYPE, NUM> op; \
|
||||||
|
op(d, input, block_size, output); \
|
||||||
|
}
|
||||||
|
|
||||||
|
#define FORCE_DEPTH2(TYPE, NAME, DEVICE) \
|
||||||
|
FORCE_DEPTH(TYPE, NAME, FORMAT_NCHW, DEVICE) \
|
||||||
|
FORCE_DEPTH(TYPE, NAME##_2, FORMAT_NHWC, DEVICE)
|
||||||
|
|
||||||
|
FORCE_DEPTH2(__int64, int64, Eigen::ThreadPoolDevice)
|
||||||
|
FORCE_DEPTH2(unsigned __int64, uint64, Eigen::ThreadPoolDevice)
|
||||||
|
FORCE_DEPTH2(unsigned int, uint, Eigen::ThreadPoolDevice)
|
||||||
|
FORCE_DEPTH2(unsigned short, ushort, Eigen::ThreadPoolDevice)
|
||||||
|
FORCE_DEPTH2(short, short, Eigen::ThreadPoolDevice)
|
||||||
|
FORCE_DEPTH2(signed char, char, Eigen::ThreadPoolDevice)
|
||||||
|
FORCE_DEPTH2(unsigned char, char, Eigen::ThreadPoolDevice)
|
||||||
|
FORCE_DEPTH2(bfloat16, bfloat16, Eigen::ThreadPoolDevice)
|
||||||
|
FORCE_DEPTH2(double, double, Eigen::ThreadPoolDevice)
|
||||||
|
FORCE_DEPTH2(complex64, complex64, Eigen::ThreadPoolDevice)
|
||||||
|
FORCE_DEPTH2(complex128, complex128, Eigen::ThreadPoolDevice)
|
||||||
|
FORCE_DEPTH2(bool, bool, Eigen::ThreadPoolDevice)
|
||||||
|
FORCE_DEPTH2(tensorflow::tstring, tstring, Eigen::ThreadPoolDevice)
|
||||||
|
FORCE_DEPTH2(tensorflow::ResourceHandle, ResourceHandle,
|
||||||
|
Eigen::ThreadPoolDevice)
|
||||||
|
FORCE_DEPTH2(tensorflow::Variant, variant, Eigen::ThreadPoolDevice)
|
||||||
|
FORCE_DEPTH2(Eigen::QInt8, qint8, Eigen::ThreadPoolDevice)
|
||||||
|
FORCE_DEPTH2(Eigen::half, half, Eigen::ThreadPoolDevice)
|
||||||
|
FORCE_DEPTH2(float, float, Eigen::ThreadPoolDevice)
|
||||||
|
FORCE_DEPTH2(int, int, Eigen::ThreadPoolDevice)
|
||||||
|
FORCE_DEPTH2(Eigen::QInt8, qint8gpu, GPUDevice)
|
||||||
|
|
||||||
|
// Special case for int, FORMAT_NHWC
|
||||||
|
template <>
|
||||||
|
struct functor::SpaceToDepthOpFunctor<GPUDevice, int, FORMAT_NHWC> {
|
||||||
|
void operator()(const GPUDevice& d,
|
||||||
|
typename TTypes<int, 4>::ConstTensor input, int block_size,
|
||||||
|
typename TTypes<int, 4>::Tensor output) {
|
||||||
|
LOG(FATAL) << "Should not be called.";
|
||||||
|
}
|
||||||
|
};
|
||||||
|
void _force_SpaceToDepthOpFunctor_int(
|
||||||
|
const GPUDevice& d, typename TTypes<int, 4>::ConstTensor input,
|
||||||
|
int block_size, typename TTypes<int, 4>::Tensor output) {
|
||||||
|
functor::SpaceToDepthOpFunctor<GPUDevice, int, FORMAT_NHWC> op;
|
||||||
|
op(d, input, block_size, output);
|
||||||
|
}
|
||||||
|
|
||||||
|
#undef FORCE_DEPTH
|
||||||
|
#undef FORCE_DEPTH2
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
} // end namespace tensorflow
|
} // end namespace tensorflow
|
||||||
|
|
||||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
@ -58,3 +58,77 @@ op {
|
|||||||
has_minimum: true
|
has_minimum: true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
op {
|
||||||
|
name: "SnapshotDatasetV2"
|
||||||
|
input_arg {
|
||||||
|
name: "input_dataset"
|
||||||
|
type: DT_VARIANT
|
||||||
|
}
|
||||||
|
input_arg {
|
||||||
|
name: "path"
|
||||||
|
type: DT_STRING
|
||||||
|
}
|
||||||
|
input_arg {
|
||||||
|
name: "reader_func_other_args"
|
||||||
|
type_list_attr: "Treader_func_args"
|
||||||
|
}
|
||||||
|
input_arg {
|
||||||
|
name: "shard_func_other_args"
|
||||||
|
type_list_attr: "Tshard_func_args"
|
||||||
|
}
|
||||||
|
output_arg {
|
||||||
|
name: "handle"
|
||||||
|
type: DT_VARIANT
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "output_types"
|
||||||
|
type: "list(type)"
|
||||||
|
has_minimum: true
|
||||||
|
minimum: 1
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "output_shapes"
|
||||||
|
type: "list(shape)"
|
||||||
|
has_minimum: true
|
||||||
|
minimum: 1
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "compression"
|
||||||
|
type: "string"
|
||||||
|
default_value {
|
||||||
|
s: ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "reader_prefix"
|
||||||
|
type: "string"
|
||||||
|
default_value {
|
||||||
|
s: ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "writer_prefix"
|
||||||
|
type: "string"
|
||||||
|
default_value {
|
||||||
|
s: ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "reader_func"
|
||||||
|
type: "func"
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "shard_func"
|
||||||
|
type: "func"
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "Treader_func_args"
|
||||||
|
type: "list(type)"
|
||||||
|
has_minimum: true
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "Tshard_func_args"
|
||||||
|
type: "list(type)"
|
||||||
|
has_minimum: true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -221,7 +221,7 @@ REGISTER_OP("FusedBatchNormV3")
|
|||||||
.Attr("U: {float}")
|
.Attr("U: {float}")
|
||||||
.Attr("epsilon: float = 0.0001")
|
.Attr("epsilon: float = 0.0001")
|
||||||
.Attr("exponential_avg_factor: float = 1.0")
|
.Attr("exponential_avg_factor: float = 1.0")
|
||||||
.Attr(GetConvnetDataFormatAttrString())
|
.Attr(GetConvnetDataFormat2D3DAttrString())
|
||||||
.Attr("is_training: bool = true")
|
.Attr("is_training: bool = true")
|
||||||
.SetShapeFn(shape_inference::FusedBatchNormV3Shape);
|
.SetShapeFn(shape_inference::FusedBatchNormV3Shape);
|
||||||
|
|
||||||
@ -308,7 +308,7 @@ REGISTER_OP("FusedBatchNormGradV3")
|
|||||||
.Attr("T: {half, bfloat16, float}")
|
.Attr("T: {half, bfloat16, float}")
|
||||||
.Attr("U: {float}")
|
.Attr("U: {float}")
|
||||||
.Attr("epsilon: float = 0.0001")
|
.Attr("epsilon: float = 0.0001")
|
||||||
.Attr(GetConvnetDataFormatAttrString())
|
.Attr(GetConvnetDataFormat2D3DAttrString())
|
||||||
.Attr("is_training: bool = true")
|
.Attr("is_training: bool = true")
|
||||||
.SetShapeFn(shape_inference::FusedBatchNormGradShape);
|
.SetShapeFn(shape_inference::FusedBatchNormGradShape);
|
||||||
// --------------------------------------------------------------------------
|
// --------------------------------------------------------------------------
|
||||||
|
@ -44435,6 +44435,20 @@ op {
|
|||||||
s: ""
|
s: ""
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
attr {
|
||||||
|
name: "reader_prefix"
|
||||||
|
type: "string"
|
||||||
|
default_value {
|
||||||
|
s: ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "writer_prefix"
|
||||||
|
type: "string"
|
||||||
|
default_value {
|
||||||
|
s: ""
|
||||||
|
}
|
||||||
|
}
|
||||||
attr {
|
attr {
|
||||||
name: "reader_func"
|
name: "reader_func"
|
||||||
type: "func"
|
type: "func"
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
# Platform-specific build configurations.
|
# Platform-specific build configurations.
|
||||||
|
|
||||||
load("@com_google_protobuf//:protobuf.bzl", "proto_gen")
|
load("@com_google_protobuf//:protobuf.bzl", "proto_gen")
|
||||||
load("//tensorflow:tensorflow.bzl", "clean_dep", "if_not_windows", "if_tpu")
|
load("//tensorflow:tensorflow.bzl", "clean_dep", "if_libtpu", "if_not_windows")
|
||||||
load("//tensorflow/core/platform:build_config_root.bzl", "if_static")
|
load("//tensorflow/core/platform:build_config_root.bzl", "if_static")
|
||||||
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
|
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
|
||||||
load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm")
|
load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm")
|
||||||
@ -814,4 +814,4 @@ def if_llvm_system_z_available(then, otherwise = []):
|
|||||||
})
|
})
|
||||||
|
|
||||||
def tf_tpu_dependencies():
|
def tf_tpu_dependencies():
|
||||||
return if_tpu(["//tensorflow/core/tpu/kernels"])
|
return if_libtpu(["//tensorflow/core/tpu/kernels"])
|
||||||
|
@ -406,7 +406,6 @@ tf_cc_test(
|
|||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
"//tensorflow/core:test_main",
|
"//tensorflow/core:test_main",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
"@com_google_googletest//:gtest",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -62,7 +62,6 @@ tf_cc_test(
|
|||||||
"//tensorflow/core/profiler/utils:xplane_visitor",
|
"//tensorflow/core/profiler/utils:xplane_visitor",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
"@com_google_absl//absl/types:optional",
|
"@com_google_absl//absl/types:optional",
|
||||||
"@com_google_googletest//:gtest",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -16,7 +16,6 @@ limitations under the License.
|
|||||||
#include <ostream>
|
#include <ostream>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
#include <gmock/gmock.h>
|
|
||||||
#include "absl/strings/string_view.h"
|
#include "absl/strings/string_view.h"
|
||||||
#include "absl/types/optional.h"
|
#include "absl/types/optional.h"
|
||||||
#include "tensorflow/core/framework/step_stats.pb.h"
|
#include "tensorflow/core/framework/step_stats.pb.h"
|
||||||
|
@ -21,7 +21,6 @@ limitations under the License.
|
|||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include <gmock/gmock.h>
|
|
||||||
#include "absl/strings/str_cat.h"
|
#include "absl/strings/str_cat.h"
|
||||||
#include "tensorflow/core/platform/env.h"
|
#include "tensorflow/core/platform/env.h"
|
||||||
#include "tensorflow/core/platform/env_time.h"
|
#include "tensorflow/core/platform/env_time.h"
|
||||||
|
@ -108,7 +108,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0
|
#define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0
|
||||||
#define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0
|
#define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0
|
||||||
#define TF_GRAPH_DEF_VERSION 542 // Updated: 2020/10/2
|
#define TF_GRAPH_DEF_VERSION 543 // Updated: 2020/10/3
|
||||||
|
|
||||||
// Checkpoint compatibility versions (the versions field in SavedSliceMeta).
|
// Checkpoint compatibility versions (the versions field in SavedSliceMeta).
|
||||||
//
|
//
|
||||||
|
@ -5,13 +5,11 @@ load(
|
|||||||
"//tensorflow/core/platform:build_config.bzl",
|
"//tensorflow/core/platform:build_config.bzl",
|
||||||
"tf_proto_library",
|
"tf_proto_library",
|
||||||
)
|
)
|
||||||
|
load("//tensorflow:tensorflow.bzl", "if_libtpu", "tf_copts")
|
||||||
load("//tensorflow:tensorflow.bzl", "tf_grpc_cc_dependency") # buildifier: disable=same-origin-load
|
load("//tensorflow:tensorflow.bzl", "tf_grpc_cc_dependency") # buildifier: disable=same-origin-load
|
||||||
load("//tensorflow:tensorflow.bzl", "tf_kernel_library") # buildifier: disable=same-origin-load
|
load("//tensorflow:tensorflow.bzl", "tf_kernel_library") # buildifier: disable=same-origin-load
|
||||||
|
|
||||||
# Config setting to enable go/libtpu support.
|
# Config setting to enable go/libtpu support.
|
||||||
WITH_TPU_SUPPORT = "//tensorflow:with_tpu_support"
|
|
||||||
|
|
||||||
DEFAULT = "//conditions:default"
|
|
||||||
|
|
||||||
package(
|
package(
|
||||||
default_visibility = [
|
default_visibility = [
|
||||||
@ -44,10 +42,10 @@ cc_library(
|
|||||||
name = "tpu_compile_op_common",
|
name = "tpu_compile_op_common",
|
||||||
srcs = ["tpu_compile_op_common.cc"],
|
srcs = ["tpu_compile_op_common.cc"],
|
||||||
hdrs = ["tpu_compile_op_common.h"],
|
hdrs = ["tpu_compile_op_common.h"],
|
||||||
deps = select({
|
deps = if_libtpu(
|
||||||
WITH_TPU_SUPPORT: [":tpu_compilation_metrics"],
|
[":tpu_compilation_metrics"],
|
||||||
DEFAULT: ["//tensorflow/core/tpu/kernels:tpu_compilation_metrics"],
|
["//tensorflow/core/tpu/kernels:tpu_compilation_metrics"],
|
||||||
}) + [
|
) + [
|
||||||
":tpu_compilation_cache_entry_unloader",
|
":tpu_compilation_cache_entry_unloader",
|
||||||
":tpu_compilation_cache_interface",
|
":tpu_compilation_cache_interface",
|
||||||
":tpu_compilation_metrics_hdrs",
|
":tpu_compilation_metrics_hdrs",
|
||||||
@ -97,14 +95,10 @@ tf_kernel_library(
|
|||||||
name = "tpu_configuration_ops",
|
name = "tpu_configuration_ops",
|
||||||
srcs = ["tpu_configuration_ops.cc"],
|
srcs = ["tpu_configuration_ops.cc"],
|
||||||
hdrs = ["tpu_configuration_ops.h"],
|
hdrs = ["tpu_configuration_ops.h"],
|
||||||
copts = select({
|
deps = if_libtpu(
|
||||||
WITH_TPU_SUPPORT: ["-DLIBTFTPU"],
|
[":tpu_util"],
|
||||||
DEFAULT: [],
|
["//tensorflow/core/tpu/kernels:tpu_util"],
|
||||||
}),
|
) + [
|
||||||
deps = select({
|
|
||||||
WITH_TPU_SUPPORT: [":tpu_util"],
|
|
||||||
DEFAULT: ["//tensorflow/core/tpu/kernels:tpu_util"],
|
|
||||||
}) + [
|
|
||||||
":tpu_compilation_cache_factory",
|
":tpu_compilation_cache_factory",
|
||||||
":tpu_compilation_cache_interface",
|
":tpu_compilation_cache_interface",
|
||||||
":tpu_compilation_cache_local_lookup",
|
":tpu_compilation_cache_local_lookup",
|
||||||
@ -346,10 +340,10 @@ cc_library(
|
|||||||
name = "tpu_compilation_cache_interface",
|
name = "tpu_compilation_cache_interface",
|
||||||
srcs = ["tpu_compilation_cache_interface.cc"],
|
srcs = ["tpu_compilation_cache_interface.cc"],
|
||||||
hdrs = ["tpu_compilation_cache_interface.h"],
|
hdrs = ["tpu_compilation_cache_interface.h"],
|
||||||
deps = select({
|
deps = if_libtpu(
|
||||||
WITH_TPU_SUPPORT: [":tpu_compilation_metrics"],
|
[":tpu_compilation_metrics"],
|
||||||
DEFAULT: ["//tensorflow/core/tpu/kernels:tpu_compilation_metrics"],
|
["//tensorflow/core/tpu/kernels:tpu_compilation_metrics"],
|
||||||
}) + [
|
) + [
|
||||||
":compiled_subgraph",
|
":compiled_subgraph",
|
||||||
":tpu_compilation_cache_common_proto_cc",
|
":tpu_compilation_cache_common_proto_cc",
|
||||||
":tpu_compilation_cache_entry",
|
":tpu_compilation_cache_entry",
|
||||||
@ -424,10 +418,7 @@ cc_library(
|
|||||||
cc_library(
|
cc_library(
|
||||||
name = "tpu_compilation_metrics",
|
name = "tpu_compilation_metrics",
|
||||||
srcs = ["tpu_compilation_metrics.cc"],
|
srcs = ["tpu_compilation_metrics.cc"],
|
||||||
copts = select({
|
copts = tf_copts(),
|
||||||
WITH_TPU_SUPPORT: ["-DLIBTFTPU"],
|
|
||||||
DEFAULT: [],
|
|
||||||
}),
|
|
||||||
deps = [
|
deps = [
|
||||||
":tpu_compilation_metrics_hdrs",
|
":tpu_compilation_metrics_hdrs",
|
||||||
],
|
],
|
||||||
@ -529,14 +520,11 @@ cc_library(
|
|||||||
cc_library(
|
cc_library(
|
||||||
name = "tpu_compilation_cache_rpc_support_hdrs",
|
name = "tpu_compilation_cache_rpc_support_hdrs",
|
||||||
hdrs = ["tpu_compilation_cache_rpc_support.h"],
|
hdrs = ["tpu_compilation_cache_rpc_support.h"],
|
||||||
copts = select({
|
copts = tf_copts(),
|
||||||
WITH_TPU_SUPPORT: ["-DLIBTFTPU"],
|
deps = if_libtpu(
|
||||||
DEFAULT: [],
|
[":tpu_compilation_cache_proto_cc"],
|
||||||
}),
|
["//tensorflow/core/tpu/kernels:tpu_compilation_cache_cc_proto"],
|
||||||
deps = select({
|
) + [
|
||||||
WITH_TPU_SUPPORT: [":tpu_compilation_cache_proto_cc"], # build_cleaner: keep
|
|
||||||
DEFAULT: ["//tensorflow/core/tpu/kernels:tpu_compilation_cache_cc_proto"], # build_cleaner: keep
|
|
||||||
}) + [
|
|
||||||
":tpu_compilation_cache_entry",
|
":tpu_compilation_cache_entry",
|
||||||
":tpu_compilation_cache_interface",
|
":tpu_compilation_cache_interface",
|
||||||
":tpu_compilation_cache_lookup",
|
":tpu_compilation_cache_lookup",
|
||||||
@ -550,10 +538,7 @@ cc_library(
|
|||||||
cc_library(
|
cc_library(
|
||||||
name = "tpu_compilation_cache_rpc_support",
|
name = "tpu_compilation_cache_rpc_support",
|
||||||
srcs = ["tpu_compilation_cache_rpc_support.cc"],
|
srcs = ["tpu_compilation_cache_rpc_support.cc"],
|
||||||
copts = select({
|
copts = tf_copts(),
|
||||||
WITH_TPU_SUPPORT: ["-DLIBTFTPU"],
|
|
||||||
DEFAULT: [],
|
|
||||||
}),
|
|
||||||
deps = [
|
deps = [
|
||||||
":tpu_compilation_cache_common_proto_cc",
|
":tpu_compilation_cache_common_proto_cc",
|
||||||
":tpu_compilation_cache_proto_cc",
|
":tpu_compilation_cache_proto_cc",
|
||||||
@ -572,14 +557,11 @@ cc_library(
|
|||||||
name = "tpu_compilation_cache_rpc_lookup",
|
name = "tpu_compilation_cache_rpc_lookup",
|
||||||
srcs = ["tpu_compilation_cache_rpc_lookup.cc"],
|
srcs = ["tpu_compilation_cache_rpc_lookup.cc"],
|
||||||
hdrs = ["tpu_compilation_cache_rpc_lookup.h"],
|
hdrs = ["tpu_compilation_cache_rpc_lookup.h"],
|
||||||
copts = select({
|
copts = tf_copts(),
|
||||||
WITH_TPU_SUPPORT: ["-DLIBTFTPU"],
|
deps = if_libtpu(
|
||||||
DEFAULT: [],
|
[":tpu_compilation_cache_rpc_support"],
|
||||||
}),
|
["//tensorflow/core/tpu/kernels:tpu_compilation_cache_rpc_support"],
|
||||||
deps = select({
|
) + [
|
||||||
WITH_TPU_SUPPORT: [":tpu_compilation_cache_rpc_support"],
|
|
||||||
DEFAULT: ["//tensorflow/core/tpu/kernels:tpu_compilation_cache_rpc_support"],
|
|
||||||
}) + [
|
|
||||||
":tpu_compilation_cache_grpc",
|
":tpu_compilation_cache_grpc",
|
||||||
":tpu_compilation_cache_interface",
|
":tpu_compilation_cache_interface",
|
||||||
":tpu_compilation_cache_lookup",
|
":tpu_compilation_cache_lookup",
|
||||||
@ -617,14 +599,11 @@ cc_library(
|
|||||||
name = "tpu_compilation_cache_grpc",
|
name = "tpu_compilation_cache_grpc",
|
||||||
srcs = ["tpu_compilation_cache_grpc.cc"],
|
srcs = ["tpu_compilation_cache_grpc.cc"],
|
||||||
hdrs = ["tpu_compilation_cache_grpc.h"],
|
hdrs = ["tpu_compilation_cache_grpc.h"],
|
||||||
copts = select({
|
copts = tf_copts(),
|
||||||
WITH_TPU_SUPPORT: ["-DLIBTFTPU"],
|
deps = if_libtpu(
|
||||||
DEFAULT: [],
|
[":tpu_compilation_cache_proto_cc"],
|
||||||
}),
|
["//tensorflow/core/tpu/kernels:tpu_compilation_cache_cc_proto"],
|
||||||
deps = select({
|
) + [
|
||||||
WITH_TPU_SUPPORT: [":tpu_compilation_cache_proto_cc"],
|
|
||||||
DEFAULT: ["//tensorflow/core/tpu/kernels:tpu_compilation_cache_cc_proto"],
|
|
||||||
}) + [
|
|
||||||
":tpu_compilation_cache_common_proto_cc",
|
":tpu_compilation_cache_common_proto_cc",
|
||||||
tf_grpc_cc_dependency(),
|
tf_grpc_cc_dependency(),
|
||||||
],
|
],
|
||||||
@ -634,20 +613,17 @@ cc_library(
|
|||||||
name = "tpu_compilation_cache_service",
|
name = "tpu_compilation_cache_service",
|
||||||
srcs = ["tpu_compilation_cache_service.cc"],
|
srcs = ["tpu_compilation_cache_service.cc"],
|
||||||
hdrs = ["tpu_compilation_cache_service.h"],
|
hdrs = ["tpu_compilation_cache_service.h"],
|
||||||
copts = select({
|
copts = tf_copts(),
|
||||||
WITH_TPU_SUPPORT: ["-DLIBTFTPU"],
|
deps = if_libtpu(
|
||||||
DEFAULT: [],
|
[
|
||||||
}),
|
":tpu_compilation_cache_rpc_support",
|
||||||
deps = select({
|
":tpu_compilation_cache_proto_cc",
|
||||||
WITH_TPU_SUPPORT: [
|
|
||||||
":tpu_compilation_cache_rpc_support", # build_cleaner: keep
|
|
||||||
":tpu_compilation_cache_proto_cc", # build_cleaner: keep
|
|
||||||
],
|
],
|
||||||
DEFAULT: [
|
[
|
||||||
"//tensorflow/core/tpu/kernels:tpu_compilation_cache_rpc_support", # build_cleaner: keep
|
"//tensorflow/core/tpu/kernels:tpu_compilation_cache_rpc_support",
|
||||||
"//tensorflow/core/tpu/kernels:tpu_compilation_cache_cc_proto", # build_cleaner: keep
|
"//tensorflow/core/tpu/kernels:tpu_compilation_cache_cc_proto",
|
||||||
],
|
],
|
||||||
}) + [
|
) + [
|
||||||
":tpu_compilation_cache_common_proto_cc",
|
":tpu_compilation_cache_common_proto_cc",
|
||||||
":tpu_compilation_cache_grpc",
|
":tpu_compilation_cache_grpc",
|
||||||
":tpu_compilation_cache_interface",
|
":tpu_compilation_cache_interface",
|
||||||
@ -704,10 +680,7 @@ cc_library(
|
|||||||
name = "tpu_compile_op_impl",
|
name = "tpu_compile_op_impl",
|
||||||
srcs = ["tpu_compile_op_impl.cc"],
|
srcs = ["tpu_compile_op_impl.cc"],
|
||||||
hdrs = ["tpu_compile_op_impl.h"],
|
hdrs = ["tpu_compile_op_impl.h"],
|
||||||
copts = select({
|
copts = tf_copts(),
|
||||||
WITH_TPU_SUPPORT: ["-DLIBTFTPU"],
|
|
||||||
DEFAULT: [],
|
|
||||||
}),
|
|
||||||
deps = [
|
deps = [
|
||||||
":tpu_compilation_cache_key",
|
":tpu_compilation_cache_key",
|
||||||
":tpu_compile_c_api_hdrs",
|
":tpu_compile_c_api_hdrs",
|
||||||
@ -952,14 +925,11 @@ cc_library(
|
|||||||
name = "tpu_pod_state",
|
name = "tpu_pod_state",
|
||||||
srcs = ["tpu_pod_state.cc"],
|
srcs = ["tpu_pod_state.cc"],
|
||||||
hdrs = ["tpu_pod_state.h"],
|
hdrs = ["tpu_pod_state.h"],
|
||||||
copts = select({
|
copts = tf_copts(),
|
||||||
WITH_TPU_SUPPORT: ["-DLIBTFTPU"],
|
deps = if_libtpu(
|
||||||
DEFAULT: [],
|
[":tpu_util"],
|
||||||
}),
|
["//tensorflow/core/tpu/kernels:tpu_util"],
|
||||||
deps = select({
|
) + [
|
||||||
WITH_TPU_SUPPORT: [":tpu_util"],
|
|
||||||
DEFAULT: ["//tensorflow/core/tpu/kernels:tpu_util"],
|
|
||||||
}) + [
|
|
||||||
":tpu_compilation_cache_service",
|
":tpu_compilation_cache_service",
|
||||||
"//tensorflow/c:tf_status",
|
"//tensorflow/c:tf_status",
|
||||||
"//tensorflow/c:tf_status_helper",
|
"//tensorflow/c:tf_status_helper",
|
||||||
|
@ -30,11 +30,11 @@ namespace tensorflow {
|
|||||||
namespace tpu {
|
namespace tpu {
|
||||||
|
|
||||||
static const char* grpcTpuCompilationCacheService_method_names[] = {
|
static const char* grpcTpuCompilationCacheService_method_names[] = {
|
||||||
#if defined(LIBTFTPU)
|
#if defined(LIBTPU_ON_GCE)
|
||||||
"/tensorflow.tpu.TpuCompilationCacheServiceExternal/GetTpuProgram",
|
"/tensorflow.tpu.TpuCompilationCacheServiceExternal/GetTpuProgram",
|
||||||
#else // LIBTFTPU
|
#else // LIBTPU_ON_GCE
|
||||||
"/tensorflow.tpu.TpuCompilationCacheService/GetTpuProgram",
|
"/tensorflow.tpu.TpuCompilationCacheService/GetTpuProgram",
|
||||||
#endif // LIBTFTPU
|
#endif // LIBTPU_ON_GCE
|
||||||
};
|
};
|
||||||
|
|
||||||
std::unique_ptr<grpc::TpuCompilationCacheService::Stub>
|
std::unique_ptr<grpc::TpuCompilationCacheService::Stub>
|
||||||
|
@ -35,7 +35,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <functional>
|
#include <functional>
|
||||||
|
|
||||||
#if defined(LIBTFTPU)
|
#if defined(LIBTPU_ON_GCE)
|
||||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache.pb.h"
|
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache.pb.h"
|
||||||
#else
|
#else
|
||||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache.pb.h" // copybara"
|
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache.pb.h" // copybara"
|
||||||
@ -48,7 +48,7 @@ namespace grpc {
|
|||||||
class TpuCompilationCacheService final {
|
class TpuCompilationCacheService final {
|
||||||
public:
|
public:
|
||||||
using RequestType = ::tensorflow::tpu::GetTpuProgramRequest;
|
using RequestType = ::tensorflow::tpu::GetTpuProgramRequest;
|
||||||
#if defined(LIBTFTPU)
|
#if defined(LIBTPU_ON_GCE)
|
||||||
using ResponseType = ::tensorflow::tpu::GetTpuProgramResponseExternal;
|
using ResponseType = ::tensorflow::tpu::GetTpuProgramResponseExternal;
|
||||||
#else
|
#else
|
||||||
using ResponseType = ::tensorflow::tpu::GetTpuProgramResponse;
|
using ResponseType = ::tensorflow::tpu::GetTpuProgramResponse;
|
||||||
@ -59,7 +59,7 @@ class TpuCompilationCacheService final {
|
|||||||
enum class MethodId { kGetTpuProgram = 0 };
|
enum class MethodId { kGetTpuProgram = 0 };
|
||||||
|
|
||||||
static constexpr char const* service_full_name() {
|
static constexpr char const* service_full_name() {
|
||||||
#if defined(LIBTFTPU)
|
#if defined(LIBTPU_ON_GCE)
|
||||||
return "tensorflow.tpu.TpuCompilationCacheServiceExternal";
|
return "tensorflow.tpu.TpuCompilationCacheServiceExternal";
|
||||||
#else
|
#else
|
||||||
return "tensorflow.tpu.TpuCompilationCacheService";
|
return "tensorflow.tpu.TpuCompilationCacheService";
|
||||||
|
@ -25,7 +25,7 @@ namespace tensorflow {
|
|||||||
namespace tpu {
|
namespace tpu {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
#if defined(LIBTFTPU)
|
#if defined(LIBTPU_ON_GCE)
|
||||||
using ResponseType = GetTpuProgramResponseExternal;
|
using ResponseType = GetTpuProgramResponseExternal;
|
||||||
#else
|
#else
|
||||||
using ResponseType = GetTpuProgramResponse;
|
using ResponseType = GetTpuProgramResponse;
|
||||||
|
@ -17,7 +17,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h"
|
#include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h"
|
||||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
|
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
|
||||||
#include "tensorflow/core/platform/casts.h"
|
#include "tensorflow/core/platform/casts.h"
|
||||||
#if defined(LIBTFTPU)
|
#if defined(LIBTPU_ON_GCE)
|
||||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache.pb.h"
|
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache.pb.h"
|
||||||
#endif
|
#endif
|
||||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_common.pb.h"
|
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_common.pb.h"
|
||||||
@ -30,7 +30,7 @@ std::shared_ptr<::grpc::ChannelCredentials> CreateChannelCredentials() {
|
|||||||
return ::grpc::InsecureChannelCredentials(); // NOLINT
|
return ::grpc::InsecureChannelCredentials(); // NOLINT
|
||||||
}
|
}
|
||||||
|
|
||||||
#if defined(LIBTFTPU)
|
#if defined(LIBTPU_ON_GCE)
|
||||||
template <>
|
template <>
|
||||||
Status DeserializeRpcResponseToCacheEntry<GetTpuProgramResponseExternal>(
|
Status DeserializeRpcResponseToCacheEntry<GetTpuProgramResponseExternal>(
|
||||||
absl::string_view local_proto_key, GetTpuProgramResponseExternal* response,
|
absl::string_view local_proto_key, GetTpuProgramResponseExternal* response,
|
||||||
@ -156,6 +156,6 @@ xla::StatusOr<std::vector<::grpc::Slice>> SerializeCacheEntryToBufferSlices(
|
|||||||
|
|
||||||
return std::vector<::grpc::Slice>{::grpc::Slice(encoded_header)};
|
return std::vector<::grpc::Slice>{::grpc::Slice(encoded_header)};
|
||||||
}
|
}
|
||||||
#endif // LIBTFTPU
|
#endif // LIBTPU_ON_GCE
|
||||||
} // namespace tpu
|
} // namespace tpu
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -19,7 +19,7 @@ namespace tpu {
|
|||||||
|
|
||||||
// TODO(henrytan): remove this once `TpuCompilationCache` migration to OSS is
|
// TODO(henrytan): remove this once `TpuCompilationCache` migration to OSS is
|
||||||
// completed.
|
// completed.
|
||||||
#if defined(LIBTFTPU)
|
#if defined(LIBTPU_ON_GCE)
|
||||||
/* static */
|
/* static */
|
||||||
void TpuCompilationMetrics::IncrementCacheLookupCount(
|
void TpuCompilationMetrics::IncrementCacheLookupCount(
|
||||||
bool is_cache_hit, absl::string_view session_name) {
|
bool is_cache_hit, absl::string_view session_name) {
|
||||||
@ -36,7 +36,7 @@ void TpuCompilationMetrics::IncrementCompilationCount(
|
|||||||
absl::string_view session_name) {
|
absl::string_view session_name) {
|
||||||
// A placeholder for tracking metrics.
|
// A placeholder for tracking metrics.
|
||||||
}
|
}
|
||||||
#endif // LIBTFTPU
|
#endif // LIBTPU_ON_GCE
|
||||||
|
|
||||||
} // namespace tpu
|
} // namespace tpu
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -68,11 +68,11 @@ class TpuCompileOpImplFactory : public CompileOpImplFactory {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
#if defined(LIBTFTPU)
|
#if defined(LIBTPU_ON_GCE)
|
||||||
REGISTER_MODULE_INITIALIZER(tpu_compile_op_impl_factory, {
|
REGISTER_MODULE_INITIALIZER(tpu_compile_op_impl_factory, {
|
||||||
VLOG(1) << "register TpuCompileOpImplFactory()";
|
VLOG(1) << "register TpuCompileOpImplFactory()";
|
||||||
CompileOpImplFactory::Register(new TpuCompileOpImplFactory());
|
CompileOpImplFactory::Register(new TpuCompileOpImplFactory());
|
||||||
});
|
});
|
||||||
#endif // LIBTFTPU
|
#endif // LIBTPU_ON_GCE
|
||||||
} // namespace tpu
|
} // namespace tpu
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -18,7 +18,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/c/tf_status_helper.h"
|
#include "tensorflow/c/tf_status_helper.h"
|
||||||
#include "tensorflow/core/tpu/tpu_api.h"
|
#include "tensorflow/core/tpu/tpu_api.h"
|
||||||
|
|
||||||
#if defined(LIBTFTPU)
|
#if defined(LIBTPU_ON_GCE)
|
||||||
#include "tensorflow/core/tpu/kernels/tpu_util.h"
|
#include "tensorflow/core/tpu/kernels/tpu_util.h"
|
||||||
#else
|
#else
|
||||||
#include "tensorflow/core/tpu/kernels/tpu_util.h" // copybara"
|
#include "tensorflow/core/tpu/kernels/tpu_util.h" // copybara"
|
||||||
@ -54,7 +54,7 @@ xla::StatusOr<std::unique_ptr<TpuCompilationCacheService>>
|
|||||||
ConstructCacheService(ResourceMgr* rmgr, int serving_port,
|
ConstructCacheService(ResourceMgr* rmgr, int serving_port,
|
||||||
tpu::TpuCompilationCacheInterface* compilation_cache) {
|
tpu::TpuCompilationCacheInterface* compilation_cache) {
|
||||||
xla::StatusOr<std::unique_ptr<::grpc::ServerBuilder>> server_builder;
|
xla::StatusOr<std::unique_ptr<::grpc::ServerBuilder>> server_builder;
|
||||||
#if defined(LIBTFTPU)
|
#if defined(LIBTPU_ON_GCE)
|
||||||
server_builder = tpu::CreateServerBuilder(serving_port);
|
server_builder = tpu::CreateServerBuilder(serving_port);
|
||||||
#else
|
#else
|
||||||
server_builder = tpu::CreateServerBuilderGoogle(serving_port);
|
server_builder = tpu::CreateServerBuilderGoogle(serving_port);
|
||||||
|
@ -286,10 +286,8 @@ cc_library(
|
|||||||
":cl_command_queue",
|
":cl_command_queue",
|
||||||
":cl_context",
|
":cl_context",
|
||||||
":cl_device",
|
":cl_device",
|
||||||
":cl_kernel",
|
|
||||||
":precision",
|
":precision",
|
||||||
":program_cache",
|
":program_cache",
|
||||||
":tensor",
|
|
||||||
":tensor_type",
|
":tensor_type",
|
||||||
":util",
|
":util",
|
||||||
"//tensorflow/lite/delegates/gpu/common:data_type",
|
"//tensorflow/lite/delegates/gpu/common:data_type",
|
||||||
|
@ -18,7 +18,6 @@ limitations under the License.
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "tensorflow/lite/delegates/gpu/cl/cl_kernel.h"
|
|
||||||
#include "tensorflow/lite/delegates/gpu/cl/util.h"
|
#include "tensorflow/lite/delegates/gpu/cl/util.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/common/shape.h"
|
#include "tensorflow/lite/delegates/gpu/common/shape.h"
|
||||||
|
|
||||||
@ -26,59 +25,6 @@ namespace tflite {
|
|||||||
namespace gpu {
|
namespace gpu {
|
||||||
namespace cl {
|
namespace cl {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
std::string GetKernelOneLayerTextureArray() {
|
|
||||||
return R"(
|
|
||||||
|
|
||||||
__kernel void main_function(__write_only image2d_array_t dst) {
|
|
||||||
int X = (int)(get_global_id(0));
|
|
||||||
int Y = (int)(get_global_id(1));
|
|
||||||
|
|
||||||
write_imagef(dst, (int4)(X, Y, 0, 0), (float4)(2.0, 2.0, 2.0, 2.0));
|
|
||||||
}
|
|
||||||
)";
|
|
||||||
}
|
|
||||||
|
|
||||||
// Some Adreno < 600 have bug with one layer texture array. b/131099086
|
|
||||||
// If we have one layer texture array and will write smt from kernel to this
|
|
||||||
// texture, we will get zeroes instead of actual values.
|
|
||||||
// The same kernel will work, if we use texture array with more than one layer.
|
|
||||||
// With help of this code we can detect this bug.
|
|
||||||
absl::Status CheckKernelSupportOfOneLayerTextureArray(Environment* env,
|
|
||||||
bool* result) {
|
|
||||||
// No bug on Adreno 6xx
|
|
||||||
if (env->device().info_.adreno_info.gpu_version >= 600) {
|
|
||||||
*result = true;
|
|
||||||
return absl::OkStatus();
|
|
||||||
}
|
|
||||||
CLKernel kernel;
|
|
||||||
RETURN_IF_ERROR(env->program_cache()->GetOrCreateCLKernel(
|
|
||||||
GetKernelOneLayerTextureArray(), "main_function", env->context(),
|
|
||||||
env->device(), &kernel));
|
|
||||||
|
|
||||||
Tensor tensor;
|
|
||||||
const BHWC shape(1, 4, 4, 4);
|
|
||||||
RETURN_IF_ERROR(CreateTensor(
|
|
||||||
env->context(), shape,
|
|
||||||
{DataType::FLOAT32, TensorStorageType::TEXTURE_ARRAY, Layout::HWC},
|
|
||||||
&tensor));
|
|
||||||
RETURN_IF_ERROR(kernel.SetMemory(0, tensor.GetMemoryPtr()));
|
|
||||||
RETURN_IF_ERROR(env->queue()->DispatchImplicit(kernel, {4, 4, 1}, {4, 4, 1}));
|
|
||||||
TensorFloat32 tensor_gpu;
|
|
||||||
tensor_gpu.shape = shape;
|
|
||||||
tensor_gpu.data.resize(shape.DimensionsProduct());
|
|
||||||
RETURN_IF_ERROR(tensor.ReadData(env->queue(), &tensor_gpu));
|
|
||||||
|
|
||||||
*result = true;
|
|
||||||
for (int i = 0; i < 64; ++i) {
|
|
||||||
if (tensor_gpu.data[i] != 2.0) {
|
|
||||||
*result = false;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return absl::OkStatus();
|
|
||||||
}
|
|
||||||
|
|
||||||
absl::Status CreateEnvironment(Environment* result, bool shared,
|
absl::Status CreateEnvironment(Environment* result, bool shared,
|
||||||
cl_context_properties egl_context,
|
cl_context_properties egl_context,
|
||||||
cl_context_properties egl_display) {
|
cl_context_properties egl_display) {
|
||||||
@ -99,16 +45,7 @@ absl::Status CreateEnvironment(Environment* result, bool shared,
|
|||||||
*result = Environment(std::move(gpu), std::move(context), std::move(queue),
|
*result = Environment(std::move(gpu), std::move(context), std::move(queue),
|
||||||
std::move(profiling_queue));
|
std::move(profiling_queue));
|
||||||
|
|
||||||
if (result->device().IsAdreno() && result->device().SupportsTextureArray()) {
|
return result->Init();
|
||||||
bool supports_one_layer;
|
|
||||||
RETURN_IF_ERROR(
|
|
||||||
CheckKernelSupportOfOneLayerTextureArray(result, &supports_one_layer));
|
|
||||||
if (!supports_one_layer) {
|
|
||||||
result->GetDevicePtr()->DisableOneLayerTextureArray();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return absl::OkStatus();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
@ -141,10 +78,12 @@ Environment& Environment::operator=(Environment&& environment) {
|
|||||||
|
|
||||||
absl::Status Environment::Init() {
|
absl::Status Environment::Init() {
|
||||||
if (device().IsAdreno() && device().SupportsTextureArray()) {
|
if (device().IsAdreno() && device().SupportsTextureArray()) {
|
||||||
bool supports_one_layer;
|
// Some Adreno < 600 have bug with one layer texture array. b/131099086
|
||||||
RETURN_IF_ERROR(
|
// If we have one layer texture array and will write smt from kernel to this
|
||||||
CheckKernelSupportOfOneLayerTextureArray(this, &supports_one_layer));
|
// texture, we will get zeroes instead of actual values.
|
||||||
if (!supports_one_layer) {
|
// The same kernel will work, if we use texture array with more than one
|
||||||
|
// layer.
|
||||||
|
if (device().info_.adreno_info.gpu_version < 600) {
|
||||||
GetDevicePtr()->DisableOneLayerTextureArray();
|
GetDevicePtr()->DisableOneLayerTextureArray();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -21,7 +21,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/lite/delegates/gpu/cl/cl_device.h"
|
#include "tensorflow/lite/delegates/gpu/cl/cl_device.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/cl/precision.h"
|
#include "tensorflow/lite/delegates/gpu/cl/precision.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/cl/program_cache.h"
|
#include "tensorflow/lite/delegates/gpu/cl/program_cache.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/cl/tensor.h"
|
|
||||||
#include "tensorflow/lite/delegates/gpu/cl/tensor_type.h"
|
#include "tensorflow/lite/delegates/gpu/cl/tensor_type.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/common/data_type.h"
|
#include "tensorflow/lite/delegates/gpu/common/data_type.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||||
|
12
tensorflow/lite/micro/cortex_m_gcc_generic/README.md
Normal file
12
tensorflow/lite/micro/cortex_m_gcc_generic/README.md
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
# Generic Cortex-Mx customizations
|
||||||
|
|
||||||
|
The customization requires a definition where the debug log goes to. The purpose
|
||||||
|
of the generic Cortex-Mx target is to generate a TFLM library file for use in
|
||||||
|
application projects outside of this repo. As the chip HAL and the board
|
||||||
|
specific layer are only defined in the application project, the TFLM library
|
||||||
|
cannot write the debug log anywhere. Instead, we allow the application layer to
|
||||||
|
register a callback function for writing the TFLM kernel debug log.
|
||||||
|
|
||||||
|
# Usage
|
||||||
|
|
||||||
|
See debug_log_callback.h
|
@ -13,14 +13,31 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
|
// Implementation for the DebugLog() function that prints to the debug logger on
|
||||||
|
// an generic cortex-m device.
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif // __cplusplus
|
||||||
|
|
||||||
#include "tensorflow/lite/micro/debug_log.h"
|
#include "tensorflow/lite/micro/debug_log.h"
|
||||||
|
|
||||||
#ifndef TF_LITE_STRIP_ERROR_STRINGS
|
#include "tensorflow/lite/micro/cortex_m_gcc_generic/debug_log_callback.h"
|
||||||
#include <cstdio>
|
|
||||||
#endif
|
|
||||||
|
|
||||||
extern "C" void DebugLog(const char* s) {
|
static DebugLogCallback debug_log_callback = nullptr;
|
||||||
|
|
||||||
|
void RegisterDebugLogCallback(void (*cb)(const char* s)) {
|
||||||
|
debug_log_callback = cb;
|
||||||
|
}
|
||||||
|
|
||||||
|
void DebugLog(const char* s) {
|
||||||
#ifndef TF_LITE_STRIP_ERROR_STRINGS
|
#ifndef TF_LITE_STRIP_ERROR_STRINGS
|
||||||
fprintf(stderr, "%s", s);
|
if (debug_log_callback != nullptr) {
|
||||||
|
debug_log_callback(s);
|
||||||
|
}
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
} // extern "C"
|
||||||
|
#endif // __cplusplus
|
@ -0,0 +1,49 @@
|
|||||||
|
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_LITE_MICRO_CORTEX_M_GCC_GENERIC_DEBUG_LOG_CALLBACK_H_
|
||||||
|
#define TENSORFLOW_LITE_MICRO_CORTEX_M_GCC_GENERIC_DEBUG_LOG_CALLBACK_H_
|
||||||
|
|
||||||
|
// The application layer must implement and register a callback before calling
|
||||||
|
// the network in a way similar to
|
||||||
|
//
|
||||||
|
// void debug_log_printf(const char* s)
|
||||||
|
// {
|
||||||
|
// printf(s);
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// int main(void)
|
||||||
|
// {
|
||||||
|
// // Register callback for printing debug log
|
||||||
|
// RegisterDebugLogCallback(debug_log_printf);
|
||||||
|
//
|
||||||
|
// // now call the network
|
||||||
|
// TfLiteStatus invoke_status = interpreter->Invoke();
|
||||||
|
// }
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif // __cplusplus
|
||||||
|
|
||||||
|
typedef void (*DebugLogCallback)(const char* s);
|
||||||
|
|
||||||
|
// Registers and application-specific callback for debug logging. It must be
|
||||||
|
// called before the first call to DebugLog().
|
||||||
|
void RegisterDebugLogCallback(DebugLogCallback callback);
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
} // extern "C"
|
||||||
|
#endif // __cplusplus
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_MICRO_CORTEX_M_GCC_GENERIC_DEBUG_LOG_CALLBACK_H_
|
@ -15,9 +15,17 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_LITE_MICRO_DEBUG_LOG_H_
|
#ifndef TENSORFLOW_LITE_MICRO_DEBUG_LOG_H_
|
||||||
#define TENSORFLOW_LITE_MICRO_DEBUG_LOG_H_
|
#define TENSORFLOW_LITE_MICRO_DEBUG_LOG_H_
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif // __cplusplus
|
||||||
|
|
||||||
// This function should be implemented by each target platform, and provide a
|
// This function should be implemented by each target platform, and provide a
|
||||||
// way for strings to be output to some text stream. For more information, see
|
// way for strings to be output to some text stream. For more information, see
|
||||||
// tensorflow/lite/micro/debug_log.cc.
|
// tensorflow/lite/micro/debug_log.cc.
|
||||||
extern "C" void DebugLog(const char* s);
|
void DebugLog(const char* s);
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
} // extern "C"
|
||||||
|
#endif // __cplusplus
|
||||||
|
|
||||||
#endif // TENSORFLOW_LITE_MICRO_DEBUG_LOG_H_
|
#endif // TENSORFLOW_LITE_MICRO_DEBUG_LOG_H_
|
||||||
|
@ -52,4 +52,7 @@ tensorflow/lite/micro/tools/ci_build/test_stm32f4.sh PRESUBMIT
|
|||||||
echo "Running Arduino tests at `date`"
|
echo "Running Arduino tests at `date`"
|
||||||
tensorflow/lite/micro/tools/ci_build/test_arduino.sh
|
tensorflow/lite/micro/tools/ci_build/test_arduino.sh
|
||||||
|
|
||||||
|
echo "Running cortex_m_gcc_generic tests at `date`"
|
||||||
|
tensorflow/lite/micro/tools/ci_build/test_cortex_m_gcc_generic.sh
|
||||||
|
|
||||||
echo "Finished all micro tests at `date`"
|
echo "Finished all micro tests at `date`"
|
||||||
|
46
tensorflow/lite/micro/tools/ci_build/test_cortex_m_gcc_generic.sh
Executable file
46
tensorflow/lite/micro/tools/ci_build/test_cortex_m_gcc_generic.sh
Executable file
@ -0,0 +1,46 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
#
|
||||||
|
# Tests the microcontroller code using a Cortex-M4/M4F platform.
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||||
|
ROOT_DIR=${SCRIPT_DIR}/../../../../..
|
||||||
|
cd "${ROOT_DIR}"
|
||||||
|
|
||||||
|
source tensorflow/lite/micro/tools/ci_build/helper_functions.sh
|
||||||
|
|
||||||
|
TARGET=cortex_m_gcc_generic
|
||||||
|
|
||||||
|
# TODO(b/143715361): downloading first to allow for parallel builds.
|
||||||
|
readable_run make -f tensorflow/lite/micro/tools/make/Makefile TAGS=cmsis-nn TARGET=${TARGET} CORTEX_M_CORE=M4F third_party_downloads
|
||||||
|
|
||||||
|
# Build for Cortex-M4 (no FPU) without CMSIS
|
||||||
|
readable_run make -f tensorflow/lite/micro/tools/make/Makefile clean
|
||||||
|
readable_run make -j8 -f tensorflow/lite/micro/tools/make/Makefile TARGET=${TARGET} CORTEX_M_CORE=M4 microlite
|
||||||
|
|
||||||
|
# Build for Cortex-M4F (FPU present) without CMSIS
|
||||||
|
readable_run make -f tensorflow/lite/micro/tools/make/Makefile clean
|
||||||
|
readable_run make -j8 -f tensorflow/lite/micro/tools/make/Makefile TARGET=${TARGET} CORTEX_M_CORE=M4F microlite
|
||||||
|
|
||||||
|
# Build for Cortex-M4 (no FPU) with CMSIS
|
||||||
|
readable_run make -f tensorflow/lite/micro/tools/make/Makefile clean
|
||||||
|
readable_run make -j8 -f tensorflow/lite/micro/tools/make/Makefile TAGS=cmsis-nn TARGET=${TARGET} CORTEX_M_CORE=M4 microlite
|
||||||
|
|
||||||
|
# Build for Cortex-M4 (FPU present) with CMSIS
|
||||||
|
readable_run make -f tensorflow/lite/micro/tools/make/Makefile clean
|
||||||
|
readable_run make -j8 -f tensorflow/lite/micro/tools/make/Makefile TAGS=cmsis-nn TARGET=${TARGET} CORTEX_M_CORE=M4F microlite
|
@ -118,4 +118,11 @@ ifneq ($(filter cmsis-nn,$(ALL_TAGS)),)
|
|||||||
$(CMSIS_PATH)CMSIS/DSP/Include/dsp/matrix_functions.h
|
$(CMSIS_PATH)CMSIS/DSP/Include/dsp/matrix_functions.h
|
||||||
|
|
||||||
|
|
||||||
|
# Need to add the CMSIS Core includes path.
|
||||||
|
# All other CMSIS header files are included with their relative path
|
||||||
|
# in the CMSIS-NN micro kernel source files in
|
||||||
|
# tensorflow/lite/micro/kernels/cmsis-nn
|
||||||
|
INCLUDES += \
|
||||||
|
-I$(CMSIS_PATH)/CMSIS/Core/Include
|
||||||
|
|
||||||
endif
|
endif
|
||||||
|
@ -1,51 +0,0 @@
|
|||||||
# Generic Makefile target for ARM Cortex M4 builds.
|
|
||||||
# REQUIRED:
|
|
||||||
# - TOOLCHAIN_PATH: The path to the ARM GCC toolchain to use.
|
|
||||||
|
|
||||||
ifeq ($(TARGET), cortex_m4_generic)
|
|
||||||
TARGET_ARCH := arm
|
|
||||||
TARGET_TOOLCHAIN_PREFIX := arm-none-eabi-
|
|
||||||
export PATH := $(TOOLCHAIN_PATH):$(PATH)
|
|
||||||
|
|
||||||
PLATFORM_FLAGS = \
|
|
||||||
-DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK \
|
|
||||||
-DTF_LITE_STATIC_MEMORY \
|
|
||||||
-DNDEBUG \
|
|
||||||
-DTF_LITE_MCU_DEBUG_LOG \
|
|
||||||
-D __FPU_PRESENT=1 \
|
|
||||||
-DARM_MATH_CM4 \
|
|
||||||
-fno-rtti \
|
|
||||||
-fmessage-length=0 \
|
|
||||||
-fno-exceptions \
|
|
||||||
-fno-unwind-tables \
|
|
||||||
-ffunction-sections \
|
|
||||||
-fdata-sections \
|
|
||||||
-funsigned-char \
|
|
||||||
-MMD \
|
|
||||||
-mcpu=cortex-m4 \
|
|
||||||
-mthumb \
|
|
||||||
-mfpu=fpv4-sp-d16 \
|
|
||||||
-mfloat-abi=softfp \
|
|
||||||
-std=gnu++11 \
|
|
||||||
-Wvla \
|
|
||||||
-Wall \
|
|
||||||
-Wextra \
|
|
||||||
-Wno-shadow \
|
|
||||||
-Wno-missing-field-initializers \
|
|
||||||
-Wno-strict-aliasing \
|
|
||||||
-Wno-type-limits \
|
|
||||||
-Wno-unused-function \
|
|
||||||
-Wno-unused-parameter \
|
|
||||||
-fno-delete-null-pointer-checks \
|
|
||||||
-fno-threadsafe-statics \
|
|
||||||
-fomit-frame-pointer \
|
|
||||||
-fno-use-cxa-atexit \
|
|
||||||
-O3
|
|
||||||
|
|
||||||
CXXFLAGS += $(PLATFORM_FLAGS)
|
|
||||||
CCFLAGS += $(PLATFORM_FLAGS)
|
|
||||||
|
|
||||||
LDFLAGS += -Wl,--gc-sections
|
|
||||||
|
|
||||||
endif
|
|
||||||
|
|
@ -0,0 +1,36 @@
|
|||||||
|
# Generic Makefile target for ARM Cortex Mx gcc builds.
|
||||||
|
ifeq ($(TARGET), cortex_m_gcc_generic)
|
||||||
|
TARGET_ARCH := arm
|
||||||
|
TARGET_TOOLCHAIN_PREFIX := arm-none-eabi-
|
||||||
|
export PATH := $(MAKEFILE_DIR)/downloads/gcc_embedded/bin/:$(PATH)
|
||||||
|
|
||||||
|
$(eval $(call add_third_party_download,$(GCC_EMBEDDED_URL),$(GCC_EMBEDDED_MD5),gcc_embedded,))
|
||||||
|
|
||||||
|
PLATFORM_FLAGS = \
|
||||||
|
-DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK \
|
||||||
|
-DTF_LITE_MCU_DEBUG_LOG \
|
||||||
|
-fmessage-length=0 \
|
||||||
|
-fno-exceptions \
|
||||||
|
-fno-unwind-tables \
|
||||||
|
-ffunction-sections \
|
||||||
|
-fdata-sections \
|
||||||
|
-funsigned-char \
|
||||||
|
-mcpu=cortex-m4 \
|
||||||
|
-mfpu=fpv4-sp-d16 \
|
||||||
|
-mthumb \
|
||||||
|
-fomit-frame-pointer
|
||||||
|
|
||||||
|
ifeq ($(CORTEX_M_CORE), M4F)
|
||||||
|
PLATFORM_FLAGS += -mfloat-abi=hard
|
||||||
|
else ifeq ($(CORTEX_M_CORE), M4)
|
||||||
|
PLATFORM_FLAGS += -mfloat-abi=softfp
|
||||||
|
else ifeq ($(CORTEX_M_CORE), )
|
||||||
|
$(error CORTEX_M_CORE=[M4|M4F] not defined on the command line)
|
||||||
|
else
|
||||||
|
$(error invalid target defined in command line option CORTEX_M_CORE=[M4|M4F])
|
||||||
|
endif
|
||||||
|
|
||||||
|
CXXFLAGS += $(PLATFORM_FLAGS)
|
||||||
|
CCFLAGS += $(PLATFORM_FLAGS)
|
||||||
|
|
||||||
|
endif
|
@ -2825,6 +2825,7 @@ tf_py_test(
|
|||||||
":framework_combinations",
|
":framework_combinations",
|
||||||
":framework_for_generated_wrappers",
|
":framework_for_generated_wrappers",
|
||||||
":framework_test_lib",
|
":framework_test_lib",
|
||||||
|
":lookup_ops",
|
||||||
":platform_test",
|
":platform_test",
|
||||||
":random_ops",
|
":random_ops",
|
||||||
":resource_variable_ops",
|
":resource_variable_ops",
|
||||||
|
@ -116,6 +116,33 @@ def _is_none_or_undef(value):
|
|||||||
or isinstance(value, variables.Undefined))
|
or isinstance(value, variables.Undefined))
|
||||||
|
|
||||||
|
|
||||||
|
def _verify_tf_condition(cond, tag):
|
||||||
|
"""Ensures that the condition can be used in a TF control flow."""
|
||||||
|
extra_hint = 'to check for None, use `is not None`'
|
||||||
|
cond = ops.convert_to_tensor_v2(cond)
|
||||||
|
|
||||||
|
if cond.dtype != dtypes.bool:
|
||||||
|
raise ValueError(
|
||||||
|
'condition of {} expected to be `tf.bool` scalar, got {}'
|
||||||
|
'; to use as boolean Tensor, use `tf.cast`'
|
||||||
|
'; {}'.format(tag, cond, extra_hint))
|
||||||
|
|
||||||
|
if cond.shape is None or cond.shape.ndims is None:
|
||||||
|
# TODO(mdan): Consider a explicit size check, if not too slow.
|
||||||
|
cond = array_ops.reshape(cond, ())
|
||||||
|
|
||||||
|
elif cond.shape.ndims > 0:
|
||||||
|
known_dims = [d for d in cond.shape.as_list() if d is not None]
|
||||||
|
if np.prod(known_dims) > 1:
|
||||||
|
raise ValueError(
|
||||||
|
'condition of {} expected to be `tf.bool` scalar, got {}'
|
||||||
|
'; {}'.format(tag, cond, extra_hint))
|
||||||
|
else:
|
||||||
|
cond = array_ops.reshape(cond, ())
|
||||||
|
|
||||||
|
return cond
|
||||||
|
|
||||||
|
|
||||||
def _verify_loop_init_vars(init_vars, symbol_names, first_iter_vars=None):
|
def _verify_loop_init_vars(init_vars, symbol_names, first_iter_vars=None):
|
||||||
"""Ensures that all values in the state are valid to use in a TF loop.
|
"""Ensures that all values in the state are valid to use in a TF loop.
|
||||||
|
|
||||||
@ -1038,7 +1065,7 @@ def _tf_while_stmt(test, body, get_state, set_state, symbol_names, opts):
|
|||||||
loop_vars = loop_vars[1:]
|
loop_vars = loop_vars[1:]
|
||||||
|
|
||||||
set_state(loop_vars)
|
set_state(loop_vars)
|
||||||
return test()
|
return _verify_tf_condition(test(), 'while loop')
|
||||||
|
|
||||||
def aug_body(*loop_vars):
|
def aug_body(*loop_vars):
|
||||||
if require_one_iteration:
|
if require_one_iteration:
|
||||||
@ -1141,6 +1168,8 @@ def if_stmt(cond, body, orelse, get_state, set_state, symbol_names, nouts):
|
|||||||
def _tf_if_stmt(
|
def _tf_if_stmt(
|
||||||
cond, body, orelse, get_state, set_state, symbol_names, nouts):
|
cond, body, orelse, get_state, set_state, symbol_names, nouts):
|
||||||
"""Overload of if_stmt that stages a TF cond."""
|
"""Overload of if_stmt that stages a TF cond."""
|
||||||
|
cond = _verify_tf_condition(cond, 'if statement')
|
||||||
|
|
||||||
if not nouts:
|
if not nouts:
|
||||||
prev_get_state, prev_set_state = get_state, set_state
|
prev_get_state, prev_set_state = get_state, set_state
|
||||||
# Control flow V1 wants at least one output.
|
# Control flow V1 wants at least one output.
|
||||||
|
@ -35,6 +35,7 @@ from tensorflow.python.autograph.utils import testing
|
|||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import errors_impl
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
@ -46,6 +47,20 @@ from tensorflow.python.ops.ragged import ragged_factory_ops
|
|||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
|
def _unranked_item(value):
|
||||||
|
rand_rank = random_ops.random_uniform(
|
||||||
|
shape=(), minval=3, maxval=4, dtype=dtypes.int32)
|
||||||
|
rand_shape = array_ops.ones([rand_rank], dtype=dtypes.int32)
|
||||||
|
return array_ops.fill(rand_shape, value)
|
||||||
|
|
||||||
|
|
||||||
|
def _partial_shaped_bools():
|
||||||
|
rand_vect = math_ops.range(
|
||||||
|
random_ops.random_uniform(
|
||||||
|
shape=(), minval=2, maxval=3, dtype=dtypes.int32))
|
||||||
|
return array_ops.expand_dims_v2(rand_vect, 0) < 0
|
||||||
|
|
||||||
|
|
||||||
class ForLoopTest(testing.AutoGraphTestCase):
|
class ForLoopTest(testing.AutoGraphTestCase):
|
||||||
|
|
||||||
def test_tensor(self):
|
def test_tensor(self):
|
||||||
@ -871,6 +886,60 @@ class WhileLoopTest(testing.AutoGraphTestCase):
|
|||||||
with self.assertRaisesRegex(ValueError, r"'s'.* shape \(1,\) after"):
|
with self.assertRaisesRegex(ValueError, r"'s'.* shape \(1,\) after"):
|
||||||
self._basic_loop(0, lambda i, s: np.array([1], dtype=np.int32))
|
self._basic_loop(0, lambda i, s: np.array([1], dtype=np.int32))
|
||||||
|
|
||||||
|
def _fixed_while_loop(self, cond_fn):
|
||||||
|
def test_():
|
||||||
|
return cond_fn(s)
|
||||||
|
|
||||||
|
def body():
|
||||||
|
nonlocal s
|
||||||
|
s += 1
|
||||||
|
|
||||||
|
def set_state(loop_vars):
|
||||||
|
nonlocal s
|
||||||
|
s, = loop_vars
|
||||||
|
|
||||||
|
s = constant_op.constant(0)
|
||||||
|
control_flow.while_stmt(
|
||||||
|
test=test_,
|
||||||
|
body=body,
|
||||||
|
get_state=lambda: (s,),
|
||||||
|
set_state=set_state,
|
||||||
|
symbol_names=('s',),
|
||||||
|
opts={})
|
||||||
|
return s
|
||||||
|
|
||||||
|
def _assertFixedLoopResult(self, cond, expected):
|
||||||
|
def test_fn():
|
||||||
|
return self._fixed_while_loop(cond)
|
||||||
|
self.assertEqual(test_fn(), expected)
|
||||||
|
|
||||||
|
def test_tensor_legal_cond_scalar(self):
|
||||||
|
self._assertFixedLoopResult(lambda s: constant_op.constant(False), 0)
|
||||||
|
self._assertFixedLoopResult(lambda s: s < 2, 2)
|
||||||
|
|
||||||
|
def test_tensor_legal_cond_single_element_nd(self):
|
||||||
|
self._assertFixedLoopResult(lambda s: constant_op.constant([[False]]), 0)
|
||||||
|
self._assertFixedLoopResult(lambda s: _unranked_item(False), 0)
|
||||||
|
|
||||||
|
def _assertCondCheckFails(self, cond):
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
ValueError, 'condition of while loop expected to be `tf.bool`'):
|
||||||
|
self._fixed_while_loop(cond)
|
||||||
|
|
||||||
|
def test_tensor_illegal_cond_not_bool(self):
|
||||||
|
self._assertCondCheckFails(lambda s: constant_op.constant(1))
|
||||||
|
self._assertCondCheckFails(lambda s: s)
|
||||||
|
|
||||||
|
def test_tensor_illegal_cond_not_single_element(self):
|
||||||
|
self._assertCondCheckFails(lambda s: constant_op.constant([1, 2, 3]))
|
||||||
|
self._assertCondCheckFails(lambda s: constant_op.constant([True, False]))
|
||||||
|
|
||||||
|
def test_tensor_illegal_cond_not_single_element_dynamic_shape(self):
|
||||||
|
self._fixed_while_loop(lambda s: _partial_shaped_bools())
|
||||||
|
# TODO(mdan): This error is quite bad. Measure the cost of an assertion.
|
||||||
|
self.assertRaisesRuntime(
|
||||||
|
errors_impl.InvalidArgumentError, 'requested shape has 1')
|
||||||
|
|
||||||
|
|
||||||
class IfStmtTest(testing.AutoGraphTestCase):
|
class IfStmtTest(testing.AutoGraphTestCase):
|
||||||
|
|
||||||
@ -1065,6 +1134,62 @@ class IfStmtTest(testing.AutoGraphTestCase):
|
|||||||
TypeError, "'x' has dtype int32.*but.*float32"):
|
TypeError, "'x' has dtype int32.*but.*float32"):
|
||||||
self._basic_cond(lambda: 1, lambda: 1.0)
|
self._basic_cond(lambda: 1, lambda: 1.0)
|
||||||
|
|
||||||
|
def _fixed_cond(self, cond_val):
|
||||||
|
def body():
|
||||||
|
nonlocal x
|
||||||
|
x = 1
|
||||||
|
|
||||||
|
def orelse():
|
||||||
|
nonlocal x
|
||||||
|
x = -1
|
||||||
|
|
||||||
|
def set_state(cond_vars):
|
||||||
|
nonlocal x
|
||||||
|
x, = cond_vars
|
||||||
|
|
||||||
|
x = 0
|
||||||
|
control_flow.if_stmt(
|
||||||
|
cond=cond_val,
|
||||||
|
body=body,
|
||||||
|
orelse=orelse,
|
||||||
|
get_state=lambda: (x,),
|
||||||
|
set_state=set_state,
|
||||||
|
symbol_names=('x',),
|
||||||
|
nouts=1)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def _assertFixedCondResult(self, cond, expected):
|
||||||
|
def test_fn():
|
||||||
|
return self._fixed_cond(cond)
|
||||||
|
self.assertEqual(test_fn(), expected)
|
||||||
|
|
||||||
|
def test_tensor_legal_cond_scalar(self):
|
||||||
|
self._assertFixedCondResult(constant_op.constant(True), 1)
|
||||||
|
self._assertFixedCondResult(constant_op.constant(False), -1)
|
||||||
|
|
||||||
|
def test_tensor_legal_cond_single_element_nd(self):
|
||||||
|
self._assertFixedCondResult(constant_op.constant([[True]]), 1)
|
||||||
|
self._assertFixedCondResult(constant_op.constant([[False]]), -1)
|
||||||
|
self._assertFixedCondResult(_unranked_item(True), 1)
|
||||||
|
self._assertFixedCondResult(_unranked_item(False), -1)
|
||||||
|
|
||||||
|
def _assertCondCheckFails(self, cond):
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
ValueError, 'condition of if statement expected to be `tf.bool`'):
|
||||||
|
self._fixed_cond(cond)
|
||||||
|
|
||||||
|
def test_tensor_illegal_cond_not_bool(self):
|
||||||
|
self._assertCondCheckFails(constant_op.constant(1))
|
||||||
|
|
||||||
|
def test_tensor_illegal_cond_not_single_element(self):
|
||||||
|
self._assertCondCheckFails(constant_op.constant([1, 2, 3]))
|
||||||
|
self._assertCondCheckFails(constant_op.constant([True, False]))
|
||||||
|
|
||||||
|
def test_tensor_illegal_cond_not_single_element_dynamic_shape(self):
|
||||||
|
self._fixed_cond(_partial_shaped_bools())
|
||||||
|
# TODO(mdan): This error is quite bad. Measure the cost of an assertion.
|
||||||
|
self.assertRaisesRuntime(
|
||||||
|
errors_impl.InvalidArgumentError, 'requested shape has 1')
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -19,6 +19,7 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import re
|
import re
|
||||||
|
import sys
|
||||||
import types
|
import types
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
@ -81,18 +82,29 @@ class AutoGraphTestCase(test.TestCase):
|
|||||||
@def_function.function(autograph=False) # Testing autograph itself.
|
@def_function.function(autograph=False) # Testing autograph itself.
|
||||||
def fn_wrapper():
|
def fn_wrapper():
|
||||||
self.assertions = []
|
self.assertions = []
|
||||||
|
self.raises_cm = None
|
||||||
self.graph_assertions = []
|
self.graph_assertions = []
|
||||||
self.trace_log = []
|
self.trace_log = []
|
||||||
fn()
|
fn()
|
||||||
targets = [args for _, args in self.assertions]
|
targets = [args for _, args in self.assertions]
|
||||||
return targets
|
return targets
|
||||||
|
|
||||||
tensors = fn_wrapper()
|
try:
|
||||||
|
tensors = fn_wrapper()
|
||||||
|
|
||||||
for assertion in self.graph_assertions:
|
for assertion in self.graph_assertions:
|
||||||
assertion(fn_wrapper.get_concrete_function().graph)
|
assertion(fn_wrapper.get_concrete_function().graph)
|
||||||
|
|
||||||
|
actuals = self.evaluate(tensors)
|
||||||
|
|
||||||
|
except: # pylint:disable=bare-except
|
||||||
|
if self.raises_cm is not None:
|
||||||
|
# Note: Yes, the Raises and function contexts cross.
|
||||||
|
self.raises_cm.__exit__(*sys.exc_info())
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
actuals = self.evaluate(tensors)
|
|
||||||
for (assertion, _), values in zip(self.assertions, actuals):
|
for (assertion, _), values in zip(self.assertions, actuals):
|
||||||
assertion(*values)
|
assertion(*values)
|
||||||
|
|
||||||
@ -109,6 +121,7 @@ class AutoGraphTestCase(test.TestCase):
|
|||||||
super().setUp()
|
super().setUp()
|
||||||
self.variables = {}
|
self.variables = {}
|
||||||
self.trace_log = []
|
self.trace_log = []
|
||||||
|
self.raises_cm = None
|
||||||
op_callbacks.add_op_callback(self._op_callback)
|
op_callbacks.add_op_callback(self._op_callback)
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
@ -145,3 +158,9 @@ class AutoGraphTestCase(test.TestCase):
|
|||||||
|
|
||||||
def assertDictEqual(self, *args):
|
def assertDictEqual(self, *args):
|
||||||
self.assertions.append((super().assertDictEqual, list(args)))
|
self.assertions.append((super().assertDictEqual, list(args)))
|
||||||
|
|
||||||
|
def assertRaisesRuntime(self, *args):
|
||||||
|
if self.raises_cm is not None:
|
||||||
|
raise ValueError('cannot use more than one assertRaisesRuntime in a test')
|
||||||
|
self.raises_cm = self.assertRaisesRegex(*args)
|
||||||
|
self.raises_cm.__enter__()
|
||||||
|
@ -33,7 +33,7 @@ from tensorflow.python.util.tf_export import tf_export
|
|||||||
# This value changes every day with an automatic CL. It can be modified in code
|
# This value changes every day with an automatic CL. It can be modified in code
|
||||||
# via `forward_compatibility_horizon()` or with the environment variable
|
# via `forward_compatibility_horizon()` or with the environment variable
|
||||||
# TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date.
|
# TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date.
|
||||||
_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2020, 10, 2)
|
_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2020, 10, 3)
|
||||||
_FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS"
|
_FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS"
|
||||||
_FORWARD_COMPATIBILITY_DATE_NUMBER = None
|
_FORWARD_COMPATIBILITY_DATE_NUMBER = None
|
||||||
|
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user