Create an empty test for python_op_gen.

PiperOrigin-RevId: 317159004
Change-Id: Iabc4810f9c51c257d62dc3e7bd20d96131939d5d
This commit is contained in:
Dan Moldovan 2020-06-18 12:38:59 -07:00 committed by TensorFlower Gardener
parent 629e6077d0
commit ebe063eb74
4 changed files with 80 additions and 7 deletions

View File

@ -3,7 +3,7 @@
# ":platform" - Low-level and platform-specific Python code.
load("//tensorflow:tensorflow.bzl", "py_strict_library")
load("//tensorflow:tensorflow.bzl", "cc_header_only_library", "if_mlir", "if_not_windows", "if_xla_available", "py_test", "py_tests", "tf_cc_shared_object", "tf_cuda_library", "tf_gen_op_wrapper_py")
load("//tensorflow:tensorflow.bzl", "cc_header_only_library", "if_mlir", "if_not_windows", "if_xla_available", "py_test", "py_tests", "tf_cc_shared_object", "tf_cc_test", "tf_cuda_library", "tf_gen_op_wrapper_py")
# buildifier: disable=same-origin-load
load("//tensorflow:tensorflow.bzl", "tf_monitoring_python_deps")
@ -1236,6 +1236,19 @@ cc_library(
alwayslink = 1,
)
tf_cc_test(
name = "python_op_gen_test",
srcs = ["framework/python_op_gen_test.cc"],
deps = [
":python_op_gen",
"//tensorflow/core:framework",
"//tensorflow/core:op_gen_lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)
py_library(
name = "framework_for_generated_wrappers",
srcs_version = "PY2AND3",

View File

@ -981,9 +981,9 @@ void GenEagerPythonOp::AddRawOpExport(const string& parameters) {
function_name_, "))\n");
}
string GetPythonOps(const OpList& ops, const ApiDefMap& api_defs,
const std::vector<string>& hidden_ops,
const string& source_file_name = "") {
string GetPythonOpsImpl(const OpList& ops, const ApiDefMap& api_defs,
const std::vector<string>& hidden_ops,
const string& source_file_name = "") {
string result;
// Header
// TODO(josh11b): Mention the library for which wrappers are being generated.
@ -1069,11 +1069,17 @@ from tensorflow.python.util.tf_export import tf_export
} // namespace
string GetPythonOps(const OpList& ops, const ApiDefMap& api_defs,
const std::vector<string>& hidden_ops,
const string& source_file_name) {
return GetPythonOpsImpl(ops, api_defs, hidden_ops, source_file_name);
}
void PrintPythonOps(const OpList& ops, const ApiDefMap& api_defs,
const std::vector<string>& hidden_ops,
const string& source_file_name) {
printf("%s",
GetPythonOps(ops, api_defs, hidden_ops, source_file_name).c_str());
GetPythonOpsImpl(ops, api_defs, hidden_ops, source_file_name).c_str());
}
string GetPythonWrappers(const char* op_list_buf, size_t op_list_len) {
@ -1081,7 +1087,7 @@ string GetPythonWrappers(const char* op_list_buf, size_t op_list_len) {
ops.ParseFromArray(op_list_buf, op_list_len);
ApiDefMap api_def_map(ops);
return GetPythonOps(ops, api_def_map, {});
return GetPythonOpsImpl(ops, api_def_map, {});
}
} // namespace tensorflow

View File

@ -23,8 +23,20 @@ limitations under the License.
namespace tensorflow {
// Returns a string containing the generated Python code for the given Ops.
// ops is a protobuff, typically generated using OpRegistry::Global()->Export.
// api_defs is typically constructed directly from ops.
// hidden_ops should be a list of Op names that should get a leading _
// in the output. Prints the output to stdout.
// in the output.
// source_file_name is optional and contains the name of the original C++ source
// file where the ops' REGISTER_OP() calls reside.
string GetPythonOps(const OpList& ops, const ApiDefMap& api_defs,
const std::vector<string>& hidden_ops,
const string& source_file_name);
// Prints the output of GetPrintOps to stdout.
// hidden_ops should be a list of Op names that should get a leading _
// in the output.
// Optional fourth argument is the name of the original C++ source file
// where the ops' REGISTER_OP() calls reside.
void PrintPythonOps(const OpList& ops, const ApiDefMap& api_defs,

View File

@ -0,0 +1,42 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/python/framework/python_op_gen.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/op_gen_lib.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
TEST(PythonOpGen, Basic) {
OpList ops;
OpRegistry::Global()->Export(false, &ops);
ApiDefMap api_def_map(ops);
string code = GetPythonOps(ops, api_def_map, {}, "");
EXPECT_TRUE(absl::StrContains(code, "def case"));
// TODO(mdan): Add tests to verify type annotations are correctly added.
}
// TODO(mdan): Include more tests with synhtetic ops and api defs.
} // namespace
} // namespace tensorflow