[TF2XLA] experimental_get_compiler_ir API: compiled IR at a given stage for a function for fixed input

```
In [1]: import tensorflow as tf
In [2]: @tf.function(experimental_compile=True)
   ...: def f(x):
   ...:   return x + 1
In [3]: a = tf.random.normal([10, 10])
In [6]: print(f.experimental_get_compiler_ir(a, stage='hlo'))
HloModule a_inference_f_13__.9

ENTRY %a_inference_f_13__.9 (arg0.1: f32[10,10]) -> f32[10,10] {
  %arg0.1 = f32[10,10]{1,0} parameter(0), parameter_replication={false}, metadata={op_name="XLA_Args"}
  %reshape.2 = f32[10,10]{1,0} reshape(f32[10,10]{1,0} %arg0.1)
  %constant.3 = f32[] constant(1), metadata={op_type="AddV2" op_name="add"}
  %broadcast.4 = f32[10,10]{1,0} broadcast(f32[] %constant.3), dimensions={}, metadata={op_type="AddV2" op_name="add"}
  %add.5 = f32[10,10]{1,0} add(f32[10,10]{1,0} %reshape.2, f32[10,10]{1,0} %broadcast.4), metadata={op_type="AddV2" op_name="add"}
  %reshape.6 = f32[10,10]{1,0} reshape(f32[10,10]{1,0} %add.5), metadata={op_name="XLA_Retvals"}
  %tuple.7 = (f32[10,10]{1,0}) tuple(f32[10,10]{1,0} %reshape.6), metadata={op_name="XLA_Retvals"}
  ROOT %get-tuple-element.8 = f32[10,10]{1,0} get-tuple-element((f32[10,10]{1,0}) %tuple.7), index=0, metadata={op_name="XLA_Retvals"}
}
```

PiperOrigin-RevId: 332325845
Change-Id: Ifafaf629e5c4e4d7e0c42d18da4bcbb503c6ae02
This commit is contained in:
George Karpenkov 2020-09-17 15:10:16 -07:00 committed by TensorFlower Gardener
parent 5e007c2770
commit eeeb6ca25f
13 changed files with 541 additions and 87 deletions

View File

@ -185,6 +185,8 @@
* XLA Support:
* xla.experimental.compile is deprecated, use
`tf.function(experimental_compile=True)` instead
* Added `tf.function.experimental_get_compiler_ir` which returns compiler IR
(currently 'hlo' and 'optimized_hlo') for given input for given function.
* <ADD RELEASE NOTES HERE>
* Tracing and Debugging:
* <ADD RELEASE NOTES HERE>

View File

