Create an empty test for python_op_gen.
PiperOrigin-RevId: 317159004 Change-Id: Iabc4810f9c51c257d62dc3e7bd20d96131939d5d
This commit is contained in:
parent
629e6077d0
commit
ebe063eb74
@ -3,7 +3,7 @@
|
||||
# ":platform" - Low-level and platform-specific Python code.
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "py_strict_library")
|
||||
load("//tensorflow:tensorflow.bzl", "cc_header_only_library", "if_mlir", "if_not_windows", "if_xla_available", "py_test", "py_tests", "tf_cc_shared_object", "tf_cuda_library", "tf_gen_op_wrapper_py")
|
||||
load("//tensorflow:tensorflow.bzl", "cc_header_only_library", "if_mlir", "if_not_windows", "if_xla_available", "py_test", "py_tests", "tf_cc_shared_object", "tf_cc_test", "tf_cuda_library", "tf_gen_op_wrapper_py")
|
||||
|
||||
# buildifier: disable=same-origin-load
|
||||
load("//tensorflow:tensorflow.bzl", "tf_monitoring_python_deps")
|
||||
@ -1236,6 +1236,19 @@ cc_library(
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "python_op_gen_test",
|
||||
srcs = ["framework/python_op_gen_test.cc"],
|
||||
deps = [
|
||||
":python_op_gen",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:op_gen_lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "framework_for_generated_wrappers",
|
||||
srcs_version = "PY2AND3",
|
||||
|
@ -981,9 +981,9 @@ void GenEagerPythonOp::AddRawOpExport(const string& parameters) {
|
||||
function_name_, "))\n");
|
||||
}
|
||||
|
||||
string GetPythonOps(const OpList& ops, const ApiDefMap& api_defs,
|
||||
const std::vector<string>& hidden_ops,
|
||||
const string& source_file_name = "") {
|
||||
string GetPythonOpsImpl(const OpList& ops, const ApiDefMap& api_defs,
|
||||
const std::vector<string>& hidden_ops,
|
||||
const string& source_file_name = "") {
|
||||
string result;
|
||||
// Header
|
||||
// TODO(josh11b): Mention the library for which wrappers are being generated.
|
||||
@ -1069,11 +1069,17 @@ from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
} // namespace
|
||||
|
||||
string GetPythonOps(const OpList& ops, const ApiDefMap& api_defs,
|
||||
const std::vector<string>& hidden_ops,
|
||||
const string& source_file_name) {
|
||||
return GetPythonOpsImpl(ops, api_defs, hidden_ops, source_file_name);
|
||||
}
|
||||
|
||||
void PrintPythonOps(const OpList& ops, const ApiDefMap& api_defs,
|
||||
const std::vector<string>& hidden_ops,
|
||||
const string& source_file_name) {
|
||||
printf("%s",
|
||||
GetPythonOps(ops, api_defs, hidden_ops, source_file_name).c_str());
|
||||
GetPythonOpsImpl(ops, api_defs, hidden_ops, source_file_name).c_str());
|
||||
}
|
||||
|
||||
string GetPythonWrappers(const char* op_list_buf, size_t op_list_len) {
|
||||
@ -1081,7 +1087,7 @@ string GetPythonWrappers(const char* op_list_buf, size_t op_list_len) {
|
||||
ops.ParseFromArray(op_list_buf, op_list_len);
|
||||
|
||||
ApiDefMap api_def_map(ops);
|
||||
return GetPythonOps(ops, api_def_map, {});
|
||||
return GetPythonOpsImpl(ops, api_def_map, {});
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -23,8 +23,20 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Returns a string containing the generated Python code for the given Ops.
|
||||
// ops is a protobuff, typically generated using OpRegistry::Global()->Export.
|
||||
// api_defs is typically constructed directly from ops.
|
||||
// hidden_ops should be a list of Op names that should get a leading _
|
||||
// in the output. Prints the output to stdout.
|
||||
// in the output.
|
||||
// source_file_name is optional and contains the name of the original C++ source
|
||||
// file where the ops' REGISTER_OP() calls reside.
|
||||
string GetPythonOps(const OpList& ops, const ApiDefMap& api_defs,
|
||||
const std::vector<string>& hidden_ops,
|
||||
const string& source_file_name);
|
||||
|
||||
// Prints the output of GetPrintOps to stdout.
|
||||
// hidden_ops should be a list of Op names that should get a leading _
|
||||
// in the output.
|
||||
// Optional fourth argument is the name of the original C++ source file
|
||||
// where the ops' REGISTER_OP() calls reside.
|
||||
void PrintPythonOps(const OpList& ops, const ApiDefMap& api_defs,
|
||||
|
42
tensorflow/python/framework/python_op_gen_test.cc
Normal file
42
tensorflow/python/framework/python_op_gen_test.cc
Normal file
@ -0,0 +1,42 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/python/framework/python_op_gen.h"
|
||||
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_def.pb.h"
|
||||
#include "tensorflow/core/framework/op_gen_lib.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
TEST(PythonOpGen, Basic) {
|
||||
OpList ops;
|
||||
OpRegistry::Global()->Export(false, &ops);
|
||||
|
||||
ApiDefMap api_def_map(ops);
|
||||
|
||||
string code = GetPythonOps(ops, api_def_map, {}, "");
|
||||
|
||||
EXPECT_TRUE(absl::StrContains(code, "def case"));
|
||||
|
||||
// TODO(mdan): Add tests to verify type annotations are correctly added.
|
||||
}
|
||||
|
||||
// TODO(mdan): Include more tests with synhtetic ops and api defs.
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
Loading…
Reference in New Issue
Block a user