STT-tensorflow/tensorflow/python/framework/python_api_info.h
Edward Loper 30b69242f8 Class to hold precomputed information about a TensorFlow Python API, to allow the API to be executed rapidly.
PiperOrigin-RevId: 341499017
Change-Id: I875ea89efcd86a7fe9e2f8fcefab1cbd3aa2c0e9
2020-11-09 15:53:35 -08:00

299 lines
12 KiB
C++

/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_PYTHON_UTIL_PYTHON_API_INFO_H_
#define TENSORFLOW_PYTHON_UTIL_PYTHON_API_INFO_H_
#include <Python.h>
#include <map>
#include <string>
#include <vector>
#include "absl/types/span.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/python/framework/op_def_util.h"
#include "tensorflow/python/framework/python_tensor_converter.h"
#include "tensorflow/python/lib/core/safe_pyobject_ptr.h"
namespace tensorflow {
// Precomputed information about a TensorFlow Python API.
//
// PythonAPIInfo records information about a single TensorFlow Python API,
// in order to allow calls to the API to be executed more efficiently. This
// information includes:
//
// * The name of the API. (E.g. "tf.math.add")
//
// * The name of the registered op that implements the API, if applicable
// (e.g. "AddV2").
//
// * Information about the API's parameters. Parameters are divided into two
// "kinds": inputs and attributes. An *input* is a parameter that
// expects a Tensor or list of Tensors, and it is described by an `ArgDef`.
// An *attribute* is a parameter that expects any other value type, and it is
// described by an `AttrDef`.
//
// * Default values for the API's attribute parameters.
//
// * Information about "inferred attributes" -- attributes whose values are
// inferred from `input` parameters. There are two kinds of inferred
// attributes: Tensor dtypes, which are inferred from tensor and list(tensor)
// parameters; and list lengths, which are inferred from list(tensor)
// parameters.
class PythonAPIInfo {
public:
// The index of a parameter in the canonicalized parameter list. The
// canonicalized parameter list includes inputs and attributes (but does
// not include inferred attributes). `-1` is used for inferred attributes.
using ParamIndex = int;
// Information about a parameter that expects a non-Tensor value.
struct Attribute {
ParamIndex index; // -1 if this is an inferred attribute
AttributeType type;
const char* name; // Interned python string
int inferred_index; // index to store attribute in InferredAttributes
};
// Information about a parameter that expects a Tensor or list(Tensor).
// Additional information about tensor parameters is stored in types
// defined below, in order to simplify dtype/length inference:
// * FixedDTypeInput: inputs with fixed dtypes.
// * InputsWithTypeAttr: groups inputs that use a type_attr for dtype.
// * InputsWithTypeListAttr: groups inputs that use a type_list_attr.
// * InputsWithNumberAttr: groups inputs by a number_attr for length.
struct Input {
ParamIndex index;
bool is_list;
};
// Information about a Tensor parameter w/ fixed dtype.
struct InputWithFixedDType {
DataType dtype;
ParamIndex index;
bool is_list;
};
// Information about Tensor parameters whose DType is specified by a single
// `type_attr` attribute.
struct InputsWithTypeAttr {
Attribute* type_attr; // not owned.
DataType default_dtype; // DT_INVALID if no default.
std::vector<ParamIndex> tensor_params; // single-tensor inputs.
std::vector<ParamIndex> tensor_list_params; // list(tensor) inputs.
std::vector<DataType> ok_dtypes;
};
// Information about Tensor parameters whose DType is specified by a single
// `type_list_attr` attribute.
struct InputsWithTypeListAttr {
Attribute* type_list_attr; // not owned.
std::vector<DataType> default_dtypes; // empty if no default.
std::vector<ParamIndex> tensor_list_params; // list(tensor) inputs.
std::vector<DataType> ok_dtypes;
};
// Information about Tensor-list parameters whose length is specified by a
// single `int` attribute.
struct InputsWithNumberAttr {
Attribute* number_attr; // not owned.
int64 default_length; // -1 for no default.
std::vector<ParamIndex> tensor_list_params; // list(tensor) inputs.
};
// Structure used to return inferred attribute values.
// * types[i] is the inferred value for inferred_type_attrs()[i]
// * type_lists[i] is the inferred value for inferred_type_list_attrs()[i]
// * lengths[i] is the inferred value for inferred_length_attrs()[i]
struct InferredAttributes {
std::vector<DataType> types;
std::vector<std::vector<DataType>> type_lists;
std::vector<int64> lengths;
};
// Constructs a new PythonAPIInfo.
//
// Note: One of the `Initialize()` functions must be called before the
// `PythonAPIInfo` is used.
//
// Args:
// api_name: The fully-qualified name of the python API (e.g., tf.math.sum).
explicit PythonAPIInfo(const std::string& api_name);
// Initializes this PythonAPIInfo.
//
// Args:
// op_def: Contains information about the parameters.
// param_names: The argument names for the python API, in canonical order.
// defaults_tuple: Tuple containing default values for the parameters,
// right-aligned with `param_names` -- i.e., `defaults[-i]` is the default
// for `param_names[-i]`.
Status Initialize(const OpDef& op_def, const std::vector<string> param_names,
PyObject* defaults_tuple);
// Initialize this PythonAPIInfo based on the registered OpDef for the given
// operation.
//
// Args:
// op_name: The registered name of the operation (e.g. "AddV2").
Status InitializeFromRegisteredOp(const std::string& op_name);
// Initializes this PythonAPIInfo based on a set of parameter specifications.
//
// Args:
// input_specs: Mapping from parameter name to specification string for
// each input (parameter that expects a tensor value).
// attr_specs: Mapping from parameter name to specification string for
// each attribute (parameter that expects a non-tensor value).
// param_names: The argument names for the python API, in canonical order.
// defaults_tuple: Tuple containing default values for the parameters,
// right-aligned with `param_names` -- i.e., `defaults[-i]` is the default
// for `param_names[-i]`.
//
// Note: the `name` parameter should not be included in `input_specs` or
// `attr_specs`.
Status InitializeFromParamSpecs(
const std::map<std::string, std::string>& input_specs,
const std::map<std::string, std::string>& attr_specs,
const std::vector<string> param_names, PyObject* defaults_tuple);
// The name of the API that is described by this PythonAPIInfo.
const char* api_name() const { return api_name_; }
// The ordered names of the canononical parameters that this API expects.
const std::vector<const char*>& param_names() const { return param_names_; }
// A Python tuple containing the default values for parameters. This is
// right-aligned with `param_name` -- i.e., `defaults[-i]` is the default
// for `param_names[-i]`.
const PyObject* defaults_tuple() const { return defaults_tuple_.get(); }
// Information about the attribute (non-tensor) parameters for this API.
const std::vector<Attribute>& attributes() const { return attributes_; }
// Information about the input (tensor) parameters for this API.
const std::vector<Input>& inputs() const { return inputs_; }
const std::vector<InputWithFixedDType>& inputs_with_fixed_dtype() const {
return inputs_with_fixed_dtype_;
}
const std::vector<InputsWithTypeAttr>& inputs_with_type_attrs() const {
return inputs_with_type_attrs_;
}
const std::vector<InputsWithTypeListAttr>& inputs_with_type_list_attrs()
const {
return inputs_with_type_list_attrs_;
}
const std::vector<InputsWithNumberAttr>& inputs_with_number_attrs() const {
return inputs_with_number_attrs_;
}
// Names of inferred attributes.
const std::vector<const char*>& inferred_type_attrs() const {
return inferred_type_attrs_;
}
const std::vector<const char*>& inferred_type_list_attrs() const {
return inferred_type_list_attrs_;
}
const std::vector<const char*>& inferred_length_attrs() const {
return inferred_length_attrs_;
}
// Returns a string summarizing the internal state of this type converter.
string DebugInfo() const;
private:
// Adds an entry to the attributes_ vector based on the given `AttrDef`.
//
// If `attr_def` describes a type attribute, then adds a value to
// inputs_with_type_attrs_ or inputs_with_type_list_attrs_ (to record any
// tensor inputs that use this dtype).
//
// If `attr_def` describes an int attribute, then adds a value to
// inputs_with_number_attrs_ (to record any tensor inputs that use this
// value as a list length).
Status InitializeAttribute(
const OpDef::AttrDef& attr_def,
const std::map<std::string, ParamIndex>& param_name_to_index);
// Adds an entry to the inputs_ vector based on the given `ArgDef`.
//
// If `arg_def` has a fixed dtype, then adds a value to `fixed_dtype_inputs`.
//
// If `arg_def`'s dtype is described by a `type` attr, then updates the
// appropriate value in `inputs_with_type_attrs_` with information about the
// `arg_def`.
//
// If `arg_def`'s dtype is described by a `list(type)` attr, then updates the
// appropriate value in `inputs_with_type_list_attrs_` with information about
// the `arg_def`.
Status InitializeInput(const OpDef::ArgDef& arg_def,
const std::map<std::string, int>& param_name_to_index);
// Checks that the OpDef used to initialize this PythonAPIInfo
// had an AttrDef or ArgDef specification for each parameter.
Status CheckParamNames() const;
// Searches inputs_with_type_attrs_ for an input with the given name.
InputsWithTypeAttr* FindInputsWithTypeAttr(const string& name);
// Searches inputs_with_type_list_attrs_ for an input with the given name.
InputsWithTypeListAttr* FindInputsWithTypeListAttr(const string& name);
// Searches inputs_with_type_list_attrs_ for an input with the given name.
InputsWithNumberAttr* FindInputsWithNumberAttr(const string& name);
ABSL_MUST_USE_RESULT
bool InferLengthAttributes(const absl::Span<PyObject*> params,
std::vector<int64>& inferred_length_attrs) const;
// ==========================================================================
// Member Variables
// ==========================================================================
// The name of the API that is described by this PythonAPIInfo.
// (Interned python string).
const char* api_name_;
// The names of the parameters that this API expects.
// (Interned python strings.)
std::vector<const char*> param_names_;
// Tuple containing default values for the parameters, right-aligned with
// `param_names` -- i.e., `defaults[-i]` is the default for `param_names[-i]`.
Safe_PyObjectPtr defaults_tuple_;
// Information about the non-tensor-valued parameters that this API expects.
std::vector<Attribute> attributes_;
// Information about the tensor-valued parameters that this API expects.
std::vector<Input> inputs_;
std::vector<InputWithFixedDType> inputs_with_fixed_dtype_;
std::vector<InputsWithTypeAttr> inputs_with_type_attrs_;
std::vector<InputsWithTypeListAttr> inputs_with_type_list_attrs_;
std::vector<InputsWithNumberAttr> inputs_with_number_attrs_;
// Names of inferred attributes. (Interned python strings.)
std::vector<const char*> inferred_type_attrs_;
std::vector<const char*> inferred_type_list_attrs_;
std::vector<const char*> inferred_length_attrs_;
};
} // namespace tensorflow
#endif // TENSORFLOW_PYTHON_UTIL_PYTHON_API_INFO_H_