STT-tensorflow/tensorflow/compiler/tf2xla/graph_compiler_util.h
Yunxing Dai 70d8aa322c Automatically set up user aliasing in tf2xla when a resource update is presented.
- When a resource update is presented, automatically alias the input and output.
- Also fix an issue where the input/output proto config is overwritten.

PiperOrigin-RevId: 294984983
Change-Id: I45e96513dfeaa91f523db63837355b698bd2fb85
2020-02-13 13:22:20 -08:00

52 lines
2.2 KiB
C++

/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_TF2XLA_GRAPH_COMPILER_UTIL_H_
#define TENSORFLOW_COMPILER_TF2XLA_GRAPH_COMPILER_UTIL_H_
#include <unordered_map>
#include "absl/types/optional.h"
#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/framework/graph.pb.h"
namespace tensorflow {
// Fills in xla_args from the corresponding _Arg nodes in the graph.
Status CreateXlaArgs(const Graph& graph,
std::vector<XlaCompiler::Argument>* xla_args);
// Populate xla_args for the given XLA config.
void PopulateXlaArgs(const tf2xla::Config& config,
std::vector<XlaCompiler::Argument>* xla_args);
// InitGraph creates a graph based on the graph_def, that may then be converted
// to an xla::XlaComputation via ConvertGraphToXla.
//
// The graph is rewritten with _Arg and _Retval nodes, representing the inputs
// and outputs of the function that will be compiled. Each feed id causes a new
// _Arg node to be created, where we first collect all existing edges pointing
// from the named node's output index, and then rewrite them to point from that
// _Arg node instead. Each fetch id causes a new _Retval node to be created,
// with a new edge pointing from the named node's output index to that _Retval
// node.
Status InitGraph(const GraphDef& graph_def, const tf2xla::Config& config,
std::unique_ptr<Graph>* graph);
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_TF2XLA_GRAPH_COMPILER_UTIL_H_