Qualify uses of std::string
PiperOrigin-RevId: 324252304 Change-Id: I82a93ab00c2cba05ea0b7fe4fc3709f0bec30d93
This commit is contained in:
parent
7f3772b7b8
commit
a697dbc604
@ -36,7 +36,7 @@ class NameAttrList;
|
|||||||
|
|
||||||
// A human-readable rendering of attr_value, that is more concise than a
|
// A human-readable rendering of attr_value, that is more concise than a
|
||||||
// text-format proto.
|
// text-format proto.
|
||||||
string SummarizeAttrValue(const AttrValue& attr_value);
|
std::string SummarizeAttrValue(const AttrValue& attr_value);
|
||||||
|
|
||||||
// Generates an error if attr_value doesn't have the indicated attr type.
|
// Generates an error if attr_value doesn't have the indicated attr type.
|
||||||
Status AttrValueHasType(const AttrValue& attr_value, StringPiece type);
|
Status AttrValueHasType(const AttrValue& attr_value, StringPiece type);
|
||||||
@ -51,7 +51,7 @@ Status AttrValueHasType(const AttrValue& attr_value, StringPiece type);
|
|||||||
bool ParseAttrValue(StringPiece type, StringPiece text, AttrValue* out);
|
bool ParseAttrValue(StringPiece type, StringPiece text, AttrValue* out);
|
||||||
|
|
||||||
// Sets *out based on the type of value.
|
// Sets *out based on the type of value.
|
||||||
void SetAttrValue(const string& value, AttrValue* out);
|
void SetAttrValue(const std::string& value, AttrValue* out);
|
||||||
void SetAttrValue(const tstring& value, AttrValue* out);
|
void SetAttrValue(const tstring& value, AttrValue* out);
|
||||||
void SetAttrValue(const char* value, AttrValue* out);
|
void SetAttrValue(const char* value, AttrValue* out);
|
||||||
void SetAttrValue(StringPiece value, AttrValue* out);
|
void SetAttrValue(StringPiece value, AttrValue* out);
|
||||||
|
@ -237,7 +237,7 @@ class DeviceBase {
|
|||||||
// Unimplemented by default
|
// Unimplemented by default
|
||||||
virtual const DeviceAttributes& attributes() const;
|
virtual const DeviceAttributes& attributes() const;
|
||||||
virtual int NumaNode() const { return attributes().locality().numa_node(); }
|
virtual int NumaNode() const { return attributes().locality().numa_node(); }
|
||||||
virtual const string& name() const;
|
virtual const std::string& name() const;
|
||||||
|
|
||||||
// Materializes the given TensorProto into 'tensor' stored in Device
|
// Materializes the given TensorProto into 'tensor' stored in Device
|
||||||
// memory. Most devices will want to override this.
|
// memory. Most devices will want to override this.
|
||||||
|
@ -114,9 +114,9 @@ class FunctionDefHelper {
|
|||||||
|
|
||||||
// Constructs an AttrValue.func given the "name" and "attrs".
|
// Constructs an AttrValue.func given the "name" and "attrs".
|
||||||
static AttrValueWrapper FunctionRef(
|
static AttrValueWrapper FunctionRef(
|
||||||
const string& name,
|
const std::string& name,
|
||||||
gtl::ArraySlice<std::pair<string, AttrValueWrapper>> attrs);
|
gtl::ArraySlice<std::pair<string, AttrValueWrapper>> attrs);
|
||||||
static AttrValueWrapper FunctionRef(const string& name) {
|
static AttrValueWrapper FunctionRef(const std::string& name) {
|
||||||
return FunctionRef(name, {});
|
return FunctionRef(name, {});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -127,11 +127,11 @@ class FunctionDefHelper {
|
|||||||
// When constructing a NodeDef, the first entry in ret is used as
|
// When constructing a NodeDef, the first entry in ret is used as
|
||||||
// the node name, the remaining values are ignored.
|
// the node name, the remaining values are ignored.
|
||||||
std::vector<string> ret;
|
std::vector<string> ret;
|
||||||
string op;
|
std::string op;
|
||||||
std::vector<string> arg;
|
std::vector<string> arg;
|
||||||
std::vector<std::pair<string, AttrValueWrapper>> attr;
|
std::vector<std::pair<string, AttrValueWrapper>> attr;
|
||||||
std::vector<string> dep;
|
std::vector<string> dep;
|
||||||
string device;
|
std::string device;
|
||||||
|
|
||||||
NodeDef ToNodeDef() const;
|
NodeDef ToNodeDef() const;
|
||||||
};
|
};
|
||||||
@ -143,7 +143,7 @@ class FunctionDefHelper {
|
|||||||
// - `control_ret_def` holds a mapping from the function control
|
// - `control_ret_def` holds a mapping from the function control
|
||||||
// output names to the nodes from `node_def`.
|
// output names to the nodes from `node_def`.
|
||||||
static FunctionDef Create(
|
static FunctionDef Create(
|
||||||
const string& function_name, gtl::ArraySlice<string> in_def,
|
const std::string& function_name, gtl::ArraySlice<string> in_def,
|
||||||
gtl::ArraySlice<string> out_def, gtl::ArraySlice<string> attr_def,
|
gtl::ArraySlice<string> out_def, gtl::ArraySlice<string> attr_def,
|
||||||
gtl::ArraySlice<Node> node_def,
|
gtl::ArraySlice<Node> node_def,
|
||||||
gtl::ArraySlice<std::pair<string, string>> ret_def,
|
gtl::ArraySlice<std::pair<string, string>> ret_def,
|
||||||
@ -153,7 +153,7 @@ class FunctionDefHelper {
|
|||||||
// function encoding (node_name:output_name[:output_index]).
|
// function encoding (node_name:output_name[:output_index]).
|
||||||
// - `ret_def` holds a mapping from the function output names from `out_def`
|
// - `ret_def` holds a mapping from the function output names from `out_def`
|
||||||
// to the node outputs from `node_def`.
|
// to the node outputs from `node_def`.
|
||||||
static FunctionDef Create(const string& function_name,
|
static FunctionDef Create(const std::string& function_name,
|
||||||
gtl::ArraySlice<string> in_def,
|
gtl::ArraySlice<string> in_def,
|
||||||
gtl::ArraySlice<string> out_def,
|
gtl::ArraySlice<string> out_def,
|
||||||
gtl::ArraySlice<string> attr_def,
|
gtl::ArraySlice<string> attr_def,
|
||||||
@ -161,7 +161,7 @@ class FunctionDefHelper {
|
|||||||
gtl::ArraySlice<std::pair<string, string>> ret_def);
|
gtl::ArraySlice<std::pair<string, string>> ret_def);
|
||||||
|
|
||||||
// TODO(josh11b): Get rid of these and transition to the one above.
|
// TODO(josh11b): Get rid of these and transition to the one above.
|
||||||
static FunctionDef Define(const string& function_name,
|
static FunctionDef Define(const std::string& function_name,
|
||||||
gtl::ArraySlice<string> arg_def,
|
gtl::ArraySlice<string> arg_def,
|
||||||
gtl::ArraySlice<string> ret_def,
|
gtl::ArraySlice<string> ret_def,
|
||||||
gtl::ArraySlice<string> attr_def,
|
gtl::ArraySlice<string> attr_def,
|
||||||
@ -175,7 +175,7 @@ class FunctionDefHelper {
|
|||||||
|
|
||||||
// Helpers to construct a constant scalar.
|
// Helpers to construct a constant scalar.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static Node Const(const string& name, const T& val) {
|
static Node Const(const std::string& name, const T& val) {
|
||||||
Node n = {{name}, "Const"};
|
Node n = {{name}, "Const"};
|
||||||
const DataType dtype = DataTypeToEnum<T>::value;
|
const DataType dtype = DataTypeToEnum<T>::value;
|
||||||
n.attr.push_back({"dtype", dtype});
|
n.attr.push_back({"dtype", dtype});
|
||||||
@ -186,7 +186,7 @@ class FunctionDefHelper {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static Node Const(const string& name, gtl::ArraySlice<T> vals) {
|
static Node Const(const std::string& name, gtl::ArraySlice<T> vals) {
|
||||||
Node n = {{name}, "Const"};
|
Node n = {{name}, "Const"};
|
||||||
const DataType dtype = DataTypeToEnum<T>::value;
|
const DataType dtype = DataTypeToEnum<T>::value;
|
||||||
n.attr.push_back({"dtype", dtype});
|
n.attr.push_back({"dtype", dtype});
|
||||||
@ -207,7 +207,7 @@ inline FunctionDefHelper::AttrValueWrapper::AttrValueWrapper(const char* val) {
|
|||||||
|
|
||||||
template <>
|
template <>
|
||||||
inline FunctionDefHelper::AttrValueWrapper::AttrValueWrapper(
|
inline FunctionDefHelper::AttrValueWrapper::AttrValueWrapper(
|
||||||
const string& val) {
|
const std::string& val) {
|
||||||
InitFromString(val);
|
InitFromString(val);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -251,13 +251,13 @@ Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values,
|
|||||||
// Particularly, it may not include all information presented in
|
// Particularly, it may not include all information presented in
|
||||||
// "func_def" (e.g., comments, description of the function arguments,
|
// "func_def" (e.g., comments, description of the function arguments,
|
||||||
// etc.)
|
// etc.)
|
||||||
string DebugString(const FunctionDef& func_def);
|
std::string DebugString(const FunctionDef& func_def);
|
||||||
string DebugString(const GraphDef& instantiated_func_def);
|
std::string DebugString(const GraphDef& instantiated_func_def);
|
||||||
string DebugString(gtl::ArraySlice<NodeDef> instantiated_func_nodes);
|
std::string DebugString(gtl::ArraySlice<NodeDef> instantiated_func_nodes);
|
||||||
|
|
||||||
// Returns a debug string for a top level graph (the main program and
|
// Returns a debug string for a top level graph (the main program and
|
||||||
// its supporting functions defined in its library).
|
// its supporting functions defined in its library).
|
||||||
string DebugStringWhole(const GraphDef& gdef);
|
std::string DebugStringWhole(const GraphDef& gdef);
|
||||||
|
|
||||||
// Returns true if f1 == f2. Compares all fields, including descriptions. Order
|
// Returns true if f1 == f2. Compares all fields, including descriptions. Order
|
||||||
// of NodeDefs doesn't matter.
|
// of NodeDefs doesn't matter.
|
||||||
@ -360,14 +360,14 @@ class FunctionLibraryDefinition : public OpRegistryInterface {
|
|||||||
delete;
|
delete;
|
||||||
|
|
||||||
// Returns True if the library contains `func`, False otherwise.
|
// Returns True if the library contains `func`, False otherwise.
|
||||||
bool Contains(const string& func) const;
|
bool Contains(const std::string& func) const;
|
||||||
|
|
||||||
// Returns nullptr if "func" is not defined in "lib_def". Otherwise,
|
// Returns nullptr if "func" is not defined in "lib_def". Otherwise,
|
||||||
// returns its definition proto.
|
// returns its definition proto.
|
||||||
//
|
//
|
||||||
// NB: This function returns a borrowed pointer, which can be invalidated by a
|
// NB: This function returns a borrowed pointer, which can be invalidated by a
|
||||||
// subsequent call to `ReplaceFunction()` with the given name.
|
// subsequent call to `ReplaceFunction()` with the given name.
|
||||||
const FunctionDef* Find(const string& func) const TF_LOCKS_EXCLUDED(mu_);
|
const FunctionDef* Find(const std::string& func) const TF_LOCKS_EXCLUDED(mu_);
|
||||||
|
|
||||||
// Adds function definition 'fdef' to this function library.
|
// Adds function definition 'fdef' to this function library.
|
||||||
// Returns status 'ok' on success, or error otherwise. This is a no-op if
|
// Returns status 'ok' on success, or error otherwise. This is a no-op if
|
||||||
@ -388,7 +388,7 @@ class FunctionLibraryDefinition : public OpRegistryInterface {
|
|||||||
// a non-OK status if "func" was not found in the library, OK otherwise.
|
// a non-OK status if "func" was not found in the library, OK otherwise.
|
||||||
// Please be careful when replacing function: make sure all previous pointers
|
// Please be careful when replacing function: make sure all previous pointers
|
||||||
// returned by `Find()` are no longer in use.
|
// returned by `Find()` are no longer in use.
|
||||||
Status ReplaceFunction(const string& func, const FunctionDef& fdef)
|
Status ReplaceFunction(const std::string& func, const FunctionDef& fdef)
|
||||||
TF_LOCKS_EXCLUDED(mu_);
|
TF_LOCKS_EXCLUDED(mu_);
|
||||||
|
|
||||||
// Replaces the gradient corresponding to `grad.function_name()`. Returns
|
// Replaces the gradient corresponding to `grad.function_name()`. Returns
|
||||||
@ -401,7 +401,7 @@ class FunctionLibraryDefinition : public OpRegistryInterface {
|
|||||||
// Please be careful when removing function: make sure there are no other
|
// Please be careful when removing function: make sure there are no other
|
||||||
// nodes using the function, and all previous pointers returned by `Find()`
|
// nodes using the function, and all previous pointers returned by `Find()`
|
||||||
// are no longer in use.
|
// are no longer in use.
|
||||||
Status RemoveFunction(const string& func) TF_LOCKS_EXCLUDED(mu_);
|
Status RemoveFunction(const std::string& func) TF_LOCKS_EXCLUDED(mu_);
|
||||||
|
|
||||||
// Adds the functions and gradients in 'other' to this function library.
|
// Adds the functions and gradients in 'other' to this function library.
|
||||||
// Duplicate functions and gradients are ignored.
|
// Duplicate functions and gradients are ignored.
|
||||||
@ -417,7 +417,8 @@ class FunctionLibraryDefinition : public OpRegistryInterface {
|
|||||||
// If the gradient function for 'func' is specified explicitly in
|
// If the gradient function for 'func' is specified explicitly in
|
||||||
// the library, returns the gradient function name. Otherwise,
|
// the library, returns the gradient function name. Otherwise,
|
||||||
// returns an empty string.
|
// returns an empty string.
|
||||||
string FindGradient(const string& func) const TF_LOCKS_EXCLUDED(mu_);
|
std::string FindGradient(const std::string& func) const
|
||||||
|
TF_LOCKS_EXCLUDED(mu_);
|
||||||
|
|
||||||
// OpRegistryInterface method. Useful for constructing a Graph.
|
// OpRegistryInterface method. Useful for constructing a Graph.
|
||||||
//
|
//
|
||||||
@ -427,26 +428,27 @@ class FunctionLibraryDefinition : public OpRegistryInterface {
|
|||||||
//
|
//
|
||||||
// NB: This function outputs a borrowed pointer, which can be invalidated by a
|
// NB: This function outputs a borrowed pointer, which can be invalidated by a
|
||||||
// subsequent call to `ReplaceFunction()` with the given name.
|
// subsequent call to `ReplaceFunction()` with the given name.
|
||||||
Status LookUp(const string& op_type_name,
|
Status LookUp(const std::string& op_type_name,
|
||||||
const OpRegistrationData** op_reg_data) const override
|
const OpRegistrationData** op_reg_data) const override
|
||||||
TF_LOCKS_EXCLUDED(mu_);
|
TF_LOCKS_EXCLUDED(mu_);
|
||||||
|
|
||||||
// Generates new function name with the specified prefix that is unique
|
// Generates new function name with the specified prefix that is unique
|
||||||
// across this library.
|
// across this library.
|
||||||
string UniqueFunctionName(StringPiece prefix) const TF_LOCKS_EXCLUDED(mu_);
|
std::string UniqueFunctionName(StringPiece prefix) const
|
||||||
|
TF_LOCKS_EXCLUDED(mu_);
|
||||||
|
|
||||||
// Given a node def 'ndef', inspects attributes of the callee
|
// Given a node def 'ndef', inspects attributes of the callee
|
||||||
// function to derive the attribute 'value' for 'attr'. Returns OK
|
// function to derive the attribute 'value' for 'attr'. Returns OK
|
||||||
// iff the attribute is given by the function's definition.
|
// iff the attribute is given by the function's definition.
|
||||||
// TODO(irving): Remove; keep only the const Node& version.
|
// TODO(irving): Remove; keep only the const Node& version.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
Status GetAttr(const NodeDef& ndef, const string& attr, T* value) const;
|
Status GetAttr(const NodeDef& ndef, const std::string& attr, T* value) const;
|
||||||
|
|
||||||
// Given a node, inspects attributes of the callee function to derive the
|
// Given a node, inspects attributes of the callee function to derive the
|
||||||
// attribute 'value' for 'attr'. Returns OK iff the attribute is given by the
|
// attribute 'value' for 'attr'. Returns OK iff the attribute is given by the
|
||||||
// function's definition.
|
// function's definition.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
Status GetAttr(const Node& node, const string& attr, T* value) const;
|
Status GetAttr(const Node& node, const std::string& attr, T* value) const;
|
||||||
|
|
||||||
// Returns a proto representation of the state of this function library.
|
// Returns a proto representation of the state of this function library.
|
||||||
FunctionDefLibrary ToProto() const TF_LOCKS_EXCLUDED(mu_);
|
FunctionDefLibrary ToProto() const TF_LOCKS_EXCLUDED(mu_);
|
||||||
@ -475,7 +477,7 @@ class FunctionLibraryDefinition : public OpRegistryInterface {
|
|||||||
// name `func` already exists in this function library, and has the same
|
// name `func` already exists in this function library, and has the same
|
||||||
// implementation as in `other`. If the implementations conflict, an invalid
|
// implementation as in `other`. If the implementations conflict, an invalid
|
||||||
// argument error is returned.
|
// argument error is returned.
|
||||||
Status CopyFunctionDefFrom(const string& func,
|
Status CopyFunctionDefFrom(const std::string& func,
|
||||||
const FunctionLibraryDefinition& other)
|
const FunctionLibraryDefinition& other)
|
||||||
TF_LOCKS_EXCLUDED(mu_);
|
TF_LOCKS_EXCLUDED(mu_);
|
||||||
|
|
||||||
@ -491,7 +493,7 @@ class FunctionLibraryDefinition : public OpRegistryInterface {
|
|||||||
|
|
||||||
std::shared_ptr<FunctionDefAndOpRegistration> FindHelper(
|
std::shared_ptr<FunctionDefAndOpRegistration> FindHelper(
|
||||||
const string& func) const TF_SHARED_LOCKS_REQUIRED(mu_);
|
const string& func) const TF_SHARED_LOCKS_REQUIRED(mu_);
|
||||||
string FindGradientHelper(const string& func) const
|
std::string FindGradientHelper(const std::string& func) const
|
||||||
TF_SHARED_LOCKS_REQUIRED(mu_);
|
TF_SHARED_LOCKS_REQUIRED(mu_);
|
||||||
|
|
||||||
Status AddHelper(std::shared_ptr<FunctionDefAndOpRegistration> registration,
|
Status AddHelper(std::shared_ptr<FunctionDefAndOpRegistration> registration,
|
||||||
@ -518,12 +520,13 @@ class FunctionLibraryDefinition : public OpRegistryInterface {
|
|||||||
// Remove `func` from the library. Returns non-OK Status unless `func` is in
|
// Remove `func` from the library. Returns non-OK Status unless `func` is in
|
||||||
// the library. This should only be called when there is a guarantee that the
|
// the library. This should only be called when there is a guarantee that the
|
||||||
// function being removed hasn't been retrieved with `Find`.
|
// function being removed hasn't been retrieved with `Find`.
|
||||||
Status RemoveFunctionHelper(const string& func)
|
Status RemoveFunctionHelper(const std::string& func)
|
||||||
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||||
|
|
||||||
// Remove gradient of function `func` from the library. Returns non-OK Status
|
// Remove gradient of function `func` from the library. Returns non-OK Status
|
||||||
// unless `func` has a gradient.
|
// unless `func` has a gradient.
|
||||||
Status RemoveGradient(const string& func) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
Status RemoveGradient(const std::string& func)
|
||||||
|
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||||
|
|
||||||
mutable mutex mu_;
|
mutable mutex mu_;
|
||||||
const OpRegistryInterface* const default_registry_;
|
const OpRegistryInterface* const default_registry_;
|
||||||
@ -566,7 +569,7 @@ class FunctionLibraryRuntime {
|
|||||||
// The canonical device name of the device on which the function
|
// The canonical device name of the device on which the function
|
||||||
// should be instantiated. If empty, the function will be
|
// should be instantiated. If empty, the function will be
|
||||||
// instantiated on the local device.
|
// instantiated on the local device.
|
||||||
string target;
|
std::string target;
|
||||||
|
|
||||||
// Should the function be instantiated as a multi-device function?
|
// Should the function be instantiated as a multi-device function?
|
||||||
bool is_multi_device_function = false;
|
bool is_multi_device_function = false;
|
||||||
@ -640,13 +643,13 @@ class FunctionLibraryRuntime {
|
|||||||
// `state_handle` will have the same handle and share the same
|
// `state_handle` will have the same handle and share the same
|
||||||
// state (in stateful kernels); and two functions with different
|
// state (in stateful kernels); and two functions with different
|
||||||
// values for `state_handle` will have independent state.
|
// values for `state_handle` will have independent state.
|
||||||
string state_handle;
|
std::string state_handle;
|
||||||
|
|
||||||
// This interface is EXPERIMENTAL and subject to change.
|
// This interface is EXPERIMENTAL and subject to change.
|
||||||
//
|
//
|
||||||
// Instantiates the function using an executor of the given type. If empty,
|
// Instantiates the function using an executor of the given type. If empty,
|
||||||
// the default TensorFlow executor will be used.
|
// the default TensorFlow executor will be used.
|
||||||
string executor_type;
|
std::string executor_type;
|
||||||
|
|
||||||
// If true, the runtime will attempt to create kernels for the function at
|
// If true, the runtime will attempt to create kernels for the function at
|
||||||
// instantiation time, rather than on the first run. This can be used to
|
// instantiation time, rather than on the first run. This can be used to
|
||||||
@ -680,10 +683,10 @@ class FunctionLibraryRuntime {
|
|||||||
bool include_optimized_graph_in_debug_string = false;
|
bool include_optimized_graph_in_debug_string = false;
|
||||||
};
|
};
|
||||||
typedef uint64 Handle;
|
typedef uint64 Handle;
|
||||||
virtual Status Instantiate(const string& function_name, AttrSlice attrs,
|
virtual Status Instantiate(const std::string& function_name, AttrSlice attrs,
|
||||||
const InstantiateOptions& options,
|
const InstantiateOptions& options,
|
||||||
Handle* handle) = 0;
|
Handle* handle) = 0;
|
||||||
Status Instantiate(const string& function_name, AttrSlice attrs,
|
Status Instantiate(const std::string& function_name, AttrSlice attrs,
|
||||||
Handle* handle) {
|
Handle* handle) {
|
||||||
auto opts = absl::make_unique<InstantiateOptions>();
|
auto opts = absl::make_unique<InstantiateOptions>();
|
||||||
return Instantiate(function_name, attrs, *opts, handle);
|
return Instantiate(function_name, attrs, *opts, handle);
|
||||||
@ -738,7 +741,7 @@ class FunctionLibraryRuntime {
|
|||||||
|
|
||||||
// Parameters for remote function execution.
|
// Parameters for remote function execution.
|
||||||
bool remote_execution = false;
|
bool remote_execution = false;
|
||||||
string source_device = ""; // Fully specified device name.
|
std::string source_device = ""; // Fully specified device name.
|
||||||
|
|
||||||
// Allocator attributes specifying where the args are / rets should be put.
|
// Allocator attributes specifying where the args are / rets should be put.
|
||||||
// These should either be {} or match the length of args / retvals. If {},
|
// These should either be {} or match the length of args / retvals. If {},
|
||||||
@ -758,7 +761,7 @@ class FunctionLibraryRuntime {
|
|||||||
bool run_all_kernels_inline = false;
|
bool run_all_kernels_inline = false;
|
||||||
|
|
||||||
// Returns a human readable representation of this.
|
// Returns a human readable representation of this.
|
||||||
string DebugString() const;
|
std::string DebugString() const;
|
||||||
};
|
};
|
||||||
typedef std::function<void(const Status&)> DoneCallback;
|
typedef std::function<void(const Status&)> DoneCallback;
|
||||||
virtual void Run(const Options& opts, Handle handle,
|
virtual void Run(const Options& opts, Handle handle,
|
||||||
@ -786,7 +789,7 @@ class FunctionLibraryRuntime {
|
|||||||
// NOTE(mrry): This method assumes that the runtime is associated with a
|
// NOTE(mrry): This method assumes that the runtime is associated with a
|
||||||
// default function library, and looks up `function_name` in that library.
|
// default function library, and looks up `function_name` in that library.
|
||||||
// It does not support overriding the function library.
|
// It does not support overriding the function library.
|
||||||
virtual bool IsStateful(const string& function_name) const = 0;
|
virtual bool IsStateful(const std::string& function_name) const = 0;
|
||||||
|
|
||||||
// Returns the device on which the function executes.
|
// Returns the device on which the function executes.
|
||||||
virtual Device* device() = 0;
|
virtual Device* device() = 0;
|
||||||
@ -817,7 +820,7 @@ class FunctionLibraryRuntime {
|
|||||||
|
|
||||||
// Returns a debug string showing the definition of the function of
|
// Returns a debug string showing the definition of the function of
|
||||||
// 'handle'.
|
// 'handle'.
|
||||||
virtual string DebugString(Handle handle) = 0;
|
virtual std::string DebugString(Handle handle) = 0;
|
||||||
|
|
||||||
// Returns the graph version number.
|
// Returns the graph version number.
|
||||||
virtual int graph_def_version() const = 0;
|
virtual int graph_def_version() const = 0;
|
||||||
@ -847,13 +850,13 @@ class FunctionLibraryRuntime {
|
|||||||
// `ExecutorFactory::GetFactory()`) that will be used based on the given
|
// `ExecutorFactory::GetFactory()`) that will be used based on the given
|
||||||
// dynamic `options` and static `attrs`. If none is specified, this method
|
// dynamic `options` and static `attrs`. If none is specified, this method
|
||||||
// will return an empty string, which leaves the decision up to the runtime.
|
// will return an empty string, which leaves the decision up to the runtime.
|
||||||
static string ExecutorType(const InstantiateOptions& options,
|
static std::string ExecutorType(const InstantiateOptions& options,
|
||||||
AttrSlice attrs);
|
AttrSlice attrs);
|
||||||
};
|
};
|
||||||
|
|
||||||
// Returns the device of the `arg_index`-th function input. Update
|
// Returns the device of the `arg_index`-th function input. Update
|
||||||
// `composite_devices` if the input device is a composite device.
|
// `composite_devices` if the input device is a composite device.
|
||||||
string GetFunctionResourceInputDevice(
|
std::string GetFunctionResourceInputDevice(
|
||||||
const Tensor& input, const int arg_index, const FunctionDef& function_def,
|
const Tensor& input, const int arg_index, const FunctionDef& function_def,
|
||||||
absl::flat_hash_map<string, std::vector<string>>* composite_devices);
|
absl::flat_hash_map<string, std::vector<string>>* composite_devices);
|
||||||
|
|
||||||
@ -864,9 +867,10 @@ string GetFunctionResourceInputDevice(
|
|||||||
// space. But it may be change as the implementation
|
// space. But it may be change as the implementation
|
||||||
// evolves. Therefore, it should not be persisted or compared across
|
// evolves. Therefore, it should not be persisted or compared across
|
||||||
// address spaces.
|
// address spaces.
|
||||||
string Canonicalize(const string& funcname, AttrSlice attrs,
|
std::string Canonicalize(
|
||||||
const FunctionLibraryRuntime::InstantiateOptions& options);
|
const std::string& funcname, AttrSlice attrs,
|
||||||
string Canonicalize(const string& funcname, AttrSlice attrs);
|
const FunctionLibraryRuntime::InstantiateOptions& options);
|
||||||
|
std::string Canonicalize(const std::string& funcname, AttrSlice attrs);
|
||||||
|
|
||||||
const FunctionLibraryRuntime::Handle kInvalidHandle = -1;
|
const FunctionLibraryRuntime::Handle kInvalidHandle = -1;
|
||||||
const FunctionLibraryRuntime::LocalHandle kInvalidLocalHandle = -1;
|
const FunctionLibraryRuntime::LocalHandle kInvalidLocalHandle = -1;
|
||||||
@ -907,8 +911,8 @@ class DistributedFunctionLibraryRuntime {
|
|||||||
// local `handle` is filled for the instantiated function data and can be used
|
// local `handle` is filled for the instantiated function data and can be used
|
||||||
// for subsequent run function calls on the remote target.
|
// for subsequent run function calls on the remote target.
|
||||||
virtual void Instantiate(
|
virtual void Instantiate(
|
||||||
const string& function_name, const FunctionLibraryDefinition& lib_def,
|
const std::string& function_name,
|
||||||
AttrSlice attrs,
|
const FunctionLibraryDefinition& lib_def, AttrSlice attrs,
|
||||||
const FunctionLibraryRuntime::InstantiateOptions& options,
|
const FunctionLibraryRuntime::InstantiateOptions& options,
|
||||||
FunctionLibraryRuntime::LocalHandle* handle,
|
FunctionLibraryRuntime::LocalHandle* handle,
|
||||||
FunctionLibraryRuntime::DoneCallback done) = 0;
|
FunctionLibraryRuntime::DoneCallback done) = 0;
|
||||||
@ -1022,11 +1026,11 @@ Status ArgNumType(AttrSlice attrs, const OpDef::ArgDef& arg_def,
|
|||||||
namespace gradient {
|
namespace gradient {
|
||||||
// Register a gradient creator for the "op".
|
// Register a gradient creator for the "op".
|
||||||
typedef std::function<Status(const AttrSlice& attrs, FunctionDef*)> Creator;
|
typedef std::function<Status(const AttrSlice& attrs, FunctionDef*)> Creator;
|
||||||
bool RegisterOp(const string& op, Creator func);
|
bool RegisterOp(const std::string& op, Creator func);
|
||||||
|
|
||||||
// Returns OK the gradient creator for the "op" is found (may be
|
// Returns OK the gradient creator for the "op" is found (may be
|
||||||
// nullptr if REGISTER_OP_NO_GRADIENT is used.
|
// nullptr if REGISTER_OP_NO_GRADIENT is used.
|
||||||
Status GetOpGradientCreator(const string& op, Creator* creator);
|
Status GetOpGradientCreator(const std::string& op, Creator* creator);
|
||||||
}; // namespace gradient
|
}; // namespace gradient
|
||||||
|
|
||||||
// Declare explicit instantiations of GetAttr
|
// Declare explicit instantiations of GetAttr
|
||||||
|
@ -52,14 +52,14 @@ class LogMemory {
|
|||||||
UNKNOWN_STEP_ID = -6,
|
UNKNOWN_STEP_ID = -6,
|
||||||
};
|
};
|
||||||
|
|
||||||
static const string kLogMemoryLabel;
|
static const std::string kLogMemoryLabel;
|
||||||
|
|
||||||
// Test to see if memory logging is enabled. For now, logging is
|
// Test to see if memory logging is enabled. For now, logging is
|
||||||
// enabled whenever VLOG_IS_ON(1) for the log_memory module.
|
// enabled whenever VLOG_IS_ON(1) for the log_memory module.
|
||||||
static bool IsEnabled();
|
static bool IsEnabled();
|
||||||
|
|
||||||
// Log the beginning of a step.
|
// Log the beginning of a step.
|
||||||
static void RecordStep(int64 step_id, const string& handle);
|
static void RecordStep(int64 step_id, const std::string& handle);
|
||||||
|
|
||||||
// Log a tensor buffer allocation. The name indicates which kernel
|
// Log a tensor buffer allocation. The name indicates which kernel
|
||||||
// made the allocation. If the allocation is made through an
|
// made the allocation. If the allocation is made through an
|
||||||
@ -67,8 +67,8 @@ class LogMemory {
|
|||||||
// otherwise step_id is one of the SpecialStepIds defined in
|
// otherwise step_id is one of the SpecialStepIds defined in
|
||||||
// op_kernel.h, e.g. Op Kernel construction or an optimization pass
|
// op_kernel.h, e.g. Op Kernel construction or an optimization pass
|
||||||
// such as constant folding.
|
// such as constant folding.
|
||||||
static void RecordTensorAllocation(const string& kernel_name, int64 step_id,
|
static void RecordTensorAllocation(const std::string& kernel_name,
|
||||||
const Tensor& tensor);
|
int64 step_id, const Tensor& tensor);
|
||||||
|
|
||||||
// Log a tensor buffer deallocation. The deallocation is triggered
|
// Log a tensor buffer deallocation. The deallocation is triggered
|
||||||
// when the buffer's refcount falls to zero, and the tracking
|
// when the buffer's refcount falls to zero, and the tracking
|
||||||
@ -77,10 +77,10 @@ class LogMemory {
|
|||||||
// corresponding tensor previously passed in to
|
// corresponding tensor previously passed in to
|
||||||
// RecordTensorAllocation.
|
// RecordTensorAllocation.
|
||||||
static void RecordTensorDeallocation(int64 allocation_id,
|
static void RecordTensorDeallocation(int64 allocation_id,
|
||||||
const string& allocator_name);
|
const std::string& allocator_name);
|
||||||
|
|
||||||
// Log the use of a tensor as an output from a kernel.
|
// Log the use of a tensor as an output from a kernel.
|
||||||
static void RecordTensorOutput(const string& kernel_name, int64 step_id,
|
static void RecordTensorOutput(const std::string& kernel_name, int64 step_id,
|
||||||
int index, const Tensor& tensor);
|
int index, const Tensor& tensor);
|
||||||
|
|
||||||
// Log a "raw" allocation, which is just a buffer sized in
|
// Log a "raw" allocation, which is just a buffer sized in
|
||||||
@ -92,7 +92,7 @@ class LogMemory {
|
|||||||
// is executing, otherwise step_id is one of the SpecialStepIds
|
// is executing, otherwise step_id is one of the SpecialStepIds
|
||||||
// defined in op_kernel.h, e.g. Op Kernel construction or an
|
// defined in op_kernel.h, e.g. Op Kernel construction or an
|
||||||
// optimization pass such as constant folding.
|
// optimization pass such as constant folding.
|
||||||
static void RecordRawAllocation(const string& operation, int64 step_id,
|
static void RecordRawAllocation(const std::string& operation, int64 step_id,
|
||||||
size_t num_bytes, void* ptr,
|
size_t num_bytes, void* ptr,
|
||||||
Allocator* allocator);
|
Allocator* allocator);
|
||||||
|
|
||||||
@ -101,7 +101,7 @@ class LogMemory {
|
|||||||
// enqueued using the buffer. A deferred deallocation should always
|
// enqueued using the buffer. A deferred deallocation should always
|
||||||
// be followed by a matching non-deferred deallocation when the
|
// be followed by a matching non-deferred deallocation when the
|
||||||
// buffer is actually returned and can be reused.
|
// buffer is actually returned and can be reused.
|
||||||
static void RecordRawDeallocation(const string& operation, int64 step_id,
|
static void RecordRawDeallocation(const std::string& operation, int64 step_id,
|
||||||
void* ptr, Allocator* allocator,
|
void* ptr, Allocator* allocator,
|
||||||
bool deferred);
|
bool deferred);
|
||||||
};
|
};
|
||||||
|
@ -62,16 +62,16 @@ extern const char* const kColocationGroupPrefix;
|
|||||||
// The parameter `max_inputs_in_summary` specifies how many inputs at most to
|
// The parameter `max_inputs_in_summary` specifies how many inputs at most to
|
||||||
// serialize in the output (in order not to get a string which is overly large).
|
// serialize in the output (in order not to get a string which is overly large).
|
||||||
// The value `-1` specifies that all inputs will be shown.
|
// The value `-1` specifies that all inputs will be shown.
|
||||||
string SummarizeNodeDef(const NodeDef& node_def,
|
std::string SummarizeNodeDef(const NodeDef& node_def,
|
||||||
int max_inputs_in_summary = -1);
|
int max_inputs_in_summary = -1);
|
||||||
string SummarizeAttrs(const NodeDef& node_def);
|
std::string SummarizeAttrs(const NodeDef& node_def);
|
||||||
string SummarizeAttrsHelper(AttrSlice attrs, StringPiece device);
|
std::string SummarizeAttrsHelper(AttrSlice attrs, StringPiece device);
|
||||||
|
|
||||||
// Produces a formatted string pattern from the node which can uniquely identify
|
// Produces a formatted string pattern from the node which can uniquely identify
|
||||||
// this node upstream to produce an informative error message. The pattern
|
// this node upstream to produce an informative error message. The pattern
|
||||||
// followed is: {{node <node_name>}}
|
// followed is: {{node <node_name>}}
|
||||||
string FormatNodeDefForError(const NodeDef& node_def);
|
std::string FormatNodeDefForError(const NodeDef& node_def);
|
||||||
string FormatNodeDefForError(
|
std::string FormatNodeDefForError(
|
||||||
StringPiece node_name, bool has_experimental_debug_info,
|
StringPiece node_name, bool has_experimental_debug_info,
|
||||||
const NodeDef_ExperimentalDebugInfo& experimental_debug_info);
|
const NodeDef_ExperimentalDebugInfo& experimental_debug_info);
|
||||||
|
|
||||||
@ -148,7 +148,7 @@ class AttrSlice {
|
|||||||
// Returns the attr with attr_name if found. Otherwise, returns
|
// Returns the attr with attr_name if found. Otherwise, returns
|
||||||
// nullptr.
|
// nullptr.
|
||||||
const AttrValue* Find(StringPiece attr_name) const;
|
const AttrValue* Find(StringPiece attr_name) const;
|
||||||
const AttrValue* FindByString(const string& attr_name) const;
|
const AttrValue* FindByString(const std::string& attr_name) const;
|
||||||
|
|
||||||
// Returns the attr_value for attr_name if found. Otherwise, returns a
|
// Returns the attr_value for attr_name if found. Otherwise, returns a
|
||||||
// NotFound status.
|
// NotFound status.
|
||||||
@ -157,8 +157,8 @@ class AttrSlice {
|
|||||||
// Helper class to avoid allocations in EqualAttrs.
|
// Helper class to avoid allocations in EqualAttrs.
|
||||||
// TODO(irving): Will go away once NodeInfo is used.
|
// TODO(irving): Will go away once NodeInfo is used.
|
||||||
struct Scratch {
|
struct Scratch {
|
||||||
string a;
|
std::string a;
|
||||||
string b;
|
std::string b;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Check if all attrs and attr values match. Does not take defaults into
|
// Check if all attrs and attr values match. Does not take defaults into
|
||||||
@ -175,13 +175,13 @@ class AttrSlice {
|
|||||||
// If this AttrSlice has an attached NodeDef, summarize it. This is for
|
// If this AttrSlice has an attached NodeDef, summarize it. This is for
|
||||||
// error messages only: we intentionally do not provide direct access to the
|
// error messages only: we intentionally do not provide direct access to the
|
||||||
// NodeDef, since it is not always there.
|
// NodeDef, since it is not always there.
|
||||||
string SummarizeNode() const;
|
std::string SummarizeNode() const;
|
||||||
|
|
||||||
// Iteration over all attrs
|
// Iteration over all attrs
|
||||||
AttrValueMap::const_iterator begin() const { return attrs_->begin(); }
|
AttrValueMap::const_iterator begin() const { return attrs_->begin(); }
|
||||||
AttrValueMap::const_iterator end() const { return attrs_->end(); }
|
AttrValueMap::const_iterator end() const { return attrs_->end(); }
|
||||||
|
|
||||||
string DebugString() const;
|
std::string DebugString() const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
const NodeDef* ndef_;
|
const NodeDef* ndef_;
|
||||||
@ -195,7 +195,7 @@ bool HasNodeAttr(const NodeDef& node_def, StringPiece attr_name);
|
|||||||
// attr with attr_name is found in node_def, or the attr does not have
|
// attr with attr_name is found in node_def, or the attr does not have
|
||||||
// a matching type, a non-ok status will be returned.
|
// a matching type, a non-ok status will be returned.
|
||||||
Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
|
Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
|
||||||
string* value); // type: "string"
|
std::string* value); // type: "string"
|
||||||
Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
|
Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
|
||||||
tstring* value); // type: "tstring"
|
tstring* value); // type: "tstring"
|
||||||
Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
|
Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
|
||||||
@ -266,7 +266,7 @@ Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
|
|||||||
// attr with attr_name is found in node_def, or the attr does not have
|
// attr with attr_name is found in node_def, or the attr does not have
|
||||||
// a matching type, false is returned.
|
// a matching type, false is returned.
|
||||||
bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
|
bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
|
||||||
string* value); // type: "string"
|
std::string* value); // type: "string"
|
||||||
bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
|
bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
|
||||||
int64* value); // type: "int"
|
int64* value); // type: "int"
|
||||||
bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
|
bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
|
||||||
@ -309,7 +309,8 @@ bool TryGetNodeAttr(
|
|||||||
// If no attr with attr_name is found in node_def, or the attr does not have
|
// If no attr with attr_name is found in node_def, or the attr does not have
|
||||||
// a matching type, a reference to an empty string is returned.
|
// a matching type, a reference to an empty string is returned.
|
||||||
// REQUIRES: Must not use the returned value beyond the lifetime of node_def.
|
// REQUIRES: Must not use the returned value beyond the lifetime of node_def.
|
||||||
const string& GetNodeAttrString(const AttrSlice& attrs, StringPiece attr_name);
|
const std::string& GetNodeAttrString(const AttrSlice& attrs,
|
||||||
|
StringPiece attr_name);
|
||||||
|
|
||||||
// Specialization to parse an attribute directly into a Padding enum.
|
// Specialization to parse an attribute directly into a Padding enum.
|
||||||
Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
|
Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
|
||||||
|
@ -45,11 +45,12 @@ class OpRegistryInterface {
|
|||||||
// Returns an error status and sets *op_reg_data to nullptr if no OpDef is
|
// Returns an error status and sets *op_reg_data to nullptr if no OpDef is
|
||||||
// registered under that name, otherwise returns the registered OpDef.
|
// registered under that name, otherwise returns the registered OpDef.
|
||||||
// Caller must not delete the returned pointer.
|
// Caller must not delete the returned pointer.
|
||||||
virtual Status LookUp(const string& op_type_name,
|
virtual Status LookUp(const std::string& op_type_name,
|
||||||
const OpRegistrationData** op_reg_data) const = 0;
|
const OpRegistrationData** op_reg_data) const = 0;
|
||||||
|
|
||||||
// Shorthand for calling LookUp to get the OpDef.
|
// Shorthand for calling LookUp to get the OpDef.
|
||||||
Status LookUpOpDef(const string& op_type_name, const OpDef** op_def) const;
|
Status LookUpOpDef(const std::string& op_type_name,
|
||||||
|
const OpDef** op_def) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
// The standard implementation of OpRegistryInterface, along with a
|
// The standard implementation of OpRegistryInterface, along with a
|
||||||
@ -71,11 +72,11 @@ class OpRegistry : public OpRegistryInterface {
|
|||||||
|
|
||||||
void Register(const OpRegistrationDataFactory& op_data_factory);
|
void Register(const OpRegistrationDataFactory& op_data_factory);
|
||||||
|
|
||||||
Status LookUp(const string& op_type_name,
|
Status LookUp(const std::string& op_type_name,
|
||||||
const OpRegistrationData** op_reg_data) const override;
|
const OpRegistrationData** op_reg_data) const override;
|
||||||
|
|
||||||
// Returns OpRegistrationData* of registered op type, else returns nullptr.
|
// Returns OpRegistrationData* of registered op type, else returns nullptr.
|
||||||
const OpRegistrationData* LookUp(const string& op_type_name) const;
|
const OpRegistrationData* LookUp(const std::string& op_type_name) const;
|
||||||
|
|
||||||
// Fills *ops with all registered OpDefs (except those with names
|
// Fills *ops with all registered OpDefs (except those with names
|
||||||
// starting with '_' if include_internal == false) sorted in
|
// starting with '_' if include_internal == false) sorted in
|
||||||
@ -84,7 +85,7 @@ class OpRegistry : public OpRegistryInterface {
|
|||||||
|
|
||||||
// Returns ASCII-format OpList for all registered OpDefs (except
|
// Returns ASCII-format OpList for all registered OpDefs (except
|
||||||
// those with names starting with '_' if include_internal == false).
|
// those with names starting with '_' if include_internal == false).
|
||||||
string DebugString(bool include_internal) const;
|
std::string DebugString(bool include_internal) const;
|
||||||
|
|
||||||
// A singleton available at startup.
|
// A singleton available at startup.
|
||||||
static OpRegistry* Global();
|
static OpRegistry* Global();
|
||||||
@ -153,7 +154,7 @@ class OpRegistry : public OpRegistryInterface {
|
|||||||
Status RegisterAlreadyLocked(const OpRegistrationDataFactory& op_data_factory)
|
Status RegisterAlreadyLocked(const OpRegistrationDataFactory& op_data_factory)
|
||||||
const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||||
|
|
||||||
const OpRegistrationData* LookUpSlow(const string& op_type_name) const;
|
const OpRegistrationData* LookUpSlow(const std::string& op_type_name) const;
|
||||||
|
|
||||||
mutable mutex mu_;
|
mutable mutex mu_;
|
||||||
// Functions in deferred_ may only be called with mu_ held.
|
// Functions in deferred_ may only be called with mu_ held.
|
||||||
@ -179,11 +180,11 @@ class OpListOpRegistry : public OpRegistryInterface {
|
|||||||
// Does not take ownership of op_list, *op_list must outlive *this.
|
// Does not take ownership of op_list, *op_list must outlive *this.
|
||||||
explicit OpListOpRegistry(const OpList* op_list);
|
explicit OpListOpRegistry(const OpList* op_list);
|
||||||
~OpListOpRegistry() override;
|
~OpListOpRegistry() override;
|
||||||
Status LookUp(const string& op_type_name,
|
Status LookUp(const std::string& op_type_name,
|
||||||
const OpRegistrationData** op_reg_data) const override;
|
const OpRegistrationData** op_reg_data) const override;
|
||||||
|
|
||||||
// Returns OpRegistrationData* of op type in list, else returns nullptr.
|
// Returns OpRegistrationData* of op type in list, else returns nullptr.
|
||||||
const OpRegistrationData* LookUp(const string& op_type_name) const;
|
const OpRegistrationData* LookUp(const std::string& op_type_name) const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Values are owned.
|
// Values are owned.
|
||||||
@ -225,15 +226,15 @@ template <>
|
|||||||
class OpDefBuilderWrapper<true> {
|
class OpDefBuilderWrapper<true> {
|
||||||
public:
|
public:
|
||||||
explicit OpDefBuilderWrapper(const char name[]) : builder_(name) {}
|
explicit OpDefBuilderWrapper(const char name[]) : builder_(name) {}
|
||||||
OpDefBuilderWrapper<true>& Attr(string spec) {
|
OpDefBuilderWrapper<true>& Attr(std::string spec) {
|
||||||
builder_.Attr(std::move(spec));
|
builder_.Attr(std::move(spec));
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
OpDefBuilderWrapper<true>& Input(string spec) {
|
OpDefBuilderWrapper<true>& Input(std::string spec) {
|
||||||
builder_.Input(std::move(spec));
|
builder_.Input(std::move(spec));
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
OpDefBuilderWrapper<true>& Output(string spec) {
|
OpDefBuilderWrapper<true>& Output(std::string spec) {
|
||||||
builder_.Output(std::move(spec));
|
builder_.Output(std::move(spec));
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
@ -259,11 +260,11 @@ class OpDefBuilderWrapper<true> {
|
|||||||
builder_.SetAllowsUninitializedInput();
|
builder_.SetAllowsUninitializedInput();
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
OpDefBuilderWrapper<true>& Deprecated(int version, string explanation) {
|
OpDefBuilderWrapper<true>& Deprecated(int version, std::string explanation) {
|
||||||
builder_.Deprecated(version, std::move(explanation));
|
builder_.Deprecated(version, std::move(explanation));
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
OpDefBuilderWrapper<true>& Doc(string text) {
|
OpDefBuilderWrapper<true>& Doc(std::string text) {
|
||||||
builder_.Doc(std::move(text));
|
builder_.Doc(std::move(text));
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
@ -53,7 +53,7 @@ struct OpRegistrationData {
|
|||||||
class OpDefBuilder {
|
class OpDefBuilder {
|
||||||
public:
|
public:
|
||||||
// Constructs an OpDef with just the name field set.
|
// Constructs an OpDef with just the name field set.
|
||||||
explicit OpDefBuilder(string op_name);
|
explicit OpDefBuilder(std::string op_name);
|
||||||
|
|
||||||
// Adds an attr to this OpDefBuilder (and returns *this). The spec has
|
// Adds an attr to this OpDefBuilder (and returns *this). The spec has
|
||||||
// format "<name>:<type>" or "<name>:<type>=<default>"
|
// format "<name>:<type>" or "<name>:<type>=<default>"
|
||||||
@ -86,7 +86,7 @@ class OpDefBuilder {
|
|||||||
// * Ability to restrict the type of the tensor like the existing
|
// * Ability to restrict the type of the tensor like the existing
|
||||||
// restrictions for type attrs.
|
// restrictions for type attrs.
|
||||||
// Perhaps by linking the type of the tensor to a type attr?
|
// Perhaps by linking the type of the tensor to a type attr?
|
||||||
OpDefBuilder& Attr(string spec);
|
OpDefBuilder& Attr(std::string spec);
|
||||||
|
|
||||||
// Adds an input or output to this OpDefBuilder (and returns *this).
|
// Adds an input or output to this OpDefBuilder (and returns *this).
|
||||||
// The spec has form "<name>:<type-expr>" or "<name>:Ref(<type-expr>)"
|
// The spec has form "<name>:<type-expr>" or "<name>:Ref(<type-expr>)"
|
||||||
@ -103,8 +103,8 @@ class OpDefBuilder {
|
|||||||
// in the spec?
|
// in the spec?
|
||||||
// TODO(josh11b): SparseInput() and SparseOutput() matching the Python
|
// TODO(josh11b): SparseInput() and SparseOutput() matching the Python
|
||||||
// handling?
|
// handling?
|
||||||
OpDefBuilder& Input(string spec);
|
OpDefBuilder& Input(std::string spec);
|
||||||
OpDefBuilder& Output(string spec);
|
OpDefBuilder& Output(std::string spec);
|
||||||
|
|
||||||
// Turns on the indicated boolean flag in this OpDefBuilder (and
|
// Turns on the indicated boolean flag in this OpDefBuilder (and
|
||||||
// returns *this).
|
// returns *this).
|
||||||
@ -114,7 +114,7 @@ class OpDefBuilder {
|
|||||||
OpDefBuilder& SetAllowsUninitializedInput();
|
OpDefBuilder& SetAllowsUninitializedInput();
|
||||||
|
|
||||||
// Deprecate the op at a certain GraphDef version.
|
// Deprecate the op at a certain GraphDef version.
|
||||||
OpDefBuilder& Deprecated(int version, string explanation);
|
OpDefBuilder& Deprecated(int version, std::string explanation);
|
||||||
|
|
||||||
// Adds docs to this OpDefBuilder (and returns *this).
|
// Adds docs to this OpDefBuilder (and returns *this).
|
||||||
// Docs have the format:
|
// Docs have the format:
|
||||||
@ -130,7 +130,7 @@ class OpDefBuilder {
|
|||||||
// to suppress the automatically-generated type documentation in
|
// to suppress the automatically-generated type documentation in
|
||||||
// generated output.
|
// generated output.
|
||||||
#ifndef TF_LEAN_BINARY
|
#ifndef TF_LEAN_BINARY
|
||||||
OpDefBuilder& Doc(string text);
|
OpDefBuilder& Doc(std::string text);
|
||||||
#else
|
#else
|
||||||
OpDefBuilder& Doc(string text) { return *this; }
|
OpDefBuilder& Doc(string text) { return *this; }
|
||||||
#endif
|
#endif
|
||||||
@ -157,7 +157,7 @@ class OpDefBuilder {
|
|||||||
// Adds control output to this OpDefBuilder (and returns *this).
|
// Adds control output to this OpDefBuilder (and returns *this).
|
||||||
// The <name> must be a valid node name (matches regexp
|
// The <name> must be a valid node name (matches regexp
|
||||||
// [a-zA-Z][a-zA-Z0-9_]*). Named control output can only exist for functions.
|
// [a-zA-Z][a-zA-Z0-9_]*). Named control output can only exist for functions.
|
||||||
OpDefBuilder& ControlOutput(string name);
|
OpDefBuilder& ControlOutput(std::string name);
|
||||||
|
|
||||||
OpDef* op_def() { return &op_reg_data_.op_def; }
|
OpDef* op_def() { return &op_reg_data_.op_def; }
|
||||||
|
|
||||||
@ -166,7 +166,7 @@ class OpDefBuilder {
|
|||||||
std::vector<string> inputs_;
|
std::vector<string> inputs_;
|
||||||
std::vector<string> outputs_;
|
std::vector<string> outputs_;
|
||||||
std::vector<string> control_outputs_;
|
std::vector<string> control_outputs_;
|
||||||
string doc_;
|
std::string doc_;
|
||||||
std::vector<string> errors_;
|
std::vector<string> errors_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -54,7 +54,7 @@ const ApiDef::Arg* FindInputArg(StringPiece name, const ApiDef& api_def);
|
|||||||
|
|
||||||
// Produce a human-readable version of an op_def that is more concise
|
// Produce a human-readable version of an op_def that is more concise
|
||||||
// than a text-format proto. Excludes descriptions.
|
// than a text-format proto. Excludes descriptions.
|
||||||
string SummarizeOpDef(const OpDef& op_def);
|
std::string SummarizeOpDef(const OpDef& op_def);
|
||||||
|
|
||||||
// Returns an error if new_op is not backwards-compatible with (more
|
// Returns an error if new_op is not backwards-compatible with (more
|
||||||
// accepting than) old_op.
|
// accepting than) old_op.
|
||||||
|
@ -145,14 +145,16 @@ class OpKernel {
|
|||||||
|
|
||||||
// Accessors.
|
// Accessors.
|
||||||
const NodeDef& def() const { return props_->node_def; }
|
const NodeDef& def() const { return props_->node_def; }
|
||||||
const string& name() const { return props_->node_def.name(); }
|
const std::string& name() const { return props_->node_def.name(); }
|
||||||
absl::string_view name_view() const { return name_view_; }
|
absl::string_view name_view() const { return name_view_; }
|
||||||
const string& type_string() const { return props_->node_def.op(); }
|
const std::string& type_string() const { return props_->node_def.op(); }
|
||||||
absl::string_view type_string_view() const { return type_string_view_; }
|
absl::string_view type_string_view() const { return type_string_view_; }
|
||||||
const string& requested_input(int i) const {
|
const std::string& requested_input(int i) const {
|
||||||
return props_->node_def.input(i);
|
return props_->node_def.input(i);
|
||||||
}
|
}
|
||||||
const string& requested_device() const { return props_->node_def.device(); }
|
const std::string& requested_device() const {
|
||||||
|
return props_->node_def.device();
|
||||||
|
}
|
||||||
|
|
||||||
int num_inputs() const { return props_->input_types.size(); }
|
int num_inputs() const { return props_->input_types.size(); }
|
||||||
DataType input_type(int i) const { return props_->input_types[i]; }
|
DataType input_type(int i) const { return props_->input_types[i]; }
|
||||||
@ -177,10 +179,11 @@ class OpKernel {
|
|||||||
// Returns a trace string for current computation, op name/type and input
|
// Returns a trace string for current computation, op name/type and input
|
||||||
// tensor shape/dtype are encoded for profiler cost analysis. Most OpKernel
|
// tensor shape/dtype are encoded for profiler cost analysis. Most OpKernel
|
||||||
// should use the default implementation.
|
// should use the default implementation.
|
||||||
virtual string TraceString(const OpKernelContext& ctx, bool verbose) const;
|
virtual std::string TraceString(const OpKernelContext& ctx,
|
||||||
|
bool verbose) const;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
string ShapeTraceString(const OpKernelContext& ctx) const;
|
std::string ShapeTraceString(const OpKernelContext& ctx) const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
const std::shared_ptr<const NodeProperties> props_;
|
const std::shared_ptr<const NodeProperties> props_;
|
||||||
@ -652,7 +655,7 @@ class OpKernelContext {
|
|||||||
SessionState* session_state = nullptr;
|
SessionState* session_state = nullptr;
|
||||||
|
|
||||||
// Unique session identifier. Can be empty.
|
// Unique session identifier. Can be empty.
|
||||||
string session_handle;
|
std::string session_handle;
|
||||||
|
|
||||||
// Metadata about the session. Can be nullptr.
|
// Metadata about the session. Can be nullptr.
|
||||||
const SessionMetadata* session_metadata = nullptr;
|
const SessionMetadata* session_metadata = nullptr;
|
||||||
@ -684,7 +687,7 @@ class OpKernelContext {
|
|||||||
StepStatsCollectorInterface* stats_collector = nullptr;
|
StepStatsCollectorInterface* stats_collector = nullptr;
|
||||||
GraphCollector* graph_collector = nullptr;
|
GraphCollector* graph_collector = nullptr;
|
||||||
bool run_all_kernels_inline = false;
|
bool run_all_kernels_inline = false;
|
||||||
const string* executor_type = nullptr;
|
const std::string* executor_type = nullptr;
|
||||||
|
|
||||||
// TensorSliceReaderCache support.
|
// TensorSliceReaderCache support.
|
||||||
checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache = nullptr;
|
checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache = nullptr;
|
||||||
@ -826,7 +829,7 @@ class OpKernelContext {
|
|||||||
|
|
||||||
// Returns the registered name for the executor type that is executing the
|
// Returns the registered name for the executor type that is executing the
|
||||||
// current kernel. If empty, the default executor is used.
|
// current kernel. If empty, the default executor is used.
|
||||||
const string& executor_type() const;
|
const std::string& executor_type() const;
|
||||||
|
|
||||||
// Input to output forwarding.
|
// Input to output forwarding.
|
||||||
|
|
||||||
@ -1100,7 +1103,7 @@ class OpKernelContext {
|
|||||||
SessionState* session_state() const { return params_->session_state; }
|
SessionState* session_state() const { return params_->session_state; }
|
||||||
|
|
||||||
// Unique identifier of the session it belongs to. Can be empty.
|
// Unique identifier of the session it belongs to. Can be empty.
|
||||||
string session_handle() const { return params_->session_handle; }
|
std::string session_handle() const { return params_->session_handle; }
|
||||||
|
|
||||||
// Metadata about the session. Can be nullptr.
|
// Metadata about the session. Can be nullptr.
|
||||||
const SessionMetadata* session_metadata() const {
|
const SessionMetadata* session_metadata() const {
|
||||||
@ -1405,7 +1408,7 @@ Status SupportedDeviceTypesForNode(
|
|||||||
|
|
||||||
// Returns a message with a description of the kernels registered for op
|
// Returns a message with a description of the kernels registered for op
|
||||||
// `op_name`.
|
// `op_name`.
|
||||||
string KernelsRegisteredForOp(StringPiece op_name);
|
std::string KernelsRegisteredForOp(StringPiece op_name);
|
||||||
|
|
||||||
// Call once after Op registration has completed.
|
// Call once after Op registration has completed.
|
||||||
Status ValidateKernelRegistrations(const OpRegistryInterface& op_registry);
|
Status ValidateKernelRegistrations(const OpRegistryInterface& op_registry);
|
||||||
@ -1497,13 +1500,13 @@ Status FindKernelDef(
|
|||||||
bool has_experimental_debug_info,
|
bool has_experimental_debug_info,
|
||||||
const NodeDef_ExperimentalDebugInfo& experimental_debug_info,
|
const NodeDef_ExperimentalDebugInfo& experimental_debug_info,
|
||||||
StringPiece node_op, StringPiece node_device, AttrSlice node_attrs,
|
StringPiece node_op, StringPiece node_device, AttrSlice node_attrs,
|
||||||
const KernelDef** def, string* kernel_class_name);
|
const KernelDef** def, std::string* kernel_class_name);
|
||||||
|
|
||||||
// If node_def has a corresponding kernel registered on device_type,
|
// If node_def has a corresponding kernel registered on device_type,
|
||||||
// returns OK and fill in the kernel def and kernel_class_name. <def> and
|
// returns OK and fill in the kernel def and kernel_class_name. <def> and
|
||||||
// <kernel_class_name> may be null.
|
// <kernel_class_name> may be null.
|
||||||
Status FindKernelDef(const DeviceType& device_type, const NodeDef& node_def,
|
Status FindKernelDef(const DeviceType& device_type, const NodeDef& node_def,
|
||||||
const KernelDef** def, string* kernel_class_name);
|
const KernelDef** def, std::string* kernel_class_name);
|
||||||
|
|
||||||
// Writes a list of all registered kernels to LOG(INFO), to help users debug
|
// Writes a list of all registered kernels to LOG(INFO), to help users debug
|
||||||
// missing kernel errors.
|
// missing kernel errors.
|
||||||
|
@ -46,8 +46,8 @@ class OpSegment {
|
|||||||
|
|
||||||
// A hold can be placed on a session, preventing all its kernels
|
// A hold can be placed on a session, preventing all its kernels
|
||||||
// from being deleted.
|
// from being deleted.
|
||||||
void AddHold(const string& session_handle);
|
void AddHold(const std::string& session_handle);
|
||||||
void RemoveHold(const string& session_handle);
|
void RemoveHold(const std::string& session_handle);
|
||||||
|
|
||||||
// If the kernel for "node_name" has been created in the
|
// If the kernel for "node_name" has been created in the
|
||||||
// "session_handle", returns the existing op kernel in "*kernel".
|
// "session_handle", returns the existing op kernel in "*kernel".
|
||||||
@ -57,12 +57,13 @@ class OpSegment {
|
|||||||
//
|
//
|
||||||
// OpSegment keeps the ownership of the returned "*kernel".
|
// OpSegment keeps the ownership of the returned "*kernel".
|
||||||
typedef std::function<Status(OpKernel**)> CreateKernelFn;
|
typedef std::function<Status(OpKernel**)> CreateKernelFn;
|
||||||
Status FindOrCreate(const string& session_handle, const string& node_name,
|
Status FindOrCreate(const std::string& session_handle,
|
||||||
OpKernel** kernel, CreateKernelFn create_fn);
|
const std::string& node_name, OpKernel** kernel,
|
||||||
|
CreateKernelFn create_fn);
|
||||||
|
|
||||||
// Returns true if OpSegment should own the kernel.
|
// Returns true if OpSegment should own the kernel.
|
||||||
static bool ShouldOwnKernel(FunctionLibraryRuntime* lib,
|
static bool ShouldOwnKernel(FunctionLibraryRuntime* lib,
|
||||||
const string& node_op);
|
const std::string& node_op);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// op name -> OpKernel
|
// op name -> OpKernel
|
||||||
|
@ -81,7 +81,7 @@ bool IsDim0SliceAligned(const TensorShape& s, int64 start, int64 end_or_size) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Returns <suffix> sanitized to have only [a-zA-Z0-9-_].
|
// Returns <suffix> sanitized to have only [a-zA-Z0-9-_].
|
||||||
string SanitizeThreadSuffix(string suffix);
|
std::string SanitizeThreadSuffix(std::string suffix);
|
||||||
|
|
||||||
// Helper to compute 'strides' given a tensor 'shape'. I.e.,
|
// Helper to compute 'strides' given a tensor 'shape'. I.e.,
|
||||||
// strides[i] = prod(shape.dim_size[(i+1):])
|
// strides[i] = prod(shape.dim_size[(i+1):])
|
||||||
|
@ -74,7 +74,7 @@ class RendezvousInterface {
|
|||||||
friend class Rendezvous;
|
friend class Rendezvous;
|
||||||
friend class SendOp;
|
friend class SendOp;
|
||||||
friend class RecvOp;
|
friend class RecvOp;
|
||||||
string buf_;
|
std::string buf_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// The caller is a tensor producer and it sends a message (a tensor
|
// The caller is a tensor producer and it sends a message (a tensor
|
||||||
@ -169,9 +169,11 @@ class Rendezvous : public RendezvousInterface, public core::RefCounted {
|
|||||||
// Constructs a rendezvous key for the tensor of "name" sent from
|
// Constructs a rendezvous key for the tensor of "name" sent from
|
||||||
// "src_device" to "dst_device". The tensor is generated in the frame
|
// "src_device" to "dst_device". The tensor is generated in the frame
|
||||||
// and iteration specified by "frame_iter".
|
// and iteration specified by "frame_iter".
|
||||||
static string CreateKey(const string& src_device, uint64 src_incarnation,
|
static std::string CreateKey(const std::string& src_device,
|
||||||
const string& dst_device, const string& name,
|
uint64 src_incarnation,
|
||||||
const FrameAndIter& frame_iter);
|
const std::string& dst_device,
|
||||||
|
const std::string& name,
|
||||||
|
const FrameAndIter& frame_iter);
|
||||||
|
|
||||||
static Status ParseKey(StringPiece key, ParsedKey* out);
|
static Status ParseKey(StringPiece key, ParsedKey* out);
|
||||||
};
|
};
|
||||||
|
@ -79,7 +79,7 @@ namespace tensorflow {
|
|||||||
class ResourceBase : public core::RefCounted {
|
class ResourceBase : public core::RefCounted {
|
||||||
public:
|
public:
|
||||||
// Returns a debug string for *this.
|
// Returns a debug string for *this.
|
||||||
virtual string DebugString() const = 0;
|
virtual std::string DebugString() const = 0;
|
||||||
|
|
||||||
// Returns memory used by this resource.
|
// Returns memory used by this resource.
|
||||||
virtual int64 MemoryUsed() const { return 0; }
|
virtual int64 MemoryUsed() const { return 0; }
|
||||||
@ -100,7 +100,7 @@ class ScopedStepContainer {
|
|||||||
|
|
||||||
ScopedStepContainer(const int64 step_id,
|
ScopedStepContainer(const int64 step_id,
|
||||||
std::function<void(const string&)> cleanup,
|
std::function<void(const string&)> cleanup,
|
||||||
const string& prefix)
|
const std::string& prefix)
|
||||||
: container_(strings::StrCat("__", prefix, "_per_step_", step_id)),
|
: container_(strings::StrCat("__", prefix, "_per_step_", step_id)),
|
||||||
cleanup_(cleanup),
|
cleanup_(cleanup),
|
||||||
dirty_(false) {}
|
dirty_(false) {}
|
||||||
@ -125,25 +125,25 @@ class ScopedStepContainer {
|
|||||||
// Pass through to MakeResourceHandle with the container name
|
// Pass through to MakeResourceHandle with the container name
|
||||||
template <typename T>
|
template <typename T>
|
||||||
ResourceHandle MakeResourceHandle(
|
ResourceHandle MakeResourceHandle(
|
||||||
const string& name, const DeviceBase& device) TF_MUST_USE_RESULT;
|
const std::string& name, const DeviceBase& device) TF_MUST_USE_RESULT;
|
||||||
// Pass through to ResourceMgr::Create with the container name
|
// Pass through to ResourceMgr::Create with the container name
|
||||||
template <typename T>
|
template <typename T>
|
||||||
Status Create(ResourceMgr* rm, const string& name,
|
Status Create(ResourceMgr* rm, const std::string& name,
|
||||||
T* resource) TF_MUST_USE_RESULT;
|
T* resource) TF_MUST_USE_RESULT;
|
||||||
// Pass through to ResourceMgr::Delete with the container name
|
// Pass through to ResourceMgr::Delete with the container name
|
||||||
template <typename T>
|
template <typename T>
|
||||||
Status Delete(ResourceMgr* rm, const string& name) TF_MUST_USE_RESULT;
|
Status Delete(ResourceMgr* rm, const std::string& name) TF_MUST_USE_RESULT;
|
||||||
// Pass through to ResourceMgr::Lookup with the container name
|
// Pass through to ResourceMgr::Lookup with the container name
|
||||||
template <typename T>
|
template <typename T>
|
||||||
Status Lookup(ResourceMgr* rm, const string& name,
|
Status Lookup(ResourceMgr* rm, const std::string& name,
|
||||||
T** resource) const TF_MUST_USE_RESULT;
|
T** resource) const TF_MUST_USE_RESULT;
|
||||||
// Pass through to ResourceMgr::LookupOrCreate with the container name
|
// Pass through to ResourceMgr::LookupOrCreate with the container name
|
||||||
template <typename T>
|
template <typename T>
|
||||||
Status LookupOrCreate(ResourceMgr* rm, const string& name, T** resource,
|
Status LookupOrCreate(ResourceMgr* rm, const std::string& name, T** resource,
|
||||||
std::function<Status(T**)> creator) TF_MUST_USE_RESULT;
|
std::function<Status(T**)> creator) TF_MUST_USE_RESULT;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
const string container_;
|
const std::string container_;
|
||||||
const std::function<void(const string&)> cleanup_;
|
const std::function<void(const string&)> cleanup_;
|
||||||
mutex mu_;
|
mutex mu_;
|
||||||
mutable std::atomic<bool> dirty_ TF_GUARDED_BY(mu_);
|
mutable std::atomic<bool> dirty_ TF_GUARDED_BY(mu_);
|
||||||
@ -152,11 +152,11 @@ class ScopedStepContainer {
|
|||||||
class ResourceMgr {
|
class ResourceMgr {
|
||||||
public:
|
public:
|
||||||
ResourceMgr();
|
ResourceMgr();
|
||||||
explicit ResourceMgr(const string& default_container);
|
explicit ResourceMgr(const std::string& default_container);
|
||||||
~ResourceMgr();
|
~ResourceMgr();
|
||||||
|
|
||||||
// Returns the default container name for *this.
|
// Returns the default container name for *this.
|
||||||
const string& default_container() const { return default_container_; }
|
const std::string& default_container() const { return default_container_; }
|
||||||
|
|
||||||
// Creates a resource "name" in the "container". The caller transfers
|
// Creates a resource "name" in the "container". The caller transfers
|
||||||
// the ownership of one ref on "resource" to *this, regardless of whether this
|
// the ownership of one ref on "resource" to *this, regardless of whether this
|
||||||
@ -165,7 +165,7 @@ class ResourceMgr {
|
|||||||
// REQUIRES: std::is_base_of<ResourceBase, T>
|
// REQUIRES: std::is_base_of<ResourceBase, T>
|
||||||
// REQUIRES: resource != nullptr.
|
// REQUIRES: resource != nullptr.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
Status Create(const string& container, const string& name,
|
Status Create(const std::string& container, const std::string& name,
|
||||||
T* resource) TF_MUST_USE_RESULT;
|
T* resource) TF_MUST_USE_RESULT;
|
||||||
|
|
||||||
// If "container" has a resource "name", returns it in "*resource" and
|
// If "container" has a resource "name", returns it in "*resource" and
|
||||||
@ -174,7 +174,7 @@ class ResourceMgr {
|
|||||||
// REQUIRES: std::is_base_of<ResourceBase, T>
|
// REQUIRES: std::is_base_of<ResourceBase, T>
|
||||||
// REQUIRES: resource != nullptr
|
// REQUIRES: resource != nullptr
|
||||||
template <typename T, bool use_dynamic_cast = false>
|
template <typename T, bool use_dynamic_cast = false>
|
||||||
Status Lookup(const string& container, const string& name,
|
Status Lookup(const std::string& container, const std::string& name,
|
||||||
T** resource) const TF_MUST_USE_RESULT;
|
T** resource) const TF_MUST_USE_RESULT;
|
||||||
|
|
||||||
// Similar to Lookup, but looks up multiple resources at once, with only a
|
// Similar to Lookup, but looks up multiple resources at once, with only a
|
||||||
@ -197,7 +197,7 @@ class ResourceMgr {
|
|||||||
// REQUIRES: std::is_base_of<ResourceBase, T>
|
// REQUIRES: std::is_base_of<ResourceBase, T>
|
||||||
// REQUIRES: resource != nullptr
|
// REQUIRES: resource != nullptr
|
||||||
template <typename T, bool use_dynamic_cast = false>
|
template <typename T, bool use_dynamic_cast = false>
|
||||||
Status LookupOrCreate(const string& container, const string& name,
|
Status LookupOrCreate(const std::string& container, const std::string& name,
|
||||||
T** resource,
|
T** resource,
|
||||||
std::function<Status(T**)> creator) TF_MUST_USE_RESULT;
|
std::function<Status(T**)> creator) TF_MUST_USE_RESULT;
|
||||||
|
|
||||||
@ -205,19 +205,20 @@ class ResourceMgr {
|
|||||||
//
|
//
|
||||||
// REQUIRES: std::is_base_of<ResourceBase, T>
|
// REQUIRES: std::is_base_of<ResourceBase, T>
|
||||||
template <typename T>
|
template <typename T>
|
||||||
Status Delete(const string& container, const string& name) TF_MUST_USE_RESULT;
|
Status Delete(const std::string& container,
|
||||||
|
const std::string& name) TF_MUST_USE_RESULT;
|
||||||
|
|
||||||
// Deletes the resource pointed by "handle".
|
// Deletes the resource pointed by "handle".
|
||||||
Status Delete(const ResourceHandle& handle) TF_MUST_USE_RESULT;
|
Status Delete(const ResourceHandle& handle) TF_MUST_USE_RESULT;
|
||||||
|
|
||||||
// Deletes all resources from the "container" and removes the container.
|
// Deletes all resources from the "container" and removes the container.
|
||||||
Status Cleanup(const string& container) TF_MUST_USE_RESULT;
|
Status Cleanup(const std::string& container) TF_MUST_USE_RESULT;
|
||||||
|
|
||||||
// Deletes all resources in all containers.
|
// Deletes all resources in all containers.
|
||||||
void Clear();
|
void Clear();
|
||||||
|
|
||||||
// Returns a text description for all resources.
|
// Returns a text description for all resources.
|
||||||
string DebugString() const;
|
std::string DebugString() const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
typedef std::pair<uint64, StringPiece> Key;
|
typedef std::pair<uint64, StringPiece> Key;
|
||||||
@ -236,7 +237,7 @@ class ResourceMgr {
|
|||||||
std::unique_ptr<string> name;
|
std::unique_ptr<string> name;
|
||||||
|
|
||||||
ResourceAndName();
|
ResourceAndName();
|
||||||
ResourceAndName(ResourceBase* resource, string name);
|
ResourceAndName(ResourceBase* resource, std::string name);
|
||||||
ResourceAndName(ResourceAndName&& other) noexcept;
|
ResourceAndName(ResourceAndName&& other) noexcept;
|
||||||
~ResourceAndName();
|
~ResourceAndName();
|
||||||
|
|
||||||
@ -247,31 +248,31 @@ class ResourceMgr {
|
|||||||
};
|
};
|
||||||
typedef std::unordered_map<Key, ResourceAndName, KeyHash, KeyEqual> Container;
|
typedef std::unordered_map<Key, ResourceAndName, KeyHash, KeyEqual> Container;
|
||||||
|
|
||||||
const string default_container_;
|
const std::string default_container_;
|
||||||
mutable mutex mu_;
|
mutable mutex mu_;
|
||||||
std::unordered_map<string, Container*> containers_ TF_GUARDED_BY(mu_);
|
std::unordered_map<string, Container*> containers_ TF_GUARDED_BY(mu_);
|
||||||
|
|
||||||
template <typename T, bool use_dynamic_cast = false>
|
template <typename T, bool use_dynamic_cast = false>
|
||||||
Status LookupInternal(const string& container, const string& name,
|
Status LookupInternal(const std::string& container, const std::string& name,
|
||||||
T** resource) const
|
T** resource) const
|
||||||
TF_SHARED_LOCKS_REQUIRED(mu_) TF_MUST_USE_RESULT;
|
TF_SHARED_LOCKS_REQUIRED(mu_) TF_MUST_USE_RESULT;
|
||||||
|
|
||||||
Status DoCreate(const string& container, TypeIndex type, const string& name,
|
Status DoCreate(const std::string& container, TypeIndex type,
|
||||||
ResourceBase* resource)
|
const std::string& name, ResourceBase* resource)
|
||||||
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) TF_MUST_USE_RESULT;
|
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) TF_MUST_USE_RESULT;
|
||||||
|
|
||||||
Status DoLookup(const string& container, TypeIndex type, const string& name,
|
Status DoLookup(const std::string& container, TypeIndex type,
|
||||||
ResourceBase** resource) const
|
const std::string& name, ResourceBase** resource) const
|
||||||
TF_SHARED_LOCKS_REQUIRED(mu_) TF_MUST_USE_RESULT;
|
TF_SHARED_LOCKS_REQUIRED(mu_) TF_MUST_USE_RESULT;
|
||||||
|
|
||||||
Status DoDelete(const string& container, uint64 type_hash_code,
|
Status DoDelete(const std::string& container, uint64 type_hash_code,
|
||||||
const string& resource_name,
|
const std::string& resource_name,
|
||||||
const string& type_name) TF_MUST_USE_RESULT;
|
const std::string& type_name) TF_MUST_USE_RESULT;
|
||||||
Status DoDelete(const string& container, TypeIndex type,
|
Status DoDelete(const std::string& container, TypeIndex type,
|
||||||
const string& resource_name) TF_MUST_USE_RESULT;
|
const std::string& resource_name) TF_MUST_USE_RESULT;
|
||||||
|
|
||||||
// Inserts the type name for 'hash_code' into the hash_code to type name map.
|
// Inserts the type name for 'hash_code' into the hash_code to type name map.
|
||||||
Status InsertDebugTypeName(uint64 hash_code, const string& type_name)
|
Status InsertDebugTypeName(uint64 hash_code, const std::string& type_name)
|
||||||
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) TF_MUST_USE_RESULT;
|
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) TF_MUST_USE_RESULT;
|
||||||
|
|
||||||
// Returns the type name for the 'hash_code'.
|
// Returns the type name for the 'hash_code'.
|
||||||
@ -289,14 +290,14 @@ class ResourceMgr {
|
|||||||
// Makes a resource handle with the specified type for a given container /
|
// Makes a resource handle with the specified type for a given container /
|
||||||
// name.
|
// name.
|
||||||
ResourceHandle MakeResourceHandle(
|
ResourceHandle MakeResourceHandle(
|
||||||
const string& container, const string& name, const DeviceBase& device,
|
const std::string& container, const std::string& name,
|
||||||
const TypeIndex& type_index,
|
const DeviceBase& device, const TypeIndex& type_index,
|
||||||
const std::vector<DtypeAndPartialTensorShape>& dtypes_and_shapes = {})
|
const std::vector<DtypeAndPartialTensorShape>& dtypes_and_shapes = {})
|
||||||
TF_MUST_USE_RESULT;
|
TF_MUST_USE_RESULT;
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
ResourceHandle MakeResourceHandle(
|
ResourceHandle MakeResourceHandle(
|
||||||
OpKernelContext* ctx, const string& container, const string& name,
|
OpKernelContext* ctx, const std::string& container, const std::string& name,
|
||||||
const std::vector<DtypeAndPartialTensorShape>& dtypes_and_shapes = {}) {
|
const std::vector<DtypeAndPartialTensorShape>& dtypes_and_shapes = {}) {
|
||||||
return MakeResourceHandle(
|
return MakeResourceHandle(
|
||||||
container.empty() ? ctx->resource_manager()->default_container()
|
container.empty() ? ctx->resource_manager()->default_container()
|
||||||
@ -306,7 +307,8 @@ ResourceHandle MakeResourceHandle(
|
|||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
ResourceHandle MakeResourceHandle(
|
ResourceHandle MakeResourceHandle(
|
||||||
OpKernelConstruction* ctx, const string& container, const string& name,
|
OpKernelConstruction* ctx, const std::string& container,
|
||||||
|
const std::string& name,
|
||||||
const std::vector<DtypeAndPartialTensorShape>& dtypes_and_shapes = {}) {
|
const std::vector<DtypeAndPartialTensorShape>& dtypes_and_shapes = {}) {
|
||||||
return MakeResourceHandle(
|
return MakeResourceHandle(
|
||||||
container.empty() ? ctx->resource_manager()->default_container()
|
container.empty() ? ctx->resource_manager()->default_container()
|
||||||
@ -315,7 +317,8 @@ ResourceHandle MakeResourceHandle(
|
|||||||
}
|
}
|
||||||
|
|
||||||
Status MakeResourceHandleToOutput(OpKernelContext* context, int output_index,
|
Status MakeResourceHandleToOutput(OpKernelContext* context, int output_index,
|
||||||
const string& container, const string& name,
|
const std::string& container,
|
||||||
|
const std::string& name,
|
||||||
const TypeIndex& type_index);
|
const TypeIndex& type_index);
|
||||||
|
|
||||||
// Returns a resource handle from a numbered op input.
|
// Returns a resource handle from a numbered op input.
|
||||||
@ -409,19 +412,19 @@ class ContainerInfo {
|
|||||||
// name is name(). If resource_is_private_to_kernel() is true, the
|
// name is name(). If resource_is_private_to_kernel() is true, the
|
||||||
// kernel should delete the resource when the kernel is deleted.
|
// kernel should delete the resource when the kernel is deleted.
|
||||||
ResourceMgr* resource_manager() const { return rmgr_; }
|
ResourceMgr* resource_manager() const { return rmgr_; }
|
||||||
const string& container() const { return container_; }
|
const std::string& container() const { return container_; }
|
||||||
const string& name() const { return name_; }
|
const std::string& name() const { return name_; }
|
||||||
bool resource_is_private_to_kernel() const {
|
bool resource_is_private_to_kernel() const {
|
||||||
return resource_is_private_to_kernel_;
|
return resource_is_private_to_kernel_;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns a readable string for *this.
|
// Returns a readable string for *this.
|
||||||
string DebugString() const;
|
std::string DebugString() const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
ResourceMgr* rmgr_ = nullptr;
|
ResourceMgr* rmgr_ = nullptr;
|
||||||
string container_;
|
std::string container_;
|
||||||
string name_;
|
std::string name_;
|
||||||
bool resource_is_private_to_kernel_ = false;
|
bool resource_is_private_to_kernel_ = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -435,8 +438,8 @@ class ContainerInfo {
|
|||||||
// Returns OK if the resource is found and transfers one ref of
|
// Returns OK if the resource is found and transfers one ref of
|
||||||
// *resource to the caller. Otherwise, returns an error.
|
// *resource to the caller. Otherwise, returns an error.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
Status GetResourceFromContext(OpKernelContext* ctx, const string& input_name,
|
Status GetResourceFromContext(OpKernelContext* ctx,
|
||||||
T** resource);
|
const std::string& input_name, T** resource);
|
||||||
|
|
||||||
// Utility op kernel to check if a handle to resource type T is initialized.
|
// Utility op kernel to check if a handle to resource type T is initialized.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
@ -470,8 +473,8 @@ class ResourceHandleOp : public OpKernel {
|
|||||||
bool IsExpensive() override { return false; }
|
bool IsExpensive() override { return false; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
string container_;
|
std::string container_;
|
||||||
string name_;
|
std::string name_;
|
||||||
mutex mutex_;
|
mutex mutex_;
|
||||||
Tensor resource_;
|
Tensor resource_;
|
||||||
std::atomic<bool> initialized_{false};
|
std::atomic<bool> initialized_{false};
|
||||||
@ -584,8 +587,8 @@ void CheckDeriveFromResourceBase() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
Status ResourceMgr::Create(const string& container, const string& name,
|
Status ResourceMgr::Create(const std::string& container,
|
||||||
T* resource) {
|
const std::string& name, T* resource) {
|
||||||
CheckDeriveFromResourceBase<T>();
|
CheckDeriveFromResourceBase<T>();
|
||||||
CHECK(resource != nullptr);
|
CHECK(resource != nullptr);
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
@ -593,8 +596,8 @@ Status ResourceMgr::Create(const string& container, const string& name,
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, bool use_dynamic_cast>
|
template <typename T, bool use_dynamic_cast>
|
||||||
Status ResourceMgr::Lookup(const string& container, const string& name,
|
Status ResourceMgr::Lookup(const std::string& container,
|
||||||
T** resource) const {
|
const std::string& name, T** resource) const {
|
||||||
CheckDeriveFromResourceBase<T>();
|
CheckDeriveFromResourceBase<T>();
|
||||||
tf_shared_lock l(mu_);
|
tf_shared_lock l(mu_);
|
||||||
return LookupInternal<T, use_dynamic_cast>(container, name, resource);
|
return LookupInternal<T, use_dynamic_cast>(container, name, resource);
|
||||||
@ -632,7 +635,8 @@ struct TypeCastFunctor<T, true> {
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <typename T, bool use_dynamic_cast>
|
template <typename T, bool use_dynamic_cast>
|
||||||
Status ResourceMgr::LookupInternal(const string& container, const string& name,
|
Status ResourceMgr::LookupInternal(const std::string& container,
|
||||||
|
const std::string& name,
|
||||||
T** resource) const {
|
T** resource) const {
|
||||||
ResourceBase* found = nullptr;
|
ResourceBase* found = nullptr;
|
||||||
Status s = DoLookup(container, TypeIndex::Make<T>(), name, &found);
|
Status s = DoLookup(container, TypeIndex::Make<T>(), name, &found);
|
||||||
@ -645,8 +649,8 @@ Status ResourceMgr::LookupInternal(const string& container, const string& name,
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, bool use_dynamic_cast>
|
template <typename T, bool use_dynamic_cast>
|
||||||
Status ResourceMgr::LookupOrCreate(const string& container, const string& name,
|
Status ResourceMgr::LookupOrCreate(const std::string& container,
|
||||||
T** resource,
|
const std::string& name, T** resource,
|
||||||
std::function<Status(T**)> creator) {
|
std::function<Status(T**)> creator) {
|
||||||
CheckDeriveFromResourceBase<T>();
|
CheckDeriveFromResourceBase<T>();
|
||||||
*resource = nullptr;
|
*resource = nullptr;
|
||||||
@ -669,14 +673,15 @@ Status ResourceMgr::LookupOrCreate(const string& container, const string& name,
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
Status ResourceMgr::Delete(const string& container, const string& name) {
|
Status ResourceMgr::Delete(const std::string& container,
|
||||||
|
const std::string& name) {
|
||||||
CheckDeriveFromResourceBase<T>();
|
CheckDeriveFromResourceBase<T>();
|
||||||
return DoDelete(container, TypeIndex::Make<T>(), name);
|
return DoDelete(container, TypeIndex::Make<T>(), name);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
Status GetResourceFromContext(OpKernelContext* ctx, const string& input_name,
|
Status GetResourceFromContext(OpKernelContext* ctx,
|
||||||
T** resource) {
|
const std::string& input_name, T** resource) {
|
||||||
DataType dtype;
|
DataType dtype;
|
||||||
TF_RETURN_IF_ERROR(ctx->input_dtype(input_name, &dtype));
|
TF_RETURN_IF_ERROR(ctx->input_dtype(input_name, &dtype));
|
||||||
if (dtype == DT_RESOURCE) {
|
if (dtype == DT_RESOURCE) {
|
||||||
@ -684,8 +689,8 @@ Status GetResourceFromContext(OpKernelContext* ctx, const string& input_name,
|
|||||||
TF_RETURN_IF_ERROR(ctx->input(input_name, &handle));
|
TF_RETURN_IF_ERROR(ctx->input(input_name, &handle));
|
||||||
return LookupResource(ctx, handle->scalar<ResourceHandle>()(), resource);
|
return LookupResource(ctx, handle->scalar<ResourceHandle>()(), resource);
|
||||||
}
|
}
|
||||||
string container;
|
std::string container;
|
||||||
string shared_name;
|
std::string shared_name;
|
||||||
{
|
{
|
||||||
mutex* mu;
|
mutex* mu;
|
||||||
TF_RETURN_IF_ERROR(ctx->input_ref_mutex(input_name, &mu));
|
TF_RETURN_IF_ERROR(ctx->input_ref_mutex(input_name, &mu));
|
||||||
@ -879,7 +884,7 @@ void ResourceHandlesOp<T>::Compute(OpKernelContext* ctx) {
|
|||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
ResourceHandle ScopedStepContainer::MakeResourceHandle(
|
ResourceHandle ScopedStepContainer::MakeResourceHandle(
|
||||||
const string& name, const DeviceBase& device) {
|
const std::string& name, const DeviceBase& device) {
|
||||||
mutex_lock ml(mu_);
|
mutex_lock ml(mu_);
|
||||||
dirty_ = true;
|
dirty_ = true;
|
||||||
return tensorflow::MakeResourceHandle(container_, name, device,
|
return tensorflow::MakeResourceHandle(container_, name, device,
|
||||||
@ -887,13 +892,14 @@ ResourceHandle ScopedStepContainer::MakeResourceHandle(
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
Status ScopedStepContainer::Lookup(ResourceMgr* rm, const string& name,
|
Status ScopedStepContainer::Lookup(ResourceMgr* rm, const std::string& name,
|
||||||
T** resource) const {
|
T** resource) const {
|
||||||
return rm->Lookup<T>(container_, name, resource);
|
return rm->Lookup<T>(container_, name, resource);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
Status ScopedStepContainer::LookupOrCreate(ResourceMgr* rm, const string& name,
|
Status ScopedStepContainer::LookupOrCreate(ResourceMgr* rm,
|
||||||
|
const std::string& name,
|
||||||
T** resource,
|
T** resource,
|
||||||
std::function<Status(T**)> creator) {
|
std::function<Status(T**)> creator) {
|
||||||
mutex_lock ml(mu_);
|
mutex_lock ml(mu_);
|
||||||
@ -902,7 +908,7 @@ Status ScopedStepContainer::LookupOrCreate(ResourceMgr* rm, const string& name,
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
Status ScopedStepContainer::Create(ResourceMgr* rm, const string& name,
|
Status ScopedStepContainer::Create(ResourceMgr* rm, const std::string& name,
|
||||||
T* resource) {
|
T* resource) {
|
||||||
mutex_lock ml(mu_);
|
mutex_lock ml(mu_);
|
||||||
dirty_ = true;
|
dirty_ = true;
|
||||||
@ -910,7 +916,7 @@ Status ScopedStepContainer::Create(ResourceMgr* rm, const string& name,
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
Status ScopedStepContainer::Delete(ResourceMgr* rm, const string& name) {
|
Status ScopedStepContainer::Delete(ResourceMgr* rm, const std::string& name) {
|
||||||
return rm->Delete<T>(container_, name);
|
return rm->Delete<T>(container_, name);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -67,7 +67,7 @@ class Var : public ResourceBase {
|
|||||||
mutex* mu() { return &mu_; }
|
mutex* mu() { return &mu_; }
|
||||||
Tensor* tensor() { return &tensor_; }
|
Tensor* tensor() { return &tensor_; }
|
||||||
|
|
||||||
string DebugString() const override {
|
std::string DebugString() const override {
|
||||||
return strings::StrCat(DataTypeString(tensor_.dtype()), "/",
|
return strings::StrCat(DataTypeString(tensor_.dtype()), "/",
|
||||||
tensor_.shape().DebugString());
|
tensor_.shape().DebugString());
|
||||||
}
|
}
|
||||||
|
@ -31,13 +31,13 @@ namespace tensorflow {
|
|||||||
class SessionState {
|
class SessionState {
|
||||||
public:
|
public:
|
||||||
// Get a tensor from the session state.
|
// Get a tensor from the session state.
|
||||||
Status GetTensor(const string& handle, Tensor* tensor);
|
Status GetTensor(const std::string& handle, Tensor* tensor);
|
||||||
|
|
||||||
// Store a tensor in the session state.
|
// Store a tensor in the session state.
|
||||||
Status AddTensor(const string& handle, const Tensor& tensor);
|
Status AddTensor(const std::string& handle, const Tensor& tensor);
|
||||||
|
|
||||||
// Delete a tensor from the session state.
|
// Delete a tensor from the session state.
|
||||||
Status DeleteTensor(const string& handle);
|
Status DeleteTensor(const std::string& handle);
|
||||||
|
|
||||||
int64 GetNewId();
|
int64 GetNewId();
|
||||||
|
|
||||||
@ -60,15 +60,15 @@ class TensorStore {
|
|||||||
struct TensorAndKey {
|
struct TensorAndKey {
|
||||||
Tensor tensor;
|
Tensor tensor;
|
||||||
int64 id;
|
int64 id;
|
||||||
string device_name;
|
std::string device_name;
|
||||||
|
|
||||||
string GetHandle(const string& tensor_name) {
|
std::string GetHandle(const std::string& tensor_name) {
|
||||||
return strings::StrCat(tensor_name, ";", id, ";", device_name);
|
return strings::StrCat(tensor_name, ";", id, ";", device_name);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Add the named tensor to the tensor store for this run.
|
// Add the named tensor to the tensor store for this run.
|
||||||
Status AddTensor(const string& name, const TensorAndKey& tk);
|
Status AddTensor(const std::string& name, const TensorAndKey& tk);
|
||||||
|
|
||||||
// Save the tensors in the tensor store of this run to the session.
|
// Save the tensors in the tensor store of this run to the session.
|
||||||
Status SaveTensors(const std::vector<string>& output_names,
|
Status SaveTensors(const std::vector<string>& output_names,
|
||||||
|
@ -344,13 +344,13 @@ class InferenceContext {
|
|||||||
// incomplete shape.
|
// incomplete shape.
|
||||||
DimensionHandle NumElements(ShapeHandle s);
|
DimensionHandle NumElements(ShapeHandle s);
|
||||||
|
|
||||||
string DebugString(ShapeHandle s);
|
std::string DebugString(ShapeHandle s);
|
||||||
string DebugString(DimensionHandle d);
|
std::string DebugString(DimensionHandle d);
|
||||||
string DebugString(const ShapeAndType& shape_and_type);
|
std::string DebugString(const ShapeAndType& shape_and_type);
|
||||||
string DebugString(gtl::ArraySlice<ShapeAndType> shape_and_types);
|
std::string DebugString(gtl::ArraySlice<ShapeAndType> shape_and_types);
|
||||||
|
|
||||||
// Describes the whole context, for debugging purposes.
|
// Describes the whole context, for debugging purposes.
|
||||||
string DebugString() const;
|
std::string DebugString() const;
|
||||||
|
|
||||||
// If <shape> has rank <rank>, or its rank is unknown, return OK and return
|
// If <shape> has rank <rank>, or its rank is unknown, return OK and return
|
||||||
// the shape with asserted rank in <*out>. Otherwise return an error.
|
// the shape with asserted rank in <*out>. Otherwise return an error.
|
||||||
|
@ -54,7 +54,7 @@ struct AllocRecord {
|
|||||||
class TrackingAllocator : public Allocator {
|
class TrackingAllocator : public Allocator {
|
||||||
public:
|
public:
|
||||||
explicit TrackingAllocator(Allocator* allocator, bool track_ids);
|
explicit TrackingAllocator(Allocator* allocator, bool track_ids);
|
||||||
string Name() override { return allocator_->Name(); }
|
std::string Name() override { return allocator_->Name(); }
|
||||||
void* AllocateRaw(size_t alignment, size_t num_bytes) override {
|
void* AllocateRaw(size_t alignment, size_t num_bytes) override {
|
||||||
return AllocateRaw(alignment, num_bytes, AllocationAttributes());
|
return AllocateRaw(alignment, num_bytes, AllocationAttributes());
|
||||||
}
|
}
|
||||||
|
@ -32,10 +32,10 @@ limitations under the License.
|
|||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
string TypeNameVariant(const T& value);
|
std::string TypeNameVariant(const T& value);
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
string DebugStringVariant(const T& value);
|
std::string DebugStringVariant(const T& value);
|
||||||
|
|
||||||
// Allows for specializations of Variant Decoding. `data` may be modified in
|
// Allows for specializations of Variant Decoding. `data` may be modified in
|
||||||
// the process of decoding to `value`.
|
// the process of decoding to `value`.
|
||||||
@ -43,13 +43,13 @@ template <typename T>
|
|||||||
bool DecodeVariant(VariantTensorData* data, T* value);
|
bool DecodeVariant(VariantTensorData* data, T* value);
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
bool DecodeVariant(string* buf, T* value);
|
bool DecodeVariant(std::string* buf, T* value);
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void EncodeVariant(const T& value, VariantTensorData* data);
|
void EncodeVariant(const T& value, VariantTensorData* data);
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void EncodeVariant(const T& value, string* buf);
|
void EncodeVariant(const T& value, std::string* buf);
|
||||||
|
|
||||||
// This is an implementation of a type-erased container that can store an
|
// This is an implementation of a type-erased container that can store an
|
||||||
// object of any type. The implementation is very similar to std::any, but has
|
// object of any type. The implementation is very similar to std::any, but has
|
||||||
@ -234,7 +234,7 @@ class Variant {
|
|||||||
return GetValue()->TypeId();
|
return GetValue()->TypeId();
|
||||||
}
|
}
|
||||||
|
|
||||||
string DebugString() const {
|
std::string DebugString() const {
|
||||||
return strings::StrCat(
|
return strings::StrCat(
|
||||||
"Variant<type: ", TypeName(),
|
"Variant<type: ", TypeName(),
|
||||||
" value: ", is_empty() ? "[empty]" : GetValue()->DebugString(), ">");
|
" value: ", is_empty() ? "[empty]" : GetValue()->DebugString(), ">");
|
||||||
@ -264,7 +264,7 @@ class Variant {
|
|||||||
// In the special case that a serialized Variant is stored (value
|
// In the special case that a serialized Variant is stored (value
|
||||||
// is a VariantTensorDataProto), returns value.TypeName(), the
|
// is a VariantTensorDataProto), returns value.TypeName(), the
|
||||||
// TypeName field stored in the VariantTensorDataProto buffer.
|
// TypeName field stored in the VariantTensorDataProto buffer.
|
||||||
string TypeName() const {
|
std::string TypeName() const {
|
||||||
if (is_empty()) {
|
if (is_empty()) {
|
||||||
return "";
|
return "";
|
||||||
}
|
}
|
||||||
@ -282,12 +282,12 @@ class Variant {
|
|||||||
bool Decode(VariantTensorData data);
|
bool Decode(VariantTensorData data);
|
||||||
|
|
||||||
// Helper methods to directly serialize/deserialize from strings.
|
// Helper methods to directly serialize/deserialize from strings.
|
||||||
void Encode(string* buf) const {
|
void Encode(std::string* buf) const {
|
||||||
if (!is_empty()) {
|
if (!is_empty()) {
|
||||||
GetValue()->Encode(buf);
|
GetValue()->Encode(buf);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
bool Decode(string buf) {
|
bool Decode(std::string buf) {
|
||||||
if (!is_empty()) {
|
if (!is_empty()) {
|
||||||
return GetValue()->Decode(std::move(buf));
|
return GetValue()->Decode(std::move(buf));
|
||||||
}
|
}
|
||||||
@ -313,12 +313,12 @@ class Variant {
|
|||||||
virtual void CloneInto(ValueInterface* memory) const = 0;
|
virtual void CloneInto(ValueInterface* memory) const = 0;
|
||||||
virtual void MoveAssign(ValueInterface* memory) = 0;
|
virtual void MoveAssign(ValueInterface* memory) = 0;
|
||||||
virtual void MoveInto(ValueInterface* memory) = 0;
|
virtual void MoveInto(ValueInterface* memory) = 0;
|
||||||
virtual string TypeName() const = 0;
|
virtual std::string TypeName() const = 0;
|
||||||
virtual string DebugString() const = 0;
|
virtual std::string DebugString() const = 0;
|
||||||
virtual void Encode(VariantTensorData* data) const = 0;
|
virtual void Encode(VariantTensorData* data) const = 0;
|
||||||
virtual bool Decode(VariantTensorData data) = 0;
|
virtual bool Decode(VariantTensorData data) = 0;
|
||||||
virtual void Encode(string* buf) const = 0;
|
virtual void Encode(std::string* buf) const = 0;
|
||||||
virtual bool Decode(string data) = 0;
|
virtual bool Decode(std::string data) = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
@ -359,9 +359,9 @@ class Variant {
|
|||||||
new (memory) Value(InPlace(), std::move(value));
|
new (memory) Value(InPlace(), std::move(value));
|
||||||
}
|
}
|
||||||
|
|
||||||
string TypeName() const final { return TypeNameVariant(value); }
|
std::string TypeName() const final { return TypeNameVariant(value); }
|
||||||
|
|
||||||
string DebugString() const final { return DebugStringVariant(value); }
|
std::string DebugString() const final { return DebugStringVariant(value); }
|
||||||
|
|
||||||
void Encode(VariantTensorData* data) const final {
|
void Encode(VariantTensorData* data) const final {
|
||||||
EncodeVariant(value, data);
|
EncodeVariant(value, data);
|
||||||
@ -371,9 +371,9 @@ class Variant {
|
|||||||
return DecodeVariant(&data, &value);
|
return DecodeVariant(&data, &value);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Encode(string* buf) const final { EncodeVariant(value, buf); }
|
void Encode(std::string* buf) const final { EncodeVariant(value, buf); }
|
||||||
|
|
||||||
bool Decode(string buf) final { return DecodeVariant(&buf, &value); }
|
bool Decode(std::string buf) final { return DecodeVariant(&buf, &value); }
|
||||||
|
|
||||||
T value;
|
T value;
|
||||||
};
|
};
|
||||||
|
@ -105,7 +105,7 @@ bool DecodeVariantImpl(VariantTensorData data,
|
|||||||
TypeResolver<T, false /* is_pod */, false /* Tensor */,
|
TypeResolver<T, false /* is_pod */, false /* Tensor */,
|
||||||
true /* protobuf */>,
|
true /* protobuf */>,
|
||||||
T* value) {
|
T* value) {
|
||||||
string metadata;
|
std::string metadata;
|
||||||
data.get_metadata(&metadata);
|
data.get_metadata(&metadata);
|
||||||
return value->ParseFromString(std::move(metadata));
|
return value->ParseFromString(std::move(metadata));
|
||||||
}
|
}
|
||||||
@ -136,27 +136,27 @@ template <typename T, bool = has_type_name<typename std::decay<T>::type>::value,
|
|||||||
struct TypeNameResolver {};
|
struct TypeNameResolver {};
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
string TypeNameVariantImpl(const T& value,
|
std::string TypeNameVariantImpl(const T& value,
|
||||||
TypeNameResolver<T, true /* has_type_name */>) {
|
TypeNameResolver<T, true /* has_type_name */>) {
|
||||||
return value.TypeName();
|
return value.TypeName();
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
string TypeNameVariantImpl(
|
std::string TypeNameVariantImpl(
|
||||||
const T& value,
|
const T& value,
|
||||||
TypeNameResolver<T, false /* has_type_name */, true /* Tensor */>) {
|
TypeNameResolver<T, false /* has_type_name */, true /* Tensor */>) {
|
||||||
return "tensorflow::Tensor";
|
return "tensorflow::Tensor";
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
string TypeNameVariantImpl(
|
std::string TypeNameVariantImpl(
|
||||||
const T& value, TypeNameResolver<T, false /* has_type_name */,
|
const T& value, TypeNameResolver<T, false /* has_type_name */,
|
||||||
false /* Tensor */, true /* protobuf */>) {
|
false /* Tensor */, true /* protobuf */>) {
|
||||||
return value.GetTypeName();
|
return value.GetTypeName();
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
string TypeNameVariantImpl(
|
std::string TypeNameVariantImpl(
|
||||||
const T& value,
|
const T& value,
|
||||||
TypeNameResolver<T, false /* has_type_name */, false /* Tensor */,
|
TypeNameResolver<T, false /* has_type_name */, false /* Tensor */,
|
||||||
false /* protobuf */>) {
|
false /* protobuf */>) {
|
||||||
@ -164,7 +164,7 @@ string TypeNameVariantImpl(
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
string TypeNameVariant(const T& value) {
|
std::string TypeNameVariant(const T& value) {
|
||||||
return TypeNameVariantImpl(value, TypeNameResolver<T>());
|
return TypeNameVariantImpl(value, TypeNameResolver<T>());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -194,27 +194,27 @@ struct DebugStringResolver {};
|
|||||||
// TODO(ebrevdo): Expand DebugStringResolver to return TypeString if
|
// TODO(ebrevdo): Expand DebugStringResolver to return TypeString if
|
||||||
// there is no StrCat<T>() constructor.
|
// there is no StrCat<T>() constructor.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
string DebugStringVariantImpl(
|
std::string DebugStringVariantImpl(
|
||||||
const T& value, DebugStringResolver<T, true /* has_debug_string */>) {
|
const T& value, DebugStringResolver<T, true /* has_debug_string */>) {
|
||||||
return value.DebugString();
|
return value.DebugString();
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
string DebugStringVariantImpl(
|
std::string DebugStringVariantImpl(
|
||||||
const T& value, DebugStringResolver<T, false /* has_debug_string */,
|
const T& value, DebugStringResolver<T, false /* has_debug_string */,
|
||||||
true /* can_strcat */>) {
|
true /* can_strcat */>) {
|
||||||
return strings::StrCat(value);
|
return strings::StrCat(value);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
string DebugStringVariantImpl(
|
std::string DebugStringVariantImpl(
|
||||||
const T& value, DebugStringResolver<T, false /* has_debug_string */,
|
const T& value, DebugStringResolver<T, false /* has_debug_string */,
|
||||||
false /* can_strcat */>) {
|
false /* can_strcat */>) {
|
||||||
return "?";
|
return "?";
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
string DebugStringVariant(const T& value) {
|
std::string DebugStringVariant(const T& value) {
|
||||||
return DebugStringVariantImpl(value, DebugStringResolver<T>());
|
return DebugStringVariantImpl(value, DebugStringResolver<T>());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -230,7 +230,7 @@ bool DecodeVariant(VariantTensorData* data, T* value) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void EncodeVariant(const T& value, string* buf) {
|
void EncodeVariant(const T& value, std::string* buf) {
|
||||||
VariantTensorData data;
|
VariantTensorData data;
|
||||||
EncodeVariantImpl(value, TypeResolver<T>(), &data);
|
EncodeVariantImpl(value, TypeResolver<T>(), &data);
|
||||||
data.set_type_name(TypeNameVariant(value));
|
data.set_type_name(TypeNameVariant(value));
|
||||||
@ -239,7 +239,7 @@ void EncodeVariant(const T& value, string* buf) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
bool DecodeVariant(string* buf, T* value) {
|
bool DecodeVariant(std::string* buf, T* value) {
|
||||||
VariantTensorData data;
|
VariantTensorData data;
|
||||||
if (!data.ParseFromString(*buf)) return false;
|
if (!data.ParseFromString(*buf)) return false;
|
||||||
if (!DecodeVariantImpl(std::move(data), TypeResolver<T>(), value)) {
|
if (!DecodeVariantImpl(std::move(data), TypeResolver<T>(), value)) {
|
||||||
@ -250,7 +250,7 @@ bool DecodeVariant(string* buf, T* value) {
|
|||||||
|
|
||||||
// Specializations for VariantTensorDataProto
|
// Specializations for VariantTensorDataProto
|
||||||
template <>
|
template <>
|
||||||
string TypeNameVariant(const VariantTensorDataProto& value);
|
std::string TypeNameVariant(const VariantTensorDataProto& value);
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
void EncodeVariant(const VariantTensorDataProto& value,
|
void EncodeVariant(const VariantTensorDataProto& value,
|
||||||
@ -260,10 +260,10 @@ template <>
|
|||||||
bool DecodeVariant(VariantTensorData* data, VariantTensorDataProto* value);
|
bool DecodeVariant(VariantTensorData* data, VariantTensorDataProto* value);
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
void EncodeVariant(const VariantTensorDataProto& value, string* buf);
|
void EncodeVariant(const VariantTensorDataProto& value, std::string* buf);
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
bool DecodeVariant(string* buf, VariantTensorDataProto* value);
|
bool DecodeVariant(std::string* buf, VariantTensorDataProto* value);
|
||||||
|
|
||||||
// Encodes an array of Variant objects in to the given StringListEncoder.
|
// Encodes an array of Variant objects in to the given StringListEncoder.
|
||||||
// `variant_array` is assumed to point to an array of `n` Variant objects.
|
// `variant_array` is assumed to point to an array of `n` Variant objects.
|
||||||
|
@ -93,7 +93,7 @@ class UnaryVariantOpRegistry {
|
|||||||
AsyncVariantDeviceCopyFn;
|
AsyncVariantDeviceCopyFn;
|
||||||
|
|
||||||
// Add a decode function to the registry.
|
// Add a decode function to the registry.
|
||||||
void RegisterDecodeFn(const string& type_name,
|
void RegisterDecodeFn(const std::string& type_name,
|
||||||
const VariantDecodeFn& decode_fn);
|
const VariantDecodeFn& decode_fn);
|
||||||
|
|
||||||
// Returns nullptr if no decode function was found for the given TypeName.
|
// Returns nullptr if no decode function was found for the given TypeName.
|
||||||
@ -124,7 +124,7 @@ class UnaryVariantOpRegistry {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Add a unary op function to the registry.
|
// Add a unary op function to the registry.
|
||||||
void RegisterUnaryOpFn(VariantUnaryOp op, const string& device,
|
void RegisterUnaryOpFn(VariantUnaryOp op, const std::string& device,
|
||||||
const TypeIndex& type_index,
|
const TypeIndex& type_index,
|
||||||
const VariantUnaryOpFn& unary_op_fn) {
|
const VariantUnaryOpFn& unary_op_fn) {
|
||||||
VariantUnaryOpFn* existing = GetUnaryOpFn(op, device, type_index);
|
VariantUnaryOpFn* existing = GetUnaryOpFn(op, device, type_index);
|
||||||
@ -146,7 +146,7 @@ class UnaryVariantOpRegistry {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Add a binary op function to the registry.
|
// Add a binary op function to the registry.
|
||||||
void RegisterBinaryOpFn(VariantBinaryOp op, const string& device,
|
void RegisterBinaryOpFn(VariantBinaryOp op, const std::string& device,
|
||||||
const TypeIndex& type_index,
|
const TypeIndex& type_index,
|
||||||
const VariantBinaryOpFn& add_fn) {
|
const VariantBinaryOpFn& add_fn) {
|
||||||
VariantBinaryOpFn* existing = GetBinaryOpFn(op, device, type_index);
|
VariantBinaryOpFn* existing = GetBinaryOpFn(op, device, type_index);
|
||||||
@ -252,7 +252,7 @@ class UnaryVariantOpRegistry {
|
|||||||
// Find or insert a string into a persistent string storage
|
// Find or insert a string into a persistent string storage
|
||||||
// container; return the StringPiece pointing to the permanent string
|
// container; return the StringPiece pointing to the permanent string
|
||||||
// location.
|
// location.
|
||||||
static StringPiece GetPersistentStringPiece(const string& str) {
|
static StringPiece GetPersistentStringPiece(const std::string& str) {
|
||||||
const auto string_storage = PersistentStringStorage();
|
const auto string_storage = PersistentStringStorage();
|
||||||
auto found = string_storage->find(str);
|
auto found = string_storage->find(str);
|
||||||
if (found == string_storage->end()) {
|
if (found == string_storage->end()) {
|
||||||
@ -307,7 +307,7 @@ Status VariantDeviceCopy(
|
|||||||
template <typename Device>
|
template <typename Device>
|
||||||
Status UnaryOpVariant(OpKernelContext* ctx, VariantUnaryOp op, const Variant& v,
|
Status UnaryOpVariant(OpKernelContext* ctx, VariantUnaryOp op, const Variant& v,
|
||||||
Variant* v_out) {
|
Variant* v_out) {
|
||||||
const string& device = DeviceName<Device>::value;
|
const std::string& device = DeviceName<Device>::value;
|
||||||
UnaryVariantOpRegistry::VariantUnaryOpFn* unary_op_fn =
|
UnaryVariantOpRegistry::VariantUnaryOpFn* unary_op_fn =
|
||||||
UnaryVariantOpRegistry::Global()->GetUnaryOpFn(op, device, v.TypeId());
|
UnaryVariantOpRegistry::Global()->GetUnaryOpFn(op, device, v.TypeId());
|
||||||
if (unary_op_fn == nullptr) {
|
if (unary_op_fn == nullptr) {
|
||||||
@ -336,7 +336,7 @@ Status BinaryOpVariants(OpKernelContext* ctx, VariantBinaryOp op,
|
|||||||
"type ids. Type names: '",
|
"type ids. Type names: '",
|
||||||
a.TypeName(), "' vs. '", b.TypeName(), "'");
|
a.TypeName(), "' vs. '", b.TypeName(), "'");
|
||||||
}
|
}
|
||||||
const string& device = DeviceName<Device>::value;
|
const std::string& device = DeviceName<Device>::value;
|
||||||
UnaryVariantOpRegistry::VariantBinaryOpFn* binary_op_fn =
|
UnaryVariantOpRegistry::VariantBinaryOpFn* binary_op_fn =
|
||||||
UnaryVariantOpRegistry::Global()->GetBinaryOpFn(op, device, a.TypeId());
|
UnaryVariantOpRegistry::Global()->GetBinaryOpFn(op, device, a.TypeId());
|
||||||
if (binary_op_fn == nullptr) {
|
if (binary_op_fn == nullptr) {
|
||||||
@ -354,7 +354,7 @@ namespace variant_op_registry_fn_registration {
|
|||||||
template <typename T>
|
template <typename T>
|
||||||
class UnaryVariantDecodeRegistration {
|
class UnaryVariantDecodeRegistration {
|
||||||
public:
|
public:
|
||||||
UnaryVariantDecodeRegistration(const string& type_name) {
|
UnaryVariantDecodeRegistration(const std::string& type_name) {
|
||||||
// The Variant is passed by pointer because it should be
|
// The Variant is passed by pointer because it should be
|
||||||
// mutable: get below may Decode the variant, which
|
// mutable: get below may Decode the variant, which
|
||||||
// is a self-mutating behavior. The variant is not modified in
|
// is a self-mutating behavior. The variant is not modified in
|
||||||
@ -386,7 +386,8 @@ class UnaryVariantDeviceCopyRegistration {
|
|||||||
UnaryVariantDeviceCopyRegistration(
|
UnaryVariantDeviceCopyRegistration(
|
||||||
const VariantDeviceCopyDirection direction, const TypeIndex& type_index,
|
const VariantDeviceCopyDirection direction, const TypeIndex& type_index,
|
||||||
const LocalVariantDeviceCopyFn& device_copy_fn) {
|
const LocalVariantDeviceCopyFn& device_copy_fn) {
|
||||||
const string type_index_name = port::MaybeAbiDemangle(type_index.name());
|
const std::string type_index_name =
|
||||||
|
port::MaybeAbiDemangle(type_index.name());
|
||||||
UnaryVariantOpRegistry::Global()->RegisterDeviceCopyFn(
|
UnaryVariantOpRegistry::Global()->RegisterDeviceCopyFn(
|
||||||
direction, type_index,
|
direction, type_index,
|
||||||
[type_index_name, device_copy_fn](
|
[type_index_name, device_copy_fn](
|
||||||
@ -413,10 +414,11 @@ class UnaryVariantUnaryOpRegistration {
|
|||||||
LocalVariantUnaryOpFn;
|
LocalVariantUnaryOpFn;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
UnaryVariantUnaryOpRegistration(VariantUnaryOp op, const string& device,
|
UnaryVariantUnaryOpRegistration(VariantUnaryOp op, const std::string& device,
|
||||||
const TypeIndex& type_index,
|
const TypeIndex& type_index,
|
||||||
const LocalVariantUnaryOpFn& unary_op_fn) {
|
const LocalVariantUnaryOpFn& unary_op_fn) {
|
||||||
const string type_index_name = port::MaybeAbiDemangle(type_index.name());
|
const std::string type_index_name =
|
||||||
|
port::MaybeAbiDemangle(type_index.name());
|
||||||
UnaryVariantOpRegistry::Global()->RegisterUnaryOpFn(
|
UnaryVariantOpRegistry::Global()->RegisterUnaryOpFn(
|
||||||
op, device, type_index,
|
op, device, type_index,
|
||||||
[type_index_name, unary_op_fn](OpKernelContext* ctx, const Variant& v,
|
[type_index_name, unary_op_fn](OpKernelContext* ctx, const Variant& v,
|
||||||
@ -442,10 +444,12 @@ class UnaryVariantBinaryOpRegistration {
|
|||||||
LocalVariantBinaryOpFn;
|
LocalVariantBinaryOpFn;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
UnaryVariantBinaryOpRegistration(VariantBinaryOp op, const string& device,
|
UnaryVariantBinaryOpRegistration(VariantBinaryOp op,
|
||||||
|
const std::string& device,
|
||||||
const TypeIndex& type_index,
|
const TypeIndex& type_index,
|
||||||
const LocalVariantBinaryOpFn& binary_op_fn) {
|
const LocalVariantBinaryOpFn& binary_op_fn) {
|
||||||
const string type_index_name = port::MaybeAbiDemangle(type_index.name());
|
const std::string type_index_name =
|
||||||
|
port::MaybeAbiDemangle(type_index.name());
|
||||||
UnaryVariantOpRegistry::Global()->RegisterBinaryOpFn(
|
UnaryVariantOpRegistry::Global()->RegisterBinaryOpFn(
|
||||||
op, device, type_index,
|
op, device, type_index,
|
||||||
[type_index_name, binary_op_fn](OpKernelContext* ctx, const Variant& a,
|
[type_index_name, binary_op_fn](OpKernelContext* ctx, const Variant& a,
|
||||||
|
@ -44,8 +44,8 @@ class VariantTensorData {
|
|||||||
VariantTensorData(VariantTensorDataProto proto);
|
VariantTensorData(VariantTensorDataProto proto);
|
||||||
|
|
||||||
// Name of the type of objects being serialized.
|
// Name of the type of objects being serialized.
|
||||||
const string& type_name() const { return type_name_; }
|
const std::string& type_name() const { return type_name_; }
|
||||||
void set_type_name(const string& type_name) { type_name_ = type_name; }
|
void set_type_name(const std::string& type_name) { type_name_ = type_name; }
|
||||||
|
|
||||||
template <typename T, bool = std::is_pod<typename std::decay<T>::type>::value>
|
template <typename T, bool = std::is_pod<typename std::decay<T>::type>::value>
|
||||||
struct PODResolver {};
|
struct PODResolver {};
|
||||||
@ -62,9 +62,9 @@ class VariantTensorData {
|
|||||||
return GetMetadata<T>(value, PODResolver<T>());
|
return GetMetadata<T>(value, PODResolver<T>());
|
||||||
}
|
}
|
||||||
|
|
||||||
string& metadata_string() { return metadata_; }
|
std::string& metadata_string() { return metadata_; }
|
||||||
|
|
||||||
const string& metadata_string() const { return metadata_; }
|
const std::string& metadata_string() const { return metadata_; }
|
||||||
|
|
||||||
// Tensors contained within objects being serialized.
|
// Tensors contained within objects being serialized.
|
||||||
int tensors_size() const;
|
int tensors_size() const;
|
||||||
@ -84,25 +84,27 @@ class VariantTensorData {
|
|||||||
bool FromConstProto(const VariantTensorDataProto& proto);
|
bool FromConstProto(const VariantTensorDataProto& proto);
|
||||||
|
|
||||||
// Serialization via VariantTensorDataProto
|
// Serialization via VariantTensorDataProto
|
||||||
string SerializeAsString() const;
|
std::string SerializeAsString() const;
|
||||||
bool SerializeToString(string* buf);
|
bool SerializeToString(std::string* buf);
|
||||||
bool ParseFromString(string s);
|
bool ParseFromString(std::string s);
|
||||||
|
|
||||||
string DebugString() const;
|
std::string DebugString() const;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
string type_name_;
|
std::string type_name_;
|
||||||
string metadata_;
|
std::string metadata_;
|
||||||
std::vector<Tensor> tensors_;
|
std::vector<Tensor> tensors_;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void SetMetadata(const string& value, PODResolver<T, false /* is_pod */>) {
|
void SetMetadata(const std::string& value,
|
||||||
|
PODResolver<T, false /* is_pod */>) {
|
||||||
metadata_ = value;
|
metadata_ = value;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
bool GetMetadata(string* value, PODResolver<T, false /* is_pod */>) const {
|
bool GetMetadata(std::string* value,
|
||||||
|
PODResolver<T, false /* is_pod */>) const {
|
||||||
*value = metadata_;
|
*value = metadata_;
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
@ -121,7 +123,7 @@ class VariantTensorData {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// For backwards compatibility for when this was a proto
|
// For backwards compatibility for when this was a proto
|
||||||
string ProtoDebugString(const VariantTensorData& object);
|
std::string ProtoDebugString(const VariantTensorData& object);
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user