[TF2XLA] experimental_get_compiler_ir
API: compiled IR at a given stage for a function for fixed input
``` In [1]: import tensorflow as tf In [2]: @tf.function(experimental_compile=True) ...: def f(x): ...: return x + 1 In [3]: a = tf.random.normal([10, 10]) In [6]: print(f.experimental_get_compiler_ir(a, stage='hlo')) HloModule a_inference_f_13__.9 ENTRY %a_inference_f_13__.9 (arg0.1: f32[10,10]) -> f32[10,10] { %arg0.1 = f32[10,10]{1,0} parameter(0), parameter_replication={false}, metadata={op_name="XLA_Args"} %reshape.2 = f32[10,10]{1,0} reshape(f32[10,10]{1,0} %arg0.1) %constant.3 = f32[] constant(1), metadata={op_type="AddV2" op_name="add"} %broadcast.4 = f32[10,10]{1,0} broadcast(f32[] %constant.3), dimensions={}, metadata={op_type="AddV2" op_name="add"} %add.5 = f32[10,10]{1,0} add(f32[10,10]{1,0} %reshape.2, f32[10,10]{1,0} %broadcast.4), metadata={op_type="AddV2" op_name="add"} %reshape.6 = f32[10,10]{1,0} reshape(f32[10,10]{1,0} %add.5), metadata={op_name="XLA_Retvals"} %tuple.7 = (f32[10,10]{1,0}) tuple(f32[10,10]{1,0} %reshape.6), metadata={op_name="XLA_Retvals"} ROOT %get-tuple-element.8 = f32[10,10]{1,0} get-tuple-element((f32[10,10]{1,0}) %tuple.7), index=0, metadata={op_name="XLA_Retvals"} } ``` PiperOrigin-RevId: 332325845 Change-Id: Ifafaf629e5c4e4d7e0c42d18da4bcbb503c6ae02
This commit is contained in:
parent
5e007c2770
commit
eeeb6ca25f
@ -185,6 +185,8 @@
|
|||||||
* XLA Support:
|
* XLA Support:
|
||||||
* xla.experimental.compile is deprecated, use
|
* xla.experimental.compile is deprecated, use
|
||||||
`tf.function(experimental_compile=True)` instead
|
`tf.function(experimental_compile=True)` instead
|
||||||
|
* Added `tf.function.experimental_get_compiler_ir` which returns compiler IR
|
||||||
|
(currently 'hlo' and 'optimized_hlo') for given input for given function.
|
||||||
* <ADD RELEASE NOTES HERE>
|
* <ADD RELEASE NOTES HERE>
|
||||||
* Tracing and Debugging:
|
* Tracing and Debugging:
|
||||||
* <ADD RELEASE NOTES HERE>
|
* <ADD RELEASE NOTES HERE>
|
||||||
|
@ -377,6 +377,7 @@ tf_cuda_library(
|
|||||||
"//tensorflow/c/eager:tfe_op_internal",
|
"//tensorflow/c/eager:tfe_op_internal",
|
||||||
"//tensorflow/c/eager:tfe_tensorhandle_internal",
|
"//tensorflow/c/eager:tfe_tensorhandle_internal",
|
||||||
"//tensorflow/compiler/jit:flags",
|
"//tensorflow/compiler/jit:flags",
|
||||||
|
"//tensorflow/compiler/jit:get_compiler_ir",
|
||||||
"//tensorflow/core:core_cpu",
|
"//tensorflow/core:core_cpu",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
|
@ -15,6 +15,7 @@ package_group(
|
|||||||
"//tensorflow/compiler/tf2xla:internal",
|
"//tensorflow/compiler/tf2xla:internal",
|
||||||
],
|
],
|
||||||
packages = [
|
packages = [
|
||||||
|
"//tensorflow/c/...",
|
||||||
"//tensorflow/compiler/tests/...",
|
"//tensorflow/compiler/tests/...",
|
||||||
"//tensorflow/python/...",
|
"//tensorflow/python/...",
|
||||||
],
|
],
|
||||||
@ -387,6 +388,53 @@ cc_library(
|
|||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "get_compiler_ir",
|
||||||
|
srcs = ["get_compiler_ir.cc"],
|
||||||
|
hdrs = ["get_compiler_ir.h"],
|
||||||
|
visibility = [
|
||||||
|
":internal",
|
||||||
|
"//learning/brain/contrib/tpu_modeling/exp/tpu_inference_converter:__pkg__",
|
||||||
|
"//tensorflow/core/common_runtime/eager:__pkg__",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":common",
|
||||||
|
":compilability_check_util",
|
||||||
|
":flags",
|
||||||
|
":xla_device_no_jit_rewrite_registration",
|
||||||
|
":xla_launch_util",
|
||||||
|
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||||
|
"//tensorflow/compiler/xla:statusor",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core/common_runtime:core_cpu_internal",
|
||||||
|
"@com_google_absl//absl/memory",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
"@com_google_absl//absl/strings:str_format",
|
||||||
|
"@com_google_absl//absl/types:span",
|
||||||
|
],
|
||||||
|
alwayslink = 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Header-only version of "flags" library, for linking from the shared object
|
||||||
|
# without ODR violations.
|
||||||
|
cc_library(
|
||||||
|
name = "get_compiler_ir_hdrs_only",
|
||||||
|
hdrs = ["get_compiler_ir.h"],
|
||||||
|
visibility = [
|
||||||
|
":internal",
|
||||||
|
"//learning/brain/contrib/tpu_modeling/exp/tpu_inference_converter:__pkg__",
|
||||||
|
"//tensorflow/core/common_runtime/eager:__pkg__",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/compiler/xla:statusor",
|
||||||
|
"@com_google_absl//absl/memory",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
"@com_google_absl//absl/strings:str_format",
|
||||||
|
"@com_google_absl//absl/types:span",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "xla_kernel_creator",
|
name = "xla_kernel_creator",
|
||||||
srcs = [
|
srcs = [
|
||||||
|
@ -84,6 +84,43 @@ Status MakeCallNodeFromAttribute(const Node& node, const std::string& attr_name,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Utility which searches for values in a sorted list by scanning over it once.
|
||||||
|
// No matter how many times ScanForValue is called, the list is scanned at most
|
||||||
|
// once. However, if a call to ScanForValue skips over a value, that value is
|
||||||
|
// not revisited in future calls to ScanForValue, so callers must take
|
||||||
|
// care to order their calls.
|
||||||
|
//
|
||||||
|
// Useful for merging multiple sorted lists in O(n) time.
|
||||||
|
class SinglePassSearch {
|
||||||
|
public:
|
||||||
|
// Creates a SinglePassSearch object that can be used to search in `values`.
|
||||||
|
// Does not take ownership of `values`. `values` must outlive this.
|
||||||
|
// `values` must be sorted.
|
||||||
|
explicit SinglePassSearch(absl::Span<int const> values)
|
||||||
|
: current_index_(0), values_(values) {}
|
||||||
|
|
||||||
|
// Scans forward in the vector looking for "value", updating the internal
|
||||||
|
// position in to the vector.
|
||||||
|
// Returns true iff the vector contains the given value at or after current
|
||||||
|
// position.
|
||||||
|
// Not thread-safe.
|
||||||
|
bool ScanForValue(int value) {
|
||||||
|
while (current_index_ < values_.size() &&
|
||||||
|
values_[current_index_] <= value) {
|
||||||
|
if (values_[current_index_] == value) {
|
||||||
|
current_index_++;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
current_index_++;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
int current_index_;
|
||||||
|
const absl::Span<int const> values_;
|
||||||
|
};
|
||||||
|
|
||||||
} // anonymous namespace
|
} // anonymous namespace
|
||||||
|
|
||||||
RecursiveCompilabilityChecker::UncompilableNodesMap
|
RecursiveCompilabilityChecker::UncompilableNodesMap
|
||||||
@ -564,6 +601,44 @@ Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
tensorflow::MemoryTypeVector GetInputMemoryTypes(
|
||||||
|
const tensorflow::FunctionBody* fbody,
|
||||||
|
absl::Span<int const> constant_arg_indices,
|
||||||
|
absl::Span<int const> resource_arg_indices) {
|
||||||
|
// Set input and output memory types.
|
||||||
|
tensorflow::MemoryTypeVector input_memory_types(fbody->arg_types.size(),
|
||||||
|
tensorflow::DEVICE_MEMORY);
|
||||||
|
// These indices are used only for optimization purposes. They allow us
|
||||||
|
// to loop over constant_arg_indices and resource_arg_indices only once
|
||||||
|
// while iterating over all the function arguments checking if it is a
|
||||||
|
// resource or a constant.
|
||||||
|
// The reason we optimized this code is because functions can have a lot of
|
||||||
|
// captured arguments. For example, the backward pass of ResNet50 takes in all
|
||||||
|
// 214 variables and a similar number of activations.
|
||||||
|
SinglePassSearch constants_search(constant_arg_indices);
|
||||||
|
SinglePassSearch resources_search(resource_arg_indices);
|
||||||
|
for (size_t i = 0; i < fbody->arg_types.size(); ++i) {
|
||||||
|
if (resources_search.ScanForValue(i) || constants_search.ScanForValue(i)) {
|
||||||
|
// Compile-time constants and resource handles are expected to be in
|
||||||
|
// host memory.
|
||||||
|
input_memory_types[i] = tensorflow::HOST_MEMORY;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return input_memory_types;
|
||||||
|
}
|
||||||
|
|
||||||
|
tensorflow::MemoryTypeVector GetOutputMemoryTypes(
|
||||||
|
const tensorflow::FunctionBody* fbody) {
|
||||||
|
tensorflow::MemoryTypeVector output_memory_types(fbody->ret_types.size(),
|
||||||
|
tensorflow::DEVICE_MEMORY);
|
||||||
|
for (size_t i = 0; i < fbody->ret_types.size(); ++i) {
|
||||||
|
if (fbody->ret_types[i] == tensorflow::DT_RESOURCE) {
|
||||||
|
output_memory_types[i] = tensorflow::HOST_MEMORY;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return output_memory_types;
|
||||||
|
}
|
||||||
|
|
||||||
static auto const ops_triggering_xla_compilation =
|
static auto const ops_triggering_xla_compilation =
|
||||||
new absl::flat_hash_set<std::string>{"XlaBroadcastHelper",
|
new absl::flat_hash_set<std::string>{"XlaBroadcastHelper",
|
||||||
"XlaConv",
|
"XlaConv",
|
||||||
|
@ -282,6 +282,41 @@ Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr,
|
|||||||
// set.
|
// set.
|
||||||
bool CanCreateXlaKernel(const NodeDef& node_def);
|
bool CanCreateXlaKernel(const NodeDef& node_def);
|
||||||
|
|
||||||
|
// Returns memory types for the input.
|
||||||
|
// `constant_arg_indices` and `resource_arg_indices` are sorted arrays of
|
||||||
|
// indices corresponding to constant and resource arguments respectively.
|
||||||
|
//
|
||||||
|
// One might wonder, about the case where a compile-time constant argument
|
||||||
|
// (which must be in host memory) is also used as an input into an op,
|
||||||
|
// e.g. `Add`, that expects its inputs in device memory. Here is how it
|
||||||
|
// works now.
|
||||||
|
// First, what do we mean by "op expects an input in XYZ memory"?
|
||||||
|
// There are two types of "ops" here: the tf2xla kernel and the HLO
|
||||||
|
// computation it builds. The tf2xla kernel needs to retrieve the actual
|
||||||
|
// numeric value of the compile-time constant tensors, so it really expects
|
||||||
|
// them to be on in host memory. However, for other inputs, it refers to them
|
||||||
|
// using xla::ComputationDataHandle, which is just a symbolic handle that
|
||||||
|
// xla::ComputationBuilder assigns. How does this handle gets assigned for
|
||||||
|
// constant arguments? Even constant arguments get an _Arg node in the graph
|
||||||
|
// instantiated for Function compilation. The tf2xla kernel for constant _Arg
|
||||||
|
// nodes takes the constant value, converts it to XlaLiteral, and feeds it
|
||||||
|
// to xla::ComputationBuilder.ConstantLiteral, which returns the handle. This
|
||||||
|
// constant XlaLiteral is included in the HLO graph, and subsequently, in
|
||||||
|
// the actual executable, which is copied to the device before being
|
||||||
|
// executed. Thus, when this executable runs, the constant is available in
|
||||||
|
// device memory.
|
||||||
|
tensorflow::MemoryTypeVector GetInputMemoryTypes(
|
||||||
|
const tensorflow::FunctionBody* fbody,
|
||||||
|
absl::Span<int const> constant_arg_indices,
|
||||||
|
absl::Span<int const> resource_arg_indices);
|
||||||
|
|
||||||
|
// Returns output memory types.
|
||||||
|
//
|
||||||
|
// XlaLaunch kernel keeps all outputs (including constants, which it copies),
|
||||||
|
// in device memory except for resources.
|
||||||
|
tensorflow::MemoryTypeVector GetOutputMemoryTypes(
|
||||||
|
const tensorflow::FunctionBody* fbody);
|
||||||
|
|
||||||
// Check whether graph can trigger XLA compilation.
|
// Check whether graph can trigger XLA compilation.
|
||||||
bool CanTriggerXlaCompilation(const GraphDef& graph);
|
bool CanTriggerXlaCompilation(const GraphDef& graph);
|
||||||
|
|
||||||
|
114
tensorflow/compiler/jit/get_compiler_ir.cc
Normal file
114
tensorflow/compiler/jit/get_compiler_ir.cc
Normal file
@ -0,0 +1,114 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/jit/get_compiler_ir.h"
|
||||||
|
|
||||||
|
#include "absl/memory/memory.h"
|
||||||
|
#include "absl/strings/str_cat.h"
|
||||||
|
#include "absl/strings/str_format.h"
|
||||||
|
#include "tensorflow/compiler/jit/compilability_check_util.h"
|
||||||
|
#include "tensorflow/compiler/jit/defs.h"
|
||||||
|
#include "tensorflow/compiler/jit/flags.h"
|
||||||
|
#include "tensorflow/compiler/jit/xla_launch_util.h"
|
||||||
|
#include "tensorflow/compiler/jit/xla_platform_info.h"
|
||||||
|
#include "tensorflow/compiler/tf2xla/const_analysis.h"
|
||||||
|
#include "tensorflow/core/common_runtime/function.h"
|
||||||
|
#include "tensorflow/core/framework/function.h"
|
||||||
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
#include "tensorflow/core/util/ptr_util.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
xla::StatusOr<std::string> GetCompilerIr(
|
||||||
|
IrExportStage stage, ProcessFunctionLibraryRuntime* pflr,
|
||||||
|
absl::string_view func_name, Device* dev,
|
||||||
|
absl::Span<const Tensor* const> inputs) {
|
||||||
|
NameAttrList function;
|
||||||
|
function.set_name(std::string{func_name});
|
||||||
|
|
||||||
|
FunctionLibraryRuntime* flr = pflr->GetFLR(dev->name());
|
||||||
|
ResourceMgr* rmgr = dev->resource_manager();
|
||||||
|
|
||||||
|
const FunctionBody* fbody = nullptr;
|
||||||
|
std::vector<int> constant_arg_indices;
|
||||||
|
std::vector<int> resource_arg_indices;
|
||||||
|
TF_RETURN_IF_ERROR(GetBodyAndConstantsAndResources(
|
||||||
|
flr, function, &fbody, &constant_arg_indices, &resource_arg_indices));
|
||||||
|
|
||||||
|
MemoryTypeVector input_memory_types =
|
||||||
|
GetInputMemoryTypes(fbody, constant_arg_indices, resource_arg_indices);
|
||||||
|
MemoryTypeVector output_memory_types = GetOutputMemoryTypes(fbody);
|
||||||
|
|
||||||
|
std::vector<VariableInfo> variable_infos;
|
||||||
|
TF_RETURN_IF_ERROR(GetVariableInfosFromInputs(
|
||||||
|
rmgr, dev, inputs, resource_arg_indices, &variable_infos));
|
||||||
|
TF_RETURN_IF_ERROR(LockVariables(absl::MakeSpan(variable_infos)));
|
||||||
|
|
||||||
|
XlaPlatformInfo platform_info = XlaPlatformInfoFromDevice(dev);
|
||||||
|
|
||||||
|
XlaCompilationCache* cache;
|
||||||
|
TF_RETURN_IF_ERROR(rmgr->LookupOrCreate<XlaCompilationCache>(
|
||||||
|
rmgr->default_container(), "xla_cache", &cache,
|
||||||
|
[&](XlaCompilationCache** cache_write_into) {
|
||||||
|
return BuildXlaCompilationCache(dev, platform_info, cache_write_into);
|
||||||
|
}));
|
||||||
|
core::ScopedUnref cache_ref(cache);
|
||||||
|
|
||||||
|
absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
|
||||||
|
|
||||||
|
XlaCompiler::Options options =
|
||||||
|
GenerateCompilerOptions(*cache, *flr, dev,
|
||||||
|
/*stream=*/nullptr, platform_info,
|
||||||
|
/*has_ref_vars=*/false, &tf_allocator_adapter);
|
||||||
|
|
||||||
|
XlaCompiler::CompileOptions compile_options;
|
||||||
|
compile_options.always_return_tuple = false;
|
||||||
|
compile_options.alias_resource_update = true;
|
||||||
|
|
||||||
|
XlaCompiler compiler(options);
|
||||||
|
|
||||||
|
xla::StatusOr<std::vector<XlaCompiler::Argument>> args =
|
||||||
|
XlaComputationLaunchContext::BuildXlaCompilerArguments(
|
||||||
|
constant_arg_indices, inputs, variable_infos);
|
||||||
|
TF_RETURN_IF_ERROR(args.status());
|
||||||
|
|
||||||
|
switch (stage) {
|
||||||
|
case IrExportStage::HLO: {
|
||||||
|
XlaCompiler::CompilationResult result;
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
compiler.CompileFunction(compile_options, function, *args, &result));
|
||||||
|
|
||||||
|
TF_ASSIGN_OR_RETURN(xla::ProgramShape program_shape,
|
||||||
|
result.computation->GetProgramShape());
|
||||||
|
xla::HloModuleConfig config(program_shape);
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
std::unique_ptr<xla::HloModule> new_module,
|
||||||
|
xla::HloModule::CreateFromProto(result.computation->proto(), config));
|
||||||
|
|
||||||
|
return new_module->ToString();
|
||||||
|
}
|
||||||
|
case IrExportStage::OPTIMIZED_HLO: {
|
||||||
|
const XlaCompiler::CompilationResult* compilation_result = nullptr;
|
||||||
|
xla::LocalExecutable* executable = nullptr;
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
cache->Compile(options, function, *args, compile_options,
|
||||||
|
XlaCompilationCache::CompileMode::kStrict,
|
||||||
|
&compilation_result, &executable));
|
||||||
|
return executable->executable()->module().ToString();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
39
tensorflow/compiler/jit/get_compiler_ir.h
Normal file
39
tensorflow/compiler/jit/get_compiler_ir.h
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_COMPILER_JIT_GET_COMPILER_IR_H_
|
||||||
|
#define TENSORFLOW_COMPILER_JIT_GET_COMPILER_IR_H_
|
||||||
|
|
||||||
|
#include "absl/strings/string_view.h"
|
||||||
|
#include "absl/types/span.h"
|
||||||
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
class ProcessFunctionLibraryRuntime;
|
||||||
|
class Device;
|
||||||
|
class Tensor;
|
||||||
|
|
||||||
|
enum class IrExportStage { HLO, OPTIMIZED_HLO };
|
||||||
|
|
||||||
|
// Returns HLO text for a given function `func_name` using library runtime
|
||||||
|
// `runtime` on a device `dev` with given `inputs`.
|
||||||
|
xla::StatusOr<std::string> GetCompilerIr(
|
||||||
|
IrExportStage stage, ProcessFunctionLibraryRuntime* pflr,
|
||||||
|
absl::string_view func_name, Device* dev,
|
||||||
|
absl::Span<const Tensor* const> inputs);
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_COMPILER_JIT_GET_COMPILER_IR_H_
|
@ -30,47 +30,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
#include "tensorflow/core/util/ptr_util.h"
|
#include "tensorflow/core/util/ptr_util.h"
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
// Utility which searches for values in a sorted list by scanning over it once.
|
|
||||||
// No matter how many times ScanForValue is called, the list is scanned at most
|
|
||||||
// once. However, if a call to ScanForValue skips over a value, that value is
|
|
||||||
// not revisited in future calls to ScanForValue, so callers must take
|
|
||||||
// care to order their calls.
|
|
||||||
//
|
|
||||||
// Useful for merging multiple sorted lists in O(n) time.
|
|
||||||
class SinglePassSearch {
|
|
||||||
public:
|
|
||||||
// Creates a SinglePassSearch object that can be used to search in `values`.
|
|
||||||
// Does not take ownership of `values`. `values` must outlive this.
|
|
||||||
// `values` must be sorted.
|
|
||||||
explicit SinglePassSearch(const std::vector<int>* values)
|
|
||||||
: current_index_(0), values_(values) {}
|
|
||||||
|
|
||||||
// Scans forward in the vector looking for "value", updating the internal
|
|
||||||
// position in to the vector.
|
|
||||||
// Returns true iff the vector contains the given value at or after current
|
|
||||||
// position.
|
|
||||||
// Not thread-safe.
|
|
||||||
bool ScanForValue(int value) {
|
|
||||||
while (current_index_ < values_->size() &&
|
|
||||||
(*values_)[current_index_] <= value) {
|
|
||||||
if ((*values_)[current_index_] == value) {
|
|
||||||
current_index_++;
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
current_index_++;
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
int current_index_;
|
|
||||||
const std::vector<int>* values_;
|
|
||||||
};
|
|
||||||
|
|
||||||
} // end namespace
|
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
bool XlaKernelCreator::CanCreateKernel(
|
bool XlaKernelCreator::CanCreateKernel(
|
||||||
@ -130,52 +89,9 @@ static Status CreateXlaKernel(FunctionLibraryRuntime* flr,
|
|||||||
TF_RETURN_IF_ERROR(GetBodyAndConstantsAndResources(
|
TF_RETURN_IF_ERROR(GetBodyAndConstantsAndResources(
|
||||||
flr, function, &fbody, &constant_arg_indices, &resource_arg_indices));
|
flr, function, &fbody, &constant_arg_indices, &resource_arg_indices));
|
||||||
|
|
||||||
// Set input and output memory types.
|
MemoryTypeVector input_memory_types =
|
||||||
MemoryTypeVector input_memory_types(fbody->arg_types.size(), DEVICE_MEMORY);
|
GetInputMemoryTypes(fbody, constant_arg_indices, resource_arg_indices);
|
||||||
// These indices are used only for optimization purposes. They allow us
|
MemoryTypeVector output_memory_types = GetOutputMemoryTypes(fbody);
|
||||||
// to loop over constant_arg_indices and resource_arg_indices only once
|
|
||||||
// while iterating over all the function arguments checking if it is a
|
|
||||||
// resource or a constant.
|
|
||||||
// The reason we optimized this code is because functions can have a lot of
|
|
||||||
// captured arguments. For example, the backward pass of ResNet50 takes in all
|
|
||||||
// 214 variables and a similar number of activations.
|
|
||||||
SinglePassSearch constants_search(&constant_arg_indices);
|
|
||||||
SinglePassSearch resources_search(&resource_arg_indices);
|
|
||||||
for (size_t i = 0; i < fbody->arg_types.size(); ++i) {
|
|
||||||
if (resources_search.ScanForValue(i) || constants_search.ScanForValue(i)) {
|
|
||||||
// Compile-time constants and resource handles are expected to be in
|
|
||||||
// host memory.
|
|
||||||
input_memory_types[i] = HOST_MEMORY;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// One might wonder, about the case where a compile-time constant argument
|
|
||||||
// (which must be in host memory) is also used as an input into an op,
|
|
||||||
// e.g. Add, that expects its inputs in device memory. Here is how it
|
|
||||||
// works now.
|
|
||||||
// First, what do we mean by "op expects an input in XYZ memory"?
|
|
||||||
// There are two types of "ops" here: the tf2xla kernel and the HLO
|
|
||||||
// computation it builds. The tf2xla kernel needs to retrieve the actual
|
|
||||||
// numeric value of the compile-time constant tensors, so it really expects
|
|
||||||
// them to be on in host memory. However, for other inputs, it refers to them
|
|
||||||
// using xla::ComputationDataHandle, which is just a symbolic handle that
|
|
||||||
// xla::ComputationBuilder assigns. How does this handle gets assigned for
|
|
||||||
// constant arguments? Even constant arguments get an _Arg node in the graph
|
|
||||||
// instantiated for Function compilation. The tf2xla kernel for constant _Arg
|
|
||||||
// nodes takes the constant value, converts it to XlaLiteral, and feeds it
|
|
||||||
// to xla::ComputationBuilder.ConstantLiteral, which returns the handle. This
|
|
||||||
// constant XlaLiteral is included in the HLO graph, and subsequently, in
|
|
||||||
// the actual executable, which is copied to the device before being
|
|
||||||
// executed. Thus, when this executable runs, the constant is available in
|
|
||||||
// device memory.
|
|
||||||
|
|
||||||
// XlaLaunch kernel keeps all outputs (including constants, which it copies),
|
|
||||||
// in device memory except for resources.
|
|
||||||
MemoryTypeVector output_memory_types(fbody->ret_types.size(), DEVICE_MEMORY);
|
|
||||||
for (size_t i = 0; i < fbody->ret_types.size(); ++i) {
|
|
||||||
if (fbody->ret_types[i] == DT_RESOURCE) {
|
|
||||||
output_memory_types[i] = HOST_MEMORY;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create the kernel.
|
// Create the kernel.
|
||||||
Device* dev = flr->device();
|
Device* dev = flr->device();
|
||||||
|
@ -6047,6 +6047,7 @@ filegroup(
|
|||||||
"//tensorflow/c:checkpoint_reader", # checkpoint_reader
|
"//tensorflow/c:checkpoint_reader", # checkpoint_reader
|
||||||
"//tensorflow/c:python_api", # tf_session
|
"//tensorflow/c:python_api", # tf_session
|
||||||
"//tensorflow/c:tf_status_helper", # tfe
|
"//tensorflow/c:tf_status_helper", # tfe
|
||||||
|
"//tensorflow/compiler/jit:get_compiler_ir", #tfe
|
||||||
"//tensorflow/compiler/jit:flags", #tfe
|
"//tensorflow/compiler/jit:flags", #tfe
|
||||||
"//tensorflow/compiler/mlir/python:mlir", # mlir
|
"//tensorflow/compiler/mlir/python:mlir", # mlir
|
||||||
"//tensorflow/core/common_runtime:device", # device_lib, tfe, tf_session
|
"//tensorflow/core/common_runtime:device", # device_lib, tfe, tf_session
|
||||||
@ -7557,6 +7558,8 @@ tf_python_pybind_extension(
|
|||||||
"//third_party/python_runtime:headers",
|
"//third_party/python_runtime:headers",
|
||||||
"//tensorflow/c/experimental/saved_model/core:pywrap_required_hdrs",
|
"//tensorflow/c/experimental/saved_model/core:pywrap_required_hdrs",
|
||||||
"//tensorflow/compiler/jit:flags_headers_only",
|
"//tensorflow/compiler/jit:flags_headers_only",
|
||||||
|
"//tensorflow/compiler/jit:get_compiler_ir_hdrs_only",
|
||||||
|
"//tensorflow/c/eager:tfe_tensorhandle_internal",
|
||||||
"//tensorflow/core/common_runtime:core_cpu_headers_lib",
|
"//tensorflow/core/common_runtime:core_cpu_headers_lib",
|
||||||
"//tensorflow/core:framework_headers_lib",
|
"//tensorflow/core:framework_headers_lib",
|
||||||
"//tensorflow/core:lib_headers_for_pybind",
|
"//tensorflow/core:lib_headers_for_pybind",
|
||||||
|
@ -1494,6 +1494,10 @@ class Context(object):
|
|||||||
|
|
||||||
self._virtual_device_map[dev] = virtual_devices
|
self._virtual_device_map[dev] = virtual_devices
|
||||||
|
|
||||||
|
def get_compiler_ir(self, function_name, args, stage="hlo"):
|
||||||
|
return pywrap_tfe.TF_GetCompilerIr(self._context_handle, function_name,
|
||||||
|
stage, self.device_name, args)
|
||||||
|
|
||||||
@deprecated(
|
@deprecated(
|
||||||
None, "XLA:CPU and XLA:GPU devices are deprecated", warn_once=True)
|
None, "XLA:CPU and XLA:GPU devices are deprecated", warn_once=True)
|
||||||
def enable_xla_devices(self):
|
def enable_xla_devices(self):
|
||||||
|
@ -914,6 +914,76 @@ class Function(object):
|
|||||||
return function_lib.defun(fn_with_cond)(canon_args, canon_kwds,
|
return function_lib.defun(fn_with_cond)(canon_args, canon_kwds,
|
||||||
filtered_flat_args)
|
filtered_flat_args)
|
||||||
|
|
||||||
|
def experimental_get_compiler_ir(self, *args, **kwargs):
|
||||||
|
"""Returns compiler IR for the compiled function.
|
||||||
|
|
||||||
|
This API is intended *only* for debugging as there are no guarantees on
|
||||||
|
backwards compatibility of returned IR or the allowed values of `stage`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
*args: Arguments used for compilation; same arguments as used for calling
|
||||||
|
the function. Need to be eager tensors.
|
||||||
|
**kwargs: Keyword arguments used for compilation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Function callable with the stage at which the compiler IR should be
|
||||||
|
serialized. Allowed values for the `stage` are `hlo` and `optimized_hlo`.
|
||||||
|
When called, the returned function returns string representation of the
|
||||||
|
compiler IR at a given stage.
|
||||||
|
|
||||||
|
For example, for
|
||||||
|
|
||||||
|
```python
|
||||||
|
@tf.function(experimental_compile=True)
|
||||||
|
def f(x):
|
||||||
|
return x + 1
|
||||||
|
|
||||||
|
f.experimental_get_compiler_ir(tf.random.normal([10, 10])(stage='hlo')
|
||||||
|
```
|
||||||
|
|
||||||
|
the output is:
|
||||||
|
|
||||||
|
```
|
||||||
|
HloModule a_inference_f_13__.9
|
||||||
|
|
||||||
|
ENTRY %a_inference_f_13__.9 (arg0.1: f32[10,10]) -> f32[10,10] {
|
||||||
|
%arg0.1 = f32[10,10]{1,0} parameter(0), parameter_replication={false}
|
||||||
|
%reshape.2 = f32[10,10]{1,0} reshape(f32[10,10]{1,0} %arg0.1)
|
||||||
|
%constant.3 = f32[] constant(1)
|
||||||
|
%broadcast.4 = f32[10,10]{1,0} broadcast(f32[] %constant.3)
|
||||||
|
%add.5 = f32[10,10]{1,0} add(f32[10,10]{1,0} %reshape.2,
|
||||||
|
f32[10,10]{1,0} %broadcast.4)
|
||||||
|
%reshape.6 = f32[10,10]{1,0} reshape(f32[10,10]{1,0} %add.5)
|
||||||
|
%tuple.7 = (f32[10,10]{1,0}) tuple(f32[10,10]{1,0} %reshape.6)
|
||||||
|
ROOT %get-tuple-element.8 = f32[10,10]{1,0}
|
||||||
|
get-tuple-element((f32[10,10]{1,0}) %tuple.7), index=0
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If an invalid `stage` is selected or if applied to a function
|
||||||
|
which is not compiled (`experimental_compile=True` is not set).
|
||||||
|
TypeError: When called with input in graph mode.
|
||||||
|
"""
|
||||||
|
context.ensure_initialized()
|
||||||
|
if not self._experimental_compile:
|
||||||
|
raise ValueError(
|
||||||
|
"Compiler IR can only be returned for functions marked with "
|
||||||
|
"experimental_compile=True")
|
||||||
|
|
||||||
|
concrete_fn = self.get_concrete_function(*args, **kwargs)
|
||||||
|
fn_name = concrete_fn.name
|
||||||
|
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
canon_args, _, _, _ = \
|
||||||
|
concrete_fn._function_spec.canonicalize_function_inputs(
|
||||||
|
*args, **kwargs)
|
||||||
|
|
||||||
|
return functools.partial(
|
||||||
|
context.context().get_compiler_ir,
|
||||||
|
function_name=fn_name,
|
||||||
|
args=list(canon_args) + concrete_fn.captured_inputs)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def python_function(self):
|
def python_function(self):
|
||||||
"""The python function wrapped in this tf.function."""
|
"""The python function wrapped in this tf.function."""
|
||||||
|
@ -526,6 +526,82 @@ class DefFunctionTest(xla_test.XLATestCase):
|
|||||||
b.backing_device) if on_gpu else 0
|
b.backing_device) if on_gpu else 0
|
||||||
self.assertEqual(initial_usage, final_usage)
|
self.assertEqual(initial_usage, final_usage)
|
||||||
|
|
||||||
|
def testGetCompilerIrConstants(self):
|
||||||
|
if 'tpu' in self.device.lower():
|
||||||
|
self.skipTest('TPU generates different HLO')
|
||||||
|
|
||||||
|
with ops.device('device:{}:0'.format(self.device)):
|
||||||
|
|
||||||
|
@def_function.function(experimental_compile=True)
|
||||||
|
def f(a, b):
|
||||||
|
return array_ops.transpose(a, b)
|
||||||
|
|
||||||
|
a = array_ops.ones([3, 4, 3], dtype=dtypes.float32)
|
||||||
|
b = constant_op.constant([0, 2, 1], dtype=dtypes.int32)
|
||||||
|
|
||||||
|
self.assertIn('{1,2,0}',
|
||||||
|
f.experimental_get_compiler_ir(a, b)(stage='optimized_hlo'))
|
||||||
|
|
||||||
|
@test_util.disable_mlir_bridge('TODO(b/168732524): MLIR bridge does not '
|
||||||
|
' optimize single-element tuples to scalars')
|
||||||
|
def testGetCompilerIrResourceVars(self):
|
||||||
|
with ops.device('device:{}:0'.format(self.device)):
|
||||||
|
|
||||||
|
v = variables.Variable([3.1, 3.2])
|
||||||
|
|
||||||
|
@def_function.function(experimental_compile=True)
|
||||||
|
def f(a, b):
|
||||||
|
v.assign_add(a * b)
|
||||||
|
|
||||||
|
a = random_ops.random_normal([2])
|
||||||
|
b = random_ops.random_normal([2])
|
||||||
|
|
||||||
|
self.assertIn('input_output_alias={ {}: (2, {}, may-alias) }',
|
||||||
|
f.experimental_get_compiler_ir(a, b)(stage='optimized_hlo'))
|
||||||
|
|
||||||
|
def testGetCompilerIrNotCompiled(self):
|
||||||
|
with ops.device('device:{}:0'.format(self.device)):
|
||||||
|
|
||||||
|
@def_function.function
|
||||||
|
def f(x):
|
||||||
|
return x + 1
|
||||||
|
|
||||||
|
a = random_ops.random_normal([10, 10])
|
||||||
|
with self.assertRaisesRegex(ValueError,
|
||||||
|
'marked with experimental_compile'):
|
||||||
|
f.experimental_get_compiler_ir(a)()
|
||||||
|
|
||||||
|
def testGetCompilerIrNested(self):
|
||||||
|
with ops.device('device:{}:0'.format(self.device)):
|
||||||
|
|
||||||
|
@def_function.function(experimental_compile=True)
|
||||||
|
def fn(x, a):
|
||||||
|
return x + a
|
||||||
|
|
||||||
|
@def_function.function(experimental_compile=False)
|
||||||
|
def fn2(x, a):
|
||||||
|
fn.experimental_get_compiler_ir(x, a)()
|
||||||
|
return fn(x, a)
|
||||||
|
|
||||||
|
inputs = constant_op.constant([1, 2, 2, 3, 3])
|
||||||
|
with self.assertRaisesRegex(TypeError, '"Graph" tensor'):
|
||||||
|
fn2(inputs, 1)
|
||||||
|
|
||||||
|
def testGetCompilerIrKwargs(self):
|
||||||
|
with ops.device('device:{}:0'.format(self.device)):
|
||||||
|
|
||||||
|
v = variables.Variable([0.1, 0.1])
|
||||||
|
|
||||||
|
@def_function.function(experimental_compile=True)
|
||||||
|
def f(a, b):
|
||||||
|
return (a + b) * v
|
||||||
|
|
||||||
|
a = constant_op.constant([1.1, 1.1])
|
||||||
|
b = constant_op.constant([2.2, 2.2])
|
||||||
|
|
||||||
|
self.assertIn('multiply',
|
||||||
|
f.experimental_get_compiler_ir(b=a, a=b)(stage='hlo'))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
ops.enable_eager_execution()
|
ops.enable_eager_execution()
|
||||||
|
@ -28,9 +28,11 @@ limitations under the License.
|
|||||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||||
#include "tensorflow/c/eager/c_api_internal.h"
|
#include "tensorflow/c/eager/c_api_internal.h"
|
||||||
#include "tensorflow/c/eager/dlpack.h"
|
#include "tensorflow/c/eager/dlpack.h"
|
||||||
|
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
||||||
#include "tensorflow/c/tf_status.h"
|
#include "tensorflow/c/tf_status.h"
|
||||||
#include "tensorflow/c/tf_status_helper.h"
|
#include "tensorflow/c/tf_status_helper.h"
|
||||||
#include "tensorflow/compiler/jit/flags.h"
|
#include "tensorflow/compiler/jit/flags.h"
|
||||||
|
#include "tensorflow/compiler/jit/get_compiler_ir.h"
|
||||||
#include "tensorflow/python/eager/pywrap_tensor_conversion.h"
|
#include "tensorflow/python/eager/pywrap_tensor_conversion.h"
|
||||||
#include "tensorflow/python/eager/pywrap_tfe.h"
|
#include "tensorflow/python/eager/pywrap_tfe.h"
|
||||||
#include "tensorflow/python/lib/core/py_exception_registry.h"
|
#include "tensorflow/python/lib/core/py_exception_registry.h"
|
||||||
@ -285,6 +287,74 @@ static py::object TFE_ClearScalarCache() {
|
|||||||
return py::none();
|
return py::none();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Returns compiler IR for a given function.
|
||||||
|
static std::string TFE_GetCompilerIr(py::handle& ctx,
|
||||||
|
const char* concrete_function_name,
|
||||||
|
const char* stage, const char* device_name,
|
||||||
|
py::handle& inputs) {
|
||||||
|
EagerContext* context = ContextFromInterface(
|
||||||
|
reinterpret_cast<ImmediateExecutionContext*>(InputTFE_Context(ctx)));
|
||||||
|
|
||||||
|
std::string s_stage(stage);
|
||||||
|
IrExportStage selected_stage = [&] {
|
||||||
|
if (s_stage == "hlo") {
|
||||||
|
return IrExportStage::HLO;
|
||||||
|
} else if (s_stage == "optimized_hlo") {
|
||||||
|
return IrExportStage::OPTIMIZED_HLO;
|
||||||
|
} else {
|
||||||
|
ThrowValueError(
|
||||||
|
absl::StrFormat("Invalid stage selected: '%s'. Valid values are: "
|
||||||
|
"'hlo', 'optimized_hlo'",
|
||||||
|
s_stage)
|
||||||
|
.c_str());
|
||||||
|
}
|
||||||
|
}();
|
||||||
|
|
||||||
|
TFE_InputTensorHandles handles = InputTFE_InputTensorHandles(inputs);
|
||||||
|
|
||||||
|
std::vector<const Tensor*> input_tensors;
|
||||||
|
for (TFE_TensorHandle* tensor_handle : handles) {
|
||||||
|
AbstractTensorHandle* abstract_tensor_handle = unwrap(tensor_handle);
|
||||||
|
TensorHandle* th = TensorHandleFromInterface(abstract_tensor_handle);
|
||||||
|
|
||||||
|
const Tensor* t;
|
||||||
|
Status st = th->Tensor(&t);
|
||||||
|
if (!st.ok()) {
|
||||||
|
ThrowValueError(
|
||||||
|
absl::StrFormat("Could not resolve tensor: '%s'", st.error_message())
|
||||||
|
.c_str());
|
||||||
|
}
|
||||||
|
input_tensors.push_back(t);
|
||||||
|
}
|
||||||
|
|
||||||
|
DeviceNameUtils::ParsedName input_device_name;
|
||||||
|
if (!DeviceNameUtils::ParseFullOrLocalName(device_name, &input_device_name)) {
|
||||||
|
ThrowValueError(
|
||||||
|
absl::StrFormat("Failed parsing device name: '%s'", device_name)
|
||||||
|
.c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<Device*> devices = context->local_device_mgr()->ListDevices();
|
||||||
|
auto selected_device = absl::c_find_if(devices, [&](const Device* d) {
|
||||||
|
return DeviceNameUtils::AreCompatibleDevNames(input_device_name,
|
||||||
|
d->parsed_name());
|
||||||
|
});
|
||||||
|
if (selected_device == devices.end()) {
|
||||||
|
ThrowValueError("No matching device found");
|
||||||
|
}
|
||||||
|
|
||||||
|
xla::StatusOr<std::string> hlo_text =
|
||||||
|
GetCompilerIr(selected_stage, context->pflr(), concrete_function_name,
|
||||||
|
*selected_device, input_tensors);
|
||||||
|
|
||||||
|
if (!hlo_text.ok()) {
|
||||||
|
ThrowValueError(absl::StrFormat("Failed getting HLO text: '%s'",
|
||||||
|
hlo_text.status().error_message())
|
||||||
|
.c_str());
|
||||||
|
}
|
||||||
|
return *hlo_text;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
@ -513,6 +583,7 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
|
|||||||
m.def("TF_SetXlaConstantFoldingDisabled", &TF_SetXlaConstantFoldingDisabled);
|
m.def("TF_SetXlaConstantFoldingDisabled", &TF_SetXlaConstantFoldingDisabled);
|
||||||
m.def("TF_GetXlaConstantFoldingDisabled", &TF_GetXlaConstantFoldingDisabled);
|
m.def("TF_GetXlaConstantFoldingDisabled", &TF_GetXlaConstantFoldingDisabled);
|
||||||
m.def("TF_SetXlaMinClusterSize", &TF_SetXlaMinClusterSize);
|
m.def("TF_SetXlaMinClusterSize", &TF_SetXlaMinClusterSize);
|
||||||
|
m.def("TF_GetCompilerIr", &tensorflow::TFE_GetCompilerIr);
|
||||||
|
|
||||||
// MLIR Logic
|
// MLIR Logic
|
||||||
m.def("TF_IsMlirBridgeEnabled", [] {
|
m.def("TF_IsMlirBridgeEnabled", [] {
|
||||||
|
Loading…
Reference in New Issue
Block a user