STT-tensorflow/tensorflow/compiler/tf2xla/const_analysis.h
George Karpenkov f3dcd9dc11 Support interprocedural constant meta-information propagation for compilation
This CL does two things:

1) Supports inter-procedural constant information propagation, across
PartitionedCall and StatefulPartitionedCall.

2) Done naively, (1) leads to exponential number of calls, as each function
will be reinlined for each (indirect) caller.
In order to address this performance issue, we cache the argument indices which
need to be constant, and attach that information to the Graph object.

This might require some clarification:

a) Caching in a passed map would not work, as duplication of constant
propagation for each top-level caller is still prohibitively expensive.

b) Caching in a global object would not work, as graphs are created and
destroyed during transformations.

c) Caching this meta-information on a `Graph` object has an added benefit that
we no longer perform the same constant propagation many times (a lot of
compilation passes call BackwardsConstAnalysis, and previously all this work
had to be repeated).

PiperOrigin-RevId: 303860413
Change-Id: I78f92ca1487fc952044e5ac6526dcaa5b50d5f21
2020-03-30 17:51:05 -07:00

52 lines
2.1 KiB
C++

/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_TF2XLA_CONST_ANALYSIS_H_
#define TENSORFLOW_COMPILER_TF2XLA_CONST_ANALYSIS_H_
#include <vector>
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
// Backwards dataflow analysis that finds nodes in a graph that must be
// compile-time constants for us to be able to lower the graph to XLA.
//
// The indices of the arguments to `graph` that must be constant are returned in
// `compile_time_const_arg_indices`, if `compile_time_const_arg_indices` is not
// null.
//
// The ids of the nodes in `graph` that must be constant are returned in
// `compile_time_const_nodes`, if `compile_time_const_nodes` is not null.
//
// If `edge_filter` is non-null, only propagate const-ness along edges for which
// `edge_filter` returns true.
Status BackwardsConstAnalysis(
const Graph& g, std::vector<bool>* compile_time_const_arg_indices,
std::vector<bool>* compile_time_const_nodes,
FunctionLibraryRuntime* flib_runtime,
std::function<bool(const Edge&)> edge_filter_input = nullptr);
// Given an op kernel and function library runtime, return all the indices of
// inputs that need to be compile time constant.
Status GetCompileTimeConstInputs(const OpKernel* op_kernel,
std::vector<int>* const_input_idxs,
FunctionLibraryRuntime* flib_runtime);
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_TF2XLA_CONST_ANALYSIS_H_