[TF2XLA] experimental_get_compiler_ir
API: compiled IR at a given stage for a function for fixed input
``` In [1]: import tensorflow as tf In [2]: @tf.function(experimental_compile=True) ...: def f(x): ...: return x + 1 In [3]: a = tf.random.normal([10, 10]) In [6]: print(f.experimental_get_compiler_ir(a, stage='hlo')) HloModule a_inference_f_13__.9 ENTRY %a_inference_f_13__.9 (arg0.1: f32[10,10]) -> f32[10,10] { %arg0.1 = f32[10,10]{1,0} parameter(0), parameter_replication={false}, metadata={op_name="XLA_Args"} %reshape.2 = f32[10,10]{1,0} reshape(f32[10,10]{1,0} %arg0.1) %constant.3 = f32[] constant(1), metadata={op_type="AddV2" op_name="add"} %broadcast.4 = f32[10,10]{1,0} broadcast(f32[] %constant.3), dimensions={}, metadata={op_type="AddV2" op_name="add"} %add.5 = f32[10,10]{1,0} add(f32[10,10]{1,0} %reshape.2, f32[10,10]{1,0} %broadcast.4), metadata={op_type="AddV2" op_name="add"} %reshape.6 = f32[10,10]{1,0} reshape(f32[10,10]{1,0} %add.5), metadata={op_name="XLA_Retvals"} %tuple.7 = (f32[10,10]{1,0}) tuple(f32[10,10]{1,0} %reshape.6), metadata={op_name="XLA_Retvals"} ROOT %get-tuple-element.8 = f32[10,10]{1,0} get-tuple-element((f32[10,10]{1,0}) %tuple.7), index=0, metadata={op_name="XLA_Retvals"} } ``` PiperOrigin-RevId: 332325845 Change-Id: Ifafaf629e5c4e4d7e0c42d18da4bcbb503c6ae02
This commit is contained in:
parent
5e007c2770
commit
eeeb6ca25f
@ -185,6 +185,8 @@
|
||||
* XLA Support:
|
||||
* xla.experimental.compile is deprecated, use
|
||||
`tf.function(experimental_compile=True)` instead
|
||||
* Added `tf.function.experimental_get_compiler_ir` which returns compiler IR
|
||||
(currently 'hlo' and 'optimized_hlo') for given input for given function.
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
* Tracing and Debugging:
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
|
@ -377,6 +377,7 @@ tf_cuda_library(
|
||||
"//tensorflow/c/eager:tfe_op_internal",
|
||||
"//tensorflow/c/eager:tfe_tensorhandle_internal",
|
||||
"//tensorflow/compiler/jit:flags",
|
||||
"//tensorflow/compiler/jit:get_compiler_ir",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
|
@ -15,6 +15,7 @@ package_group(
|
||||
"//tensorflow/compiler/tf2xla:internal",
|
||||
],
|
||||
packages = [
|
||||
"//tensorflow/c/...",
|
||||
"//tensorflow/compiler/tests/...",
|
||||
"//tensorflow/python/...",
|
||||
],
|
||||
@ -387,6 +388,53 @@ cc_library(
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "get_compiler_ir",
|
||||
srcs = ["get_compiler_ir.cc"],
|
||||
hdrs = ["get_compiler_ir.h"],
|
||||
visibility = [
|
||||
":internal",
|
||||
"//learning/brain/contrib/tpu_modeling/exp/tpu_inference_converter:__pkg__",
|
||||
"//tensorflow/core/common_runtime/eager:__pkg__",
|
||||
],
|
||||
deps = [
|
||||
":common",
|
||||
":compilability_check_util",
|
||||
":flags",
|
||||
":xla_device_no_jit_rewrite_registration",
|
||||
":xla_launch_util",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/common_runtime:core_cpu_internal",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
# Header-only version of "flags" library, for linking from the shared object
|
||||
# without ODR violations.
|
||||
cc_library(
|
||||
name = "get_compiler_ir_hdrs_only",
|
||||
hdrs = ["get_compiler_ir.h"],
|
||||
visibility = [
|
||||
":internal",
|
||||
"//learning/brain/contrib/tpu_modeling/exp/tpu_inference_converter:__pkg__",
|
||||
"//tensorflow/core/common_runtime/eager:__pkg__",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "xla_kernel_creator",
|
||||
srcs = [
|
||||
|
@ -84,6 +84,43 @@ Status MakeCallNodeFromAttribute(const Node& node, const std::string& attr_name,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Utility which searches for values in a sorted list by scanning over it once.
|
||||
// No matter how many times ScanForValue is called, the list is scanned at most
|
||||
// once. However, if a call to ScanForValue skips over a value, that value is
|
||||
// not revisited in future calls to ScanForValue, so callers must take
|
||||
// care to order their calls.
|
||||
//
|
||||
// Useful for merging multiple sorted lists in O(n) time.
|
||||
class SinglePassSearch {
|
||||
public:
|
||||
// Creates a SinglePassSearch object that can be used to search in `values`.
|
||||
// Does not take ownership of `values`. `values` must outlive this.
|
||||
// `values` must be sorted.
|
||||
explicit SinglePassSearch(absl::Span<int const> values)
|
||||
: current_index_(0), values_(values) {}
|
||||
|
||||
// Scans forward in the vector looking for "value", updating the internal
|
||||
// position in to the vector.
|
||||
// Returns true iff the vector contains the given value at or after current
|
||||
// position.
|
||||
// Not thread-safe.
|
||||
bool ScanForValue(int value) {
|
||||
while (current_index_ < values_.size() &&
|
||||
values_[current_index_] <= value) {
|
||||
if (values_[current_index_] == value) {
|
||||
current_index_++;
|
||||
return true;
|
||||
}
|
||||
current_index_++;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
private:
|
||||
int current_index_;
|
||||
const absl::Span<int const> values_;
|
||||
};
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
RecursiveCompilabilityChecker::UncompilableNodesMap
|
||||
@ -564,6 +601,44 @@ Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
tensorflow::MemoryTypeVector GetInputMemoryTypes(
|
||||
const tensorflow::FunctionBody* fbody,
|
||||
absl::Span<int const> constant_arg_indices,
|
||||
absl::Span<int const> resource_arg_indices) {
|
||||
// Set input and output memory types.
|
||||
tensorflow::MemoryTypeVector input_memory_types(fbody->arg_types.size(),
|
||||
tensorflow::DEVICE_MEMORY);
|
||||
// These indices are used only for optimization purposes. They allow us
|
||||
// to loop over constant_arg_indices and resource_arg_indices only once
|
||||
// while iterating over all the function arguments checking if it is a
|
||||
// resource or a constant.
|
||||
// The reason we optimized this code is because functions can have a lot of
|
||||
// captured arguments. For example, the backward pass of ResNet50 takes in all
|
||||
// 214 variables and a similar number of activations.
|
||||
SinglePassSearch constants_search(constant_arg_indices);
|
||||
SinglePassSearch resources_search(resource_arg_indices);
|
||||
for (size_t i = 0; i < fbody->arg_types.size(); ++i) {
|
||||
if (resources_search.ScanForValue(i) || constants_search.ScanForValue(i)) {
|
||||
// Compile-time constants and resource handles are expected to be in
|
||||
// host memory.
|
||||
input_memory_types[i] = tensorflow::HOST_MEMORY;
|
||||
}
|
||||
}
|
||||
return input_memory_types;
|
||||
}
|
||||
|
||||
tensorflow::MemoryTypeVector GetOutputMemoryTypes(
|
||||
const tensorflow::FunctionBody* fbody) {
|
||||
tensorflow::MemoryTypeVector output_memory_types(fbody->ret_types.size(),
|
||||
tensorflow::DEVICE_MEMORY);
|
||||
for (size_t i = 0; i < fbody->ret_types.size(); ++i) {
|
||||
if (fbody->ret_types[i] == tensorflow::DT_RESOURCE) {
|
||||
output_memory_types[i] = tensorflow::HOST_MEMORY;
|
||||
}
|
||||
}
|
||||
return output_memory_types;
|
||||
}
|
||||
|
||||
static auto const ops_triggering_xla_compilation =
|
||||
new absl::flat_hash_set<std::string>{"XlaBroadcastHelper",
|
||||
"XlaConv",
|
||||
|
@ -282,6 +282,41 @@ Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr,
|
||||
// set.
|
||||
bool CanCreateXlaKernel(const NodeDef& node_def);
|
||||
|
||||
// Returns memory types for the input.
|
||||
// `constant_arg_indices` and `resource_arg_indices` are sorted arrays of
|
||||
// indices corresponding to constant and resource arguments respectively.
|
||||
//
|
||||
// One might wonder, about the case where a compile-time constant argument
|
||||
// (which must be in host memory) is also used as an input into an op,
|
||||
// e.g. `Add`, that expects its inputs in device memory. Here is how it
|
||||
// works now.
|
||||
// First, what do we mean by "op expects an input in XYZ memory"?
|
||||
// There are two types of "ops" here: the tf2xla kernel and the HLO
|
||||
// computation it builds. The tf2xla kernel needs to retrieve the actual
|
||||
// numeric value of the compile-time constant tensors, so it really expects
|
||||
// them to be on in host memory. However, for other inputs, it refers to them
|
||||
// using xla::ComputationDataHandle, which is just a symbolic handle that
|
||||
// xla::ComputationBuilder assigns. How does this handle gets assigned for
|
||||
// constant arguments? Even constant arguments get an _Arg node in the graph
|
||||
// instantiated for Function compilation. The tf2xla kernel for constant _Arg
|
||||
// nodes takes the constant value, converts it to XlaLiteral, and feeds it
|
||||
// to xla::ComputationBuilder.ConstantLiteral, which returns the handle. This
|
||||
// constant XlaLiteral is included in the HLO graph, and subsequently, in
|
||||
// the actual executable, which is copied to the device before being
|
||||
// executed. Thus, when this executable runs, the constant is available in
|
||||
// device memory.
|
||||
tensorflow::MemoryTypeVector GetInputMemoryTypes(
|
||||
const tensorflow::FunctionBody* fbody,
|
||||
absl::Span<int const> constant_arg_indices,
|
||||
absl::Span<int const> resource_arg_indices);
|
||||
|
||||
// Returns output memory types.
|
||||
//
|
||||
// XlaLaunch kernel keeps all outputs (including constants, which it copies),
|
||||
// in device memory except for resources.
|
||||
tensorflow::MemoryTypeVector GetOutputMemoryTypes(
|
||||
const tensorflow::FunctionBody* fbody);
|
||||
|
||||
// Check whether graph can trigger XLA compilation.
|
||||
bool CanTriggerXlaCompilation(const GraphDef& graph);
|
||||
|
||||
|
114
tensorflow/compiler/jit/get_compiler_ir.cc
Normal file
114
tensorflow/compiler/jit/get_compiler_ir.cc
Normal file
@ -0,0 +1,114 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/jit/get_compiler_ir.h"
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "tensorflow/compiler/jit/compilability_check_util.h"
|
||||
#include "tensorflow/compiler/jit/defs.h"
|
||||
#include "tensorflow/compiler/jit/flags.h"
|
||||
#include "tensorflow/compiler/jit/xla_launch_util.h"
|
||||
#include "tensorflow/compiler/jit/xla_platform_info.h"
|
||||
#include "tensorflow/compiler/tf2xla/const_analysis.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/util/ptr_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
xla::StatusOr<std::string> GetCompilerIr(
|
||||
IrExportStage stage, ProcessFunctionLibraryRuntime* pflr,
|
||||
absl::string_view func_name, Device* dev,
|
||||
absl::Span<const Tensor* const> inputs) {
|
||||
NameAttrList function;
|
||||
function.set_name(std::string{func_name});
|
||||
|
||||
FunctionLibraryRuntime* flr = pflr->GetFLR(dev->name());
|
||||
ResourceMgr* rmgr = dev->resource_manager();
|
||||
|
||||
const FunctionBody* fbody = nullptr;
|
||||
std::vector<int> constant_arg_indices;
|
||||
std::vector<int> resource_arg_indices;
|
||||
TF_RETURN_IF_ERROR(GetBodyAndConstantsAndResources(
|
||||
flr, function, &fbody, &constant_arg_indices, &resource_arg_indices));
|
||||
|
||||
MemoryTypeVector input_memory_types =
|
||||
GetInputMemoryTypes(fbody, constant_arg_indices, resource_arg_indices);
|
||||
MemoryTypeVector output_memory_types = GetOutputMemoryTypes(fbody);
|
||||
|
||||
std::vector<VariableInfo> variable_infos;
|
||||
TF_RETURN_IF_ERROR(GetVariableInfosFromInputs(
|
||||
rmgr, dev, inputs, resource_arg_indices, &variable_infos));
|
||||
TF_RETURN_IF_ERROR(LockVariables(absl::MakeSpan(variable_infos)));
|
||||
|
||||
XlaPlatformInfo platform_info = XlaPlatformInfoFromDevice(dev);
|
||||
|
||||
XlaCompilationCache* cache;
|
||||
TF_RETURN_IF_ERROR(rmgr->LookupOrCreate<XlaCompilationCache>(
|
||||
rmgr->default_container(), "xla_cache", &cache,
|
||||
[&](XlaCompilationCache** cache_write_into) {
|
||||
return BuildXlaCompilationCache(dev, platform_info, cache_write_into);
|
||||
}));
|
||||
core::ScopedUnref cache_ref(cache);
|
||||
|
||||
absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
|
||||
|
||||
XlaCompiler::Options options =
|
||||
GenerateCompilerOptions(*cache, *flr, dev,
|
||||
/*stream=*/nullptr, platform_info,
|
||||
/*has_ref_vars=*/false, &tf_allocator_adapter);
|
||||
|
||||
XlaCompiler::CompileOptions compile_options;
|
||||
compile_options.always_return_tuple = false;
|
||||
compile_options.alias_resource_update = true;
|
||||
|
||||
XlaCompiler compiler(options);
|
||||
|
||||
xla::StatusOr<std::vector<XlaCompiler::Argument>> args =
|
||||
XlaComputationLaunchContext::BuildXlaCompilerArguments(
|
||||
constant_arg_indices, inputs, variable_infos);
|
||||
TF_RETURN_IF_ERROR(args.status());
|
||||
|
||||
switch (stage) {
|
||||
case IrExportStage::HLO: {
|
||||
XlaCompiler::CompilationResult result;
|
||||
TF_RETURN_IF_ERROR(
|
||||
compiler.CompileFunction(compile_options, function, *args, &result));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(xla::ProgramShape program_shape,
|
||||
result.computation->GetProgramShape());
|
||||
xla::HloModuleConfig config(program_shape);
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::unique_ptr<xla::HloModule> new_module,
|
||||
xla::HloModule::CreateFromProto(result.computation->proto(), config));
|
||||
|
||||
return new_module->ToString();
|
||||
}
|
||||
case IrExportStage::OPTIMIZED_HLO: {
|
||||
const XlaCompiler::CompilationResult* compilation_result = nullptr;
|
||||
xla::LocalExecutable* executable = nullptr;
|
||||
TF_RETURN_IF_ERROR(
|
||||
cache->Compile(options, function, *args, compile_options,
|
||||
XlaCompilationCache::CompileMode::kStrict,
|
||||
&compilation_result, &executable));
|
||||
return executable->executable()->module().ToString();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
39
tensorflow/compiler/jit/get_compiler_ir.h
Normal file
39
tensorflow/compiler/jit/get_compiler_ir.h
Normal file
@ -0,0 +1,39 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_COMPILER_JIT_GET_COMPILER_IR_H_
|
||||
#define TENSORFLOW_COMPILER_JIT_GET_COMPILER_IR_H_
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class ProcessFunctionLibraryRuntime;
|
||||
class Device;
|
||||
class Tensor;
|
||||
|
||||
enum class IrExportStage { HLO, OPTIMIZED_HLO };
|
||||
|
||||
// Returns HLO text for a given function `func_name` using library runtime
|
||||
// `runtime` on a device `dev` with given `inputs`.
|
||||
xla::StatusOr<std::string> GetCompilerIr(
|
||||
IrExportStage stage, ProcessFunctionLibraryRuntime* pflr,
|
||||
absl::string_view func_name, Device* dev,
|
||||
absl::Span<const Tensor* const> inputs);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_JIT_GET_COMPILER_IR_H_
|
@ -30,47 +30,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/util/ptr_util.h"
|
||||
|
||||
namespace {
|
||||
|
||||
// Utility which searches for values in a sorted list by scanning over it once.
|
||||
// No matter how many times ScanForValue is called, the list is scanned at most
|
||||
// once. However, if a call to ScanForValue skips over a value, that value is
|
||||
// not revisited in future calls to ScanForValue, so callers must take
|
||||
// care to order their calls.
|
||||
//
|
||||
// Useful for merging multiple sorted lists in O(n) time.
|
||||
class SinglePassSearch {
|
||||
public:
|
||||
// Creates a SinglePassSearch object that can be used to search in `values`.
|
||||
// Does not take ownership of `values`. `values` must outlive this.
|
||||
// `values` must be sorted.
|
||||
explicit SinglePassSearch(const std::vector<int>* values)
|
||||
: current_index_(0), values_(values) {}
|
||||
|
||||
// Scans forward in the vector looking for "value", updating the internal
|
||||
// position in to the vector.
|
||||
// Returns true iff the vector contains the given value at or after current
|
||||
// position.
|
||||
// Not thread-safe.
|
||||
bool ScanForValue(int value) {
|
||||
while (current_index_ < values_->size() &&
|
||||
(*values_)[current_index_] <= value) {
|
||||
if ((*values_)[current_index_] == value) {
|
||||
current_index_++;
|
||||
return true;
|
||||
}
|
||||
current_index_++;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
private:
|
||||
int current_index_;
|
||||
const std::vector<int>* values_;
|
||||
};
|
||||
|
||||
} // end namespace
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
bool XlaKernelCreator::CanCreateKernel(
|
||||
@ -130,52 +89,9 @@ static Status CreateXlaKernel(FunctionLibraryRuntime* flr,
|
||||
TF_RETURN_IF_ERROR(GetBodyAndConstantsAndResources(
|
||||
flr, function, &fbody, &constant_arg_indices, &resource_arg_indices));
|
||||
|
||||
// Set input and output memory types.
|
||||
MemoryTypeVector input_memory_types(fbody->arg_types.size(), DEVICE_MEMORY);
|
||||
// These indices are used only for optimization purposes. They allow us
|
||||
// to loop over constant_arg_indices and resource_arg_indices only once
|
||||
// while iterating over all the function arguments checking if it is a
|
||||
// resource or a constant.
|
||||
// The reason we optimized this code is because functions can have a lot of
|
||||
// captured arguments. For example, the backward pass of ResNet50 takes in all
|
||||
// 214 variables and a similar number of activations.
|
||||
SinglePassSearch constants_search(&constant_arg_indices);
|
||||
SinglePassSearch resources_search(&resource_arg_indices);
|
||||
for (size_t i = 0; i < fbody->arg_types.size(); ++i) {
|
||||
if (resources_search.ScanForValue(i) || constants_search.ScanForValue(i)) {
|
||||
// Compile-time constants and resource handles are expected to be in
|
||||
// host memory.
|
||||
input_memory_types[i] = HOST_MEMORY;
|
||||
}
|
||||
}
|
||||
// One might wonder, about the case where a compile-time constant argument
|
||||
// (which must be in host memory) is also used as an input into an op,
|
||||
// e.g. Add, that expects its inputs in device memory. Here is how it
|
||||
// works now.
|
||||
// First, what do we mean by "op expects an input in XYZ memory"?
|
||||
// There are two types of "ops" here: the tf2xla kernel and the HLO
|
||||
// computation it builds. The tf2xla kernel needs to retrieve the actual
|
||||
// numeric value of the compile-time constant tensors, so it really expects
|
||||
// them to be on in host memory. However, for other inputs, it refers to them
|
||||
// using xla::ComputationDataHandle, which is just a symbolic handle that
|
||||
// xla::ComputationBuilder assigns. How does this handle gets assigned for
|
||||
// constant arguments? Even constant arguments get an _Arg node in the graph
|
||||
// instantiated for Function compilation. The tf2xla kernel for constant _Arg
|
||||
// nodes takes the constant value, converts it to XlaLiteral, and feeds it
|
||||
// to xla::ComputationBuilder.ConstantLiteral, which returns the handle. This
|
||||
// constant XlaLiteral is included in the HLO graph, and subsequently, in
|
||||
// the actual executable, which is copied to the device before being
|
||||
// executed. Thus, when this executable runs, the constant is available in
|
||||
// device memory.
|
||||
|
||||
// XlaLaunch kernel keeps all outputs (including constants, which it copies),
|
||||
// in device memory except for resources.
|
||||
MemoryTypeVector output_memory_types(fbody->ret_types.size(), DEVICE_MEMORY);
|
||||
for (size_t i = 0; i < fbody->ret_types.size(); ++i) {
|
||||
if (fbody->ret_types[i] == DT_RESOURCE) {
|
||||
output_memory_types[i] = HOST_MEMORY;
|
||||
}
|
||||
}
|
||||
MemoryTypeVector input_memory_types =
|
||||
GetInputMemoryTypes(fbody, constant_arg_indices, resource_arg_indices);
|
||||
MemoryTypeVector output_memory_types = GetOutputMemoryTypes(fbody);
|
||||
|
||||
// Create the kernel.
|
||||
Device* dev = flr->device();
|
||||
|
@ -6047,6 +6047,7 @@ filegroup(
|
||||
"//tensorflow/c:checkpoint_reader", # checkpoint_reader
|
||||
"//tensorflow/c:python_api", # tf_session
|
||||
"//tensorflow/c:tf_status_helper", # tfe
|
||||
"//tensorflow/compiler/jit:get_compiler_ir", #tfe
|
||||
"//tensorflow/compiler/jit:flags", #tfe
|
||||
"//tensorflow/compiler/mlir/python:mlir", # mlir
|
||||
"//tensorflow/core/common_runtime:device", # device_lib, tfe, tf_session
|
||||
@ -7557,6 +7558,8 @@ tf_python_pybind_extension(
|
||||
"//third_party/python_runtime:headers",
|
||||
"//tensorflow/c/experimental/saved_model/core:pywrap_required_hdrs",
|
||||
"//tensorflow/compiler/jit:flags_headers_only",
|
||||
"//tensorflow/compiler/jit:get_compiler_ir_hdrs_only",
|
||||
"//tensorflow/c/eager:tfe_tensorhandle_internal",
|
||||
"//tensorflow/core/common_runtime:core_cpu_headers_lib",
|
||||
"//tensorflow/core:framework_headers_lib",
|
||||
"//tensorflow/core:lib_headers_for_pybind",
|
||||
|
@ -1494,6 +1494,10 @@ class Context(object):
|
||||
|
||||
self._virtual_device_map[dev] = virtual_devices
|
||||
|
||||
def get_compiler_ir(self, function_name, args, stage="hlo"):
|
||||
return pywrap_tfe.TF_GetCompilerIr(self._context_handle, function_name,
|
||||
stage, self.device_name, args)
|
||||
|
||||
@deprecated(
|
||||
None, "XLA:CPU and XLA:GPU devices are deprecated", warn_once=True)
|
||||
def enable_xla_devices(self):
|
||||
|
@ -914,6 +914,76 @@ class Function(object):
|
||||
return function_lib.defun(fn_with_cond)(canon_args, canon_kwds,
|
||||
filtered_flat_args)
|
||||
|
||||
def experimental_get_compiler_ir(self, *args, **kwargs):
|
||||
"""Returns compiler IR for the compiled function.
|
||||
|
||||
This API is intended *only* for debugging as there are no guarantees on
|
||||
backwards compatibility of returned IR or the allowed values of `stage`.
|
||||
|
||||
Args:
|
||||
*args: Arguments used for compilation; same arguments as used for calling
|
||||
the function. Need to be eager tensors.
|
||||
**kwargs: Keyword arguments used for compilation.
|
||||
|
||||
Returns:
|
||||
Function callable with the stage at which the compiler IR should be
|
||||
serialized. Allowed values for the `stage` are `hlo` and `optimized_hlo`.
|
||||
When called, the returned function returns string representation of the
|
||||
compiler IR at a given stage.
|
||||
|
||||
For example, for
|
||||
|
||||
```python
|
||||
@tf.function(experimental_compile=True)
|
||||
def f(x):
|
||||
return x + 1
|
||||
|
||||
f.experimental_get_compiler_ir(tf.random.normal([10, 10])(stage='hlo')
|
||||
```
|
||||
|
||||
the output is:
|
||||
|
||||
```
|
||||
HloModule a_inference_f_13__.9
|
||||
|
||||
ENTRY %a_inference_f_13__.9 (arg0.1: f32[10,10]) -> f32[10,10] {
|
||||
%arg0.1 = f32[10,10]{1,0} parameter(0), parameter_replication={false}
|
||||
%reshape.2 = f32[10,10]{1,0} reshape(f32[10,10]{1,0} %arg0.1)
|
||||
%constant.3 = f32[] constant(1)
|
||||
%broadcast.4 = f32[10,10]{1,0} broadcast(f32[] %constant.3)
|
||||
%add.5 = f32[10,10]{1,0} add(f32[10,10]{1,0} %reshape.2,
|
||||
f32[10,10]{1,0} %broadcast.4)
|
||||
%reshape.6 = f32[10,10]{1,0} reshape(f32[10,10]{1,0} %add.5)
|
||||
%tuple.7 = (f32[10,10]{1,0}) tuple(f32[10,10]{1,0} %reshape.6)
|
||||
ROOT %get-tuple-element.8 = f32[10,10]{1,0}
|
||||
get-tuple-element((f32[10,10]{1,0}) %tuple.7), index=0
|
||||
}
|
||||
```
|
||||
|
||||
Raises:
|
||||
ValueError: If an invalid `stage` is selected or if applied to a function
|
||||
which is not compiled (`experimental_compile=True` is not set).
|
||||
TypeError: When called with input in graph mode.
|
||||
"""
|
||||
context.ensure_initialized()
|
||||
if not self._experimental_compile:
|
||||
raise ValueError(
|
||||
"Compiler IR can only be returned for functions marked with "
|
||||
"experimental_compile=True")
|
||||
|
||||
concrete_fn = self.get_concrete_function(*args, **kwargs)
|
||||
fn_name = concrete_fn.name
|
||||
|
||||
# pylint: disable=protected-access
|
||||
canon_args, _, _, _ = \
|
||||
concrete_fn._function_spec.canonicalize_function_inputs(
|
||||
*args, **kwargs)
|
||||
|
||||
return functools.partial(
|
||||
context.context().get_compiler_ir,
|
||||
function_name=fn_name,
|
||||
args=list(canon_args) + concrete_fn.captured_inputs)
|
||||
|
||||
@property
|
||||
def python_function(self):
|
||||
"""The python function wrapped in this tf.function."""
|
||||
|
@ -526,6 +526,82 @@ class DefFunctionTest(xla_test.XLATestCase):
|
||||
b.backing_device) if on_gpu else 0
|
||||
self.assertEqual(initial_usage, final_usage)
|
||||
|
||||
def testGetCompilerIrConstants(self):
|
||||
if 'tpu' in self.device.lower():
|
||||
self.skipTest('TPU generates different HLO')
|
||||
|
||||
with ops.device('device:{}:0'.format(self.device)):
|
||||
|
||||
@def_function.function(experimental_compile=True)
|
||||
def f(a, b):
|
||||
return array_ops.transpose(a, b)
|
||||
|
||||
a = array_ops.ones([3, 4, 3], dtype=dtypes.float32)
|
||||
b = constant_op.constant([0, 2, 1], dtype=dtypes.int32)
|
||||
|
||||
self.assertIn('{1,2,0}',
|
||||
f.experimental_get_compiler_ir(a, b)(stage='optimized_hlo'))
|
||||
|
||||
@test_util.disable_mlir_bridge('TODO(b/168732524): MLIR bridge does not '
|
||||
' optimize single-element tuples to scalars')
|
||||
def testGetCompilerIrResourceVars(self):
|
||||
with ops.device('device:{}:0'.format(self.device)):
|
||||
|
||||
v = variables.Variable([3.1, 3.2])
|
||||
|
||||
@def_function.function(experimental_compile=True)
|
||||
def f(a, b):
|
||||
v.assign_add(a * b)
|
||||
|
||||
a = random_ops.random_normal([2])
|
||||
b = random_ops.random_normal([2])
|
||||
|
||||
self.assertIn('input_output_alias={ {}: (2, {}, may-alias) }',
|
||||
f.experimental_get_compiler_ir(a, b)(stage='optimized_hlo'))
|
||||
|
||||
def testGetCompilerIrNotCompiled(self):
|
||||
with ops.device('device:{}:0'.format(self.device)):
|
||||
|
||||
@def_function.function
|
||||
def f(x):
|
||||
return x + 1
|
||||
|
||||
a = random_ops.random_normal([10, 10])
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
'marked with experimental_compile'):
|
||||
f.experimental_get_compiler_ir(a)()
|
||||
|
||||
def testGetCompilerIrNested(self):
|
||||
with ops.device('device:{}:0'.format(self.device)):
|
||||
|
||||
@def_function.function(experimental_compile=True)
|
||||
def fn(x, a):
|
||||
return x + a
|
||||
|
||||
@def_function.function(experimental_compile=False)
|
||||
def fn2(x, a):
|
||||
fn.experimental_get_compiler_ir(x, a)()
|
||||
return fn(x, a)
|
||||
|
||||
inputs = constant_op.constant([1, 2, 2, 3, 3])
|
||||
with self.assertRaisesRegex(TypeError, '"Graph" tensor'):
|
||||
fn2(inputs, 1)
|
||||
|
||||
def testGetCompilerIrKwargs(self):
|
||||
with ops.device('device:{}:0'.format(self.device)):
|
||||
|
||||
v = variables.Variable([0.1, 0.1])
|
||||
|
||||
@def_function.function(experimental_compile=True)
|
||||
def f(a, b):
|
||||
return (a + b) * v
|
||||
|
||||
a = constant_op.constant([1.1, 1.1])
|
||||
b = constant_op.constant([2.2, 2.2])
|
||||
|
||||
self.assertIn('multiply',
|
||||
f.experimental_get_compiler_ir(b=a, a=b)(stage='hlo'))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
ops.enable_eager_execution()
|
||||
|
@ -28,9 +28,11 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/dlpack.h"
|
||||
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/compiler/jit/flags.h"
|
||||
#include "tensorflow/compiler/jit/get_compiler_ir.h"
|
||||
#include "tensorflow/python/eager/pywrap_tensor_conversion.h"
|
||||
#include "tensorflow/python/eager/pywrap_tfe.h"
|
||||
#include "tensorflow/python/lib/core/py_exception_registry.h"
|
||||
@ -285,6 +287,74 @@ static py::object TFE_ClearScalarCache() {
|
||||
return py::none();
|
||||
}
|
||||
|
||||
// Returns compiler IR for a given function.
|
||||
static std::string TFE_GetCompilerIr(py::handle& ctx,
|
||||
const char* concrete_function_name,
|
||||
const char* stage, const char* device_name,
|
||||
py::handle& inputs) {
|
||||
EagerContext* context = ContextFromInterface(
|
||||
reinterpret_cast<ImmediateExecutionContext*>(InputTFE_Context(ctx)));
|
||||
|
||||
std::string s_stage(stage);
|
||||
IrExportStage selected_stage = [&] {
|
||||
if (s_stage == "hlo") {
|
||||
return IrExportStage::HLO;
|
||||
} else if (s_stage == "optimized_hlo") {
|
||||
return IrExportStage::OPTIMIZED_HLO;
|
||||
} else {
|
||||
ThrowValueError(
|
||||
absl::StrFormat("Invalid stage selected: '%s'. Valid values are: "
|
||||
"'hlo', 'optimized_hlo'",
|
||||
s_stage)
|
||||
.c_str());
|
||||
}
|
||||
}();
|
||||
|
||||
TFE_InputTensorHandles handles = InputTFE_InputTensorHandles(inputs);
|
||||
|
||||
std::vector<const Tensor*> input_tensors;
|
||||
for (TFE_TensorHandle* tensor_handle : handles) {
|
||||
AbstractTensorHandle* abstract_tensor_handle = unwrap(tensor_handle);
|
||||
TensorHandle* th = TensorHandleFromInterface(abstract_tensor_handle);
|
||||
|
||||
const Tensor* t;
|
||||
Status st = th->Tensor(&t);
|
||||
if (!st.ok()) {
|
||||
ThrowValueError(
|
||||
absl::StrFormat("Could not resolve tensor: '%s'", st.error_message())
|
||||
.c_str());
|
||||
}
|
||||
input_tensors.push_back(t);
|
||||
}
|
||||
|
||||
DeviceNameUtils::ParsedName input_device_name;
|
||||
if (!DeviceNameUtils::ParseFullOrLocalName(device_name, &input_device_name)) {
|
||||
ThrowValueError(
|
||||
absl::StrFormat("Failed parsing device name: '%s'", device_name)
|
||||
.c_str());
|
||||
}
|
||||
|
||||
std::vector<Device*> devices = context->local_device_mgr()->ListDevices();
|
||||
auto selected_device = absl::c_find_if(devices, [&](const Device* d) {
|
||||
return DeviceNameUtils::AreCompatibleDevNames(input_device_name,
|
||||
d->parsed_name());
|
||||
});
|
||||
if (selected_device == devices.end()) {
|
||||
ThrowValueError("No matching device found");
|
||||
}
|
||||
|
||||
xla::StatusOr<std::string> hlo_text =
|
||||
GetCompilerIr(selected_stage, context->pflr(), concrete_function_name,
|
||||
*selected_device, input_tensors);
|
||||
|
||||
if (!hlo_text.ok()) {
|
||||
ThrowValueError(absl::StrFormat("Failed getting HLO text: '%s'",
|
||||
hlo_text.status().error_message())
|
||||
.c_str());
|
||||
}
|
||||
return *hlo_text;
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
namespace {
|
||||
@ -513,6 +583,7 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
|
||||
m.def("TF_SetXlaConstantFoldingDisabled", &TF_SetXlaConstantFoldingDisabled);
|
||||
m.def("TF_GetXlaConstantFoldingDisabled", &TF_GetXlaConstantFoldingDisabled);
|
||||
m.def("TF_SetXlaMinClusterSize", &TF_SetXlaMinClusterSize);
|
||||
m.def("TF_GetCompilerIr", &tensorflow::TFE_GetCompilerIr);
|
||||
|
||||
// MLIR Logic
|
||||
m.def("TF_IsMlirBridgeEnabled", [] {
|
||||
|
Loading…
Reference in New Issue
Block a user