@ -377,6 +377,7 @@ tf_cuda_library(
"//tensorflow/c/eager:tfe_op_internal",
"//tensorflow/c/eager:tfe_tensorhandle_internal",
"//tensorflow/compiler/jit:flags",
"//tensorflow/compiler/jit:get_compiler_ir",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",

View File

@ -15,6 +15,7 @@ package_group(
"//tensorflow/compiler/tf2xla:internal",
],
packages = [
"//tensorflow/c/...",
"//tensorflow/compiler/tests/...",
"//tensorflow/python/...",
],
@ -387,6 +388,53 @@ cc_library(
alwayslink = 1,
)
cc_library(
name = "get_compiler_ir",
srcs = ["get_compiler_ir.cc"],
hdrs = ["get_compiler_ir.h"],
visibility = [
":internal",
"//learning/brain/contrib/tpu_modeling/exp/tpu_inference_converter:__pkg__",
"//tensorflow/core/common_runtime/eager:__pkg__",
],
deps = [
":common",
":compilability_check_util",
":flags",
":xla_device_no_jit_rewrite_registration",
":xla_launch_util",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core/common_runtime:core_cpu_internal",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:span",
],
alwayslink = 1,
)
# Header-only version of "flags" library, for linking from the shared object
# without ODR violations.
cc_library(
name = "get_compiler_ir_hdrs_only",
hdrs = ["get_compiler_ir.h"],
visibility = [
":internal",
"//learning/brain/contrib/tpu_modeling/exp/tpu_inference_converter:__pkg__",
"//tensorflow/core/common_runtime/eager:__pkg__",
],
deps = [
"//tensorflow/compiler/xla:statusor",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:span",
],
)
cc_library(
name = "xla_kernel_creator",
srcs = [

View File

@ -84,6 +84,43 @@ Status MakeCallNodeFromAttribute(const Node& node, const std::string& attr_name,
return Status::OK();
}
// Utility which searches for values in a sorted list by scanning over it once.
// No matter how many times ScanForValue is called, the list is scanned at most
// once. However, if a call to ScanForValue skips over a value, that value is
// not revisited in future calls to ScanForValue, so callers must take
// care to order their calls.
//
// Useful for merging multiple sorted lists in O(n) time.
class SinglePassSearch {
public:
// Creates a SinglePassSearch object that can be used to search in `values`.
// Does not take ownership of `values`. `values` must outlive this.
// `values` must be sorted.
explicit SinglePassSearch(absl::Span<int const> values)
: current_index_(0), values_(values) {}
// Scans forward in the vector looking for "value", updating the internal
// position in to the vector.
// Returns true iff the vector contains the given value at or after current
// position.
// Not thread-safe.
bool ScanForValue(int value) {
while (current_index_ < values_.size() &&
values_[current_index_] <= value) {
if (values_[current_index_] == value) {
current_index_++;
return true;
}
current_index_++;
}
return false;
}
private:
int current_index_;
const absl::Span<int const> values_;
};
} // anonymous namespace
RecursiveCompilabilityChecker::UncompilableNodesMap
@ -564,6 +601,44 @@ Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr,
return Status::OK();
}
tensorflow::MemoryTypeVector GetInputMemoryTypes(
const tensorflow::FunctionBody* fbody,
absl::Span<int const> constant_arg_indices,
absl::Span<int const> resource_arg_indices) {
// Set input and output memory types.
tensorflow::MemoryTypeVector input_memory_types(fbody->arg_types.size(),
tensorflow::DEVICE_MEMORY);
// These indices are used only for optimization purposes. They allow us
// to loop over constant_arg_indices and resource_arg_indices only once
// while iterating over all the function arguments checking if it is a
// resource or a constant.
// The reason we optimized this code is because functions can have a lot of
// captured arguments. For example, the backward pass of ResNet50 takes in all
// 214 variables and a similar number of activations.
SinglePassSearch constants_search(constant_arg_indices);
SinglePassSearch resources_search(resource_arg_indices);
for (size_t i = 0; i < fbody->arg_types.size(); ++i) {
if (resources_search.ScanForValue(i) || constants_search.ScanForValue(i)) {
// Compile-time constants and resource handles are expected to be in
// host memory.
input_memory_types[i] = tensorflow::HOST_MEMORY;
}
}
return input_memory_types;
}
tensorflow::MemoryTypeVector GetOutputMemoryTypes(
const tensorflow::FunctionBody* fbody) {
tensorflow::MemoryTypeVector output_memory_types(fbody->ret_types.size(),
tensorflow::DEVICE_MEMORY);
for (size_t i = 0; i < fbody->ret_types.size(); ++i) {
if (fbody->ret_types[i] == tensorflow::DT_RESOURCE) {
output_memory_types[i] = tensorflow::HOST_MEMORY;
}
}
return output_memory_types;
}
static auto const ops_triggering_xla_compilation =
new absl::flat_hash_set<std::string>{"XlaBroadcastHelper",
"XlaConv",

View File

@ -282,6 +282,41 @@ Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr,
// set.
bool CanCreateXlaKernel(const NodeDef& node_def);
// Returns memory types for the input.
// `constant_arg_indices` and `resource_arg_indices` are sorted arrays of
// indices corresponding to constant and resource arguments respectively.
//
// One might wonder, about the case where a compile-time constant argument
// (which must be in host memory) is also used as an input into an op,
// e.g. `Add`, that expects its inputs in device memory. Here is how it
// works now.
// First, what do we mean by "op expects an input in XYZ memory"?
// There are two types of "ops" here: the tf2xla kernel and the HLO
// computation it builds. The tf2xla kernel needs to retrieve the actual
// numeric value of the compile-time constant tensors, so it really expects
// them to be on in host memory. However, for other inputs, it refers to them
// using xla::ComputationDataHandle, which is just a symbolic handle that
// xla::ComputationBuilder assigns. How does this handle gets assigned for
// constant arguments? Even constant arguments get an _Arg node in the graph
// instantiated for Function compilation. The tf2xla kernel for constant _Arg
// nodes takes the constant value, converts it to XlaLiteral, and feeds it
// to xla::ComputationBuilder.ConstantLiteral, which returns the handle. This
// constant XlaLiteral is included in the HLO graph, and subsequently, in
// the actual executable, which is copied to the device before being
// executed. Thus, when this executable runs, the constant is available in
// device memory.
tensorflow::MemoryTypeVector GetInputMemoryTypes(
const tensorflow::FunctionBody* fbody,
absl::Span<int const> constant_arg_indices,
absl::Span<int const> resource_arg_indices);
// Returns output memory types.
//
// XlaLaunch kernel keeps all outputs (including constants, which it copies),
// in device memory except for resources.
tensorflow::MemoryTypeVector GetOutputMemoryTypes(
const tensorflow::FunctionBody* fbody);
// Check whether graph can trigger XLA compilation.
bool CanTriggerXlaCompilation(const GraphDef& graph);

View File

@ -0,0 +1,114 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/get_compiler_ir.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "tensorflow/compiler/jit/compilability_check_util.h"
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/jit/xla_launch_util.h"
#include "tensorflow/compiler/jit/xla_platform_info.h"
#include "tensorflow/compiler/tf2xla/const_analysis.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
xla::StatusOr<std::string> GetCompilerIr(
IrExportStage stage, ProcessFunctionLibraryRuntime* pflr,
absl::string_view func_name, Device* dev,
absl::Span<const Tensor* const> inputs) {
NameAttrList function;
function.set_name(std::string{func_name});
FunctionLibraryRuntime* flr = pflr->GetFLR(dev->name());
ResourceMgr* rmgr = dev->resource_manager();
const FunctionBody* fbody = nullptr;
std::vector<int> constant_arg_indices;
std::vector<int> resource_arg_indices;
TF_RETURN_IF_ERROR(GetBodyAndConstantsAndResources(
flr, function, &fbody, &constant_arg_indices, &resource_arg_indices));
MemoryTypeVector input_memory_types =
GetInputMemoryTypes(fbody, constant_arg_indices, resource_arg_indices);
MemoryTypeVector output_memory_types = GetOutputMemoryTypes(fbody);
std::vector<VariableInfo> variable_infos;
TF_RETURN_IF_ERROR(GetVariableInfosFromInputs(
rmgr, dev, inputs, resource_arg_indices, &variable_infos));
TF_RETURN_IF_ERROR(LockVariables(absl::MakeSpan(variable_infos)));
XlaPlatformInfo platform_info = XlaPlatformInfoFromDevice(dev);
XlaCompilationCache* cache;
TF_RETURN_IF_ERROR(rmgr->LookupOrCreate<XlaCompilationCache>(
rmgr->default_container(), "xla_cache", &cache,
[&](XlaCompilationCache** cache_write_into) {
return BuildXlaCompilationCache(dev, platform_info, cache_write_into);
}));
core::ScopedUnref cache_ref(cache);
absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
XlaCompiler::Options options =
GenerateCompilerOptions(*cache, *flr, dev,
/*stream=*/nullptr, platform_info,
/*has_ref_vars=*/false, &tf_allocator_adapter);
XlaCompiler::CompileOptions compile_options;
compile_options.always_return_tuple = false;
compile_options.alias_resource_update = true;
XlaCompiler compiler(options);
xla::StatusOr<std::vector<XlaCompiler::Argument>> args =
XlaComputationLaunchContext::BuildXlaCompilerArguments(
constant_arg_indices, inputs, variable_infos);
TF_RETURN_IF_ERROR(args.status());
switch (stage) {
case IrExportStage::HLO: {
XlaCompiler::CompilationResult result;
TF_RETURN_IF_ERROR(
compiler.CompileFunction(compile_options, function, *args, &result));
TF_ASSIGN_OR_RETURN(xla::ProgramShape program_shape,
result.computation->GetProgramShape());
xla::HloModuleConfig config(program_shape);
TF_ASSIGN_OR_RETURN(
std::unique_ptr<xla::HloModule> new_module,
xla::HloModule::CreateFromProto(result.computation->proto(), config));
return new_module->ToString();
}
case IrExportStage::OPTIMIZED_HLO: {
const XlaCompiler::CompilationResult* compilation_result = nullptr;
xla::LocalExecutable* executable = nullptr;
TF_RETURN_IF_ERROR(
cache->Compile(options, function, *args, compile_options,
XlaCompilationCache::CompileMode::kStrict,
&compilation_result, &executable));
return executable->executable()->module().ToString();
}
}
}
} // namespace tensorflow

View File

@ -0,0 +1,39 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_JIT_GET_COMPILER_IR_H_
#define TENSORFLOW_COMPILER_JIT_GET_COMPILER_IR_H_
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/statusor.h"
namespace tensorflow {
class ProcessFunctionLibraryRuntime;
class Device;
class Tensor;
enum class IrExportStage { HLO, OPTIMIZED_HLO };
// Returns HLO text for a given function `func_name` using library runtime
// `runtime` on a device `dev` with given `inputs`.
xla::StatusOr<std::string> GetCompilerIr(
IrExportStage stage, ProcessFunctionLibraryRuntime* pflr,
absl::string_view func_name, Device* dev,
absl::Span<const Tensor* const> inputs);
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_GET_COMPILER_IR_H_

View File

@ -30,47 +30,6 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/util/ptr_util.h"
namespace {
// Utility which searches for values in a sorted list by scanning over it once.
// No matter how many times ScanForValue is called, the list is scanned at most
// once. However, if a call to ScanForValue skips over a value, that value is
// not revisited in future calls to ScanForValue, so callers must take
// care to order their calls.
//
// Useful for merging multiple sorted lists in O(n) time.
class SinglePassSearch {
public:
// Creates a SinglePassSearch object that can be used to search in `values`.
// Does not take ownership of `values`. `values` must outlive this.
// `values` must be sorted.
explicit SinglePassSearch(const std::vector<int>* values)
: current_index_(0), values_(values) {}
// Scans forward in the vector looking for "value", updating the internal
// position in to the vector.
// Returns true iff the vector contains the given value at or after current
// position.
// Not thread-safe.
bool ScanForValue(int value) {
while (current_index_ < values_->size() &&
(*values_)[current_index_] <= value) {
if ((*values_)[current_index_] == value) {
current_index_++;
return true;
}
current_index_++;
}
return false;
}
private:
int current_index_;
const std::vector<int>* values_;
};
} // end namespace
namespace tensorflow {
bool XlaKernelCreator::CanCreateKernel(
@ -130,52 +89,9 @@ static Status CreateXlaKernel(FunctionLibraryRuntime* flr,
TF_RETURN_IF_ERROR(GetBodyAndConstantsAndResources(
flr, function, &fbody, &constant_arg_indices, &resource_arg_indices));
// Set input and output memory types.
MemoryTypeVector input_memory_types(fbody->arg_types.size(), DEVICE_MEMORY);
// These indices are used only for optimization purposes. They allow us
// to loop over constant_arg_indices and resource_arg_indices only once
// while iterating over all the function arguments checking if it is a
// resource or a constant.
// The reason we optimized this code is because functions can have a lot of
// captured arguments. For example, the backward pass of ResNet50 takes in all
// 214 variables and a similar number of activations.
SinglePassSearch constants_search(&constant_arg_indices);
SinglePassSearch resources_search(&resource_arg_indices);
for (size_t i = 0; i < fbody->arg_types.size(); ++i) {
if (resources_search.ScanForValue(i) || constants_search.ScanForValue(i)) {
// Compile-time constants and resource handles are expected to be in
// host memory.
input_memory_types[i] = HOST_MEMORY;
}
}
// One might wonder, about the case where a compile-time constant argument
// (which must be in host memory) is also used as an input into an op,
// e.g. Add, that expects its inputs in device memory. Here is how it
// works now.
// First, what do we mean by "op expects an input in XYZ memory"?
// There are two types of "ops" here: the tf2xla kernel and the HLO
// computation it builds. The tf2xla kernel needs to retrieve the actual
// numeric value of the compile-time constant tensors, so it really expects
// them to be on in host memory. However, for other inputs, it refers to them
// using xla::ComputationDataHandle, which is just a symbolic handle that
// xla::ComputationBuilder assigns. How does this handle gets assigned for
// constant arguments? Even constant arguments get an _Arg node in the graph
// instantiated for Function compilation. The tf2xla kernel for constant _Arg
// nodes takes the constant value, converts it to XlaLiteral, and feeds it
// to xla::ComputationBuilder.ConstantLiteral, which returns the handle. This
// constant XlaLiteral is included in the HLO graph, and subsequently, in
// the actual executable, which is copied to the device before being
// executed. Thus, when this executable runs, the constant is available in
// device memory.
// XlaLaunch kernel keeps all outputs (including constants, which it copies),
// in device memory except for resources.
MemoryTypeVector output_memory_types(fbody->ret_types.size(), DEVICE_MEMORY);
for (size_t i = 0; i < fbody->ret_types.size(); ++i) {
if (fbody->ret_types[i] == DT_RESOURCE) {
output_memory_types[i] = HOST_MEMORY;
}
}
MemoryTypeVector input_memory_types =
GetInputMemoryTypes(fbody, constant_arg_indices, resource_arg_indices);
MemoryTypeVector output_memory_types = GetOutputMemoryTypes(fbody);
// Create the kernel.
Device* dev = flr->device();

View File

@ -6047,6 +6047,7 @@ filegroup(
"//tensorflow/c:checkpoint_reader", # checkpoint_reader
"//tensorflow/c:python_api", # tf_session
"//tensorflow/c:tf_status_helper", # tfe
"//tensorflow/compiler/jit:get_compiler_ir", #tfe
"//tensorflow/compiler/jit:flags", #tfe
"//tensorflow/compiler/mlir/python:mlir", # mlir
"//tensorflow/core/common_runtime:device", # device_lib, tfe, tf_session
@ -7557,6 +7558,8 @@ tf_python_pybind_extension(
"//third_party/python_runtime:headers",
"//tensorflow/c/experimental/saved_model/core:pywrap_required_hdrs",
"//tensorflow/compiler/jit:flags_headers_only",
"//tensorflow/compiler/jit:get_compiler_ir_hdrs_only",
"//tensorflow/c/eager:tfe_tensorhandle_internal",
"//tensorflow/core/common_runtime:core_cpu_headers_lib",
"//tensorflow/core:framework_headers_lib",
"//tensorflow/core:lib_headers_for_pybind",

View File

@ -1494,6 +1494,10 @@ class Context(object):
self._virtual_device_map[dev] = virtual_devices
def get_compiler_ir(self, function_name, args, stage="hlo"):
return pywrap_tfe.TF_GetCompilerIr(self._context_handle, function_name,
stage, self.device_name, args)
@deprecated(
None, "XLA:CPU and XLA:GPU devices are deprecated", warn_once=True)
def enable_xla_devices(self):

View File

@ -914,6 +914,76 @@ class Function(object):
return function_lib.defun(fn_with_cond)(canon_args, canon_kwds,
filtered_flat_args)
def experimental_get_compiler_ir(self, *args, **kwargs):
"""Returns compiler IR for the compiled function.
This API is intended *only* for debugging as there are no guarantees on
backwards compatibility of returned IR or the allowed values of `stage`.
Args:
*args: Arguments used for compilation; same arguments as used for calling
the function. Need to be eager tensors.
**kwargs: Keyword arguments used for compilation.
Returns:
Function callable with the stage at which the compiler IR should be
serialized. Allowed values for the `stage` are `hlo` and `optimized_hlo`.
When called, the returned function returns string representation of the
compiler IR at a given stage.
For example, for
```python
@tf.function(experimental_compile=True)
def f(x):
return x + 1
f.experimental_get_compiler_ir(tf.random.normal([10, 10])(stage='hlo')
```
the output is:
```
HloModule a_inference_f_13__.9
ENTRY %a_inference_f_13__.9 (arg0.1: f32[10,10]) -> f32[10,10] {
%arg0.1 = f32[10,10]{1,0} parameter(0), parameter_replication={false}
%reshape.2 = f32[10,10]{1,0} reshape(f32[10,10]{1,0} %arg0.1)
%constant.3 = f32[] constant(1)
%broadcast.4 = f32[10,10]{1,0} broadcast(f32[] %constant.3)
%add.5 = f32[10,10]{1,0} add(f32[10,10]{1,0} %reshape.2,
f32[10,10]{1,0} %broadcast.4)
%reshape.6 = f32[10,10]{1,0} reshape(f32[10,10]{1,0} %add.5)
%tuple.7 = (f32[10,10]{1,0}) tuple(f32[10,10]{1,0} %reshape.6)
ROOT %get-tuple-element.8 = f32[10,10]{1,0}
get-tuple-element((f32[10,10]{1,0}) %tuple.7), index=0
}
```
Raises:
ValueError: If an invalid `stage` is selected or if applied to a function
which is not compiled (`experimental_compile=True` is not set).
TypeError: When called with input in graph mode.
"""
context.ensure_initialized()
if not self._experimental_compile:
raise ValueError(
"Compiler IR can only be returned for functions marked with "
"experimental_compile=True")
concrete_fn = self.get_concrete_function(*args, **kwargs)
fn_name = concrete_fn.name
# pylint: disable=protected-access
canon_args, _, _, _ = \
concrete_fn._function_spec.canonicalize_function_inputs(
*args, **kwargs)
return functools.partial(
context.context().get_compiler_ir,
function_name=fn_name,
args=list(canon_args) + concrete_fn.captured_inputs)
@property
def python_function(self):
"""The python function wrapped in this tf.function."""

View File

@ -526,6 +526,82 @@ class DefFunctionTest(xla_test.XLATestCase):
b.backing_device) if on_gpu else 0
self.assertEqual(initial_usage, final_usage)
def testGetCompilerIrConstants(self):
if 'tpu' in self.device.lower():
self.skipTest('TPU generates different HLO')
with ops.device('device:{}:0'.format(self.device)):
@def_function.function(experimental_compile=True)
def f(a, b):
return array_ops.transpose(a, b)
a = array_ops.ones([3, 4, 3], dtype=dtypes.float32)
b = constant_op.constant([0, 2, 1], dtype=dtypes.int32)
self.assertIn('{1,2,0}',
f.experimental_get_compiler_ir(a, b)(stage='optimized_hlo'))
@test_util.disable_mlir_bridge('TODO(b/168732524): MLIR bridge does not '
' optimize single-element tuples to scalars')
def testGetCompilerIrResourceVars(self):
with ops.device('device:{}:0'.format(self.device)):
v = variables.Variable([3.1, 3.2])
@def_function.function(experimental_compile=True)
def f(a, b):
v.assign_add(a * b)
a = random_ops.random_normal([2])
b = random_ops.random_normal([2])
self.assertIn('input_output_alias={ {}: (2, {}, may-alias) }',
f.experimental_get_compiler_ir(a, b)(stage='optimized_hlo'))
def testGetCompilerIrNotCompiled(self):
with ops.device('device:{}:0'.format(self.device)):
@def_function.function
def f(x):
return x + 1
a = random_ops.random_normal([10, 10])
with self.assertRaisesRegex(ValueError,
'marked with experimental_compile'):
f.experimental_get_compiler_ir(a)()
def testGetCompilerIrNested(self):
with ops.device('device:{}:0'.format(self.device)):
@def_function.function(experimental_compile=True)
def fn(x, a):
return x + a
@def_function.function(experimental_compile=False)
def fn2(x, a):
fn.experimental_get_compiler_ir(x, a)()
return fn(x, a)
inputs = constant_op.constant([1, 2, 2, 3, 3])
with self.assertRaisesRegex(TypeError, '"Graph" tensor'):
fn2(inputs, 1)
def testGetCompilerIrKwargs(self):
with ops.device('device:{}:0'.format(self.device)):
v = variables.Variable([0.1, 0.1])
@def_function.function(experimental_compile=True)
def f(a, b):
return (a + b) * v
a = constant_op.constant([1.1, 1.1])
b = constant_op.constant([2.2, 2.2])
self.assertIn('multiply',
f.experimental_get_compiler_ir(b=a, a=b)(stage='hlo'))
if __name__ == '__main__':
ops.enable_eager_execution()

View File

@ -28,9 +28,11 @@ limitations under the License.
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/dlpack.h"
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/jit/get_compiler_ir.h"
#include "tensorflow/python/eager/pywrap_tensor_conversion.h"
#include "tensorflow/python/eager/pywrap_tfe.h"
#include "tensorflow/python/lib/core/py_exception_registry.h"
@ -285,6 +287,74 @@ static py::object TFE_ClearScalarCache() {
return py::none();
}
// Returns compiler IR for a given function.
static std::string TFE_GetCompilerIr(py::handle& ctx,
const char* concrete_function_name,
const char* stage, const char* device_name,
py::handle& inputs) {
EagerContext* context = ContextFromInterface(
reinterpret_cast<ImmediateExecutionContext*>(InputTFE_Context(ctx)));
std::string s_stage(stage);
IrExportStage selected_stage = [&] {
if (s_stage == "hlo") {
return IrExportStage::HLO;
} else if (s_stage == "optimized_hlo") {
return IrExportStage::OPTIMIZED_HLO;
} else {
ThrowValueError(
absl::StrFormat("Invalid stage selected: '%s'. Valid values are: "
"'hlo', 'optimized_hlo'",
s_stage)
.c_str());
}
}();
TFE_InputTensorHandles handles = InputTFE_InputTensorHandles(inputs);
std::vector<const Tensor*> input_tensors;
for (TFE_TensorHandle* tensor_handle : handles) {
AbstractTensorHandle* abstract_tensor_handle = unwrap(tensor_handle);
TensorHandle* th = TensorHandleFromInterface(abstract_tensor_handle);
const Tensor* t;
Status st = th->Tensor(&t);
if (!st.ok()) {
ThrowValueError(
absl::StrFormat("Could not resolve tensor: '%s'", st.error_message())
.c_str());
}
input_tensors.push_back(t);
}
DeviceNameUtils::ParsedName input_device_name;
if (!DeviceNameUtils::ParseFullOrLocalName(device_name, &input_device_name)) {
ThrowValueError(
absl::StrFormat("Failed parsing device name: '%s'", device_name)
.c_str());
}
std::vector<Device*> devices = context->local_device_mgr()->ListDevices();
auto selected_device = absl::c_find_if(devices, [&](const Device* d) {
return DeviceNameUtils::AreCompatibleDevNames(input_device_name,
d->parsed_name());
});
if (selected_device == devices.end()) {
ThrowValueError("No matching device found");
}
xla::StatusOr<std::string> hlo_text =
GetCompilerIr(selected_stage, context->pflr(), concrete_function_name,
*selected_device, input_tensors);
if (!hlo_text.ok()) {
ThrowValueError(absl::StrFormat("Failed getting HLO text: '%s'",
hlo_text.status().error_message())
.c_str());
}
return *hlo_text;
}
} // namespace tensorflow
namespace {
@ -513,6 +583,7 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
m.def("TF_SetXlaConstantFoldingDisabled", &TF_SetXlaConstantFoldingDisabled);
m.def("TF_GetXlaConstantFoldingDisabled", &TF_GetXlaConstantFoldingDisabled);
m.def("TF_SetXlaMinClusterSize", &TF_SetXlaMinClusterSize);
m.def("TF_GetCompilerIr", &tensorflow::TFE_GetCompilerIr);
// MLIR Logic
m.def("TF_IsMlirBridgeEnabled", [] {