Use string_view type for FormatFromString and FilterFormatFromString arguments
PiperOrigin-RevId: 325391279 Change-Id: If834446e5eac71840b92e6efcd6eac5170579769
This commit is contained in:
parent
f10eb870f1
commit
5e0ed38eb7
tensorflow
@ -339,7 +339,8 @@ void BatchToSpaceOp::getCanonicalizationPatterns(
|
||||
// are not unknown.
|
||||
//
|
||||
static LogicalResult Verify(BiasAddOp op) {
|
||||
std::string data_format = op.data_format().str();
|
||||
absl::string_view data_format(op.data_format().data(),
|
||||
op.data_format().size());
|
||||
tensorflow::TensorFormat format;
|
||||
bool is_valid = FormatFromString(data_format, &format);
|
||||
DCHECK(is_valid) << data_format;
|
||||
@ -385,7 +386,8 @@ static LogicalResult Verify(BiasAddOp op) {
|
||||
// * the out_backprop operands have valid ranks or are unranked.
|
||||
//
|
||||
static LogicalResult Verify(BiasAddGradOp op) {
|
||||
std::string data_format = op.data_format().str();
|
||||
absl::string_view data_format(op.data_format().data(),
|
||||
op.data_format().size());
|
||||
tensorflow::TensorFormat format;
|
||||
bool is_valid = FormatFromString(data_format, &format);
|
||||
DCHECK(is_valid) << data_format;
|
||||
@ -995,7 +997,8 @@ static LogicalResult Verify(OpT op) {
|
||||
|
||||
int64_t input_channels = -1;
|
||||
if (auto ty = op.input().getType().template dyn_cast<RankedTensorType>()) {
|
||||
std::string data_format = op.data_format().str();
|
||||
absl::string_view data_format(op.data_format().data(),
|
||||
op.data_format().size());
|
||||
tensorflow::TensorFormat format;
|
||||
auto is_valid = FormatFromString(data_format, &format);
|
||||
DCHECK(is_valid) << data_format;
|
||||
|
@ -519,6 +519,7 @@ cc_library(
|
||||
"//tensorflow/core/lib/gtl:array_slice",
|
||||
"//tensorflow/core/lib/gtl:inlined_vector",
|
||||
"//tensorflow/core/platform:types",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -73,7 +73,7 @@ string ToString(FilterTensorFormat format) {
|
||||
}
|
||||
}
|
||||
|
||||
bool FormatFromString(const string& format_str, TensorFormat* format) {
|
||||
bool FormatFromString(absl::string_view format_str, TensorFormat* format) {
|
||||
if (format_str == "NHWC" || format_str == "NDHWC") {
|
||||
*format = FORMAT_NHWC;
|
||||
return true;
|
||||
@ -101,7 +101,7 @@ bool FormatFromString(const string& format_str, TensorFormat* format) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool FilterFormatFromString(const string& format_str,
|
||||
bool FilterFormatFromString(absl::string_view format_str,
|
||||
FilterTensorFormat* format) {
|
||||
if (format_str == "HWIO" || format_str == "DHWIO") {
|
||||
*format = FORMAT_HWIO;
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include <array>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||
@ -97,11 +98,11 @@ enum FilterTensorFormat {
|
||||
|
||||
// Parse tensor format from the given string.
|
||||
// Return true if the parsing succeeds, and false if it fails.
|
||||
bool FormatFromString(const std::string& format_str, TensorFormat* format);
|
||||
bool FormatFromString(absl::string_view format_str, TensorFormat* format);
|
||||
|
||||
// Parse tensor format from the given string.
|
||||
// Return true if the parsing succeeds, and false if it fails.
|
||||
bool FilterFormatFromString(const std::string& format_str,
|
||||
bool FilterFormatFromString(absl::string_view format_str,
|
||||
FilterTensorFormat* format);
|
||||
|
||||
// Convert a tensor format into string.
|
||||
|
Loading…
Reference in New Issue
Block a user