Merge pull request #15960 from frankchn/branch_181239691
Branch 181239691
This commit is contained in:
commit
c15d457842
@ -529,6 +529,7 @@ filegroup(
|
||||
"//tensorflow/contrib/periodic_resample:all_files",
|
||||
"//tensorflow/contrib/predictor:all_files",
|
||||
"//tensorflow/contrib/py2tf:all_files",
|
||||
"//tensorflow/contrib/py2tf/convert:all_files",
|
||||
"//tensorflow/contrib/py2tf/pyct:all_files",
|
||||
"//tensorflow/contrib/py2tf/pyct/static_analysis:all_files",
|
||||
"//tensorflow/contrib/quantize:all_files",
|
||||
|
@ -481,6 +481,7 @@ cc_library(
|
||||
"//tensorflow/core:cuda_libdevice_path",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:regexp_internal",
|
||||
"//tensorflow/core:stream_executor_no_cuda",
|
||||
"@llvm//:core",
|
||||
"@llvm//:support",
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include <stdlib.h>
|
||||
#include <atomic>
|
||||
#include <functional>
|
||||
#include <mutex> // NOLINT(build/c++11): only using std::call_once, not mutex.
|
||||
#include <utility>
|
||||
|
||||
#include "llvm/IR/DiagnosticInfo.h"
|
||||
@ -77,9 +78,11 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/cuda_libdevice_path.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/regexp.h"
|
||||
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
||||
#include "tensorflow/core/platform/subprocess.h"
|
||||
#include "tensorflow/core/platform/tracing.h"
|
||||
#include "tensorflow/stream_executor/cuda/cuda_diagnostics.h"
|
||||
|
||||
namespace se = ::perftools::gputools;
|
||||
|
||||
@ -241,6 +244,93 @@ tensorflow::Status PrepareHloModuleForIrEmitting(
|
||||
return pipeline.Run(hlo_module).status();
|
||||
}
|
||||
|
||||
// Prints a warning if the ptxas at ptxas_path has known bugs.
|
||||
//
|
||||
// Only prints a warning the first time it's called for a particular value of
|
||||
// ptxas_path.
|
||||
void WarnIfBadPtxasVersion(const string& ptxas_path) {
|
||||
static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED);
|
||||
static std::unordered_set<string>* seen_ptxas_paths GUARDED_BY(mu) =
|
||||
new std::unordered_set<string>();
|
||||
|
||||
tensorflow::mutex_lock lock(mu);
|
||||
if (!seen_ptxas_paths->insert(ptxas_path).second) {
|
||||
// Already checked this ptx binary, nothing to do.
|
||||
return;
|
||||
}
|
||||
|
||||
tensorflow::SubProcess ptxas;
|
||||
ptxas.SetProgram(ptxas_path, {ptxas_path, "--version"});
|
||||
ptxas.SetChannelAction(tensorflow::CHAN_STDOUT, tensorflow::ACTION_PIPE);
|
||||
if (!ptxas.Start()) {
|
||||
LOG(WARNING) << "Couldn't invoke " << ptxas_path << " --version";
|
||||
return;
|
||||
}
|
||||
|
||||
string out;
|
||||
int exit_code = ptxas.Communicate(/*stdin_input=*/nullptr, &out,
|
||||
/*stderr_output=*/nullptr);
|
||||
if (exit_code != 0) {
|
||||
LOG(WARNING) << "Running " << ptxas_path << " --version returned "
|
||||
<< exit_code;
|
||||
return;
|
||||
}
|
||||
|
||||
int64 vmaj, vmin, vdot;
|
||||
string vmaj_str, vmin_str, vdot_str;
|
||||
using tensorflow::strings::safe_strto64;
|
||||
if (!RE2::PartialMatch(out, R"(\bV(\d+)\.(\d+)\.(\d+)\b)", &vmaj_str,
|
||||
&vmin_str, &vdot_str) ||
|
||||
!safe_strto64(vmaj_str, &vmaj) || !safe_strto64(vmin_str, &vmin) ||
|
||||
!safe_strto64(vdot_str, &vdot)) {
|
||||
LOG(WARNING) << "Couldn't parse ptxas version in output of " << ptxas_path
|
||||
<< " --version:\n"
|
||||
<< out;
|
||||
return;
|
||||
}
|
||||
|
||||
// ptxas 9.0 before 9.0.276 miscompiles some address calculations with large
|
||||
// offsets (e.g. "load ptr + large_constant"), b/70245379.
|
||||
if (vmaj == 9 && vmin == 0 && vdot < 276) {
|
||||
LOG(WARNING) << "*** WARNING *** You are using ptxas " << vmaj << "."
|
||||
<< vmin << "." << vdot
|
||||
<< ", which is in range [9.0.0, 9.0.276). These versions are "
|
||||
"known to miscompile XLA code, leading to incorrect "
|
||||
"results or invalid-address errors.";
|
||||
}
|
||||
}
|
||||
|
||||
// Prints a warning if the ptx->sass JIT in the driver has known bugs.
|
||||
//
|
||||
// Using such a driver only a problem if we fail to use ptxas to compile our ptx
|
||||
// and have to use the driver instead, so you should only call this function if
|
||||
// we're going to use the driver JIT.
|
||||
//
|
||||
// Only prints a warning the first time it's called.
|
||||
void WarnIfBadDriverJITVersion() {
|
||||
static std::once_flag run_once;
|
||||
std::call_once(run_once, [] {
|
||||
auto version_or_status = se::cuda::Diagnostician::FindKernelDriverVersion();
|
||||
if (!version_or_status.ok()) {
|
||||
LOG(WARNING) << "Couldn't read CUDA driver version.";
|
||||
return;
|
||||
}
|
||||
se::cuda::DriverVersion version = version_or_status.ValueOrDie();
|
||||
|
||||
// The driver JIT in 384 before 384.108 miscompiles some address
|
||||
// calculations with large offsets (e.g. "load ptr + large_constant"),
|
||||
// b/70245379.
|
||||
if (std::get<0>(version) == 384 && std::get<1>(version) < 108) {
|
||||
LOG(WARNING)
|
||||
<< "*** WARNING *** Invoking the PTX->SASS JIT from driver version "
|
||||
<< se::cuda::DriverVersionToString(version)
|
||||
<< ", which is in range [384.0.0, 384.108.0). These versions are "
|
||||
"known to miscompile XLA code, leading to incorrect results or "
|
||||
"invalid-address errors.";
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Compiles the given PTX string using ptxas and returns the resulting machine
|
||||
// code (i.e. a cubin) as a byte array.
|
||||
StatusOr<std::vector<uint8>> CompilePtx(const string& ptx, int cc_major,
|
||||
@ -252,6 +342,8 @@ StatusOr<std::vector<uint8>> CompilePtx(const string& ptx, int cc_major,
|
||||
auto env = tensorflow::Env::Default();
|
||||
TF_RETURN_IF_ERROR(env->FileExists(ptxas_path));
|
||||
|
||||
WarnIfBadPtxasVersion(ptxas_path);
|
||||
|
||||
// Write ptx into a temporary file.
|
||||
string ptx_path;
|
||||
if (!env->LocalTempFilename(&ptx_path)) {
|
||||
@ -555,6 +647,10 @@ std::vector<uint8> GpuCompiler::CompilePtxOrGetCachedResult(const string& ptx,
|
||||
"GPU driver compile the ptx. "
|
||||
<< maybe_cubin.status();
|
||||
}
|
||||
|
||||
// We're going to use the driver to JIT our PTX->SASS, so warn if
|
||||
// the JIT in the driver has known bugs.
|
||||
WarnIfBadDriverJITVersion();
|
||||
}
|
||||
}
|
||||
cache_value->compilation_done = true;
|
||||
|
@ -23,7 +23,7 @@ namespace {
|
||||
|
||||
class ConditionalOpTest : public ClientLibraryTestBase {
|
||||
protected:
|
||||
Computation CreateR0F32ConstantComputation(float value) {
|
||||
Computation CreateR0ConstantComputation(float value) {
|
||||
ComputationBuilder builder(client_, "Constant");
|
||||
builder.Parameter(0, empty_tuple_, "tuple");
|
||||
builder.ConstantR0<float>(value);
|
||||
@ -32,7 +32,7 @@ class ConditionalOpTest : public ClientLibraryTestBase {
|
||||
return build_status.ConsumeValueOrDie();
|
||||
}
|
||||
|
||||
Computation CreateR0F32IdentityComputation() {
|
||||
Computation CreateR0IdentityComputation() {
|
||||
ComputationBuilder builder(client_, "Identity");
|
||||
builder.Parameter(0, r0f32_, "x");
|
||||
auto build_status = builder.Build();
|
||||
@ -40,25 +40,85 @@ class ConditionalOpTest : public ClientLibraryTestBase {
|
||||
return build_status.ConsumeValueOrDie();
|
||||
}
|
||||
|
||||
Computation CreateR0F32CeilComputation() {
|
||||
Computation CreateCeilComputation(const Shape& shape) {
|
||||
ComputationBuilder builder(client_, "Ceil");
|
||||
auto param = builder.Parameter(0, r0f32_, "param");
|
||||
auto param = builder.Parameter(0, shape, "param");
|
||||
builder.Ceil(param);
|
||||
auto build_status = builder.Build();
|
||||
EXPECT_IS_OK(build_status.status());
|
||||
return build_status.ConsumeValueOrDie();
|
||||
}
|
||||
|
||||
Computation CreateR0F32FloorComputation() {
|
||||
ComputationBuilder builder(client_, "Ceil");
|
||||
auto param = builder.Parameter(0, r0f32_, "param");
|
||||
Computation CreateR0CeilComputation() {
|
||||
return CreateCeilComputation(r0f32_);
|
||||
}
|
||||
|
||||
Computation CreateR1CeilComputation() {
|
||||
return CreateCeilComputation(r1s2f32_);
|
||||
}
|
||||
|
||||
Computation CreateFloorComputation(const Shape& shape) {
|
||||
ComputationBuilder builder(client_, "Floor");
|
||||
auto param = builder.Parameter(0, shape, "param");
|
||||
builder.Floor(param);
|
||||
auto build_status = builder.Build();
|
||||
EXPECT_IS_OK(build_status.status());
|
||||
return build_status.ConsumeValueOrDie();
|
||||
}
|
||||
|
||||
Computation CreateAddTupleComputation(const string& computation_name,
|
||||
Computation CreateR0FloorComputation() {
|
||||
return CreateFloorComputation(r0f32_);
|
||||
}
|
||||
|
||||
Computation CreateR1FloorComputation() {
|
||||
return CreateFloorComputation(r1s2f32_);
|
||||
}
|
||||
|
||||
Computation CreateTupleCeilComputation(const string& computation_name,
|
||||
const Shape& tuple_shape) {
|
||||
ComputationBuilder builder(client_, computation_name);
|
||||
auto tuple = builder.Parameter(0, tuple_shape, "tuple");
|
||||
auto x = builder.GetTupleElement(tuple, 0);
|
||||
auto y = builder.GetTupleElement(tuple, 1);
|
||||
auto x_ceil = builder.Ceil(x);
|
||||
auto y_ceil = builder.Ceil(y);
|
||||
builder.Tuple({x_ceil, y_ceil});
|
||||
auto build_status = builder.Build();
|
||||
EXPECT_IS_OK(build_status.status());
|
||||
return build_status.ConsumeValueOrDie();
|
||||
}
|
||||
|
||||
Computation CreateR0TupleCeilComputation() {
|
||||
return CreateTupleCeilComputation("CeilR0", tuple_2_r0f32_);
|
||||
}
|
||||
|
||||
Computation CreateR1TupleCeilComputation() {
|
||||
return CreateTupleCeilComputation("CeilR1", tuple_2_r1s2f32_);
|
||||
}
|
||||
|
||||
Computation CreateTupleFloorComputation(const string& computation_name,
|
||||
const Shape& tuple_shape) {
|
||||
ComputationBuilder builder(client_, computation_name);
|
||||
auto tuple = builder.Parameter(0, tuple_shape, "tuple");
|
||||
auto x = builder.GetTupleElement(tuple, 0);
|
||||
auto y = builder.GetTupleElement(tuple, 1);
|
||||
auto x_floor = builder.Floor(x);
|
||||
auto y_floor = builder.Floor(y);
|
||||
builder.Tuple({x_floor, y_floor});
|
||||
auto build_status = builder.Build();
|
||||
EXPECT_IS_OK(build_status.status());
|
||||
return build_status.ConsumeValueOrDie();
|
||||
}
|
||||
|
||||
Computation CreateR0TupleFloorComputation() {
|
||||
return CreateTupleFloorComputation("FloorR0", tuple_2_r0f32_);
|
||||
}
|
||||
|
||||
Computation CreateR1TupleFloorComputation() {
|
||||
return CreateTupleFloorComputation("FloorR1", tuple_2_r1s2f32_);
|
||||
}
|
||||
|
||||
Computation CreateTupleAddComputation(const string& computation_name,
|
||||
const Shape& tuple_shape) {
|
||||
ComputationBuilder builder(client_, computation_name);
|
||||
auto tuple = builder.Parameter(0, tuple_shape, "tuple");
|
||||
@ -70,15 +130,15 @@ class ConditionalOpTest : public ClientLibraryTestBase {
|
||||
return build_status.ConsumeValueOrDie();
|
||||
}
|
||||
|
||||
Computation CreateAddR0Computation() {
|
||||
return CreateAddTupleComputation("AddR0", tuple_2_r0f32_);
|
||||
Computation CreateR0TupleAddComputation() {
|
||||
return CreateTupleAddComputation("AddR0", tuple_2_r0f32_);
|
||||
}
|
||||
|
||||
Computation CreateAddR1Computation() {
|
||||
return CreateAddTupleComputation("AddR1", tuple_2_r1s2f32_);
|
||||
Computation CreateR1TupleAddComputation() {
|
||||
return CreateTupleAddComputation("AddR1", tuple_2_r1s2f32_);
|
||||
}
|
||||
|
||||
Computation CreateSubTupleComputation(const string& computation_name,
|
||||
Computation CreateTupleSubComputation(const string& computation_name,
|
||||
const Shape& tuple_shape) {
|
||||
ComputationBuilder builder(client_, computation_name);
|
||||
auto tuple = builder.Parameter(0, tuple_shape, "tuple");
|
||||
@ -90,15 +150,16 @@ class ConditionalOpTest : public ClientLibraryTestBase {
|
||||
return build_status.ConsumeValueOrDie();
|
||||
}
|
||||
|
||||
Computation CreateSubR0Computation() {
|
||||
return CreateSubTupleComputation("SubR0", tuple_2_r0f32_);
|
||||
Computation CreateR0TupleSubComputation() {
|
||||
return CreateTupleSubComputation("SubR0", tuple_2_r0f32_);
|
||||
}
|
||||
|
||||
Computation CreateSubR1Computation() {
|
||||
return CreateSubTupleComputation("SubR1", tuple_2_r1s2f32_);
|
||||
Computation CreateR1TupleSubComputation() {
|
||||
return CreateTupleSubComputation("SubR1", tuple_2_r1s2f32_);
|
||||
}
|
||||
|
||||
Shape r0f32_ = ShapeUtil::MakeShape(F32, {});
|
||||
Shape r1s2f32_ = ShapeUtil::MakeShape(F32, {2});
|
||||
Shape tuple_2_r0f32_ = ShapeUtil::MakeTupleShape(
|
||||
{ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})});
|
||||
Shape tuple_2_r1s2f32_ = ShapeUtil::MakeTupleShape(
|
||||
@ -112,8 +173,8 @@ XLA_TEST_F(ConditionalOpTest, Parameters0) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto pred = builder.ConstantR0<bool>(true);
|
||||
auto operands = builder.Tuple({});
|
||||
auto true_computation = CreateR0F32ConstantComputation(56.0f);
|
||||
auto false_computation = CreateR0F32ConstantComputation(12.0f);
|
||||
auto true_computation = CreateR0ConstantComputation(56.0f);
|
||||
auto false_computation = CreateR0ConstantComputation(12.0f);
|
||||
auto result = builder.Conditional(pred, operands, true_computation, operands,
|
||||
false_computation);
|
||||
|
||||
@ -126,7 +187,7 @@ XLA_TEST_F(ConditionalOpTest, Parameters1) {
|
||||
auto pred = builder.ConstantR0<bool>(false);
|
||||
auto operand1 = builder.ConstantR0<float>(56.0f);
|
||||
auto operand2 = builder.ConstantR0<float>(12.0f);
|
||||
auto identity = CreateR0F32IdentityComputation();
|
||||
auto identity = CreateR0IdentityComputation();
|
||||
auto result =
|
||||
builder.Conditional(pred, operand1, identity, operand2, identity);
|
||||
|
||||
@ -140,9 +201,8 @@ XLA_TEST_F(ConditionalOpTest, DiffComputationsDiffArgs) {
|
||||
auto pred = builder.ConstantR0<bool>(false);
|
||||
auto operand1 = builder.ConstantR0<float>(56.4f);
|
||||
auto operand2 = builder.ConstantR0<float>(12.6f);
|
||||
auto result =
|
||||
builder.Conditional(pred, operand1, CreateR0F32CeilComputation(),
|
||||
operand2, CreateR0F32FloorComputation());
|
||||
auto result = builder.Conditional(pred, operand1, CreateR0CeilComputation(),
|
||||
operand2, CreateR0FloorComputation());
|
||||
|
||||
ComputeAndCompareR0<float>(&builder, 12.0f, {}, error_spec_);
|
||||
}
|
||||
@ -153,8 +213,8 @@ XLA_TEST_F(ConditionalOpTest, DiffComputationsSameArg) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto pred = builder.ConstantR0<bool>(false);
|
||||
auto operand = builder.ConstantR0<float>(12.6f);
|
||||
auto result = builder.Conditional(pred, operand, CreateR0F32CeilComputation(),
|
||||
operand, CreateR0F32FloorComputation());
|
||||
auto result = builder.Conditional(pred, operand, CreateR0CeilComputation(),
|
||||
operand, CreateR0FloorComputation());
|
||||
|
||||
ComputeAndCompareR0<float>(&builder, 12.0f, {}, error_spec_);
|
||||
}
|
||||
@ -166,7 +226,7 @@ XLA_TEST_F(ConditionalOpTest, SameComputationDiffArgs) {
|
||||
auto pred = builder.ConstantR0<bool>(false);
|
||||
auto operand1 = builder.ConstantR0<float>(56.4f);
|
||||
auto operand2 = builder.ConstantR0<float>(12.6f);
|
||||
auto floor = CreateR0F32FloorComputation();
|
||||
auto floor = CreateR0FloorComputation();
|
||||
auto result = builder.Conditional(pred, operand1, floor, operand2, floor);
|
||||
|
||||
ComputeAndCompareR0<float>(&builder, 12.0f, {}, error_spec_);
|
||||
@ -178,7 +238,7 @@ XLA_TEST_F(ConditionalOpTest, SameComputationSameArg) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto pred = builder.ConstantR0<bool>(false);
|
||||
auto operand = builder.ConstantR0<float>(12.6f);
|
||||
auto floor = CreateR0F32FloorComputation();
|
||||
auto floor = CreateR0FloorComputation();
|
||||
auto result = builder.Conditional(pred, operand, floor, operand, floor);
|
||||
|
||||
ComputeAndCompareR0<float>(&builder, 12.0f, {}, error_spec_);
|
||||
@ -191,9 +251,8 @@ XLA_TEST_F(ConditionalOpTest, SameComputationDiffInstances) {
|
||||
auto pred = builder.ConstantR0<bool>(false);
|
||||
auto operand1 = builder.ConstantR0<float>(56.4f);
|
||||
auto operand2 = builder.ConstantR0<float>(12.6f);
|
||||
auto result =
|
||||
builder.Conditional(pred, operand1, CreateR0F32FloorComputation(),
|
||||
operand2, CreateR0F32FloorComputation());
|
||||
auto result = builder.Conditional(pred, operand1, CreateR0FloorComputation(),
|
||||
operand2, CreateR0FloorComputation());
|
||||
|
||||
ComputeAndCompareR0<float>(&builder, 12.0f, {}, error_spec_);
|
||||
}
|
||||
@ -205,9 +264,8 @@ XLA_TEST_F(ConditionalOpTest, ConditionalWithCall) {
|
||||
auto pred_cond = inner_builder.Parameter(0, r0bool, "param0");
|
||||
auto true_operand = inner_builder.Parameter(1, r0f32_, "param1");
|
||||
auto false_operand = inner_builder.Parameter(2, r0f32_, "param2");
|
||||
inner_builder.Conditional(pred_cond, true_operand,
|
||||
CreateR0F32CeilComputation(), false_operand,
|
||||
CreateR0F32FloorComputation());
|
||||
inner_builder.Conditional(pred_cond, true_operand, CreateR0CeilComputation(),
|
||||
false_operand, CreateR0FloorComputation());
|
||||
auto inner_builder_result = inner_builder.Build();
|
||||
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
@ -228,8 +286,9 @@ XLA_TEST_F(ConditionalOpTest, Parameters2TrueBranch) {
|
||||
auto operand1 = builder.ConstantR0<float>(56.0f);
|
||||
auto operand2 = builder.ConstantR0<float>(12.0f);
|
||||
auto operands = builder.Tuple({operand1, operand2});
|
||||
auto result = builder.Conditional(pred, operands, CreateAddR0Computation(),
|
||||
operands, CreateSubR0Computation());
|
||||
auto result =
|
||||
builder.Conditional(pred, operands, CreateR0TupleAddComputation(),
|
||||
operands, CreateR0TupleSubComputation());
|
||||
|
||||
ComputeAndCompareR0<float>(&builder, 68.0f, {}, error_spec_);
|
||||
}
|
||||
@ -242,8 +301,9 @@ XLA_TEST_F(ConditionalOpTest, Parameters2FalseBranch) {
|
||||
auto operand1 = builder.ConstantR0<float>(56.0f);
|
||||
auto operand2 = builder.ConstantR0<float>(12.0f);
|
||||
auto operands = builder.Tuple({operand1, operand2});
|
||||
auto result = builder.Conditional(pred, operands, CreateAddR0Computation(),
|
||||
operands, CreateSubR0Computation());
|
||||
auto result =
|
||||
builder.Conditional(pred, operands, CreateR0TupleAddComputation(),
|
||||
operands, CreateR0TupleSubComputation());
|
||||
|
||||
ComputeAndCompareR0<float>(&builder, 44.0f, {}, error_spec_);
|
||||
}
|
||||
@ -256,8 +316,9 @@ XLA_TEST_F(ConditionalOpTest, Parameters2ArrayTrueBranch) {
|
||||
auto operand1 = builder.ConstantR1<float>({24.0f, 56.0f});
|
||||
auto operand2 = builder.ConstantR1<float>({10.0f, 11.0f});
|
||||
auto operands = builder.Tuple({operand1, operand2});
|
||||
auto result = builder.Conditional(pred, operands, CreateAddR1Computation(),
|
||||
operands, CreateSubR1Computation());
|
||||
auto result =
|
||||
builder.Conditional(pred, operands, CreateR1TupleAddComputation(),
|
||||
operands, CreateR1TupleSubComputation());
|
||||
|
||||
ComputeAndCompareR1<float>(&builder, {34.0f, 67.0f}, {}, error_spec_);
|
||||
}
|
||||
@ -270,25 +331,192 @@ XLA_TEST_F(ConditionalOpTest, Parameters2ArrayFalseBranch) {
|
||||
auto operand1 = builder.ConstantR1<float>({24.0f, 56.0f});
|
||||
auto operand2 = builder.ConstantR1<float>({10.0f, 11.0f});
|
||||
auto operands = builder.Tuple({operand1, operand2});
|
||||
auto result = builder.Conditional(pred, operands, CreateAddR1Computation(),
|
||||
operands, CreateSubR1Computation());
|
||||
auto result =
|
||||
builder.Conditional(pred, operands, CreateR1TupleAddComputation(),
|
||||
operands, CreateR1TupleSubComputation());
|
||||
|
||||
ComputeAndCompareR1<float>(&builder, {14.0f, 45.0f}, {}, error_spec_);
|
||||
}
|
||||
|
||||
// Test true and false computations that return a tuple of scalars.
|
||||
XLA_TEST_F(ConditionalOpTest, ReturnTupleOfScalars) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto pred = builder.ConstantR0<bool>(false);
|
||||
auto operands = builder.Tuple(
|
||||
{builder.ConstantR0<float>(12.2f), builder.ConstantR0<float>(25.6f)});
|
||||
builder.Conditional(pred, operands, CreateR0TupleCeilComputation(), operands,
|
||||
CreateR0TupleFloorComputation());
|
||||
|
||||
ComputeAndCompareTuple(
|
||||
&builder,
|
||||
*Literal::MakeTuple({Literal::CreateR0<float>(12.0f).get(),
|
||||
Literal::CreateR0<float>(25.0f).get()}),
|
||||
{}, error_spec_);
|
||||
}
|
||||
|
||||
// Test true and false computations that return a tuple of arrays.
|
||||
// TODO(b/71715476): Returning tuples from Conditional fails in GPU backend.
|
||||
XLA_TEST_F(ConditionalOpTest, DISABLED_ON_GPU(ReturnTupleOfArrays)) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto pred = builder.ConstantR0<bool>(true);
|
||||
auto operands = builder.Tuple({builder.ConstantR1<float>({12.2f, 15.8f}),
|
||||
builder.ConstantR1<float>({25.6f, 29.2f})});
|
||||
builder.Conditional(pred, operands, CreateR1TupleCeilComputation(), operands,
|
||||
CreateR1TupleFloorComputation());
|
||||
|
||||
ComputeAndCompareTuple(
|
||||
&builder,
|
||||
*Literal::MakeTuple({Literal::CreateR1<float>({13.0f, 16.0f}).get(),
|
||||
Literal::CreateR1<float>({26.0f, 30.0f}).get()}),
|
||||
{}, error_spec_);
|
||||
}
|
||||
|
||||
// Test true and false computations that return a tuple of a predicate, a
|
||||
// scalar, and an array.
|
||||
// TODO(b/71715476): Returning tuples from Conditional fails in GPU backend.
|
||||
XLA_TEST_F(ConditionalOpTest,
|
||||
DISABLED_ON_GPU(ReturnTupleofPredicateScalarArray)) {
|
||||
ComputationBuilder true_builder(client_, TestName() + ".true");
|
||||
{
|
||||
true_builder.Parameter(0, empty_tuple_, "tuple");
|
||||
auto true_pred = true_builder.ConstantR0<bool>(true);
|
||||
auto true_scalar = true_builder.ConstantR0<float>(12.2f);
|
||||
auto true_array = true_builder.ConstantR1<float>({12.8f, 14.6f});
|
||||
true_builder.Tuple({true_pred, true_scalar, true_array});
|
||||
}
|
||||
auto true_builder_result = true_builder.Build();
|
||||
EXPECT_IS_OK(true_builder_result.status());
|
||||
|
||||
ComputationBuilder false_builder(client_, TestName() + ".false");
|
||||
{
|
||||
false_builder.Parameter(0, empty_tuple_, "tuple");
|
||||
auto false_pred = false_builder.ConstantR0<bool>(false);
|
||||
auto false_scalar = false_builder.ConstantR0<float>(25.6f);
|
||||
auto false_array = false_builder.ConstantR1<float>({26.4f, 32.6f});
|
||||
false_builder.Tuple({false_pred, false_scalar, false_array});
|
||||
}
|
||||
auto false_builder_result = false_builder.Build();
|
||||
EXPECT_IS_OK(false_builder_result.status());
|
||||
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto pred = builder.ConstantR0<bool>(true);
|
||||
auto operands = builder.Tuple({});
|
||||
builder.Conditional(pred, operands, true_builder_result.ConsumeValueOrDie(),
|
||||
operands, false_builder_result.ConsumeValueOrDie());
|
||||
|
||||
ComputeAndCompareTuple(
|
||||
&builder,
|
||||
*Literal::MakeTuple({Literal::CreateR0<bool>(true).get(),
|
||||
Literal::CreateR0<float>(12.2f).get(),
|
||||
Literal::CreateR1<float>({12.8f, 14.6f}).get()}),
|
||||
{}, error_spec_);
|
||||
}
|
||||
|
||||
// Test true and false computations that return a nested tuple.
|
||||
// TODO(b/71715476): Returning tuples from Conditional fails in GPU backend.
|
||||
XLA_TEST_F(ConditionalOpTest, DISABLED_ON_GPU(ReturnNestedTuple)) {
|
||||
ComputationBuilder true_builder(client_, TestName() + ".true");
|
||||
{
|
||||
true_builder.Parameter(0, empty_tuple_, "tuple");
|
||||
auto true_constant1 = true_builder.ConstantR0<float>(12.2f);
|
||||
auto true_constant2 = true_builder.ConstantR1<float>({12.8f, 14.6f});
|
||||
auto true_constant3 = true_builder.ConstantR1<float>({25.4f, 29.8f});
|
||||
auto true_constant4 = true_builder.ConstantR0<float>(35.6f);
|
||||
true_builder.Tuple({true_builder.Tuple({true_constant1, true_constant2}),
|
||||
true_builder.Tuple({true_constant3, true_constant4})});
|
||||
}
|
||||
auto true_builder_result = true_builder.Build();
|
||||
EXPECT_IS_OK(true_builder_result.status());
|
||||
|
||||
ComputationBuilder false_builder(client_, TestName() + ".false");
|
||||
{
|
||||
false_builder.Parameter(0, empty_tuple_, "tuple");
|
||||
auto false_constant1 = false_builder.ConstantR0<float>(46.6f);
|
||||
auto false_constant2 = false_builder.ConstantR1<float>({54.4f, 58.4f});
|
||||
auto false_constant3 = false_builder.ConstantR1<float>({62.1f, 67.4f});
|
||||
auto false_constant4 = false_builder.ConstantR0<float>(9.3f);
|
||||
false_builder.Tuple(
|
||||
{false_builder.Tuple({false_constant1, false_constant2}),
|
||||
false_builder.Tuple({false_constant3, false_constant4})});
|
||||
}
|
||||
auto false_builder_result = false_builder.Build();
|
||||
EXPECT_IS_OK(false_builder_result.status());
|
||||
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto pred = builder.ConstantR0<bool>(false);
|
||||
auto operands = builder.Tuple({});
|
||||
builder.Conditional(pred, operands, true_builder_result.ConsumeValueOrDie(),
|
||||
operands, false_builder_result.ConsumeValueOrDie());
|
||||
|
||||
ComputeAndCompareTuple(
|
||||
&builder,
|
||||
*Literal::MakeTuple(
|
||||
{Literal::MakeTuple({Literal::CreateR0<float>(46.6f).get(),
|
||||
Literal::CreateR1<float>({54.4f, 58.4f}).get()})
|
||||
.get(),
|
||||
Literal::MakeTuple({Literal::CreateR1<float>({62.1f, 67.4f}).get(),
|
||||
Literal::CreateR0<float>(9.3f).get()})
|
||||
.get()}),
|
||||
{}, error_spec_);
|
||||
}
|
||||
|
||||
// Test conditional that takes in scalar operands in the form of external
|
||||
// params.
|
||||
XLA_TEST_F(ConditionalOpTest, ScalarOperandsFromExternalParams) {
|
||||
Shape r0bool = ShapeUtil::MakeShape(PRED, {});
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
|
||||
ComputationDataHandle pred, operand1, operand2;
|
||||
auto pred_arg = CreateR0Parameter<bool>(true, 0, "pred", &builder, &pred);
|
||||
auto operand1_param =
|
||||
CreateR0Parameter<float>(56.3f, 1, "operand1", &builder, &operand1);
|
||||
auto operand2_param =
|
||||
CreateR0Parameter<float>(12.7f, 2, "operand2", &builder, &operand2);
|
||||
auto result = builder.Conditional(pred, operand1, CreateR0CeilComputation(),
|
||||
operand2, CreateR0FloorComputation());
|
||||
|
||||
ComputeAndCompareR0<float>(
|
||||
&builder, 57.0f,
|
||||
{pred_arg.get(), operand1_param.get(), operand2_param.get()},
|
||||
error_spec_);
|
||||
}
|
||||
|
||||
// Test conditional that takes in array operands in the form of external params.
|
||||
XLA_TEST_F(ConditionalOpTest, ArrayOperandsFromExternalParams) {
|
||||
Shape r0bool = ShapeUtil::MakeShape(PRED, {});
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
|
||||
ComputationDataHandle pred, operand1, operand2;
|
||||
auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
|
||||
auto operand1_param = CreateR1Parameter<float>({24.3f, 56.7f}, 1, "operand1",
|
||||
&builder, &operand1);
|
||||
auto operand2_param = CreateR1Parameter<float>({10.2f, 11.6f}, 2, "operand2",
|
||||
&builder, &operand2);
|
||||
auto result = builder.Conditional(pred, operand1, CreateR1CeilComputation(),
|
||||
operand2, CreateR1FloorComputation());
|
||||
|
||||
ComputeAndCompareR1<float>(
|
||||
&builder, {10.0f, 11.0f},
|
||||
{pred_arg.get(), operand1_param.get(), operand2_param.get()},
|
||||
error_spec_);
|
||||
}
|
||||
|
||||
// Test the case where one conditional is nested within another.
|
||||
XLA_TEST_F(ConditionalOpTest, NestedConditionals) {
|
||||
Shape r0bool = ShapeUtil::MakeShape(PRED, {});
|
||||
Shape tuple_shape = ShapeUtil::MakeTupleShape({r0bool, r0f32_, r0f32_});
|
||||
ComputationBuilder inner_builder(client_, TestName() + ".inner_conditional");
|
||||
auto param0 = inner_builder.Parameter(0, tuple_shape, "param0");
|
||||
auto pred_cond = inner_builder.GetTupleElement(param0, 0);
|
||||
auto true_operand = inner_builder.GetTupleElement(param0, 1);
|
||||
auto false_operand = inner_builder.GetTupleElement(param0, 2);
|
||||
inner_builder.Conditional(pred_cond, true_operand,
|
||||
CreateR0F32CeilComputation(), false_operand,
|
||||
CreateR0F32FloorComputation());
|
||||
{
|
||||
Shape r0bool = ShapeUtil::MakeShape(PRED, {});
|
||||
Shape tuple_shape = ShapeUtil::MakeTupleShape({r0bool, r0f32_, r0f32_});
|
||||
auto param0 = inner_builder.Parameter(0, tuple_shape, "param0");
|
||||
auto pred_cond = inner_builder.GetTupleElement(param0, 0);
|
||||
auto true_operand = inner_builder.GetTupleElement(param0, 1);
|
||||
auto false_operand = inner_builder.GetTupleElement(param0, 2);
|
||||
inner_builder.Conditional(pred_cond, true_operand,
|
||||
CreateR0CeilComputation(), false_operand,
|
||||
CreateR0FloorComputation());
|
||||
}
|
||||
auto inner_builder_result = inner_builder.Build();
|
||||
EXPECT_IS_OK(inner_builder_result.status());
|
||||
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto pred1 = builder.ConstantR0<bool>(true);
|
||||
@ -299,7 +527,7 @@ XLA_TEST_F(ConditionalOpTest, NestedConditionals) {
|
||||
auto tuple_operand = builder.Tuple({pred2, operand1, operand2});
|
||||
builder.Conditional(pred1, tuple_operand,
|
||||
inner_builder_result.ConsumeValueOrDie(), operand3,
|
||||
CreateR0F32IdentityComputation());
|
||||
CreateR0IdentityComputation());
|
||||
|
||||
ComputeAndCompareR0<float>(&builder, 12.0f, {}, error_spec_);
|
||||
}
|
||||
@ -311,8 +539,8 @@ XLA_TEST_F(ConditionalOpTest, ShapeMismatch) {
|
||||
auto operand1 = builder.ConstantR0<float>(56.0f);
|
||||
auto operand2 = builder.ConstantR0<float>(12.0f);
|
||||
auto operands = builder.Tuple({operand1, operand2});
|
||||
builder.Conditional(pred, operands, CreateAddR1Computation(), operands,
|
||||
CreateSubR0Computation());
|
||||
builder.Conditional(pred, operands, CreateR1TupleAddComputation(), operands,
|
||||
CreateR0TupleSubComputation());
|
||||
|
||||
auto result = builder.Build();
|
||||
EXPECT_FALSE(result.ok());
|
||||
|
@ -258,7 +258,7 @@ class ResNet50Benchmarks(tf.test.Benchmark):
|
||||
return tfe.Iterator(ds)
|
||||
|
||||
self._benchmark_eager_train(
|
||||
'eager_train_dataset', make_iterator, defun=True)
|
||||
'eager_train_dataset_with_defun', make_iterator, defun=True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -56,7 +56,7 @@ LIBS := \
|
||||
-lz
|
||||
|
||||
# If we're on Linux, also link in the dl library.
|
||||
ifeq ($(OS),LINUX)
|
||||
ifeq ($(HOST_OS),LINUX)
|
||||
LIBS += -ldl -lpthread
|
||||
endif
|
||||
|
||||
|
@ -191,6 +191,13 @@ typedef struct {
|
||||
int axis;
|
||||
} TfLiteGatherParams;
|
||||
|
||||
typedef struct {
|
||||
// TODO(ahentz): We can't have dynamic data in this struct, at least not yet.
|
||||
// For now we will fix the maximum possible number of dimensions.
|
||||
int perm[8];
|
||||
int num_dimensions;
|
||||
} TfLiteTransposeParams;
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern "C"
|
||||
#endif // __cplusplus
|
||||
|
@ -101,6 +101,7 @@ cc_library(
|
||||
"space_to_batch_nd.cc",
|
||||
"space_to_depth.cc",
|
||||
"svdf.cc",
|
||||
"transpose.cc",
|
||||
"unidirectional_sequence_rnn.cc",
|
||||
],
|
||||
hdrs = [
|
||||
|
@ -2485,8 +2485,8 @@ void ArgMax(const T3* axis, const T1* input_data, const Dims<4>& input_dims,
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void Transpose(const T* input, Dims<4>& input_dims, T* output,
|
||||
Dims<4>& output_dims, int* permuted_axes) {
|
||||
void Transpose(const T* input, const Dims<4>& input_dims, T* output,
|
||||
const Dims<4>& output_dims, int* permuted_axes) {
|
||||
int out_sizes[4];
|
||||
// Compute the inverse permutation array so we can do an output centered
|
||||
// transpose. Also, check to make sure output_dims is matching input_dims.
|
||||
|
@ -52,6 +52,7 @@ TfLiteRegistration* Register_RESIZE_BILINEAR();
|
||||
TfLiteRegistration* Register_SKIP_GRAM();
|
||||
TfLiteRegistration* Register_SPACE_TO_DEPTH();
|
||||
TfLiteRegistration* Register_GATHER();
|
||||
TfLiteRegistration* Register_TRANSPOSE();
|
||||
|
||||
BuiltinOpResolver::BuiltinOpResolver() {
|
||||
AddBuiltin(BuiltinOperator_RELU, Register_RELU());
|
||||
@ -90,6 +91,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
|
||||
AddBuiltin(BuiltinOperator_SKIP_GRAM, Register_SKIP_GRAM());
|
||||
AddBuiltin(BuiltinOperator_SPACE_TO_DEPTH, Register_SPACE_TO_DEPTH());
|
||||
AddBuiltin(BuiltinOperator_GATHER, Register_GATHER());
|
||||
AddBuiltin(BuiltinOperator_TRANSPOSE, Register_TRANSPOSE());
|
||||
}
|
||||
|
||||
TfLiteRegistration* BuiltinOpResolver::FindOp(
|
||||
|
142
tensorflow/contrib/lite/kernels/transpose.cc
Normal file
142
tensorflow/contrib/lite/kernels/transpose.cc
Normal file
@ -0,0 +1,142 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include <string.h>
|
||||
#include <vector>
|
||||
#include "tensorflow/contrib/lite/builtin_op_data.h"
|
||||
#include "tensorflow/contrib/lite/context.h"
|
||||
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
|
||||
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
|
||||
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/contrib/lite/kernels/op_macros.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace ops {
|
||||
namespace builtin {
|
||||
namespace transpose {
|
||||
|
||||
// This file has two implementations of Transpose.
|
||||
enum KernelType {
|
||||
kReference,
|
||||
};
|
||||
|
||||
// TODO(nupurgarg): Permutation arrays represented as a tensor are ignored. Only
|
||||
// use the `perm` specified in `params`.
|
||||
struct TransposeContext {
|
||||
TransposeContext(TfLiteContext* context, TfLiteNode* node) {
|
||||
params = reinterpret_cast<TfLiteTransposeParams*>(node->builtin_data);
|
||||
input = GetInput(context, node, 0);
|
||||
output = GetOutput(context, node, 0);
|
||||
}
|
||||
TfLiteTransposeParams* params;
|
||||
TfLiteTensor* input;
|
||||
TfLiteTensor* output;
|
||||
};
|
||||
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE(context, NumInputs(node) == 1 || NumInputs(node) == 2);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
|
||||
TransposeContext op_context(context, node);
|
||||
int dims = NumDimensions(op_context.input);
|
||||
|
||||
// Ensure validity of input tensor and permutation array.
|
||||
TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type);
|
||||
TF_LITE_ENSURE_EQ(context, dims, op_context.params->num_dimensions);
|
||||
TF_LITE_ENSURE_MSG(context, dims <= 4,
|
||||
"Transpose op only supports 1D-4D input arrays.");
|
||||
for (int idx = 0; idx < dims; ++idx) {
|
||||
TF_LITE_ENSURE_MSG(context,
|
||||
op_context.params->perm[idx] >= 0 &&
|
||||
op_context.params->perm[idx] < dims,
|
||||
"Transpose op permutations array is out of bounds.");
|
||||
}
|
||||
|
||||
// Determine size of output tensor.
|
||||
const TfLiteIntArray* input_size = op_context.input->dims;
|
||||
TfLiteIntArray* output_size = TfLiteIntArrayCreate(dims);
|
||||
for (int idx = 0; idx < dims; ++idx) {
|
||||
output_size->data[idx] = input_size->data[op_context.params->perm[idx]];
|
||||
}
|
||||
|
||||
return context->ResizeTensor(context, op_context.output, output_size);
|
||||
}
|
||||
|
||||
template <KernelType kernel_type>
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
TransposeContext op_context(context, node);
|
||||
|
||||
// Reverse the permuted axes and convert to 4D due to the way Dims are
|
||||
// constructed in GetTensorDims.
|
||||
const int kOutputDimensionNum = 4;
|
||||
int reversed_perm[kOutputDimensionNum];
|
||||
int size = op_context.params->num_dimensions;
|
||||
for (int output_k = 0, input_k = size - 1; output_k < size;
|
||||
++output_k, --input_k) {
|
||||
reversed_perm[output_k] = size - op_context.params->perm[input_k] - 1;
|
||||
}
|
||||
for (int k = size; k < kOutputDimensionNum; ++k) {
|
||||
reversed_perm[k] = k;
|
||||
}
|
||||
|
||||
#define TF_LITE_TRANSPOSE(type, scalar) \
|
||||
type::Transpose(GetTensorData<scalar>(op_context.input), \
|
||||
GetTensorDims(op_context.input), \
|
||||
GetTensorData<scalar>(op_context.output), \
|
||||
GetTensorDims(op_context.output), reversed_perm)
|
||||
|
||||
switch (op_context.input->type) {
|
||||
case kTfLiteFloat32:
|
||||
if (kernel_type == kReference) {
|
||||
TF_LITE_TRANSPOSE(reference_ops, float);
|
||||
}
|
||||
break;
|
||||
case kTfLiteUInt8:
|
||||
if (kernel_type == kReference) {
|
||||
TF_LITE_TRANSPOSE(reference_ops, uint8_t);
|
||||
}
|
||||
break;
|
||||
case kTfLiteInt32:
|
||||
if (kernel_type == kReference) {
|
||||
TF_LITE_TRANSPOSE(reference_ops, int32_t);
|
||||
}
|
||||
break;
|
||||
case kTfLiteInt64:
|
||||
if (kernel_type == kReference) {
|
||||
TF_LITE_TRANSPOSE(reference_ops, int64_t);
|
||||
}
|
||||
break;
|
||||
default:
|
||||
context->ReportError(context,
|
||||
"Type is currently not supported by Transpose.");
|
||||
return kTfLiteError;
|
||||
}
|
||||
#undef TF_LITE_TRANSPOSE
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
} // namespace transpose
|
||||
|
||||
TfLiteRegistration* Register_TRANSPOSE_REF() {
|
||||
static TfLiteRegistration r = {nullptr, nullptr, transpose::Prepare,
|
||||
transpose::Eval<transpose::kReference>};
|
||||
return &r;
|
||||
}
|
||||
|
||||
TfLiteRegistration* Register_TRANSPOSE() { return Register_TRANSPOSE_REF(); }
|
||||
|
||||
} // namespace builtin
|
||||
} // namespace ops
|
||||
} // namespace tflite
|
@ -16,11 +16,15 @@ limitations under the License.
|
||||
#include "tensorflow/contrib/lite/interpreter.h"
|
||||
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
|
||||
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
|
||||
#include "tensorflow/contrib/lite/kernels/register.h"
|
||||
#include "tensorflow/contrib/lite/kernels/test_util.h"
|
||||
#include "tensorflow/contrib/lite/model.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace {
|
||||
|
||||
using ::testing::ElementsAreArray;
|
||||
|
||||
void RunTestPermutation(const std::vector<int>& shape,
|
||||
const std::vector<int>& perms,
|
||||
std::vector<float>* input_transposed) {
|
||||
@ -64,14 +68,14 @@ void RunTestPermutation(const std::vector<int>& shape,
|
||||
reversed_perms);
|
||||
}
|
||||
|
||||
TEST(TransposeTest, Test1D) {
|
||||
TEST(TransposeTest, TestRefOps1D) {
|
||||
// Basic 1D identity.
|
||||
std::vector<float> out;
|
||||
RunTestPermutation({3}, {0}, &out);
|
||||
ASSERT_EQ(out, std::vector<float>({0, 1, 2}));
|
||||
}
|
||||
|
||||
TEST(TransposeTest, Test2D) {
|
||||
TEST(TransposeTest, TestRefOps2D) {
|
||||
std::vector<float> out;
|
||||
// Basic 2D.
|
||||
RunTestPermutation({3, 2}, {1, 0}, &out);
|
||||
@ -81,7 +85,7 @@ TEST(TransposeTest, Test2D) {
|
||||
ASSERT_EQ(out, std::vector<float>({0, 1, 2, 3, 4, 5}));
|
||||
}
|
||||
|
||||
TEST(TransposeTest, Test3D) {
|
||||
TEST(TransposeTest, TestRefOps3D) {
|
||||
std::vector<float> out;
|
||||
// Test 3 dimensional
|
||||
{
|
||||
@ -99,7 +103,7 @@ TEST(TransposeTest, Test3D) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(TransposeTest, Test4D) {
|
||||
TEST(TransposeTest, TestRefOps4D) {
|
||||
std::vector<float> out;
|
||||
// Basic 4d.
|
||||
RunTestPermutation({2, 3, 4, 5}, {2, 0, 1, 3}, &out);
|
||||
@ -121,6 +125,118 @@ TEST(TransposeTest, Test4D) {
|
||||
ASSERT_EQ(out, ref);
|
||||
}
|
||||
|
||||
class TransposeOpModel : public SingleOpModel {
|
||||
public:
|
||||
TransposeOpModel(std::initializer_list<int> input_shape,
|
||||
std::initializer_list<int> perm) {
|
||||
input_ = AddInput(TensorType_FLOAT32);
|
||||
output_ = AddOutput(TensorType_FLOAT32);
|
||||
SetBuiltinOp(
|
||||
BuiltinOperator_TRANSPOSE, BuiltinOptions_TransposeOptions,
|
||||
CreateTransposeOptions(builder_, builder_.CreateVector<int>(perm))
|
||||
.Union());
|
||||
BuildInterpreter({input_shape});
|
||||
}
|
||||
|
||||
void SetInput(std::initializer_list<float> data) {
|
||||
PopulateTensor<float>(input_, data);
|
||||
}
|
||||
|
||||
std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
|
||||
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
|
||||
|
||||
private:
|
||||
int input_;
|
||||
int output_;
|
||||
};
|
||||
|
||||
TEST(TransposeTest, TestUnequalPermSize) {
|
||||
EXPECT_DEATH(TransposeOpModel({1, 3, 3, 1}, {2, 2}),
|
||||
"dims != op_context.params->num_dimensions");
|
||||
}
|
||||
|
||||
TEST(TransposeTest, TestPermOutOfBounds) {
|
||||
EXPECT_DEATH(TransposeOpModel({1, 3, 3, 1}, {0, -1, -2, -3}),
|
||||
"Transpose op permutations array is out of bounds.");
|
||||
EXPECT_DEATH(TransposeOpModel({1, 3, 3, 1}, {0, 1, 2, 4}),
|
||||
"Transpose op permutations array is out of bounds.");
|
||||
}
|
||||
|
||||
TEST(TransposeTest, Test1DInputTensor) {
|
||||
TransposeOpModel m({3}, {0});
|
||||
m.SetInput({1, 2, 3});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3}));
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3}));
|
||||
}
|
||||
|
||||
TEST(TransposeTest, Test2DInputTensor) {
|
||||
TransposeOpModel m({3, 2}, {1, 0});
|
||||
m.SetInput({0, 1, 2, 3, 4, 5});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3}));
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 2, 4, 1, 3, 5}));
|
||||
}
|
||||
|
||||
TEST(TransposeTest, Test3DInputTensor) {
|
||||
TransposeOpModel m({2, 3, 4}, {2, 0, 1});
|
||||
m.SetInput({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
|
||||
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 2, 3}));
|
||||
EXPECT_THAT(m.GetOutput(),
|
||||
ElementsAreArray({0, 4, 8, 12, 16, 20, 1, 5, 9, 13, 17, 21,
|
||||
2, 6, 10, 14, 18, 22, 3, 7, 11, 15, 19, 23}));
|
||||
}
|
||||
|
||||
TEST(TransposeTest, Test5DInputTensor) {
|
||||
EXPECT_DEATH(TransposeOpModel({1, 2, 3, 4, 5}, {0, 1, 2, 3, 4}),
|
||||
"Transpose op only supports 1D-4D input arrays.");
|
||||
}
|
||||
|
||||
TEST(TransposeTest, SimpleTestNoReorder) {
|
||||
TransposeOpModel m({1, 2, 3, 1}, {0, 1, 2, 3});
|
||||
m.SetInput({1, 2, 3, 4, 5, 6});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2, 3, 1}));
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
|
||||
}
|
||||
|
||||
TEST(TransposeTest, SimpleTestWithReorder) {
|
||||
TransposeOpModel m({1, 2, 3, 1}, {2, 1, 3, 0});
|
||||
m.SetInput({1, 2, 3, 4, 5, 6});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 2, 1, 1}));
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 4, 2, 5, 3, 6}));
|
||||
}
|
||||
|
||||
TEST(TransposeTest, ComplexTestWithReorder) {
|
||||
TransposeOpModel m({2, 3, 4, 5}, {2, 0, 1, 3});
|
||||
m.SetInput({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
|
||||
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
|
||||
24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
|
||||
36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
|
||||
48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59,
|
||||
60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71,
|
||||
72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83,
|
||||
84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95,
|
||||
96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107,
|
||||
108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119});
|
||||
m.Invoke();
|
||||
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 2, 3, 5}));
|
||||
auto result = ElementsAreArray(
|
||||
{0, 1, 2, 3, 4, 20, 21, 22, 23, 24, 40, 41, 42, 43, 44,
|
||||
60, 61, 62, 63, 64, 80, 81, 82, 83, 84, 100, 101, 102, 103, 104,
|
||||
5, 6, 7, 8, 9, 25, 26, 27, 28, 29, 45, 46, 47, 48, 49,
|
||||
65, 66, 67, 68, 69, 85, 86, 87, 88, 89, 105, 106, 107, 108, 109,
|
||||
10, 11, 12, 13, 14, 30, 31, 32, 33, 34, 50, 51, 52, 53, 54,
|
||||
70, 71, 72, 73, 74, 90, 91, 92, 93, 94, 110, 111, 112, 113, 114,
|
||||
15, 16, 17, 18, 19, 35, 36, 37, 38, 39, 55, 56, 57, 58, 59,
|
||||
75, 76, 77, 78, 79, 95, 96, 97, 98, 99, 115, 116, 117, 118, 119});
|
||||
EXPECT_THAT(m.GetOutput(), result);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tflite
|
||||
|
||||
|
@ -555,6 +555,17 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type,
|
||||
builtin_data = reinterpret_cast<void*>(params);
|
||||
break;
|
||||
}
|
||||
case BuiltinOperator_TRANSPOSE: {
|
||||
auto* params = MallocPOD<TfLiteTransposeParams>();
|
||||
if (auto* schema_params = op->builtin_options_as_TransposeOptions()) {
|
||||
const auto& perm = schema_params->perm();
|
||||
FlatBufferIntVectorToArray(sizeof(params->perm), perm, params->perm,
|
||||
error_reporter);
|
||||
params->num_dimensions = perm->Length();
|
||||
}
|
||||
builtin_data = reinterpret_cast<void*>(params);
|
||||
break;
|
||||
}
|
||||
}
|
||||
return builtin_data;
|
||||
}
|
||||
|
@ -309,6 +309,7 @@ void AddOpsAndParams(tflite::Interpreter* interpreter,
|
||||
case tflite::BuiltinOperator_GATHER:
|
||||
case tflite::BuiltinOperator_SPACE_TO_BATCH_ND:
|
||||
case tflite::BuiltinOperator_BATCH_TO_SPACE_ND:
|
||||
case tflite::BuiltinOperator_TRANSPOSE:
|
||||
FATAL("Op code %d is currently not delegated to NNAPI", builtin);
|
||||
nn_op_type = -1; // set to invalid
|
||||
break;
|
||||
|
@ -109,6 +109,7 @@ enum BuiltinOperator : byte {
|
||||
GATHER = 36,
|
||||
BATCH_TO_SPACE_ND = 37,
|
||||
SPACE_TO_BATCH_ND = 38,
|
||||
TRANSPOSE = 39,
|
||||
}
|
||||
|
||||
// Options for the builtin operators.
|
||||
@ -138,6 +139,7 @@ union BuiltinOptions {
|
||||
GatherOptions,
|
||||
BatchToSpaceNDOptions,
|
||||
SpaceToBatchNDOptions,
|
||||
TransposeOptions,
|
||||
}
|
||||
|
||||
enum Padding : byte { SAME, VALID }
|
||||
@ -298,6 +300,10 @@ table GatherOptions {
|
||||
axis: int;
|
||||
}
|
||||
|
||||
table TransposeOptions {
|
||||
perm:[int];
|
||||
}
|
||||
|
||||
// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a
|
||||
// builtin, or a string if the operator is custom.
|
||||
table OperatorCode {
|
||||
|
@ -1,4 +1,4 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// automatically generated by the FlatBuffers compiler, do not modify
|
||||
|
||||
#ifndef FLATBUFFERS_GENERATED_SCHEMA_TFLITE_H_
|
||||
@ -103,6 +102,9 @@ struct EmbeddingLookupSparseOptionsT;
|
||||
struct GatherOptions;
|
||||
struct GatherOptionsT;
|
||||
|
||||
struct TransposeOptions;
|
||||
struct TransposeOptionsT;
|
||||
|
||||
struct OperatorCode;
|
||||
struct OperatorCodeT;
|
||||
|
||||
@ -184,11 +186,12 @@ enum BuiltinOperator {
|
||||
BuiltinOperator_GATHER = 36,
|
||||
BuiltinOperator_BATCH_TO_SPACE_ND = 37,
|
||||
BuiltinOperator_SPACE_TO_BATCH_ND = 38,
|
||||
BuiltinOperator_TRANSPOSE = 39,
|
||||
BuiltinOperator_MIN = BuiltinOperator_ADD,
|
||||
BuiltinOperator_MAX = BuiltinOperator_SPACE_TO_BATCH_ND
|
||||
BuiltinOperator_MAX = BuiltinOperator_TRANSPOSE
|
||||
};
|
||||
|
||||
inline BuiltinOperator (&EnumValuesBuiltinOperator())[36] {
|
||||
inline BuiltinOperator (&EnumValuesBuiltinOperator())[37] {
|
||||
static BuiltinOperator values[] = {
|
||||
BuiltinOperator_ADD,
|
||||
BuiltinOperator_AVERAGE_POOL_2D,
|
||||
@ -225,7 +228,8 @@ inline BuiltinOperator (&EnumValuesBuiltinOperator())[36] {
|
||||
BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN,
|
||||
BuiltinOperator_GATHER,
|
||||
BuiltinOperator_BATCH_TO_SPACE_ND,
|
||||
BuiltinOperator_SPACE_TO_BATCH_ND};
|
||||
BuiltinOperator_SPACE_TO_BATCH_ND,
|
||||
BuiltinOperator_TRANSPOSE};
|
||||
return values;
|
||||
}
|
||||
|
||||
@ -269,6 +273,7 @@ inline const char **EnumNamesBuiltinOperator() {
|
||||
"GATHER",
|
||||
"BATCH_TO_SPACE_ND",
|
||||
"SPACE_TO_BATCH_ND",
|
||||
"TRANSPOSE",
|
||||
nullptr};
|
||||
return names;
|
||||
}
|
||||
@ -305,11 +310,12 @@ enum BuiltinOptions {
|
||||
BuiltinOptions_GatherOptions = 23,
|
||||
BuiltinOptions_BatchToSpaceNDOptions = 24,
|
||||
BuiltinOptions_SpaceToBatchNDOptions = 25,
|
||||
BuiltinOptions_TransposeOptions = 26,
|
||||
BuiltinOptions_MIN = BuiltinOptions_NONE,
|
||||
BuiltinOptions_MAX = BuiltinOptions_SpaceToBatchNDOptions
|
||||
BuiltinOptions_MAX = BuiltinOptions_TransposeOptions
|
||||
};
|
||||
|
||||
inline BuiltinOptions (&EnumValuesBuiltinOptions())[26] {
|
||||
inline BuiltinOptions (&EnumValuesBuiltinOptions())[27] {
|
||||
static BuiltinOptions values[] = {
|
||||
BuiltinOptions_NONE,
|
||||
BuiltinOptions_Conv2DOptions,
|
||||
@ -336,7 +342,8 @@ inline BuiltinOptions (&EnumValuesBuiltinOptions())[26] {
|
||||
BuiltinOptions_PadOptions,
|
||||
BuiltinOptions_GatherOptions,
|
||||
BuiltinOptions_BatchToSpaceNDOptions,
|
||||
BuiltinOptions_SpaceToBatchNDOptions};
|
||||
BuiltinOptions_SpaceToBatchNDOptions,
|
||||
BuiltinOptions_TransposeOptions};
|
||||
return values;
|
||||
}
|
||||
|
||||
@ -367,6 +374,7 @@ inline const char **EnumNamesBuiltinOptions() {
|
||||
"GatherOptions",
|
||||
"BatchToSpaceNDOptions",
|
||||
"SpaceToBatchNDOptions",
|
||||
"TransposeOptions",
|
||||
nullptr};
|
||||
return names;
|
||||
}
|
||||
@ -510,6 +518,11 @@ struct BuiltinOptionsTraits<SpaceToBatchNDOptions> {
|
||||
static const BuiltinOptions enum_value = BuiltinOptions_SpaceToBatchNDOptions;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct BuiltinOptionsTraits<TransposeOptions> {
|
||||
static const BuiltinOptions enum_value = BuiltinOptions_TransposeOptions;
|
||||
};
|
||||
|
||||
struct BuiltinOptionsUnion {
|
||||
BuiltinOptions type;
|
||||
void *value;
|
||||
@ -807,6 +820,16 @@ struct BuiltinOptionsUnion {
|
||||
? reinterpret_cast<const SpaceToBatchNDOptionsT *>(value)
|
||||
: nullptr;
|
||||
}
|
||||
TransposeOptionsT *AsTransposeOptions() {
|
||||
return type == BuiltinOptions_TransposeOptions
|
||||
? reinterpret_cast<TransposeOptionsT *>(value)
|
||||
: nullptr;
|
||||
}
|
||||
const TransposeOptionsT *AsTransposeOptions() const {
|
||||
return type == BuiltinOptions_TransposeOptions
|
||||
? reinterpret_cast<const TransposeOptionsT *>(value)
|
||||
: nullptr;
|
||||
}
|
||||
};
|
||||
|
||||
bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj,
|
||||
@ -2996,6 +3019,69 @@ flatbuffers::Offset<GatherOptions> CreateGatherOptions(
|
||||
flatbuffers::FlatBufferBuilder &_fbb, const GatherOptionsT *_o,
|
||||
const flatbuffers::rehasher_function_t *_rehasher = nullptr);
|
||||
|
||||
struct TransposeOptionsT : public flatbuffers::NativeTable {
|
||||
typedef TransposeOptions TableType;
|
||||
std::vector<int32_t> perm;
|
||||
TransposeOptionsT() {}
|
||||
};
|
||||
|
||||
struct TransposeOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
||||
typedef TransposeOptionsT NativeTableType;
|
||||
enum { VT_PERM = 4 };
|
||||
const flatbuffers::Vector<int32_t> *perm() const {
|
||||
return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_PERM);
|
||||
}
|
||||
bool Verify(flatbuffers::Verifier &verifier) const {
|
||||
return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_PERM) &&
|
||||
verifier.Verify(perm()) && verifier.EndTable();
|
||||
}
|
||||
TransposeOptionsT *UnPack(
|
||||
const flatbuffers::resolver_function_t *_resolver = nullptr) const;
|
||||
void UnPackTo(
|
||||
TransposeOptionsT *_o,
|
||||
const flatbuffers::resolver_function_t *_resolver = nullptr) const;
|
||||
static flatbuffers::Offset<TransposeOptions> Pack(
|
||||
flatbuffers::FlatBufferBuilder &_fbb, const TransposeOptionsT *_o,
|
||||
const flatbuffers::rehasher_function_t *_rehasher = nullptr);
|
||||
};
|
||||
|
||||
struct TransposeOptionsBuilder {
|
||||
flatbuffers::FlatBufferBuilder &fbb_;
|
||||
flatbuffers::uoffset_t start_;
|
||||
void add_perm(flatbuffers::Offset<flatbuffers::Vector<int32_t>> perm) {
|
||||
fbb_.AddOffset(TransposeOptions::VT_PERM, perm);
|
||||
}
|
||||
explicit TransposeOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
|
||||
: fbb_(_fbb) {
|
||||
start_ = fbb_.StartTable();
|
||||
}
|
||||
TransposeOptionsBuilder &operator=(const TransposeOptionsBuilder &);
|
||||
flatbuffers::Offset<TransposeOptions> Finish() {
|
||||
const auto end = fbb_.EndTable(start_);
|
||||
auto o = flatbuffers::Offset<TransposeOptions>(end);
|
||||
return o;
|
||||
}
|
||||
};
|
||||
|
||||
inline flatbuffers::Offset<TransposeOptions> CreateTransposeOptions(
|
||||
flatbuffers::FlatBufferBuilder &_fbb,
|
||||
flatbuffers::Offset<flatbuffers::Vector<int32_t>> perm = 0) {
|
||||
TransposeOptionsBuilder builder_(_fbb);
|
||||
builder_.add_perm(perm);
|
||||
return builder_.Finish();
|
||||
}
|
||||
|
||||
inline flatbuffers::Offset<TransposeOptions> CreateTransposeOptionsDirect(
|
||||
flatbuffers::FlatBufferBuilder &_fbb,
|
||||
const std::vector<int32_t> *perm = nullptr) {
|
||||
return tflite::CreateTransposeOptions(
|
||||
_fbb, perm ? _fbb.CreateVector<int32_t>(*perm) : 0);
|
||||
}
|
||||
|
||||
flatbuffers::Offset<TransposeOptions> CreateTransposeOptions(
|
||||
flatbuffers::FlatBufferBuilder &_fbb, const TransposeOptionsT *_o,
|
||||
const flatbuffers::rehasher_function_t *_rehasher = nullptr);
|
||||
|
||||
struct OperatorCodeT : public flatbuffers::NativeTable {
|
||||
typedef OperatorCode TableType;
|
||||
BuiltinOperator builtin_code;
|
||||
@ -3250,6 +3336,11 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
||||
? static_cast<const SpaceToBatchNDOptions *>(builtin_options())
|
||||
: nullptr;
|
||||
}
|
||||
const TransposeOptions *builtin_options_as_TransposeOptions() const {
|
||||
return builtin_options_type() == BuiltinOptions_TransposeOptions
|
||||
? static_cast<const TransposeOptions *>(builtin_options())
|
||||
: nullptr;
|
||||
}
|
||||
const flatbuffers::Vector<uint8_t> *custom_options() const {
|
||||
return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS);
|
||||
}
|
||||
@ -3424,6 +3515,12 @@ Operator::builtin_options_as<SpaceToBatchNDOptions>() const {
|
||||
return builtin_options_as_SpaceToBatchNDOptions();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline const TransposeOptions *Operator::builtin_options_as<TransposeOptions>()
|
||||
const {
|
||||
return builtin_options_as_TransposeOptions();
|
||||
}
|
||||
|
||||
struct OperatorBuilder {
|
||||
flatbuffers::FlatBufferBuilder &fbb_;
|
||||
flatbuffers::uoffset_t start_;
|
||||
@ -5183,6 +5280,50 @@ inline flatbuffers::Offset<GatherOptions> CreateGatherOptions(
|
||||
return tflite::CreateGatherOptions(_fbb, _axis);
|
||||
}
|
||||
|
||||
inline TransposeOptionsT *TransposeOptions::UnPack(
|
||||
const flatbuffers::resolver_function_t *_resolver) const {
|
||||
auto _o = new TransposeOptionsT();
|
||||
UnPackTo(_o, _resolver);
|
||||
return _o;
|
||||
}
|
||||
|
||||
inline void TransposeOptions::UnPackTo(
|
||||
TransposeOptionsT *_o,
|
||||
const flatbuffers::resolver_function_t *_resolver) const {
|
||||
(void)_o;
|
||||
(void)_resolver;
|
||||
{
|
||||
auto _e = perm();
|
||||
if (_e) {
|
||||
_o->perm.resize(_e->size());
|
||||
for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) {
|
||||
_o->perm[_i] = _e->Get(_i);
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
inline flatbuffers::Offset<TransposeOptions> TransposeOptions::Pack(
|
||||
flatbuffers::FlatBufferBuilder &_fbb, const TransposeOptionsT *_o,
|
||||
const flatbuffers::rehasher_function_t *_rehasher) {
|
||||
return CreateTransposeOptions(_fbb, _o, _rehasher);
|
||||
}
|
||||
|
||||
inline flatbuffers::Offset<TransposeOptions> CreateTransposeOptions(
|
||||
flatbuffers::FlatBufferBuilder &_fbb, const TransposeOptionsT *_o,
|
||||
const flatbuffers::rehasher_function_t *_rehasher) {
|
||||
(void)_rehasher;
|
||||
(void)_o;
|
||||
struct _VectorArgs {
|
||||
flatbuffers::FlatBufferBuilder *__fbb;
|
||||
const TransposeOptionsT *__o;
|
||||
const flatbuffers::rehasher_function_t *__rehasher;
|
||||
} _va = {&_fbb, _o, _rehasher};
|
||||
(void)_va;
|
||||
auto _perm = _o->perm.size() ? _fbb.CreateVector(_o->perm) : 0;
|
||||
return tflite::CreateTransposeOptions(_fbb, _perm);
|
||||
}
|
||||
|
||||
inline OperatorCodeT *OperatorCode::UnPack(
|
||||
const flatbuffers::resolver_function_t *_resolver) const {
|
||||
auto _o = new OperatorCodeT();
|
||||
@ -5671,6 +5812,10 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier,
|
||||
auto ptr = reinterpret_cast<const SpaceToBatchNDOptions *>(obj);
|
||||
return verifier.VerifyTable(ptr);
|
||||
}
|
||||
case BuiltinOptions_TransposeOptions: {
|
||||
auto ptr = reinterpret_cast<const TransposeOptions *>(obj);
|
||||
return verifier.VerifyTable(ptr);
|
||||
}
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
@ -5795,6 +5940,10 @@ inline void *BuiltinOptionsUnion::UnPack(
|
||||
auto ptr = reinterpret_cast<const SpaceToBatchNDOptions *>(obj);
|
||||
return ptr->UnPack(resolver);
|
||||
}
|
||||
case BuiltinOptions_TransposeOptions: {
|
||||
auto ptr = reinterpret_cast<const TransposeOptions *>(obj);
|
||||
return ptr->UnPack(resolver);
|
||||
}
|
||||
default:
|
||||
return nullptr;
|
||||
}
|
||||
@ -5906,6 +6055,10 @@ inline flatbuffers::Offset<void> BuiltinOptionsUnion::Pack(
|
||||
auto ptr = reinterpret_cast<const SpaceToBatchNDOptionsT *>(value);
|
||||
return CreateSpaceToBatchNDOptions(_fbb, ptr, _rehasher).Union();
|
||||
}
|
||||
case BuiltinOptions_TransposeOptions: {
|
||||
auto ptr = reinterpret_cast<const TransposeOptionsT *>(value);
|
||||
return CreateTransposeOptions(_fbb, ptr, _rehasher).Union();
|
||||
}
|
||||
default:
|
||||
return 0;
|
||||
}
|
||||
@ -6029,6 +6182,11 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u)
|
||||
*reinterpret_cast<SpaceToBatchNDOptionsT *>(u.value));
|
||||
break;
|
||||
}
|
||||
case BuiltinOptions_TransposeOptions: {
|
||||
value = new TransposeOptionsT(
|
||||
*reinterpret_cast<TransposeOptionsT *>(u.value));
|
||||
break;
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
@ -6161,6 +6319,11 @@ inline void BuiltinOptionsUnion::Reset() {
|
||||
delete ptr;
|
||||
break;
|
||||
}
|
||||
case BuiltinOptions_TransposeOptions: {
|
||||
auto ptr = reinterpret_cast<TransposeOptionsT *>(value);
|
||||
delete ptr;
|
||||
break;
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
@ -43,6 +43,7 @@ gen_zipped_test_files(
|
||||
"softmax.zip",
|
||||
"space_to_batch_nd.zip",
|
||||
"space_to_depth.zip",
|
||||
"transpose.zip",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -1283,6 +1283,41 @@ def make_batch_to_space_nd_tests(zip_path):
|
||||
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
|
||||
|
||||
|
||||
def make_transpose_tests(zip_path):
|
||||
"""Make a set of tests to do transpose."""
|
||||
|
||||
# TODO(nupurgarg): Add test for uint8.
|
||||
test_parameters = [{
|
||||
"dtype": [tf.int32, tf.int64, tf.float32],
|
||||
"input_shape": [[2, 2, 3]],
|
||||
"perm": [[0, 1, 2], [0, 2, 1]],
|
||||
}, {
|
||||
"dtype": [tf.float32],
|
||||
"input_shape": [[1, 2, 3, 4]],
|
||||
"perm": [[0, 1, 2, 3], [3, 0, 1, 2]],
|
||||
}, {
|
||||
"dtype": [tf.float32],
|
||||
"input_shape": [[1, 2, 3, 4, 5]],
|
||||
"perm": [[0, 1, 2, 3, 4]],
|
||||
}]
|
||||
|
||||
def build_graph(parameters):
|
||||
input_tensor = tf.placeholder(
|
||||
dtype=parameters["dtype"],
|
||||
name="input",
|
||||
shape=parameters["input_shape"])
|
||||
out = tf.transpose(input_tensor, perm=parameters["perm"])
|
||||
return [input_tensor], [out]
|
||||
|
||||
def build_inputs(parameters, sess, inputs, outputs):
|
||||
input_values = create_tensor_data(parameters["dtype"],
|
||||
parameters["input_shape"])
|
||||
return [input_values], sess.run(
|
||||
outputs, feed_dict=dict(zip(inputs, [input_values])))
|
||||
|
||||
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
|
||||
|
||||
|
||||
def make_l2_pool(input_tensor, ksize, strides, padding, data_format):
|
||||
"""Given an input perform a sequence of TensorFlow ops to produce l2pool."""
|
||||
return tf.sqrt(tf.nn.avg_pool(
|
||||
@ -1336,6 +1371,7 @@ def main(unused_args):
|
||||
"sigmoid.zip": make_sigmoid_tests,
|
||||
"softmax.zip": make_softmax_tests,
|
||||
"space_to_depth.zip": make_space_to_depth_tests,
|
||||
"transpose.zip": make_transpose_tests,
|
||||
}
|
||||
out = FLAGS.zip_to_output
|
||||
bin_path = FLAGS.toco
|
||||
|
@ -85,6 +85,9 @@ std::map<string, string> kBrokenTests = {
|
||||
|
||||
// ResizeBilinear looks completely incompatible with Tensorflow
|
||||
{R"(resize_bilinear)", "67964336"},
|
||||
|
||||
// Transpose only supports 1D-4D input tensors.
|
||||
{R"(transposedtype=.*,input_shape=\[.,.,.,.,.\],perm=.*)", "71545879"},
|
||||
};
|
||||
|
||||
// Allows test data to be unzipped into a temporary directory and makes
|
||||
@ -270,6 +273,7 @@ INSTANTIATE_TESTS(resize_bilinear)
|
||||
INSTANTIATE_TESTS(sigmoid)
|
||||
INSTANTIATE_TESTS(softmax)
|
||||
INSTANTIATE_TESTS(space_to_depth)
|
||||
INSTANTIATE_TESTS(transpose)
|
||||
|
||||
} // namespace testing
|
||||
} // namespace tflite
|
||||
|
@ -222,6 +222,7 @@ cc_library(
|
||||
"graph_transformations/resolve_tensorflow_squeeze.cc",
|
||||
"graph_transformations/resolve_tensorflow_switch.cc",
|
||||
"graph_transformations/resolve_tensorflow_tile.cc",
|
||||
"graph_transformations/resolve_transpose_attributes.cc",
|
||||
"graph_transformations/unfuse_activation_functions.cc",
|
||||
],
|
||||
hdrs = [
|
||||
|
@ -158,6 +158,7 @@ DECLARE_GRAPH_TRANSFORMATION(ResolvePadAttributes)
|
||||
DECLARE_GRAPH_TRANSFORMATION(ResolveStridedSliceAttributes)
|
||||
DECLARE_GRAPH_TRANSFORMATION(ResolveSliceAttributes)
|
||||
DECLARE_GRAPH_TRANSFORMATION(ResolveMeanAttributes)
|
||||
DECLARE_GRAPH_TRANSFORMATION(ResolveTransposeAttributes)
|
||||
DECLARE_GRAPH_TRANSFORMATION(ResolveConstantTensorFlowShape)
|
||||
DECLARE_GRAPH_TRANSFORMATION(ResolveConstantFill)
|
||||
DECLARE_GRAPH_TRANSFORMATION(Dequantize)
|
||||
|
@ -631,6 +631,9 @@ void ProcessResizeBilinearOperator(Model* model, ResizeBilinearOperator* op) {
|
||||
const auto& output_size_shape = output_size_array.shape();
|
||||
CHECK_EQ(output_size_shape.dimensions_count(), 1);
|
||||
CHECK_EQ(output_size_shape.dims(0), 2);
|
||||
if (!output_size_array.buffer) {
|
||||
return;
|
||||
}
|
||||
std::vector<int32> output_shape =
|
||||
output_size_array.GetBuffer<ArrayDataType::kInt32>().data;
|
||||
model->arrays[op->outputs[0]]->copy_shape(
|
||||
|
@ -0,0 +1,53 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
|
||||
#include "tensorflow/contrib/lite/toco/model.h"
|
||||
#include "tensorflow/contrib/lite/toco/tooling_util.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
namespace toco {
|
||||
|
||||
bool ResolveTransposeAttributes::Run(Model* model, std::size_t op_index) {
|
||||
const auto op_it = model->operators.begin() + op_index;
|
||||
if (op_it->get()->type != OperatorType::kTranspose) return false;
|
||||
|
||||
auto* op = static_cast<TransposeOperator*>(op_it->get());
|
||||
if (!op->perm.empty()) return false;
|
||||
|
||||
CHECK_EQ(op->inputs.size(), 2);
|
||||
if (!IsConstantParameterArray(*model, op->inputs[1])) return false;
|
||||
|
||||
// Handling perm.
|
||||
const auto& perm_array = *model->arrays[op->inputs[1]];
|
||||
if (!perm_array.has_shape()) return false;
|
||||
|
||||
const std::vector<int>& perm_dims = perm_array.shape().dims();
|
||||
CHECK_EQ(perm_dims.size(), 1);
|
||||
|
||||
std::vector<int> perm_buffer =
|
||||
perm_array.GetBuffer<ArrayDataType::kInt32>().data;
|
||||
for (int i = 0; i < perm_dims[0]; ++i) {
|
||||
op->perm.push_back(perm_buffer[i]);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace toco
|
@ -957,6 +957,7 @@ struct TensorFlowSquareOperator : Operator {
|
||||
// TensorFlow equivalent: Transpose
|
||||
struct TransposeOperator : Operator {
|
||||
TransposeOperator() : Operator(OperatorType::kTranspose) {}
|
||||
std::vector<int> perm;
|
||||
};
|
||||
|
||||
// Element-wise subtraction operator.
|
||||
|
@ -506,6 +506,25 @@ class SpaceToDepth
|
||||
}
|
||||
};
|
||||
|
||||
class Transpose
|
||||
: public BuiltinOperator<TransposeOperator, ::tflite::TransposeOptions,
|
||||
::tflite::BuiltinOptions_TransposeOptions> {
|
||||
public:
|
||||
using BuiltinOperator::BuiltinOperator;
|
||||
flatbuffers::Offset<TfLiteOptions> WriteOptions(
|
||||
const TocoOperator& op,
|
||||
flatbuffers::FlatBufferBuilder* builder) const override {
|
||||
return ::tflite::CreateTransposeOptions(*builder,
|
||||
builder->CreateVector(op.perm));
|
||||
}
|
||||
|
||||
void ReadOptions(const TfLiteOptions& options,
|
||||
TocoOperator* op) const override {
|
||||
op->perm.insert(op->perm.end(), options.perm()->begin(),
|
||||
options.perm()->end());
|
||||
}
|
||||
};
|
||||
|
||||
class Split : public CustomOperator<TensorFlowSplitOperator> {
|
||||
public:
|
||||
using CustomOperator::CustomOperator;
|
||||
@ -670,6 +689,8 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() {
|
||||
OperatorType::kSpaceToDepth));
|
||||
ops.emplace_back(
|
||||
new Svdf(::tflite::BuiltinOperator_SVDF, OperatorType::kSvdf));
|
||||
ops.emplace_back(new Transpose(::tflite::BuiltinOperator_TRANSPOSE,
|
||||
OperatorType::kTranspose));
|
||||
|
||||
// Custom Operators.
|
||||
ops.emplace_back(new Cast("CAST", OperatorType::kCast));
|
||||
|
@ -369,6 +369,15 @@ TEST_F(OperatorTest, Svdf) {
|
||||
EXPECT_EQ(op.rank, output_toco_op->rank);
|
||||
}
|
||||
|
||||
TEST_F(OperatorTest, Transpose) {
|
||||
TransposeOperator op;
|
||||
op.perm = {0, 1, 2, 3};
|
||||
|
||||
auto output_toco_op = SerializeAndDeserialize(
|
||||
GetOperator("TRANSPOSE", OperatorType::kTranspose), op);
|
||||
EXPECT_EQ(op.perm, output_toco_op->perm);
|
||||
}
|
||||
|
||||
TEST_F(OperatorTest, TensorFlowUnsupported) {
|
||||
TensorFlowUnsupportedOperator op;
|
||||
op.tensorflow_op = "MyCustomUnsupportedOp";
|
||||
|
@ -85,6 +85,7 @@ void MakeGeneralGraphTransformationsSet(
|
||||
transformations->Add(new ResolveStridedSliceAttributes);
|
||||
transformations->Add(new ResolveSliceAttributes);
|
||||
transformations->Add(new ResolveMeanAttributes);
|
||||
transformations->Add(new ResolveTransposeAttributes);
|
||||
transformations->Add(new ResolveConstantTensorFlowShape);
|
||||
transformations->Add(new MakeInitialDequantizeOperator);
|
||||
}
|
||||
|
100
tensorflow/contrib/py2tf/convert/BUILD
Normal file
100
tensorflow/contrib/py2tf/convert/BUILD
Normal file
@ -0,0 +1,100 @@
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "py_test")
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
["**/*"],
|
||||
exclude = [
|
||||
"**/METADATA",
|
||||
"**/OWNERS",
|
||||
],
|
||||
),
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "convert",
|
||||
srcs = [
|
||||
"call_trees.py",
|
||||
"control_flow.py",
|
||||
"gradients_function.py",
|
||||
"logical_expressions.py",
|
||||
"print_functions.py",
|
||||
"side_effect_guards.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
deps = [
|
||||
"@gast_archive//:gast",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "call_trees_test",
|
||||
srcs = ["call_trees_test.py"],
|
||||
deps = [
|
||||
":convert",
|
||||
"//tensorflow/contrib/py2tf/pyct",
|
||||
"//tensorflow/contrib/py2tf/pyct/static_analysis",
|
||||
"//tensorflow/python:client_testlib",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "control_flow_test",
|
||||
srcs = ["control_flow_test.py"],
|
||||
deps = [
|
||||
":convert",
|
||||
"//tensorflow/contrib/py2tf/pyct",
|
||||
"//tensorflow/contrib/py2tf/pyct/static_analysis",
|
||||
"//tensorflow/python:client_testlib",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "gradients_function_test",
|
||||
srcs = ["gradients_function_test.py"],
|
||||
deps = [
|
||||
":convert",
|
||||
"//tensorflow/contrib/eager/python:tfe",
|
||||
"//tensorflow/contrib/py2tf/pyct",
|
||||
"//tensorflow/contrib/py2tf/pyct/static_analysis",
|
||||
"//tensorflow/python:client_testlib",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "logical_expressions_test",
|
||||
srcs = ["logical_expressions_test.py"],
|
||||
deps = [
|
||||
":convert",
|
||||
"//tensorflow/contrib/py2tf/pyct",
|
||||
"//tensorflow/contrib/py2tf/pyct/static_analysis",
|
||||
"//tensorflow/python:client_testlib",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "print_functions_test",
|
||||
srcs = ["print_functions_test.py"],
|
||||
deps = [
|
||||
":convert",
|
||||
"//tensorflow/contrib/py2tf/pyct",
|
||||
"//tensorflow/contrib/py2tf/pyct/static_analysis",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"@gast_archive//:gast",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "side_effect_guards_test",
|
||||
srcs = ["side_effect_guards_test.py"],
|
||||
deps = [
|
||||
":convert",
|
||||
"//tensorflow/contrib/py2tf/pyct",
|
||||
"//tensorflow/contrib/py2tf/pyct/static_analysis",
|
||||
"//tensorflow/python:client_testlib",
|
||||
],
|
||||
)
|
21
tensorflow/contrib/py2tf/convert/__init__.py
Normal file
21
tensorflow/contrib/py2tf/convert/__init__.py
Normal file
@ -0,0 +1,21 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Code converters used by Py2TF."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# TODO(mdan): Define a base transformer class that can recognize skip_processing
|
158
tensorflow/contrib/py2tf/convert/call_trees.py
Normal file
158
tensorflow/contrib/py2tf/convert/call_trees.py
Normal file
@ -0,0 +1,158 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Handles function calls, by generating compiled function names and calls."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import types
|
||||
|
||||
import gast
|
||||
|
||||
from tensorflow.contrib.py2tf.pyct import anno
|
||||
from tensorflow.contrib.py2tf.pyct import templates
|
||||
|
||||
|
||||
class FunctionNamer(object):
|
||||
"""Describes the interface for CallTreeTransformer's namer."""
|
||||
|
||||
def compiled_function_name(self, original_name, live_object=None):
|
||||
"""Generate the name corresponding to the compiled version of a function.
|
||||
|
||||
Args:
|
||||
original_name: String
|
||||
live_object: Callable, the actual target function, if known.
|
||||
Returns:
|
||||
String.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class CallTreeTransformer(gast.NodeTransformer):
|
||||
"""Transforms the call tree by renaming transformed symbols."""
|
||||
|
||||
def __init__(self, namer, uncompiled_modules):
|
||||
self.namer = namer
|
||||
self.uncompiled_modules = uncompiled_modules
|
||||
|
||||
# pylint:disable=invalid-name
|
||||
|
||||
def visit_FunctionDef(self, node):
|
||||
self.generic_visit(node)
|
||||
node.name = self.namer.compiled_function_name(node.name)
|
||||
return node
|
||||
|
||||
def _rename_compilable_function(self, node):
|
||||
assert anno.hasanno(node.func, 'live_val')
|
||||
assert anno.hasanno(node.func, 'fqn')
|
||||
target_obj = anno.getanno(node.func, 'live_val')
|
||||
target_fqn = anno.getanno(node.func, 'fqn')
|
||||
|
||||
fqn = ''
|
||||
for s in target_fqn:
|
||||
if fqn:
|
||||
fqn += '.'
|
||||
fqn += s
|
||||
if fqn in self.uncompiled_modules:
|
||||
return node
|
||||
|
||||
new_name = self.namer.compiled_function_name(fqn, live_object=target_obj)
|
||||
node.func = gast.Name(id=new_name, ctx=gast.Load(), annotation=None)
|
||||
return node
|
||||
|
||||
def _rename_member_function_of_known_type(self, node):
|
||||
target_fqn = anno.getanno(node.func, 'type_fqn')
|
||||
|
||||
fqn = ''
|
||||
for s in target_fqn:
|
||||
if fqn:
|
||||
fqn += '.'
|
||||
fqn += s
|
||||
if fqn in self.uncompiled_modules:
|
||||
return node
|
||||
|
||||
raise NotImplementedError('Member function call (of known type).')
|
||||
|
||||
def _wrap_to_py_func_no_return(self, node):
|
||||
args_scope = anno.getanno(node, 'args_scope')
|
||||
# TODO(mdan): Properly handle varargs, kwargs, etc.
|
||||
args = tuple(gast.Name(n, gast.Load(), None) for n in args_scope.used)
|
||||
|
||||
# pylint:disable=undefined-variable,unused-argument,function-redefined
|
||||
|
||||
def template(call, wrapper, args):
|
||||
|
||||
def wrapper(args):
|
||||
call(args)
|
||||
return 1
|
||||
|
||||
tf.py_func(wrapper, [args], [tf.int64])
|
||||
|
||||
# pylint:enable=undefined-variable,unused-argument,function-redefined
|
||||
|
||||
wrapper_name = self.namer.compiled_function_name(node.func.id)
|
||||
wrapper_def, call_expr = templates.replace(
|
||||
template,
|
||||
call=node.func,
|
||||
wrapper=gast.Name(wrapper_name, gast.Load(), None),
|
||||
args=args)
|
||||
anno.setanno(call_expr.value, 'args_scope', args_scope)
|
||||
anno.setanno(wrapper_def, 'skip_processing', True)
|
||||
|
||||
return (wrapper_def, call_expr)
|
||||
|
||||
def visit_Expr(self, node):
|
||||
if isinstance(node.value, gast.Call):
|
||||
node = self._wrap_to_py_func_no_return(node.value)
|
||||
else:
|
||||
self.generic_visit(node)
|
||||
return node
|
||||
|
||||
def visit_Call(self, node):
|
||||
self.generic_visit(node)
|
||||
if anno.hasanno(node.func, 'live_val'):
|
||||
target_obj = anno.getanno(node.func, 'live_val')
|
||||
if isinstance(target_obj, types.BuiltinFunctionType):
|
||||
raise NotImplementedError('py_func with return values')
|
||||
else:
|
||||
node = self._rename_compilable_function(node)
|
||||
elif anno.hasanno(node.func, 'type_fqn'):
|
||||
node = self._rename_member_function_of_known_type(node)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
'Member function call (of unknown type): %s.' % node.func.id)
|
||||
return node
|
||||
|
||||
# pylint:enable=invalid-name
|
||||
|
||||
|
||||
def transform(node, namer, uncompiled_modules):
|
||||
"""Transform function call to the compiled counterparts.
|
||||
|
||||
Args:
|
||||
node: AST to transform.
|
||||
namer: FunctionNamer-like.
|
||||
uncompiled_modules: set of string tuples, each tuple represents the fully
|
||||
qualified name of a package containing functions that will not be
|
||||
compiled.
|
||||
Returns:
|
||||
A tuple (node, new_names):
|
||||
node: The transformed AST
|
||||
new_names: set(string), containing any newly-generated names
|
||||
"""
|
||||
transformer = CallTreeTransformer(namer, uncompiled_modules)
|
||||
node = transformer.visit(node)
|
||||
return node
|
91
tensorflow/contrib/py2tf/convert/call_trees_test.py
Normal file
91
tensorflow/contrib/py2tf/convert/call_trees_test.py
Normal file
@ -0,0 +1,91 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for call_trees module."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.py2tf.convert import call_trees
|
||||
from tensorflow.contrib.py2tf.pyct import compiler
|
||||
from tensorflow.contrib.py2tf.pyct import parser
|
||||
from tensorflow.contrib.py2tf.pyct.static_analysis import access
|
||||
from tensorflow.contrib.py2tf.pyct.static_analysis import live_values
|
||||
from tensorflow.contrib.py2tf.pyct.static_analysis import type_info
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class TestNamer(call_trees.FunctionNamer):
|
||||
|
||||
def compiled_function_name(self, original_name, live_object=None):
|
||||
return 'renamed_%s' % original_name
|
||||
|
||||
|
||||
class CallTreesTest(test.TestCase):
|
||||
|
||||
def _parse_and_analyze(self, test_fn, namespace):
|
||||
node = parser.parse_object(test_fn)
|
||||
node = access.resolve(node)
|
||||
node = live_values.resolve(node, namespace, {})
|
||||
node = type_info.resolve(node, None)
|
||||
return node
|
||||
|
||||
def test_basic(self):
|
||||
|
||||
def test_fn_1(_):
|
||||
raise ValueError('This should not be called in the compiled verison.')
|
||||
|
||||
def renamed_test_fn_1(a):
|
||||
return a + 1
|
||||
|
||||
def test_fn_2(a):
|
||||
return test_fn_1(a) + 1
|
||||
|
||||
node = self._parse_and_analyze(test_fn_2, {'test_fn_1': test_fn_1})
|
||||
node = call_trees.transform(node, TestNamer(), set())
|
||||
result = compiler.ast_to_object(node)
|
||||
# Only test_fn_2 is transformed, so we'll insert renamed_test_fn_1 manually.
|
||||
setattr(result, 'renamed_test_fn_1', renamed_test_fn_1)
|
||||
|
||||
self.assertEquals(3, result.renamed_test_fn_2(1))
|
||||
|
||||
def test_uncompiled_modules(self):
|
||||
|
||||
def test_fn(a):
|
||||
a = math_ops.multiply(a, constant_op.constant(2))
|
||||
a = math_ops.add(a, constant_op.constant(1))
|
||||
return a
|
||||
|
||||
node = self._parse_and_analyze(test_fn, {
|
||||
'math_ops': math_ops,
|
||||
'constant_op': constant_op
|
||||
})
|
||||
node = call_trees.transform(node, TestNamer(),
|
||||
set((math_ops.__name__, constant_op.__name__)))
|
||||
result = compiler.ast_to_object(node)
|
||||
setattr(result, 'math_ops', math_ops)
|
||||
setattr(result, 'constant_op', constant_op)
|
||||
|
||||
with self.test_session() as sess:
|
||||
result_tensor = result.renamed_test_fn(constant_op.constant(1))
|
||||
result_val = sess.run(result_tensor)
|
||||
|
||||
self.assertEquals(3, result_val)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
118
tensorflow/contrib/py2tf/convert/control_flow.py
Normal file
118
tensorflow/contrib/py2tf/convert/control_flow.py
Normal file
@ -0,0 +1,118 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Identity converter. Useful for testing and diagnostic."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import gast
|
||||
|
||||
from tensorflow.contrib.py2tf.pyct import anno
|
||||
from tensorflow.contrib.py2tf.pyct import templates
|
||||
|
||||
|
||||
class SymbolNamer(object):
|
||||
"""Describes the interface for ControlFlowTransformer's namer."""
|
||||
|
||||
def new_symbol(self, name_root, reserved_locals):
|
||||
"""Generate a new unique symbol.
|
||||
|
||||
Args:
|
||||
name_root: String, used as stem in the new name.
|
||||
reserved_locals: Set(string), additional local symbols that are reserved
|
||||
and which should not be used.
|
||||
Returns:
|
||||
String.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class ControlFlowTransformer(gast.NodeTransformer):
|
||||
"""Transforms control flow structures like loops an conditionals."""
|
||||
|
||||
def __init__(self, namer):
|
||||
self.namer = namer
|
||||
|
||||
# pylint:disable=invalid-name
|
||||
|
||||
def _tuple_or_item(self, elts):
|
||||
elts = tuple(elts)
|
||||
if len(elts) == 1:
|
||||
return elts[0]
|
||||
return elts
|
||||
|
||||
def _ast_tuple_or_item(self, elts, ctx):
|
||||
elts = list(elts)
|
||||
if len(elts) == 1:
|
||||
return elts[0]
|
||||
return gast.Tuple(elts, ctx)
|
||||
|
||||
def visit_If(self, node):
|
||||
raise NotImplementedError()
|
||||
|
||||
def visit_While(self, node):
|
||||
self.generic_visit(node)
|
||||
# Scrape out the data flow analysis
|
||||
body_scope = anno.getanno(node, 'body_scope')
|
||||
parent_scope_values = anno.getanno(node, 'parent_scope_values')
|
||||
body_closure = tuple(body_scope.modified - body_scope.created)
|
||||
|
||||
def template(
|
||||
state_args, # pylint:disable=unused-argument
|
||||
state_locals,
|
||||
state_results, # pylint:disable=unused-argument
|
||||
test_name,
|
||||
test, # pylint:disable=unused-argument
|
||||
body_name,
|
||||
body,
|
||||
state_init):
|
||||
|
||||
def test_name(state_args): # pylint:disable=function-redefined,unused-argument
|
||||
return test
|
||||
|
||||
def body_name(state_args): # pylint:disable=function-redefined,unused-argument
|
||||
body # pylint:disable=pointless-statement
|
||||
return state_locals
|
||||
|
||||
state_results = tf.while_loop(test_name, body_name, [state_init]) # pylint:disable=undefined-variable
|
||||
|
||||
test_name = self.namer.new_symbol('loop_test', body_scope.used)
|
||||
body_name = self.namer.new_symbol('loop_body', body_scope.used)
|
||||
node = templates.replace(
|
||||
template,
|
||||
state_args=self._tuple_or_item(
|
||||
gast.Name(n, gast.Param(), None) for n in body_closure),
|
||||
state_locals=self._ast_tuple_or_item(
|
||||
(gast.Name(n, gast.Load(), None) for n in body_closure),
|
||||
gast.Load()),
|
||||
state_results=self._ast_tuple_or_item(
|
||||
(gast.Name(n, gast.Store(), None) for n in body_closure),
|
||||
gast.Store()),
|
||||
test_name=gast.Name(test_name, gast.Load(), None),
|
||||
test=node.test,
|
||||
body_name=gast.Name(body_name, gast.Load(), None),
|
||||
body=node.body,
|
||||
state_init=[parent_scope_values.getval(n) for n in body_closure])
|
||||
|
||||
return node
|
||||
|
||||
# pylint:enable=invalid-name
|
||||
|
||||
|
||||
def transform(node, namer):
|
||||
transformer = ControlFlowTransformer(namer)
|
||||
node = transformer.visit(node)
|
||||
return node
|
83
tensorflow/contrib/py2tf/convert/control_flow_test.py
Normal file
83
tensorflow/contrib/py2tf/convert/control_flow_test.py
Normal file
@ -0,0 +1,83 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for control_flow module."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.py2tf.convert import control_flow
|
||||
from tensorflow.contrib.py2tf.pyct import compiler
|
||||
from tensorflow.contrib.py2tf.pyct import parser
|
||||
from tensorflow.contrib.py2tf.pyct.static_analysis import access
|
||||
from tensorflow.contrib.py2tf.pyct.static_analysis import live_values
|
||||
from tensorflow.contrib.py2tf.pyct.static_analysis import type_info
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class TestNamer(control_flow.SymbolNamer):
|
||||
|
||||
def new_symbol(self, name_root, _):
|
||||
return name_root
|
||||
|
||||
|
||||
class ControlFlowTest(test.TestCase):
|
||||
|
||||
def _parse_and_analyze(self, test_fn, namespace):
|
||||
node = parser.parse_object(test_fn)
|
||||
node = access.resolve(node)
|
||||
node = live_values.resolve(node, namespace, {})
|
||||
node = type_info.resolve(node, None)
|
||||
return node
|
||||
|
||||
def test_simple_while(self):
|
||||
|
||||
def test_fn(n):
|
||||
i = 0
|
||||
s = 0
|
||||
while i < n:
|
||||
s += i
|
||||
i += 1
|
||||
return s, i, n
|
||||
|
||||
node = self._parse_and_analyze(test_fn, {})
|
||||
node = control_flow.transform(node, TestNamer())
|
||||
result = compiler.ast_to_object(node)
|
||||
setattr(result, 'tf', control_flow_ops)
|
||||
|
||||
with self.test_session() as sess:
|
||||
self.assertEqual((10, 5, 5),
|
||||
sess.run(result.test_fn(constant_op.constant(5))))
|
||||
|
||||
def test_while_single_var(self):
|
||||
|
||||
def test_fn(n):
|
||||
while n > 0:
|
||||
n -= 1
|
||||
return n
|
||||
|
||||
node = self._parse_and_analyze(test_fn, {})
|
||||
node = control_flow.transform(node, TestNamer())
|
||||
result = compiler.ast_to_object(node)
|
||||
setattr(result, 'tf', control_flow_ops)
|
||||
|
||||
with self.test_session() as sess:
|
||||
self.assertEqual(0, sess.run(result.test_fn(constant_op.constant(5))))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
80
tensorflow/contrib/py2tf/convert/gradients_function.py
Normal file
80
tensorflow/contrib/py2tf/convert/gradients_function.py
Normal file
@ -0,0 +1,80 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Allows converting Eager-style gradients to graph versions."""
|
||||
# TODO(mdan): This is not needed. Remove once the static analysis works.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import gast
|
||||
|
||||
from tensorflow.contrib.py2tf.pyct import templates
|
||||
|
||||
|
||||
class GradientsFunctionTransformer(gast.NodeTransformer):
|
||||
"""Hack: transforms eager-style gradients to TF compatible calls.
|
||||
|
||||
Requires an expression of exactly this form:
|
||||
... = tfe.value_and_gradients_function(...)(...)
|
||||
"""
|
||||
|
||||
# pylint:disable=invalid-name
|
||||
|
||||
def visit_Assign(self, node):
|
||||
self.generic_visit(node)
|
||||
|
||||
val = node.value
|
||||
if isinstance(val, gast.Call):
|
||||
if isinstance(val.func, gast.Call):
|
||||
if isinstance(val.func.func, gast.Attribute):
|
||||
if isinstance(val.func.func.value, gast.Name):
|
||||
if (val.func.func.value.id == 'tfe' and
|
||||
val.func.func.attr == 'value_and_gradients_function'):
|
||||
|
||||
# pylint:disable=unused-argument,undefined-variable
|
||||
|
||||
def template(loss_var, loss_fn, args, d_vars, wrt_vars):
|
||||
loss_var = loss_fn(args)
|
||||
d_vars = tf.gradients(loss_var, [wrt_vars])
|
||||
|
||||
# pylint:enable=unused-argument,undefined-variable
|
||||
|
||||
# How to get these values? Print out the node.
|
||||
loss_var = gast.Name(node.targets[0].elts[0].id, gast.Store(),
|
||||
None)
|
||||
loss_fn = gast.Name(val.func.args[0].id, gast.Load(), None)
|
||||
args = tuple(
|
||||
gast.Name(a.id, gast.Param(), None) for a in val.args)
|
||||
d_vars = node.targets[0].elts[1]
|
||||
wrt_vars = [val.args[e.n] for e in val.func.args[1].elts]
|
||||
|
||||
node = templates.replace(
|
||||
template,
|
||||
loss_var=loss_var,
|
||||
loss_fn=loss_fn,
|
||||
args=args,
|
||||
d_vars=d_vars,
|
||||
wrt_vars=wrt_vars)
|
||||
|
||||
return node
|
||||
|
||||
# pylint:enable=invalid-name
|
||||
|
||||
|
||||
def transform(node):
|
||||
transformer = GradientsFunctionTransformer()
|
||||
node = transformer.visit(node)
|
||||
return node
|
55
tensorflow/contrib/py2tf/convert/gradients_function_test.py
Normal file
55
tensorflow/contrib/py2tf/convert/gradients_function_test.py
Normal file
@ -0,0 +1,55 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for gradients_function module."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.eager.python import tfe
|
||||
from tensorflow.contrib.py2tf.convert import gradients_function
|
||||
from tensorflow.contrib.py2tf.pyct import compiler
|
||||
from tensorflow.contrib.py2tf.pyct import parser
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.ops import gradients_impl
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class GradientsFunctionTest(test.TestCase):
|
||||
|
||||
def test_transform(self):
|
||||
|
||||
def loss(x, w):
|
||||
return x * w
|
||||
|
||||
def test_fn(x, w):
|
||||
l, (dw,) = tfe.value_and_gradients_function(loss, [1])(x, w) # pylint:disable=undefined-variable
|
||||
return l, dw
|
||||
|
||||
node = parser.parse_object(test_fn)
|
||||
node = gradients_function.transform(node)
|
||||
result = compiler.ast_to_object(node)
|
||||
setattr(result, 'tf', gradients_impl)
|
||||
setattr(result, 'loss', loss)
|
||||
|
||||
with self.test_session() as sess:
|
||||
self.assertEqual(
|
||||
(12, 3),
|
||||
sess.run(
|
||||
result.test_fn(constant_op.constant(3), constant_op.constant(4))))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
55
tensorflow/contrib/py2tf/convert/logical_expressions.py
Normal file
55
tensorflow/contrib/py2tf/convert/logical_expressions.py
Normal file
@ -0,0 +1,55 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Converter for logical expressions.
|
||||
|
||||
e.g. `a and b -> tf.logical_and(a, b)`. This is not done automatically in TF.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import gast
|
||||
|
||||
from tensorflow.contrib.py2tf.pyct import parser
|
||||
|
||||
|
||||
class LogicalExpressionTransformer(gast.NodeTransformer):
|
||||
"""Converts logical expressions to corresponding TF calls."""
|
||||
|
||||
def __init__(self):
|
||||
# TODO(mdan): Look into replacing with bitwise operators instead.
|
||||
self.op_mapping = {
|
||||
gast.And: 'tf.logical_and',
|
||||
gast.Or: 'tf.logical_or',
|
||||
}
|
||||
|
||||
def visit_UnaryOp(self, node):
|
||||
raise NotImplementedError()
|
||||
|
||||
def visit_BoolOp(self, node):
|
||||
# TODO(mdan): A normalizer may be useful here. Use ANF?
|
||||
tf_function = parser.parse_str(self.op_mapping[type(node.op)]).body[0].value
|
||||
left = node.values[0]
|
||||
for i in range(1, len(node.values)):
|
||||
left = gast.Call(
|
||||
func=tf_function, args=[left, node.values[i]], keywords=[])
|
||||
return left
|
||||
|
||||
|
||||
def transform(node):
|
||||
transformer = LogicalExpressionTransformer()
|
||||
node = transformer.visit(node)
|
||||
return node
|
45
tensorflow/contrib/py2tf/convert/logical_expressions_test.py
Normal file
45
tensorflow/contrib/py2tf/convert/logical_expressions_test.py
Normal file
@ -0,0 +1,45 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for logical_expressions module."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.py2tf.convert import logical_expressions
|
||||
from tensorflow.contrib.py2tf.pyct import compiler
|
||||
from tensorflow.contrib.py2tf.pyct import parser
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class GradientsFunctionTest(test.TestCase):
|
||||
|
||||
def test_transform(self):
|
||||
|
||||
def test_fn(a, b, c):
|
||||
return (a or b) and (a or b or c)
|
||||
|
||||
node = parser.parse_object(test_fn)
|
||||
node = logical_expressions.transform(node)
|
||||
result = compiler.ast_to_object(node)
|
||||
setattr(result, 'tf', math_ops)
|
||||
|
||||
with self.test_session() as sess:
|
||||
self.assertTrue(sess.run(result.test_fn(True, False, True)))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
51
tensorflow/contrib/py2tf/convert/print_functions.py
Normal file
51
tensorflow/contrib/py2tf/convert/print_functions.py
Normal file
@ -0,0 +1,51 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Compatibility support. Converts Print nodes to function calls."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import gast
|
||||
|
||||
from tensorflow.contrib.py2tf.pyct import anno
|
||||
|
||||
|
||||
class PrintFunctionTransformer(gast.NodeTransformer):
|
||||
"""Transforms Print nodes to Call so they can be handled as functions."""
|
||||
|
||||
# pylint:disable=invalid-name
|
||||
|
||||
def visit_Print(self, node):
|
||||
self.generic_visit(node)
|
||||
for n in node.values:
|
||||
n.ctx = gast.Param()
|
||||
call_node = gast.Call(
|
||||
func=gast.Name('print', gast.Load(), None),
|
||||
args=node.values,
|
||||
keywords=[])
|
||||
anno.setanno(call_node.func, 'live_val', print)
|
||||
anno.setanno(call_node.func, 'fqn', 'print')
|
||||
anno.setanno(call_node, 'args_scope', anno.getanno(node, 'args_scope'))
|
||||
node = gast.Expr(call_node)
|
||||
return node
|
||||
|
||||
# pylint:enable=invalid-name
|
||||
|
||||
|
||||
def transform(node):
|
||||
transformer = PrintFunctionTransformer()
|
||||
node = transformer.visit(node)
|
||||
return node
|
55
tensorflow/contrib/py2tf/convert/print_functions_test.py
Normal file
55
tensorflow/contrib/py2tf/convert/print_functions_test.py
Normal file
@ -0,0 +1,55 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for print_functions module."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import gast
|
||||
|
||||
from tensorflow.contrib.py2tf.convert import print_functions
|
||||
from tensorflow.contrib.py2tf.pyct import compiler
|
||||
from tensorflow.contrib.py2tf.pyct import parser
|
||||
from tensorflow.contrib.py2tf.pyct.static_analysis import access
|
||||
from tensorflow.contrib.py2tf.pyct.static_analysis import live_values
|
||||
from tensorflow.contrib.py2tf.pyct.static_analysis import type_info
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class PrintFunctionsTest(test.TestCase):
|
||||
|
||||
def _parse_and_analyze(self, test_fn, namespace):
|
||||
node = parser.parse_object(test_fn)
|
||||
node = access.resolve(node)
|
||||
node = live_values.resolve(node, namespace, {})
|
||||
node = type_info.resolve(node, None)
|
||||
return node
|
||||
|
||||
def test_transform(self):
|
||||
|
||||
def test_fn(a):
|
||||
print(a)
|
||||
|
||||
node = self._parse_and_analyze(test_fn, {'print': print})
|
||||
node = print_functions.transform(node)
|
||||
result = compiler.ast_to_object(node)
|
||||
|
||||
result.test_fn('a')
|
||||
self.assertTrue(isinstance(node.body[0].body[0].value, gast.Call))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
155
tensorflow/contrib/py2tf/convert/side_effect_guards.py
Normal file
155
tensorflow/contrib/py2tf/convert/side_effect_guards.py
Normal file
@ -0,0 +1,155 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Adds guards against function calls with side effects.
|
||||
|
||||
Only standalone calls are guarded.
|
||||
|
||||
WARNING: This mechanism is incomplete. Particularly, it only guards the
|
||||
arguments passed to functions, and does not account for indirectly modified
|
||||
state.
|
||||
|
||||
Example:
|
||||
y = tf.layers.dense(x) # Creates TF variable 'foo'
|
||||
loss = loss(y)
|
||||
opt.minimize(loss) # indirectly affects 'foo'
|
||||
z = tf.get_variable('foo') # Indirectly affects `loss` and 'foo'
|
||||
# Here, `loss` can be guarded. But `z` cannot.
|
||||
|
||||
# TODO(mdan): We should probably define a safe mode where we guard everything.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import gast
|
||||
|
||||
from tensorflow.contrib.py2tf.pyct import anno
|
||||
from tensorflow.contrib.py2tf.pyct import templates
|
||||
|
||||
|
||||
class SymbolNamer(object):
|
||||
"""Describes the interface for SideEffectGuardTransformer's namer."""
|
||||
|
||||
def new_symbol(self, name_root, reserved_locals):
|
||||
"""Generate a new unique function_name.
|
||||
|
||||
Args:
|
||||
name_root: String, used as stem in the new name.
|
||||
reserved_locals: Set(string), additional local symbols that are reserved.
|
||||
Returns:
|
||||
String.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class SideEffectGuardTransformer(gast.NodeTransformer):
|
||||
"""Adds control dependencies to functions with side effects."""
|
||||
|
||||
def __init__(self, namer):
|
||||
self.namer = namer
|
||||
self.indent_next = False
|
||||
self.next_indent_owner = None
|
||||
|
||||
# pylint:disable=invalid-name
|
||||
|
||||
def _visit_and_reindent(self, nodes):
|
||||
new_nodes = []
|
||||
current_dest = new_nodes
|
||||
for n in nodes:
|
||||
n = self.visit(n)
|
||||
if isinstance(n, (list, tuple)):
|
||||
current_dest.extend(n)
|
||||
else:
|
||||
current_dest.append(n)
|
||||
if self.indent_next:
|
||||
assert self.next_indent_owner is not None
|
||||
current_dest.append(self.next_indent_owner)
|
||||
current_dest = self.next_indent_owner.body
|
||||
self.next_indent_owner = None
|
||||
self.indent_next = False
|
||||
if not current_dest:
|
||||
# TODO(mdan): There may still be something that could be done.
|
||||
raise ValueError('Unable to insert statement into the computation flow: '
|
||||
'it is not followed by any computation that can we can '
|
||||
'condition on the statement.')
|
||||
return new_nodes
|
||||
|
||||
def visit_FunctionDef(self, node):
|
||||
if anno.hasanno(node, 'skip_processing'):
|
||||
return node
|
||||
node.body = self._visit_and_reindent(node.body)
|
||||
return node
|
||||
|
||||
def _gate_symbols(self, guard_statement, guarded_args):
|
||||
|
||||
def template(dst_args, src_args): # pylint:disable=unused-argument
|
||||
(dst_args,) = (tf.identity(a) for a in (src_args,)) # pylint:disable=undefined-variable
|
||||
|
||||
guards = templates.replace(
|
||||
template,
|
||||
dst_args=tuple(gast.Name(a, gast.Store(), None) for a in guarded_args),
|
||||
src_args=tuple(gast.Name(a, gast.Load(), None) for a in guarded_args))
|
||||
guard_statement.body.extend(guards)
|
||||
return guard_statement
|
||||
|
||||
def visit_Expr(self, node):
|
||||
self.generic_visit(node)
|
||||
if isinstance(node.value, gast.Call):
|
||||
# Patterns of single function calls, like:
|
||||
# opt.minimize(loss)
|
||||
# or:
|
||||
# tf.py_func(...)
|
||||
|
||||
args_scope = anno.getanno(node.value, 'args_scope')
|
||||
temp_name = self.namer.new_symbol('temp', args_scope.parent.used)
|
||||
# TODO(mdan): Unsafe reference modification!
|
||||
args_scope.mark_write(temp_name)
|
||||
|
||||
def template(call, temp_result):
|
||||
temp_result = call
|
||||
if not isinstance(temp_result, (list, tuple)):
|
||||
temp_result = (temp_result,)
|
||||
with tf.control_dependencies(temp_result): # pylint:disable=undefined-variable
|
||||
# TODO(mdan): Also insert ops to re-fetch if variables are involved.
|
||||
pass # Will be removed below.
|
||||
|
||||
guard_var_assign, arg_checker, control_deps_guard = templates.replace(
|
||||
template,
|
||||
call=node.value,
|
||||
temp_result=gast.Name(temp_name, gast.Store(), None))
|
||||
control_deps_guard.body = []
|
||||
|
||||
# First, attempt to gate future evaluation of args. If that's not
|
||||
# possible, gate all remaining statements (and that may fail too, see
|
||||
# _visit_and_reindent.
|
||||
guarded_args = tuple(
|
||||
n for n in args_scope.used if n in args_scope.parent.modified)
|
||||
if guarded_args:
|
||||
node = (guard_var_assign, arg_checker,
|
||||
self._gate_symbols(control_deps_guard, guarded_args))
|
||||
else:
|
||||
node = (guard_var_assign, arg_checker)
|
||||
# The mechanism will insert the guard statement later.
|
||||
self.indent_next = True
|
||||
self.next_indent_owner = control_deps_guard
|
||||
return node
|
||||
|
||||
# pylint:enable=invalid-name
|
||||
|
||||
|
||||
def transform(node, namer):
|
||||
transformer = SideEffectGuardTransformer(namer)
|
||||
return transformer.visit(node)
|
71
tensorflow/contrib/py2tf/convert/side_effect_guards_test.py
Normal file
71
tensorflow/contrib/py2tf/convert/side_effect_guards_test.py
Normal file
@ -0,0 +1,71 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for side_effect_guards module."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.py2tf.convert import side_effect_guards
|
||||
from tensorflow.contrib.py2tf.pyct import compiler
|
||||
from tensorflow.contrib.py2tf.pyct import parser
|
||||
from tensorflow.contrib.py2tf.pyct.static_analysis import access
|
||||
from tensorflow.contrib.py2tf.pyct.static_analysis import live_values
|
||||
from tensorflow.contrib.py2tf.pyct.static_analysis import type_info
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class TestNamer(side_effect_guards.SymbolNamer):
|
||||
|
||||
def new_symbol(self, name_root, _):
|
||||
return name_root
|
||||
|
||||
|
||||
class SideEffectGuardsTest(test.TestCase):
|
||||
|
||||
def _parse_and_analyze(self, test_fn, namespace):
|
||||
node = parser.parse_object(test_fn)
|
||||
node = access.resolve(node)
|
||||
node = live_values.resolve(node, namespace, {})
|
||||
node = type_info.resolve(node, None)
|
||||
return node
|
||||
|
||||
def test_transform(self):
|
||||
|
||||
def test_fn(a):
|
||||
state_ops.assign(a, a + 1)
|
||||
return a
|
||||
|
||||
node = self._parse_and_analyze(test_fn, {'state_ops': state_ops})
|
||||
node = side_effect_guards.transform(node, TestNamer())
|
||||
result = compiler.ast_to_object(node)
|
||||
setattr(result, 'state_ops', state_ops)
|
||||
|
||||
# TODO(mdan): Configure the namespaces instead of doing these hacks.
|
||||
ops.identity = array_ops.identity
|
||||
setattr(result, 'tf', ops)
|
||||
|
||||
with self.test_session() as sess:
|
||||
v = variables.Variable(2)
|
||||
sess.run(v.initializer)
|
||||
self.assertEqual(3, sess.run(result.test_fn(v)))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
@ -62,7 +62,9 @@ def _get_computed_nodes(g, output, seen):
|
||||
for each in node_def.input:
|
||||
# Parses name of input node.
|
||||
if each.startswith('^'):
|
||||
each = each[1:]
|
||||
# The character '^' denotes a control dependency, so this input node can
|
||||
# be safely ignored.
|
||||
continue
|
||||
each = each.split(':')[0]
|
||||
# Recursively computes ordering.
|
||||
new_v = _get_computed_nodes(g, each, seen)
|
||||
|
@ -33,9 +33,10 @@ import numpy as np
|
||||
# White-listed layer operations, which do not affect the receptive field
|
||||
# computation.
|
||||
_UNCHANGED_RF_LAYER_OPS = [
|
||||
'Add', 'BiasAdd', 'Ceil', 'ConcatV2', 'Const', 'Floor', 'Identity', 'Log',
|
||||
'Mul', 'Pow', 'RealDiv', 'Relu', 'Round', 'Rsqrt', 'Softplus', 'Sub',
|
||||
'VariableV2']
|
||||
"Add", "BiasAdd", "Cast", "Ceil", "ConcatV2", "Const", "Floor", "Identity",
|
||||
"Log", "Mul", "Pow", "RealDiv", "Relu", "Relu6", "Round", "Rsqrt",
|
||||
"Softplus", "Sub", "VariableV2"
|
||||
]
|
||||
|
||||
# Different ways in which padding modes may be spelled.
|
||||
_VALID_PADDING = ["VALID", b"VALID"]
|
||||
@ -240,8 +241,8 @@ def _get_layer_params(node, name_to_order_node):
|
||||
padding_x = 0
|
||||
padding_y = 0
|
||||
else:
|
||||
raise ValueError("Unknown layer for operation '%s': %s" %
|
||||
(node.name, node.op))
|
||||
raise ValueError("Unknown layer for operation '%s': %s" % (node.name,
|
||||
node.op))
|
||||
return kernel_size_x, kernel_size_y, stride_x, stride_y, padding_x, padding_y
|
||||
|
||||
|
||||
@ -308,22 +309,21 @@ def _get_effective_padding_node_input(stride, padding,
|
||||
|
||||
|
||||
class ReceptiveField:
|
||||
"""
|
||||
Receptive field of a convolutional neural network.
|
||||
"""Receptive field of a convolutional neural network.
|
||||
|
||||
Args:
|
||||
size: Receptive field size.
|
||||
stride: Effective stride.
|
||||
padding: Effective padding.
|
||||
"""
|
||||
|
||||
def __init__(self, size, stride, padding):
|
||||
self.size = np.asarray(size)
|
||||
self.stride = np.asarray(stride)
|
||||
self.padding = np.asarray(padding)
|
||||
|
||||
def compute_input_center_coordinates(self, y, axis=None):
|
||||
"""
|
||||
Computes the center of the receptive field that generated a feature.
|
||||
"""Computes the center of the receptive field that generated a feature.
|
||||
|
||||
Args:
|
||||
y: An array of feature coordinates with shape `(..., d)`, where `d` is the
|
||||
@ -354,8 +354,7 @@ class ReceptiveField:
|
||||
(self.size[axis] - 1) / 2
|
||||
|
||||
def compute_feature_coordinates(self, x, axis=None):
|
||||
"""
|
||||
Computes the position of a feature given the center of a receptive field.
|
||||
"""Computes the position of a feature given the center of a receptive field.
|
||||
|
||||
Args:
|
||||
x: An array of input center coordinates with shape `(..., d)`, where `d`
|
||||
@ -388,7 +387,9 @@ class ReceptiveField:
|
||||
return iter(np.concatenate([self.size, self.stride, self.padding]))
|
||||
|
||||
|
||||
def compute_receptive_field_from_graph_def(graph_def, input_node, output_node,
|
||||
def compute_receptive_field_from_graph_def(graph_def,
|
||||
input_node,
|
||||
output_node,
|
||||
stop_propagation=None):
|
||||
"""Computes receptive field (RF) parameters from a Graph or GraphDef object.
|
||||
|
||||
@ -531,7 +532,13 @@ def compute_receptive_field_from_graph_def(graph_def, input_node, output_node,
|
||||
if any(inp_name.startswith(stop) for stop in stop_propagation):
|
||||
logging.vlog(3, "Skipping explicitly ignored node %s.", node.name)
|
||||
continue
|
||||
|
||||
logging.vlog(4, "inp_name = %s", inp_name)
|
||||
if inp_name.startswith("^"):
|
||||
# The character "^" denotes a control dependency, so this input node
|
||||
# can be safely ignored.
|
||||
continue
|
||||
|
||||
inp_node = name_to_order_node[inp_name].node
|
||||
logging.vlog(4, "inp_node = \n%s", inp_node)
|
||||
if inp_node.name in rf_sizes_x:
|
||||
@ -590,6 +597,6 @@ def compute_receptive_field_from_graph_def(graph_def, input_node, output_node,
|
||||
if input_node not in rf_sizes_x:
|
||||
raise ValueError("Input node was not found")
|
||||
return ReceptiveField(
|
||||
(rf_sizes_x[input_node], rf_sizes_y[input_node]),
|
||||
(effective_strides_x[input_node], effective_strides_y[input_node]),
|
||||
(effective_paddings_x[input_node], effective_paddings_y[input_node]))
|
||||
(rf_sizes_x[input_node], rf_sizes_y[input_node]),
|
||||
(effective_strides_x[input_node], effective_strides_y[input_node]),
|
||||
(effective_paddings_x[input_node], effective_paddings_y[input_node]))
|
||||
|
@ -23,6 +23,8 @@ from tensorflow.contrib.receptive_field.python.util import receptive_field
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import gen_math_ops
|
||||
from tensorflow.python.ops import nn
|
||||
from tensorflow.python.platform import test
|
||||
import numpy as np
|
||||
@ -176,6 +178,34 @@ def create_test_network_6():
|
||||
return g
|
||||
|
||||
|
||||
def create_test_network_7():
|
||||
"""Aligned network for test, with a control dependency.
|
||||
|
||||
The graph is similar to create_test_network_1(), except that it includes an
|
||||
assert operation on the left branch.
|
||||
|
||||
Returns:
|
||||
g: Tensorflow graph object (Graph proto).
|
||||
"""
|
||||
g = ops.Graph()
|
||||
with g.as_default():
|
||||
# An 8x8 test image.
|
||||
x = array_ops.placeholder(dtypes.float32, (1, 8, 8, 1), name='input_image')
|
||||
# Left branch.
|
||||
l1 = slim.conv2d(x, 1, [1, 1], stride=4, scope='L1', padding='VALID')
|
||||
l1_shape = array_ops.shape(l1)
|
||||
assert_op = control_flow_ops.Assert(
|
||||
gen_math_ops.equal(l1_shape[1], 2), [l1_shape], summarize=4)
|
||||
# Right branch.
|
||||
l2_pad = array_ops.pad(x, [[0, 0], [1, 0], [1, 0], [0, 0]])
|
||||
l2 = slim.conv2d(l2_pad, 1, [3, 3], stride=2, scope='L2', padding='VALID')
|
||||
l3 = slim.conv2d(l2, 1, [1, 1], stride=2, scope='L3', padding='VALID')
|
||||
# Addition.
|
||||
with ops.control_dependencies([assert_op]):
|
||||
nn.relu(l1 + l3, name='output')
|
||||
return g
|
||||
|
||||
|
||||
class RfUtilsTest(test.TestCase):
|
||||
|
||||
def testComputeRFFromGraphDefAligned(self):
|
||||
@ -269,7 +299,7 @@ class RfUtilsTest(test.TestCase):
|
||||
input_node = 'input_image'
|
||||
output_node = 'output'
|
||||
rf = receptive_field.compute_receptive_field_from_graph_def(
|
||||
graph_def, input_node, output_node)
|
||||
graph_def, input_node, output_node)
|
||||
|
||||
x = np.random.randint(0, 100, (50, 2))
|
||||
y = rf.compute_feature_coordinates(x)
|
||||
@ -277,5 +307,21 @@ class RfUtilsTest(test.TestCase):
|
||||
|
||||
self.assertAllEqual(x, x2)
|
||||
|
||||
def testComputeRFFromGraphDefAlignedWithControlDependencies(self):
|
||||
graph_def = create_test_network_7().as_graph_def()
|
||||
input_node = 'input_image'
|
||||
output_node = 'output'
|
||||
(receptive_field_x, receptive_field_y, effective_stride_x,
|
||||
effective_stride_y, effective_padding_x, effective_padding_y) = (
|
||||
receptive_field.compute_receptive_field_from_graph_def(
|
||||
graph_def, input_node, output_node))
|
||||
self.assertEqual(receptive_field_x, 3)
|
||||
self.assertEqual(receptive_field_y, 3)
|
||||
self.assertEqual(effective_stride_x, 4)
|
||||
self.assertEqual(effective_stride_y, 4)
|
||||
self.assertEqual(effective_padding_x, 1)
|
||||
self.assertEqual(effective_padding_y, 1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -36,18 +36,17 @@ namespace tensorflow {
|
||||
namespace grappler {
|
||||
namespace {
|
||||
|
||||
const char kPrefix[] = "LayoutOptimizer";
|
||||
const char kPermNHWCToNCHW[] = "LayoutOptimizerPermConstNHWCToNCHW";
|
||||
const char kPermNCHWToNHWC[] = "LayoutOptimizerPermConstNCHWToNHWC";
|
||||
const char kTransposeNHWCToNCHW[] = "LayoutOptimizerTransposeNHWCToNCHW";
|
||||
const char kTransposeNCHWToNHWC[] = "LayoutOptimizerTransposeNCHWToNHWC";
|
||||
const char kDimMapNHWCToNCHW[] = "LayoutOptimizerDimMapNHWCToNCHW";
|
||||
const char kDimMapNCHWToNHWC[] = "LayoutOptimizerDimMapNCHWToNHWC";
|
||||
const char kVecPermuteNHWCToNCHW[] = "LayoutOptimizerVecPermuteNHWCToNCHW";
|
||||
const char kVecPermuteNCHWToNHWC[] = "LayoutOptimizerVecPermuteNCHWToNHWC";
|
||||
const char kReshapeNHWCToNCHW[] = "LayoutOptimizerReshapeNHWCToNCHW";
|
||||
const char kReshapeConst[] = "LayoutOptimizerReshapeConst";
|
||||
const char kReductionConst[] = "LayoutOptimizerReductionConst";
|
||||
const char kSuffix[] = "LayoutOptimizer";
|
||||
const char kPermNHWCToNCHW[] = "PermConstNHWCToNCHW";
|
||||
const char kPermNCHWToNHWC[] = "PermConstNCHWToNHWC";
|
||||
const char kTransposeNHWCToNCHW[] = "TransposeNHWCToNCHW";
|
||||
const char kTransposeNCHWToNHWC[] = "TransposeNCHWToNHWC";
|
||||
const char kDimMapNHWCToNCHW[] = "DimMapNHWCToNCHW";
|
||||
const char kDimMapNCHWToNHWC[] = "DimMapNCHWToNHWC";
|
||||
const char kVecPermuteNHWCToNCHW[] = "VecPermuteNHWCToNCHW";
|
||||
const char kVecPermuteNCHWToNHWC[] = "VecPermuteNCHWToNHWC";
|
||||
const char kReshapeNHWCToNCHW[] = "ReshapeNHWCToNCHW";
|
||||
const char kReshapeConst[] = "ReshapeConst";
|
||||
|
||||
std::set<string> GetOpsFormatSupported() {
|
||||
std::set<string> ops_format_supported = {
|
||||
@ -210,55 +209,45 @@ std::set<string> GetOpsFormatAgnostic() {
|
||||
return ops_format_agnostic;
|
||||
}
|
||||
|
||||
bool EndWith(const string& str, const string& ending) {
|
||||
if (str.size() < ending.size()) return false;
|
||||
if (str.substr(str.size() - ending.size(), ending.size()) == ending)
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
bool IsNodeByLayoutOptimizer(const string& node_name) {
|
||||
const string prefix_pattern = kPrefix;
|
||||
string prefix = node_name.substr(0, prefix_pattern.length());
|
||||
if (prefix.compare(prefix_pattern) == 0) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
const string suffix = kSuffix;
|
||||
return EndWith(node_name, suffix);
|
||||
}
|
||||
|
||||
bool IsNodeNHWCToNCHW(const string& node_name, const string& prefix_const) {
|
||||
const string transform_prefix = prefix_const;
|
||||
string prefix = node_name.substr(0, transform_prefix.length());
|
||||
if (prefix.compare(transform_prefix) == 0) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool IsNodeNCHWToNHWC(const string& node_name, const string& prefix_const) {
|
||||
const string transform_prefix = prefix_const;
|
||||
string prefix = node_name.substr(0, transform_prefix.length());
|
||||
if (prefix.compare(transform_prefix) == 0) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
bool IsNodeType(const string& node_name, const string& type) {
|
||||
const string suffix = strings::StrCat(type, "-", kSuffix);
|
||||
return EndWith(node_name, suffix);
|
||||
}
|
||||
|
||||
bool IsTransposeNHWCToNCHW(const string& node_name) {
|
||||
return IsNodeNHWCToNCHW(node_name, kTransposeNHWCToNCHW);
|
||||
return IsNodeType(node_name, kTransposeNHWCToNCHW);
|
||||
}
|
||||
|
||||
bool IsTransposeNCHWToNHWC(const string& node_name) {
|
||||
return IsNodeNCHWToNHWC(node_name, kTransposeNCHWToNHWC);
|
||||
return IsNodeType(node_name, kTransposeNCHWToNHWC);
|
||||
}
|
||||
|
||||
bool IsDimMapNHWCToNCHW(const string& node_name) {
|
||||
return IsNodeNHWCToNCHW(node_name, kDimMapNHWCToNCHW);
|
||||
return IsNodeType(node_name, kDimMapNHWCToNCHW);
|
||||
}
|
||||
|
||||
bool IsDimMapNCHWToNHWC(const string& node_name) {
|
||||
return IsNodeNCHWToNHWC(node_name, kDimMapNCHWToNHWC);
|
||||
return IsNodeType(node_name, kDimMapNCHWToNHWC);
|
||||
}
|
||||
|
||||
bool IsVecPermuteNHWCToNCHW(const string& node_name) {
|
||||
return IsNodeNHWCToNCHW(node_name, kVecPermuteNHWCToNCHW);
|
||||
return IsNodeType(node_name, kVecPermuteNHWCToNCHW);
|
||||
}
|
||||
|
||||
bool IsVecPermuteNCHWToNHWC(const string& node_name) {
|
||||
return IsNodeNCHWToNHWC(node_name, kVecPermuteNCHWToNHWC);
|
||||
return IsNodeType(node_name, kVecPermuteNCHWToNHWC);
|
||||
}
|
||||
|
||||
bool IsConcat(const NodeDef& node) {
|
||||
@ -439,6 +428,10 @@ class GraphProcessor {
|
||||
return node;
|
||||
}
|
||||
|
||||
string LayoutOptimizerNode(const string& base_name) {
|
||||
return strings::StrCat(base_name, "-", kSuffix);
|
||||
}
|
||||
|
||||
const VirtualPlacer& virtual_placer_;
|
||||
const std::unordered_set<string>& nodes_to_preserve_;
|
||||
GraphDef* graph_;
|
||||
@ -591,7 +584,7 @@ class NodeProcessor : public GraphProcessor {
|
||||
NodeDef* added_node = graph_->add_node();
|
||||
*added_node = *input_node;
|
||||
string base_name = strings::StrCat(node_->name(), "-", input_node->name());
|
||||
string node_name = AddPrefixToNodeName(base_name, "LayoutOptimizer", "-");
|
||||
string node_name = LayoutOptimizerNode(base_name);
|
||||
added_node->set_name(node_name);
|
||||
*node_->mutable_input(input_index) = node_name;
|
||||
node_map_->AddNode(node_name, added_node);
|
||||
@ -612,8 +605,8 @@ class NodeProcessor : public GraphProcessor {
|
||||
virtual Status AddLayoutTransposeToInputs() {
|
||||
std::vector<int> input_pos = GetInputPos();
|
||||
for (const auto& pos : input_pos) {
|
||||
string node_name =
|
||||
strings::StrCat(kTransposeNHWCToNCHW, "-", node_->name(), "-", pos);
|
||||
string node_name = LayoutOptimizerNode(
|
||||
strings::StrCat(node_->name(), "-", pos, "-", kTransposeNHWCToNCHW));
|
||||
TF_RETURN_IF_ERROR(HasAttribute(*node_, "T"));
|
||||
auto input_node = node_map_->GetNode(node_->input(pos));
|
||||
TF_RETURN_IF_ERROR(HasAttribute(*input_node, "_output_shapes"));
|
||||
@ -652,8 +645,8 @@ class NodeProcessor : public GraphProcessor {
|
||||
strings::StrCat(node_->name(), "-", output_count, "-", i);
|
||||
string added_node_name;
|
||||
if (op == "Transpose") {
|
||||
added_node_name = AddPrefixToNodeName(added_node_base_name,
|
||||
kTransposeNCHWToNHWC, "-");
|
||||
added_node_name = LayoutOptimizerNode(strings::StrCat(
|
||||
added_node_base_name, "-", kTransposeNCHWToNHWC));
|
||||
DataType dtype;
|
||||
if (IsAngle(*node_) || IsComplex(*node_) ||
|
||||
IsComplexAbs(*node_) || IsImag(*node_) || IsReal(*node_)) {
|
||||
@ -674,8 +667,8 @@ class NodeProcessor : public GraphProcessor {
|
||||
node_->attr().at("_output_shapes").list().shape(input_port),
|
||||
false);
|
||||
} else if (op == "DataFormatVecPermute") {
|
||||
added_node_name = AddPrefixToNodeName(added_node_base_name,
|
||||
kVecPermuteNCHWToNHWC, "-");
|
||||
added_node_name = LayoutOptimizerNode(strings::StrCat(
|
||||
added_node_base_name, "-", kVecPermuteNCHWToNHWC));
|
||||
TF_RETURN_IF_ERROR(HasAttribute(*node_, "out_type"));
|
||||
DataType dtype = (IsSplit(*node_) || IsSplitV(*node_))
|
||||
? DT_INT32
|
||||
@ -835,10 +828,11 @@ class NodeProcessor : public GraphProcessor {
|
||||
return node;
|
||||
}
|
||||
|
||||
NodeDef* AddNodePermNHWCToNCHW(const string& suffix,
|
||||
NodeDef* AddNodePermNHWCToNCHW(const string& base_name,
|
||||
const string& depended_node,
|
||||
const string& device) {
|
||||
string name = strings::StrCat(kPermNHWCToNCHW, "-", suffix);
|
||||
string name =
|
||||
LayoutOptimizerNode(strings::StrCat(base_name, "-", kPermNHWCToNCHW));
|
||||
auto const_node = AddNodePermConst(name, device, {0, 3, 1, 2});
|
||||
// This is to ensure the transpose node and the const node are in the
|
||||
// same frame.
|
||||
@ -846,11 +840,12 @@ class NodeProcessor : public GraphProcessor {
|
||||
return const_node;
|
||||
}
|
||||
|
||||
NodeDef* AddNodePermNCHWToNHWC(const string& suffix,
|
||||
NodeDef* AddNodePermNCHWToNHWC(const string& base_name,
|
||||
const string& depended_node,
|
||||
const string& device) {
|
||||
auto const_node = AddNodePermConst(
|
||||
strings::StrCat(kPermNCHWToNHWC, "-", suffix), device, {0, 2, 3, 1});
|
||||
LayoutOptimizerNode(strings::StrCat(base_name, "-", kPermNCHWToNHWC)),
|
||||
device, {0, 2, 3, 1});
|
||||
// This is to ensure the transpose node and the const node are in the same
|
||||
// frame.
|
||||
*const_node->add_input() = AsControlDependency(depended_node);
|
||||
@ -860,7 +855,7 @@ class NodeProcessor : public GraphProcessor {
|
||||
string GetOrAddNodePermNHWCToNCHW(int pos) {
|
||||
string const_name;
|
||||
if (is_in_frame_) {
|
||||
string suffix = strings::StrCat(node_->name(), "_", pos);
|
||||
string base_name = strings::StrCat(node_->name(), "-", pos);
|
||||
string input = NodeName(node_->input(pos));
|
||||
string depended_node;
|
||||
if (!IsTransposeNCHWToNHWC(input)) {
|
||||
@ -870,10 +865,10 @@ class NodeProcessor : public GraphProcessor {
|
||||
depended_node = NodeName(input_node->input(0));
|
||||
}
|
||||
auto const_node =
|
||||
AddNodePermNHWCToNCHW(suffix, depended_node, node_->device());
|
||||
AddNodePermNHWCToNCHW(base_name, depended_node, node_->device());
|
||||
const_name = const_node->name();
|
||||
} else {
|
||||
const_name = kPermNHWCToNCHW;
|
||||
const_name = LayoutOptimizerNode(kPermNHWCToNCHW);
|
||||
}
|
||||
return const_name;
|
||||
}
|
||||
@ -885,7 +880,7 @@ class NodeProcessor : public GraphProcessor {
|
||||
AddNodePermNCHWToNHWC(node_->name(), node_->name(), node_->device());
|
||||
const_name = const_node->name();
|
||||
} else {
|
||||
const_name = kPermNCHWToNHWC;
|
||||
const_name = LayoutOptimizerNode(kPermNCHWToNHWC);
|
||||
}
|
||||
return const_name;
|
||||
}
|
||||
@ -923,9 +918,10 @@ class NodeProcessor : public GraphProcessor {
|
||||
|
||||
void AddDataFormatTranformToParamInput(const string& op, int input_pos,
|
||||
DataType dtype) {
|
||||
string prefix = (op == "DataFormatVecPermute") ? kVecPermuteNHWCToNCHW
|
||||
string suffix = (op == "DataFormatVecPermute") ? kVecPermuteNHWCToNCHW
|
||||
: kDimMapNHWCToNCHW;
|
||||
string name = strings::StrCat(prefix, "_", node_->name(), "_", input_pos);
|
||||
string name = LayoutOptimizerNode(
|
||||
strings::StrCat(node_->name(), "-", input_pos, "-", suffix));
|
||||
auto added_node =
|
||||
AddNodeDataFormatOp(name, node_->input(input_pos), op, dtype, true);
|
||||
*node_->mutable_input(input_pos) = added_node->name();
|
||||
@ -1320,10 +1316,10 @@ class BinaryOpProcessor : public AgnosticNodeProcessor {
|
||||
}
|
||||
if (vector_index != -1) {
|
||||
string base_name = strings::StrCat(node_->name(), "-", vector_index);
|
||||
string reshape_node_name =
|
||||
AddPrefixToNodeName(base_name, kReshapeNHWCToNCHW, "-");
|
||||
string reshape_node_name = LayoutOptimizerNode(
|
||||
strings::StrCat(base_name, "-", kReshapeNHWCToNCHW));
|
||||
string shape_const_node_name =
|
||||
AddPrefixToNodeName(base_name, kReshapeConst, "-");
|
||||
LayoutOptimizerNode(strings::StrCat(base_name, "-", kReshapeConst));
|
||||
auto input_node = node_map_->GetNode(node_->input(vector_index));
|
||||
TF_RETURN_IF_ERROR(HasAttribute(*input_node, "_output_shapes"));
|
||||
int port;
|
||||
@ -1839,11 +1835,13 @@ class DataLayoutOptimizer : GraphProcessor {
|
||||
|
||||
private:
|
||||
NodeDef* AddNodePermNHWCToNCHW() {
|
||||
return AddNodePermConst(kPermNHWCToNCHW, "", {0, 3, 1, 2});
|
||||
return AddNodePermConst(LayoutOptimizerNode(kPermNHWCToNCHW), "",
|
||||
{0, 3, 1, 2});
|
||||
}
|
||||
|
||||
NodeDef* AddNodePermNCHWToNHWC() {
|
||||
return AddNodePermConst(kPermNCHWToNHWC, "", {0, 2, 3, 1});
|
||||
return AddNodePermConst(LayoutOptimizerNode(kPermNCHWToNHWC), "",
|
||||
{0, 2, 3, 1});
|
||||
}
|
||||
|
||||
// Expand all nodes which is in NHWC, but supports NCHW or is layout agnostic.
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
#include "tensorflow/core/grappler/utils.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/protobuf/device_properties.pb.h"
|
||||
|
||||
@ -171,8 +172,8 @@ TEST_F(LayoutOptimizerTest, Conv2DBackpropInput) {
|
||||
|
||||
Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
|
||||
NodeMap node_map(&output);
|
||||
string input_name = AddPrefixToNodeName("Conv2DBackpropInput-InputSizes",
|
||||
"LayoutOptimizer", "-");
|
||||
string input_name =
|
||||
strings::StrCat("Conv2DBackpropInput-InputSizes", "-", "LayoutOptimizer");
|
||||
auto input_sizes_node = node_map.GetNode(input_name);
|
||||
CHECK(input_sizes_node);
|
||||
auto conv2d_backprop_node = node_map.GetNode("Conv2DBackpropInput");
|
||||
@ -198,9 +199,9 @@ TEST_F(LayoutOptimizerTest, Conv2DBackpropInputNonConstInputSizes) {
|
||||
auto conv2d_backprop_node = node_map.GetNode("Conv2DBackpropInput");
|
||||
CHECK(conv2d_backprop_node);
|
||||
EXPECT_EQ(conv2d_backprop_node->input(0),
|
||||
"LayoutOptimizerVecPermuteNHWCToNCHW_Conv2DBackpropInput_0");
|
||||
"Conv2DBackpropInput-0-VecPermuteNHWCToNCHW-LayoutOptimizer");
|
||||
auto input_sizes_node = node_map.GetNode(
|
||||
"LayoutOptimizerVecPermuteNHWCToNCHW_Conv2DBackpropInput_0");
|
||||
"Conv2DBackpropInput-0-VecPermuteNHWCToNCHW-LayoutOptimizer");
|
||||
CHECK(input_sizes_node);
|
||||
EXPECT_EQ(input_sizes_node->input(0), "InputSizesIdentity");
|
||||
EXPECT_EQ(input_sizes_node->op(), "DataFormatVecPermute");
|
||||
@ -216,8 +217,7 @@ TEST_F(LayoutOptimizerTest, FilterSizeIsOne) {
|
||||
GraphDef output;
|
||||
Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
|
||||
NodeMap node_map(&output);
|
||||
EXPECT_FALSE(
|
||||
node_map.GetNode("LayoutOptimizerTransposeNHWCToNCHW-Conv2D-Input"));
|
||||
EXPECT_TRUE(node_map.GetNode("Conv2D-0-TransposeNHWCToNCHW-LayoutOptimizer"));
|
||||
}
|
||||
|
||||
TEST_F(LayoutOptimizerTest, FilterSizeNotOne) {
|
||||
@ -230,8 +230,7 @@ TEST_F(LayoutOptimizerTest, FilterSizeNotOne) {
|
||||
GraphDef output;
|
||||
Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
|
||||
NodeMap node_map(&output);
|
||||
EXPECT_FALSE(
|
||||
node_map.GetNode("LayoutOptimizerTransposeNHWCToNCHW-Conv2D-Input"));
|
||||
EXPECT_TRUE(node_map.GetNode("Conv2D-0-TransposeNHWCToNCHW-LayoutOptimizer"));
|
||||
}
|
||||
|
||||
TEST_F(LayoutOptimizerTest, EqualSizeWithValidPadding) {
|
||||
@ -244,8 +243,7 @@ TEST_F(LayoutOptimizerTest, EqualSizeWithValidPadding) {
|
||||
GraphDef output;
|
||||
Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
|
||||
NodeMap node_map(&output);
|
||||
EXPECT_FALSE(
|
||||
node_map.GetNode("LayoutOptimizerTransposeNHWCToNCHW-Conv2D-Input"));
|
||||
EXPECT_TRUE(node_map.GetNode("Conv2D-0-TransposeNHWCToNCHW-LayoutOptimizer"));
|
||||
}
|
||||
|
||||
TEST_F(LayoutOptimizerTest, EqualSizeWithSamePadding) {
|
||||
@ -258,7 +256,7 @@ TEST_F(LayoutOptimizerTest, EqualSizeWithSamePadding) {
|
||||
GraphDef output;
|
||||
Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
|
||||
NodeMap node_map(&output);
|
||||
EXPECT_TRUE(node_map.GetNode("LayoutOptimizerTransposeNHWCToNCHW-Conv2D-0"));
|
||||
EXPECT_TRUE(node_map.GetNode("Conv2D-0-TransposeNHWCToNCHW-LayoutOptimizer"));
|
||||
}
|
||||
|
||||
TEST_F(LayoutOptimizerTest, NotEqualSizeWithValidPadding) {
|
||||
@ -271,7 +269,7 @@ TEST_F(LayoutOptimizerTest, NotEqualSizeWithValidPadding) {
|
||||
GraphDef output;
|
||||
Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
|
||||
NodeMap node_map(&output);
|
||||
EXPECT_TRUE(node_map.GetNode("LayoutOptimizerTransposeNHWCToNCHW-Conv2D-0"));
|
||||
EXPECT_TRUE(node_map.GetNode("Conv2D-0-TransposeNHWCToNCHW-LayoutOptimizer"));
|
||||
}
|
||||
|
||||
TEST_F(LayoutOptimizerTest, Pad) {
|
||||
@ -290,7 +288,7 @@ TEST_F(LayoutOptimizerTest, Pad) {
|
||||
auto pad = node_map.GetNode("p");
|
||||
EXPECT_EQ(pad->input(0), "Conv2D");
|
||||
|
||||
auto pad_const = node_map.GetNode("LayoutOptimizer-p-c");
|
||||
auto pad_const = node_map.GetNode("p-c-LayoutOptimizer");
|
||||
EXPECT_TRUE(pad_const);
|
||||
EXPECT_TRUE(pad_const->attr().find("value") != pad_const->attr().end());
|
||||
Tensor tensor;
|
||||
@ -478,9 +476,9 @@ TEST_F(LayoutOptimizerTest, SplitDimC) {
|
||||
Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
|
||||
NodeMap node_map(&output);
|
||||
auto split_node = node_map.GetNode("split");
|
||||
EXPECT_EQ(split_node->input(0), "LayoutOptimizer-split-c");
|
||||
EXPECT_EQ(split_node->input(0), "split-c-LayoutOptimizer");
|
||||
EXPECT_EQ(split_node->input(1), "Conv2D");
|
||||
auto split_const = node_map.GetNode("LayoutOptimizer-split-c");
|
||||
auto split_const = node_map.GetNode("split-c-LayoutOptimizer");
|
||||
EXPECT_EQ(split_const->op(), "Const");
|
||||
EXPECT_EQ(split_const->attr().at({"value"}).tensor().int_val(0), 1);
|
||||
}
|
||||
@ -498,9 +496,9 @@ TEST_F(LayoutOptimizerTest, SplitDimH) {
|
||||
Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
|
||||
NodeMap node_map(&output);
|
||||
auto split_node = node_map.GetNode("split");
|
||||
EXPECT_EQ(split_node->input(0), "LayoutOptimizer-split-c");
|
||||
EXPECT_EQ(split_node->input(0), "split-c-LayoutOptimizer");
|
||||
EXPECT_EQ(split_node->input(1), "Conv2D");
|
||||
auto split_const = node_map.GetNode("LayoutOptimizer-split-c");
|
||||
auto split_const = node_map.GetNode("split-c-LayoutOptimizer");
|
||||
EXPECT_EQ(split_const->op(), "Const");
|
||||
EXPECT_EQ(split_const->attr().at({"value"}).tensor().int_val(0), 2);
|
||||
}
|
||||
@ -518,9 +516,9 @@ TEST_F(LayoutOptimizerTest, SplitDimW) {
|
||||
Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
|
||||
NodeMap node_map(&output);
|
||||
auto split_node = node_map.GetNode("split");
|
||||
EXPECT_EQ(split_node->input(0), "LayoutOptimizer-split-c");
|
||||
EXPECT_EQ(split_node->input(0), "split-c-LayoutOptimizer");
|
||||
EXPECT_EQ(split_node->input(1), "Conv2D");
|
||||
auto split_const = node_map.GetNode("LayoutOptimizer-split-c");
|
||||
auto split_const = node_map.GetNode("split-c-LayoutOptimizer");
|
||||
EXPECT_EQ(split_const->op(), "Const");
|
||||
EXPECT_EQ(split_const->attr().at({"value"}).tensor().int_val(0), 3);
|
||||
}
|
||||
@ -538,9 +536,9 @@ TEST_F(LayoutOptimizerTest, SplitDimN) {
|
||||
Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
|
||||
NodeMap node_map(&output);
|
||||
auto split_node = node_map.GetNode("split");
|
||||
EXPECT_EQ(split_node->input(0), "LayoutOptimizer-split-c");
|
||||
EXPECT_EQ(split_node->input(0), "split-c-LayoutOptimizer");
|
||||
EXPECT_EQ(split_node->input(1), "Conv2D");
|
||||
auto split_const = node_map.GetNode("LayoutOptimizer-split-c");
|
||||
auto split_const = node_map.GetNode("split-c-LayoutOptimizer");
|
||||
EXPECT_EQ(split_const->op(), "Const");
|
||||
EXPECT_EQ(split_const->attr().at({"value"}).tensor().int_val(0), 0);
|
||||
}
|
||||
@ -559,9 +557,9 @@ TEST_F(LayoutOptimizerTest, SplitNonConstDim) {
|
||||
Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
|
||||
NodeMap node_map(&output);
|
||||
auto split_node = node_map.GetNode("split");
|
||||
EXPECT_EQ(split_node->input(0), "LayoutOptimizerDimMapNHWCToNCHW_split_0");
|
||||
EXPECT_EQ(split_node->input(0), "split-0-DimMapNHWCToNCHW-LayoutOptimizer");
|
||||
EXPECT_EQ(split_node->input(1), "Conv2D");
|
||||
auto map_node = node_map.GetNode("LayoutOptimizerDimMapNHWCToNCHW_split_0");
|
||||
auto map_node = node_map.GetNode("split-0-DimMapNHWCToNCHW-LayoutOptimizer");
|
||||
EXPECT_EQ(map_node->op(), "DataFormatDimMap");
|
||||
EXPECT_EQ(map_node->input(0), "i1");
|
||||
}
|
||||
@ -584,8 +582,8 @@ TEST_F(LayoutOptimizerTest, SplitSamePortToMultipleInputsOfSameNode) {
|
||||
EXPECT_EQ(concat_node->input(0), "split:1");
|
||||
EXPECT_EQ(concat_node->input(1), "split:1");
|
||||
EXPECT_EQ(concat_node->input(2), "split:1");
|
||||
EXPECT_EQ(concat_node->input(3), "LayoutOptimizer-concat-axis");
|
||||
auto concat_dim = node_map.GetNode("LayoutOptimizer-concat-axis");
|
||||
EXPECT_EQ(concat_node->input(3), "concat-axis-LayoutOptimizer");
|
||||
auto concat_dim = node_map.GetNode("concat-axis-LayoutOptimizer");
|
||||
EXPECT_EQ(concat_dim->attr().at({"value"}).tensor().int_val(0), 1);
|
||||
}
|
||||
|
||||
@ -605,8 +603,8 @@ TEST_F(LayoutOptimizerTest, ConcatDimH) {
|
||||
auto concat_node = node_map.GetNode("concat");
|
||||
EXPECT_EQ(concat_node->input(0), "split");
|
||||
EXPECT_EQ(concat_node->input(1), "split:1");
|
||||
EXPECT_EQ(concat_node->input(2), "LayoutOptimizer-concat-axis");
|
||||
auto concat_dim = node_map.GetNode("LayoutOptimizer-concat-axis");
|
||||
EXPECT_EQ(concat_node->input(2), "concat-axis-LayoutOptimizer");
|
||||
auto concat_dim = node_map.GetNode("concat-axis-LayoutOptimizer");
|
||||
EXPECT_EQ(concat_dim->attr().at({"value"}).tensor().int_val(0), 2);
|
||||
}
|
||||
|
||||
@ -627,9 +625,9 @@ TEST_F(LayoutOptimizerTest, ConcatNonConst) {
|
||||
auto concat_node = node_map.GetNode("concat");
|
||||
EXPECT_EQ(concat_node->input(0), "split");
|
||||
EXPECT_EQ(concat_node->input(1), "split:1");
|
||||
EXPECT_EQ(concat_node->input(2), "LayoutOptimizerDimMapNHWCToNCHW_concat_2");
|
||||
EXPECT_EQ(concat_node->input(2), "concat-2-DimMapNHWCToNCHW-LayoutOptimizer");
|
||||
auto concat_dim =
|
||||
node_map.GetNode("LayoutOptimizerDimMapNHWCToNCHW_concat_2");
|
||||
node_map.GetNode("concat-2-DimMapNHWCToNCHW-LayoutOptimizer");
|
||||
EXPECT_EQ(concat_dim->op(), "DataFormatDimMap");
|
||||
EXPECT_EQ(concat_dim->input(0), "i");
|
||||
}
|
||||
@ -650,8 +648,8 @@ TEST_F(LayoutOptimizerTest, ConcatDimW) {
|
||||
auto concat_node = node_map.GetNode("concat");
|
||||
EXPECT_EQ(concat_node->input(0), "split");
|
||||
EXPECT_EQ(concat_node->input(1), "split:1");
|
||||
EXPECT_EQ(concat_node->input(2), "LayoutOptimizer-concat-axis");
|
||||
auto concat_dim = node_map.GetNode("LayoutOptimizer-concat-axis");
|
||||
EXPECT_EQ(concat_node->input(2), "concat-axis-LayoutOptimizer");
|
||||
auto concat_dim = node_map.GetNode("concat-axis-LayoutOptimizer");
|
||||
EXPECT_EQ(concat_dim->attr().at({"value"}).tensor().int_val(0), 3);
|
||||
}
|
||||
|
||||
@ -671,8 +669,8 @@ TEST_F(LayoutOptimizerTest, ConcatDimN) {
|
||||
auto concat_node = node_map.GetNode("concat");
|
||||
EXPECT_EQ(concat_node->input(0), "split");
|
||||
EXPECT_EQ(concat_node->input(1), "split:1");
|
||||
EXPECT_EQ(concat_node->input(2), "LayoutOptimizer-concat-axis");
|
||||
auto concat_dim = node_map.GetNode("LayoutOptimizer-concat-axis");
|
||||
EXPECT_EQ(concat_node->input(2), "concat-axis-LayoutOptimizer");
|
||||
auto concat_dim = node_map.GetNode("concat-axis-LayoutOptimizer");
|
||||
EXPECT_EQ(concat_dim->attr().at({"value"}).tensor().int_val(0), 0);
|
||||
}
|
||||
|
||||
@ -692,8 +690,8 @@ TEST_F(LayoutOptimizerTest, ConcatDimC) {
|
||||
auto concat_node = node_map.GetNode("concat");
|
||||
EXPECT_EQ(concat_node->input(0), "split");
|
||||
EXPECT_EQ(concat_node->input(1), "split:1");
|
||||
EXPECT_EQ(concat_node->input(2), "LayoutOptimizer-concat-axis");
|
||||
auto concat_dim = node_map.GetNode("LayoutOptimizer-concat-axis");
|
||||
EXPECT_EQ(concat_node->input(2), "concat-axis-LayoutOptimizer");
|
||||
auto concat_dim = node_map.GetNode("concat-axis-LayoutOptimizer");
|
||||
EXPECT_EQ(concat_dim->attr().at({"value"}).tensor().int_val(0), 1);
|
||||
}
|
||||
|
||||
@ -779,7 +777,7 @@ TEST_F(LayoutOptimizerTest, Mul4DAndUnknownRank) {
|
||||
// Node mul should not be processed by layout optimizer, because one of its
|
||||
// inputs is of unknown rank.
|
||||
EXPECT_EQ(mul_node->input(0),
|
||||
"LayoutOptimizerTransposeNCHWToNHWC-Conv2D-0-0");
|
||||
"Conv2D-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
|
||||
EXPECT_EQ(mul_node->input(1), "unknown");
|
||||
}
|
||||
|
||||
@ -814,8 +812,8 @@ TEST_F(LayoutOptimizerTest, Mul4DAndVector) {
|
||||
NodeMap node_map(&output);
|
||||
auto mul_node = node_map.GetNode("mul");
|
||||
EXPECT_EQ(mul_node->input(0), "Conv2D");
|
||||
EXPECT_EQ(mul_node->input(1), "LayoutOptimizerReshapeNHWCToNCHW-mul-1");
|
||||
auto mul_const = node_map.GetNode("LayoutOptimizerReshapeConst-mul-1");
|
||||
EXPECT_EQ(mul_node->input(1), "mul-1-ReshapeNHWCToNCHW-LayoutOptimizer");
|
||||
auto mul_const = node_map.GetNode("mul-1-ReshapeConst-LayoutOptimizer");
|
||||
Tensor tensor;
|
||||
EXPECT_TRUE(
|
||||
tensor.FromProto(mul_const->mutable_attr()->at({"value"}).tensor()));
|
||||
@ -837,9 +835,9 @@ TEST_F(LayoutOptimizerTest, MulVectorAnd4D) {
|
||||
Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
|
||||
NodeMap node_map(&output);
|
||||
auto mul_node = node_map.GetNode("mul");
|
||||
EXPECT_EQ(mul_node->input(0), "LayoutOptimizerReshapeNHWCToNCHW-mul-0");
|
||||
EXPECT_EQ(mul_node->input(0), "mul-0-ReshapeNHWCToNCHW-LayoutOptimizer");
|
||||
EXPECT_EQ(mul_node->input(1), "Conv2D");
|
||||
auto mul_const = node_map.GetNode("LayoutOptimizerReshapeConst-mul-0");
|
||||
auto mul_const = node_map.GetNode("mul-0-ReshapeConst-LayoutOptimizer");
|
||||
Tensor tensor;
|
||||
EXPECT_TRUE(
|
||||
tensor.FromProto(mul_const->mutable_attr()->at({"value"}).tensor()));
|
||||
@ -863,10 +861,10 @@ TEST_F(LayoutOptimizerTest, SliceConst) {
|
||||
NodeMap node_map(&output);
|
||||
auto slice_node = node_map.GetNode("slice");
|
||||
EXPECT_EQ(slice_node->input(0), "Conv2D");
|
||||
EXPECT_EQ(slice_node->input(1), "LayoutOptimizer-slice-begin");
|
||||
EXPECT_EQ(slice_node->input(2), "LayoutOptimizer-slice-size");
|
||||
EXPECT_EQ(slice_node->input(1), "slice-begin-LayoutOptimizer");
|
||||
EXPECT_EQ(slice_node->input(2), "slice-size-LayoutOptimizer");
|
||||
|
||||
auto begin_const = node_map.GetNode("LayoutOptimizer-slice-begin");
|
||||
auto begin_const = node_map.GetNode("slice-begin-LayoutOptimizer");
|
||||
Tensor begin_tensor;
|
||||
EXPECT_TRUE(begin_tensor.FromProto(
|
||||
begin_const->mutable_attr()->at({"value"}).tensor()));
|
||||
@ -874,7 +872,7 @@ TEST_F(LayoutOptimizerTest, SliceConst) {
|
||||
test::FillValues<int>(&begin_tensor_expected, {0, 1, 2, 3});
|
||||
test::ExpectTensorEqual<int>(begin_tensor_expected, begin_tensor);
|
||||
|
||||
auto size_const = node_map.GetNode("LayoutOptimizer-slice-size");
|
||||
auto size_const = node_map.GetNode("slice-size-LayoutOptimizer");
|
||||
Tensor size_tensor;
|
||||
EXPECT_TRUE(size_tensor.FromProto(
|
||||
size_const->mutable_attr()->at({"value"}).tensor()));
|
||||
@ -901,13 +899,13 @@ TEST_F(LayoutOptimizerTest, SliceNonConst) {
|
||||
auto slice_node = node_map.GetNode("slice");
|
||||
EXPECT_EQ(slice_node->input(0), "Conv2D");
|
||||
EXPECT_EQ(slice_node->input(1),
|
||||
"LayoutOptimizerVecPermuteNHWCToNCHW_slice_1");
|
||||
"slice-1-VecPermuteNHWCToNCHW-LayoutOptimizer");
|
||||
EXPECT_EQ(slice_node->input(2),
|
||||
"LayoutOptimizerVecPermuteNHWCToNCHW_slice_2");
|
||||
auto perm1 = node_map.GetNode("LayoutOptimizerVecPermuteNHWCToNCHW_slice_1");
|
||||
"slice-2-VecPermuteNHWCToNCHW-LayoutOptimizer");
|
||||
auto perm1 = node_map.GetNode("slice-1-VecPermuteNHWCToNCHW-LayoutOptimizer");
|
||||
EXPECT_EQ(perm1->op(), "DataFormatVecPermute");
|
||||
EXPECT_EQ(perm1->input(0), "ibegin");
|
||||
auto perm2 = node_map.GetNode("LayoutOptimizerVecPermuteNHWCToNCHW_slice_2");
|
||||
auto perm2 = node_map.GetNode("slice-2-VecPermuteNHWCToNCHW-LayoutOptimizer");
|
||||
EXPECT_EQ(perm1->op(), "DataFormatVecPermute");
|
||||
EXPECT_EQ(perm2->input(0), "isize");
|
||||
}
|
||||
@ -915,7 +913,7 @@ TEST_F(LayoutOptimizerTest, SliceNonConst) {
|
||||
TEST_F(LayoutOptimizerTest, DoNotApplyOptimizerTwice) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
auto scalar =
|
||||
ops::Const(s.WithOpName("LayoutOptimizerAlreadyApplied"), 3.0f, {});
|
||||
ops::Const(s.WithOpName("AlreadyApplied-LayoutOptimizer"), 3.0f, {});
|
||||
auto mul = ops::Mul(s.WithOpName("mul"), scalar, scalar);
|
||||
auto o = ops::Identity(s.WithOpName("o"), mul);
|
||||
GrapplerItem item;
|
||||
@ -942,15 +940,15 @@ TEST_F(LayoutOptimizerTest, ShapeNWithInputs4DAnd4D) {
|
||||
EXPECT_EQ(shapen_node->input(1), "Conv2D");
|
||||
auto add_node = node_map.GetNode("add");
|
||||
EXPECT_EQ(add_node->input(0),
|
||||
"LayoutOptimizerVecPermuteNCHWToNHWC-shapen-0-0");
|
||||
"shapen-0-0-VecPermuteNCHWToNHWC-LayoutOptimizer");
|
||||
EXPECT_EQ(add_node->input(1),
|
||||
"LayoutOptimizerVecPermuteNCHWToNHWC-shapen-0-1");
|
||||
"shapen-0-1-VecPermuteNCHWToNHWC-LayoutOptimizer");
|
||||
auto vec_permute1 =
|
||||
node_map.GetNode("LayoutOptimizerVecPermuteNCHWToNHWC-shapen-0-0");
|
||||
node_map.GetNode("shapen-0-0-VecPermuteNCHWToNHWC-LayoutOptimizer");
|
||||
EXPECT_EQ(vec_permute1->input(0), "shapen");
|
||||
EXPECT_EQ(vec_permute1->op(), "DataFormatVecPermute");
|
||||
auto vec_permute2 =
|
||||
node_map.GetNode("LayoutOptimizerVecPermuteNCHWToNHWC-shapen-0-1");
|
||||
node_map.GetNode("shapen-0-1-VecPermuteNCHWToNHWC-LayoutOptimizer");
|
||||
EXPECT_EQ(vec_permute2->input(0), "shapen:1");
|
||||
EXPECT_EQ(vec_permute2->op(), "DataFormatVecPermute");
|
||||
}
|
||||
@ -973,9 +971,9 @@ TEST_F(LayoutOptimizerTest, ShapeNWithInputsVectorAnd4D) {
|
||||
auto add_node = node_map.GetNode("add");
|
||||
EXPECT_EQ(add_node->input(0), "shapen");
|
||||
EXPECT_EQ(add_node->input(1),
|
||||
"LayoutOptimizerVecPermuteNCHWToNHWC-shapen-0-1");
|
||||
"shapen-0-1-VecPermuteNCHWToNHWC-LayoutOptimizer");
|
||||
auto vec_permute =
|
||||
node_map.GetNode("LayoutOptimizerVecPermuteNCHWToNHWC-shapen-0-1");
|
||||
node_map.GetNode("shapen-0-1-VecPermuteNCHWToNHWC-LayoutOptimizer");
|
||||
EXPECT_EQ(vec_permute->input(0), "shapen:1");
|
||||
EXPECT_EQ(vec_permute->op(), "DataFormatVecPermute");
|
||||
}
|
||||
@ -1039,9 +1037,9 @@ TEST_F(LayoutOptimizerTest, MergeBothInputsConvertible) {
|
||||
EXPECT_EQ(merge_node->input(0), "Conv2D");
|
||||
EXPECT_EQ(merge_node->input(1), "i1");
|
||||
auto i2_node = node_map.GetNode("i2");
|
||||
EXPECT_EQ(i2_node->input(0), "LayoutOptimizerTransposeNCHWToNHWC-merge-0-0");
|
||||
EXPECT_EQ(i2_node->input(0), "merge-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
|
||||
auto transpose =
|
||||
node_map.GetNode("LayoutOptimizerTransposeNCHWToNHWC-merge-0-0");
|
||||
node_map.GetNode("merge-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
|
||||
EXPECT_EQ(transpose->input(0), "merge");
|
||||
}
|
||||
|
||||
@ -1060,7 +1058,7 @@ TEST_F(LayoutOptimizerTest, MergeOneInputNotConvertible) {
|
||||
auto merge_node = node_map.GetNode("merge");
|
||||
EXPECT_EQ(merge_node->input(0), "tensor_4d");
|
||||
EXPECT_EQ(merge_node->input(1),
|
||||
"LayoutOptimizerTransposeNCHWToNHWC-Conv2D-0-1");
|
||||
"Conv2D-0-1-TransposeNCHWToNHWC-LayoutOptimizer");
|
||||
}
|
||||
|
||||
TEST_F(LayoutOptimizerTest, Complex) {
|
||||
@ -1078,7 +1076,7 @@ TEST_F(LayoutOptimizerTest, Complex) {
|
||||
EXPECT_EQ(merge_node->input(0), "Conv2D");
|
||||
EXPECT_EQ(merge_node->input(1), "Conv2D");
|
||||
auto trans =
|
||||
node_map.GetNode("LayoutOptimizerTransposeNCHWToNHWC-complex-0-0");
|
||||
node_map.GetNode("complex-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
|
||||
EXPECT_EQ(trans->attr().at("T").type(), DT_COMPLEX64);
|
||||
}
|
||||
|
||||
@ -1098,12 +1096,12 @@ TEST_F(LayoutOptimizerTest, IdentityNWithInputsVectorAnd4D) {
|
||||
EXPECT_EQ(i->input(0), "vector");
|
||||
EXPECT_EQ(i->input(1), "Conv2D");
|
||||
auto trans =
|
||||
node_map.GetNode("LayoutOptimizerTransposeNCHWToNHWC-identity_n-0-1");
|
||||
node_map.GetNode("identity_n-0-1-TransposeNCHWToNHWC-LayoutOptimizer");
|
||||
EXPECT_EQ(trans->input(0), "identity_n:1");
|
||||
auto add_node = node_map.GetNode("add");
|
||||
EXPECT_EQ(add_node->input(0), "identity_n");
|
||||
EXPECT_EQ(add_node->input(1),
|
||||
"LayoutOptimizerTransposeNCHWToNHWC-identity_n-0-1");
|
||||
"identity_n-0-1-TransposeNCHWToNHWC-LayoutOptimizer");
|
||||
}
|
||||
} // namespace
|
||||
} // namespace grappler
|
||||
|
@ -11782,6 +11782,24 @@ op {
|
||||
}
|
||||
allows_uninitialized_input: true
|
||||
}
|
||||
op {
|
||||
name: "DebugGradientRefIdentity"
|
||||
input_arg {
|
||||
name: "input"
|
||||
type_attr: "T"
|
||||
is_ref: true
|
||||
}
|
||||
output_arg {
|
||||
name: "output"
|
||||
type_attr: "T"
|
||||
is_ref: true
|
||||
}
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
}
|
||||
allows_uninitialized_input: true
|
||||
}
|
||||
op {
|
||||
name: "DebugIdentity"
|
||||
input_arg {
|
||||
|
@ -6545,7 +6545,27 @@ op {
|
||||
type: "type"
|
||||
}
|
||||
summary: "Identity op for gradient debugging."
|
||||
description: "This op is hidden from public in Python. It is used by TensorFlow Debugger to\nregister gradient tensors for gradient debugging."
|
||||
description: "This op is hidden from public in Python. It is used by TensorFlow Debugger to\nregister gradient tensors for gradient debugging.\nThis op operates on non-reference-type tensors."
|
||||
allows_uninitialized_input: true
|
||||
}
|
||||
op {
|
||||
name: "DebugGradientRefIdentity"
|
||||
input_arg {
|
||||
name: "input"
|
||||
type_attr: "T"
|
||||
is_ref: true
|
||||
}
|
||||
output_arg {
|
||||
name: "output"
|
||||
type_attr: "T"
|
||||
is_ref: true
|
||||
}
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
}
|
||||
summary: "Identity op for gradient debugging."
|
||||
description: "This op is hidden from public in Python. It is used by TensorFlow Debugger to\nregister gradient tensors for gradient debugging.\nThis op operates on reference-type tensors."
|
||||
allows_uninitialized_input: true
|
||||
}
|
||||
op {
|
||||
|
@ -51,6 +51,9 @@ const int kTensorBundleMinProducer = 0;
|
||||
const int kTensorBundleMinConsumer = 0;
|
||||
const int kTensorBundleVersion = 1;
|
||||
|
||||
// Size of our input buffer for streaming reads
|
||||
static const int kBufferSize = 1024 * 1024;
|
||||
|
||||
// Key to the special BundleHeaderProto entry. Do not change this, as clients
|
||||
// can make the assumption that the header is always the first entry in the
|
||||
// bundle.
|
||||
@ -810,8 +813,7 @@ Status BundleReader::GetValue(const BundleEntryProto& entry, Tensor* val) {
|
||||
std::unique_ptr<RandomAccessFile> file = nullptr;
|
||||
TF_RETURN_IF_ERROR(env_->NewRandomAccessFile(
|
||||
DataFilename(prefix_, entry.shard_id(), num_shards_), &file));
|
||||
buffered_file =
|
||||
new io::InputBuffer(file.release(), 1024 << 10 /* 1024KB buffer */);
|
||||
buffered_file = new io::InputBuffer(file.release(), kBufferSize);
|
||||
// The InputBuffer and RandomAccessFile objects are both released in dtor.
|
||||
data_[entry.shard_id()] = buffered_file;
|
||||
}
|
||||
@ -819,11 +821,21 @@ Status BundleReader::GetValue(const BundleEntryProto& entry, Tensor* val) {
|
||||
|
||||
TF_RETURN_IF_ERROR(buffered_file->Seek(entry.offset()));
|
||||
uint32 actual_crc32c = 0;
|
||||
|
||||
if (DataTypeCanUseMemcpy(entry.dtype())) {
|
||||
char* backing_buffer = const_cast<char*>((ret->tensor_data().data()));
|
||||
size_t unused_bytes_read;
|
||||
TF_RETURN_IF_ERROR(buffered_file->ReadNBytes(entry.size(), backing_buffer,
|
||||
&unused_bytes_read));
|
||||
if (entry.size() > kBufferSize) {
|
||||
StringPiece sp;
|
||||
TF_RETURN_IF_ERROR(buffered_file->file()->Read(
|
||||
entry.offset(), entry.size(), &sp, backing_buffer));
|
||||
if (sp.data() != backing_buffer) {
|
||||
memmove(backing_buffer, sp.data(), entry.size());
|
||||
}
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(buffered_file->ReadNBytes(entry.size(), backing_buffer,
|
||||
&unused_bytes_read));
|
||||
}
|
||||
actual_crc32c = crc32c::Value(backing_buffer, entry.size());
|
||||
} else if (entry.dtype() == DT_VARIANT) {
|
||||
// Relies on io::InputBuffer's buffering, because we issue many neighboring
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Benchmarks
|
||||
# Defining and Running Benchmarks
|
||||
|
||||
This guide contains instructions for defining and running a TensorFlow benchmark. These benchmarks store output in [TestResults](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/util/test_log.proto) format. If these benchmarks are added to TensorFlow github repo, then we will run them daily with our continuous build and display a graph on our dashboard: https://benchmarks-dot-tensorflow-testing.appspot.com/.
|
||||
|
||||
@ -52,6 +52,19 @@ Key points to note in the example above:
|
||||
* Benchmark method calls `report_benchmark` to report the metric value.
|
||||
|
||||
|
||||
## Running with Python
|
||||
|
||||
Use the `--benchmarks` flag to run the benchmark with python. A [BenchmarkEntries](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/core/util/test_log.proto) proto will be printed.
|
||||
|
||||
```
|
||||
python sample_benchmark.py --benchmarks=SampleBenchmark
|
||||
```
|
||||
|
||||
Setting the flag as `--benchmarks=.` or `--benchmarks=all` would work as well.
|
||||
|
||||
(Please ensure that Tensorflow is installed to successfully import the package in the line `import tensorflow as tf`. For installation instructions, see [Installing TensorFlow](https://www.tensorflow.org/install/). This step is not necessary when running with bazel.)
|
||||
|
||||
|
||||
## Adding a `bazel` Target
|
||||
|
||||
We have a special target called `tf_py_logged_benchmark` for benchmarks defined under TensorFlow github repo. `tf_py_logged_benchmark` should wrap around a regular `py_test` target. Running a `tf_py_logged_benchmark` would print a [TestResults](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/util/test_log.proto) proto. Defining a `tf_py_logged_benchmark` also lets us run it with TensorFlow continuous build.
|
||||
@ -84,7 +97,7 @@ load("//tensorflow/tools/test:performance.bzl", "tf_py_logged_benchmark")
|
||||
|
||||
tf_py_logged_benchmark(
|
||||
name = "sample_logged_benchmark",
|
||||
target = "//tensorflow/tools/test:sample_benchmark",
|
||||
target = "//tensorflow/examples/benchmark:sample_benchmark",
|
||||
)
|
||||
```
|
||||
|
||||
|
@ -59,14 +59,14 @@ filegroup(
|
||||
"**/OWNERS",
|
||||
],
|
||||
),
|
||||
visibility = ["//third_party/tensorflow:__subpackages__"],
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
||||
```
|
||||
|
||||
* When adding new BUILD file, add this line to `tensorflow/BUILD` file into `all_opensource_files` target.
|
||||
|
||||
```
|
||||
"//third_party/tensorflow/<directory>:all_files",
|
||||
"//tensorflow/<directory>:all_files",
|
||||
```
|
||||
|
||||
* For all Python BUILD targets (libraries and tests) add next line:
|
||||
|
@ -305,7 +305,7 @@ features, we can build the estimator.
|
||||
|
||||
## Instantiate an Estimator
|
||||
|
||||
The Iris problem is a classic classifier problem. Fortunately, TensorFlow
|
||||
The Iris problem is a classic classification problem. Fortunately, TensorFlow
|
||||
provides several pre-made classifier Estimators, including:
|
||||
|
||||
* @{tf.estimator.DNNClassifier}—for deep models that perform multi-class
|
||||
|
@ -203,8 +203,8 @@ bn = tf.contrib.layers.batch_norm(input_layer, fused=True, data_format='NCHW')
|
||||
|
||||
### RNN Performance
|
||||
|
||||
There are many ways to specify an RNN computation in Tensorflow and they have
|
||||
have trade-offs with respect to model flexibility and performance. The
|
||||
There are many ways to specify an RNN computation in TensorFlow and they have
|
||||
trade-offs with respect to model flexibility and performance. The
|
||||
@{tf.nn.rnn_cell.BasicLSTMCell} should be considered a reference implementation
|
||||
and used only as a last resort when no other options will work.
|
||||
|
||||
@ -230,7 +230,7 @@ If you need to run one step of the RNN at a time, as might be the case in
|
||||
reinforcement learning with a recurrent policy, then you should use the
|
||||
@{tf.contrib.rnn.LSTMBlockCell} with your own environment interaction loop
|
||||
inside a @{tf.while_loop} construct. Running one step of the RNN at a time and
|
||||
returning to python is possible but it will be slower.
|
||||
returning to Python is possible, but it will be slower.
|
||||
|
||||
On CPUs, mobile devices, and if @{tf.contrib.cudnn_rnn} is not available on
|
||||
your GPU, the fastest and most memory efficient option is
|
||||
|
@ -15,6 +15,6 @@ limitations under the License.
|
||||
*/
|
||||
|
||||
//go:generate go generate ../genop
|
||||
//go:generate go run ../genop/main.go -outfile wrappers.go
|
||||
//go:generate go run ../genop/main.go -outfile wrappers.go -api_def_dirs ../../core/api_def/base_api/
|
||||
|
||||
package op
|
||||
|
@ -641,7 +641,10 @@ py_test(
|
||||
size = "small",
|
||||
srcs = ["cli/curses_ui_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = ["no_windows"],
|
||||
tags = [
|
||||
"no_oss",
|
||||
"no_windows",
|
||||
],
|
||||
deps = [
|
||||
":curses_ui",
|
||||
":debugger_cli_common",
|
||||
@ -831,6 +834,9 @@ py_test(
|
||||
size = "small",
|
||||
srcs = ["cli/tensor_format_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = [
|
||||
"no_oss",
|
||||
],
|
||||
deps = [
|
||||
":debug_data",
|
||||
":tensor_format",
|
||||
@ -901,6 +907,9 @@ cuda_py_test(
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python:variables",
|
||||
],
|
||||
tags = [
|
||||
"no_oss",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
|
@ -173,7 +173,7 @@ class TFETensorTest(test_util.TensorFlowTestCase):
|
||||
self.assertIn("id=%d, shape=%s, dtype=%s, numpy=\n%r" %
|
||||
(t._id, t.shape, t.dtype.name, t.numpy()), tensor_repr)
|
||||
|
||||
def testTensorStrReprObeyNumpyPrintOptions(self):
|
||||
def disabled_testTensorStrReprObeyNumpyPrintOptions(self):
|
||||
orig_threshold = np.get_printoptions()["threshold"]
|
||||
orig_edgeitems = np.get_printoptions()["edgeitems"]
|
||||
np.set_printoptions(threshold=2, edgeitems=1)
|
||||
|
@ -420,5 +420,10 @@ def _warmstart(warmstart_settings):
|
||||
if warmstart_settings.vars_to_warmstart:
|
||||
logging.info("Warm-starting variable: {}; prev_var_name: {}".format(
|
||||
var_name, prev_var_name or "Unchanged"))
|
||||
# Because we use a default empty list in grouped_variables, single
|
||||
# unpartitioned variables will be lists here, which we rectify in order
|
||||
# for init_from_checkpoint logic to work correctly.
|
||||
if len(variable) == 1:
|
||||
variable = variable[0]
|
||||
_warmstart_var(variable, warmstart_settings.ckpt_to_initialize_from,
|
||||
prev_var_name)
|
||||
|
@ -659,6 +659,67 @@ class WarmStartingUtilTest(test.TestCase):
|
||||
]
|
||||
}, sess)
|
||||
|
||||
def testWarmStartMoreSettingsNoPartitioning(self):
|
||||
# Create old and new vocabs for sparse column "sc_vocab".
|
||||
prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],
|
||||
"old_vocab")
|
||||
new_vocab_path = self._write_vocab(
|
||||
["orange", "guava", "banana", "apple", "raspberry",
|
||||
"blueberry"], "new_vocab")
|
||||
# Create feature columns.
|
||||
sc_hash = fc.categorical_column_with_hash_bucket(
|
||||
"sc_hash", hash_bucket_size=15)
|
||||
sc_keys = fc.categorical_column_with_vocabulary_list(
|
||||
"sc_keys", vocabulary_list=["a", "b", "c", "e"])
|
||||
sc_vocab = fc.categorical_column_with_vocabulary_file(
|
||||
"sc_vocab", vocabulary_file=new_vocab_path, vocabulary_size=6)
|
||||
all_linear_cols = [sc_hash, sc_keys, sc_vocab]
|
||||
|
||||
# Save checkpoint from which to warm-start.
|
||||
with ops.Graph().as_default() as g:
|
||||
with self.test_session(graph=g) as sess:
|
||||
variable_scope.get_variable(
|
||||
"linear_model/sc_hash/weights", shape=[15, 1], initializer=norms())
|
||||
sc_keys_weights = variable_scope.get_variable(
|
||||
"some_other_name", shape=[4, 1], initializer=rand())
|
||||
variable_scope.get_variable(
|
||||
"linear_model/sc_vocab/weights",
|
||||
initializer=[[0.5], [1.], [2.], [3.]])
|
||||
self._write_checkpoint(sess)
|
||||
prev_keys_val = sess.run(sc_keys_weights)
|
||||
|
||||
# New graph, new session with warmstarting.
|
||||
with ops.Graph().as_default() as g:
|
||||
with self.test_session(graph=g) as sess:
|
||||
cols_to_vars = self._create_linear_model(all_linear_cols,
|
||||
partitioner=None)
|
||||
vocab_info = ws_util._VocabInfo(
|
||||
new_vocab=sc_vocab.vocabulary_file,
|
||||
new_vocab_size=sc_vocab.vocabulary_size,
|
||||
num_oov_buckets=sc_vocab.num_oov_buckets,
|
||||
old_vocab=prev_vocab_path
|
||||
)
|
||||
ws_settings = ws_util._WarmStartSettings(
|
||||
self.get_temp_dir(),
|
||||
vars_to_warmstart=".*(sc_keys|sc_vocab).*",
|
||||
var_name_to_vocab_info={
|
||||
ws_util._infer_var_name(cols_to_vars[sc_vocab]): vocab_info
|
||||
},
|
||||
var_name_to_prev_var_name={
|
||||
ws_util._infer_var_name(cols_to_vars[sc_keys]):
|
||||
"some_other_name"
|
||||
})
|
||||
ws_util._warmstart(ws_settings)
|
||||
sess.run(variables.global_variables_initializer())
|
||||
# Verify weights were correctly warmstarted. Var corresponding to
|
||||
# sc_hash should not be warm-started. Var corresponding to sc_vocab
|
||||
# should be correctly warmstarted after vocab remapping.
|
||||
self._assert_cols_to_vars(cols_to_vars, {
|
||||
sc_keys: [prev_keys_val],
|
||||
sc_hash: [np.zeros([15, 1])],
|
||||
sc_vocab: [np.array([[3.], [2.], [1.], [0.5], [0.], [0.]])]
|
||||
}, sess)
|
||||
|
||||
def testWarmStartVarsToWarmstartIsNone(self):
|
||||
# Create old and new vocabs for sparse column "sc_vocab".
|
||||
prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],
|
||||
|
@ -2950,23 +2950,42 @@ class Graph(object):
|
||||
|
||||
"""
|
||||
# pylint: enable=line-too-long
|
||||
with self._lock:
|
||||
graph = graph_pb2.GraphDef()
|
||||
graph.versions.CopyFrom(self._graph_def_versions)
|
||||
bytesize = 0
|
||||
for op_id in sorted(self._nodes_by_id):
|
||||
op = self._nodes_by_id[op_id]
|
||||
if from_version is None or op_id > from_version:
|
||||
graph.node.extend([op.node_def])
|
||||
if op.outputs and add_shapes:
|
||||
assert "_output_shapes" not in graph.node[-1].attr
|
||||
graph.node[-1].attr["_output_shapes"].list.shape.extend(
|
||||
[output.get_shape().as_proto() for output in op.outputs])
|
||||
bytesize += op.node_def.ByteSize()
|
||||
if bytesize >= (1 << 31) or bytesize < 0:
|
||||
raise ValueError("GraphDef cannot be larger than 2GB.")
|
||||
self._copy_functions_to_graph_def(graph, bytesize)
|
||||
return graph, self._version
|
||||
if _USE_C_API:
|
||||
with self._lock:
|
||||
with c_api_util.tf_buffer() as buf:
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
c_api.TF_GraphToGraphDef(self._c_graph, buf, status)
|
||||
data = c_api.TF_GetBuffer(buf)
|
||||
graph = graph_pb2.GraphDef()
|
||||
graph.ParseFromString(compat.as_bytes(data))
|
||||
# Strip the experimental library field iff it's empty.
|
||||
if not graph.library.function:
|
||||
graph.ClearField("library")
|
||||
|
||||
if add_shapes:
|
||||
for node in graph.node:
|
||||
op = self._nodes_by_name[node.name]
|
||||
if op.outputs:
|
||||
node.attr["_output_shapes"].list.shape.extend(
|
||||
[output.get_shape().as_proto() for output in op.outputs])
|
||||
else:
|
||||
with self._lock:
|
||||
graph = graph_pb2.GraphDef()
|
||||
graph.versions.CopyFrom(self._graph_def_versions)
|
||||
bytesize = 0
|
||||
for op_id in sorted(self._nodes_by_id):
|
||||
op = self._nodes_by_id[op_id]
|
||||
if from_version is None or op_id > from_version:
|
||||
graph.node.extend([op.node_def])
|
||||
if op.outputs and add_shapes:
|
||||
assert "_output_shapes" not in graph.node[-1].attr
|
||||
graph.node[-1].attr["_output_shapes"].list.shape.extend(
|
||||
[output.get_shape().as_proto() for output in op.outputs])
|
||||
bytesize += op.node_def.ByteSize()
|
||||
if bytesize >= (1 << 31) or bytesize < 0:
|
||||
raise ValueError("GraphDef cannot be larger than 2GB.")
|
||||
self._copy_functions_to_graph_def(graph, bytesize)
|
||||
return graph, self._version
|
||||
|
||||
def as_graph_def(self, from_version=None, add_shapes=False):
|
||||
# pylint: disable=line-too-long
|
||||
|
@ -2290,6 +2290,8 @@ class AsGraphDefTest(test_util.TensorFlowTestCase):
|
||||
t4.set_shape([43, 37])
|
||||
t5.set_shape([43, None])
|
||||
|
||||
b = constant_op.constant(1.0) # pylint: disable=unused-variable
|
||||
|
||||
gd = g.as_graph_def(add_shapes=True)
|
||||
self.assertProtoEqualsVersion("""
|
||||
node { name: "FiveFloatOutputs" op: "FiveFloatOutputs"
|
||||
@ -2306,6 +2308,26 @@ class AsGraphDefTest(test_util.TensorFlowTestCase):
|
||||
}
|
||||
}
|
||||
}
|
||||
node { name: "Const" op: "Const"
|
||||
attr {
|
||||
key: "_output_shapes"
|
||||
value {
|
||||
list {
|
||||
shape { }
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "dtype"
|
||||
value { type: DT_FLOAT }
|
||||
}
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_FLOAT
|
||||
tensor_shape { }
|
||||
float_val: 1.0 } } } }
|
||||
""", gd)
|
||||
|
||||
|
||||
|
@ -184,9 +184,34 @@ def _get_cluster():
|
||||
return cluster
|
||||
|
||||
|
||||
def _is_transpose(node):
|
||||
return node.endswith('TransposeNHWCToNCHW-LayoutOptimizer') or node.endswith(
|
||||
'TransposeNCHWToNHWC-LayoutOptimizer')
|
||||
|
||||
|
||||
def _is_permute(node):
|
||||
return node.endswith('VecPermuteNHWCToNCHW-LayoutOptimizer') or node.endswith(
|
||||
'VecPermuteNCHWToNHWC-LayoutOptimizer')
|
||||
|
||||
|
||||
class LayoutOptimizerTest(test.TestCase):
|
||||
"""Tests the Grappler layout optimizer."""
|
||||
|
||||
def _assert_trans_nchw_to_nhwc(self, name, nodes):
|
||||
self.assertIn(name + '-TransposeNCHWToNHWC-LayoutOptimizer', nodes)
|
||||
|
||||
def _assert_trans_nhwc_to_nchw(self, name, nodes):
|
||||
self.assertIn(name + '-TransposeNHWCToNCHW-LayoutOptimizer', nodes)
|
||||
|
||||
def _assert_map_nhwc_to_nchw(self, name, nodes):
|
||||
self.assertIn(name + '-DimMapNHWCToNCHW-LayoutOptimizer', nodes)
|
||||
|
||||
def _assert_vec_nchw_to_nhwc(self, name, nodes):
|
||||
self.assertIn(name + '-VecPermuteNCHWToNHWC-LayoutOptimizer', nodes)
|
||||
|
||||
def _assert_vec_nhwc_to_nchw(self, name, nodes):
|
||||
self.assertIn(name + '-VecPermuteNHWCToNCHW-LayoutOptimizer', nodes)
|
||||
|
||||
def _train(self, checkpoint_path, layout_optimizer=False, restore=False):
|
||||
ops.reset_default_graph()
|
||||
graph = ops.get_default_graph()
|
||||
@ -238,7 +263,7 @@ class LayoutOptimizerTest(test.TestCase):
|
||||
nodes = []
|
||||
num_transposes = 0
|
||||
for node in metadata.cost_graph.node:
|
||||
if node.name.startswith('LayoutOptimizerTranspose'):
|
||||
if _is_transpose(node.name):
|
||||
num_transposes += 1
|
||||
nodes.append(node.name)
|
||||
|
||||
@ -246,8 +271,8 @@ class LayoutOptimizerTest(test.TestCase):
|
||||
# LayoutOptimizer; two of them are cancelled out in the Collapse phase.
|
||||
expected_num_transposes = 2
|
||||
self.assertEqual(expected_num_transposes, num_transposes)
|
||||
self.assertIn('LayoutOptimizerTransposeNHWCToNCHW-Conv2D-0', nodes)
|
||||
self.assertIn('LayoutOptimizerTransposeNCHWToNHWC-Relu_1-0-0', nodes)
|
||||
self._assert_trans_nhwc_to_nchw('Conv2D-0', nodes)
|
||||
self._assert_trans_nchw_to_nhwc('Relu_1-0-0', nodes)
|
||||
|
||||
self.assertAllClose(output_val_ref, output_val, atol=1e-3)
|
||||
|
||||
@ -270,7 +295,7 @@ class LayoutOptimizerTest(test.TestCase):
|
||||
nodes = []
|
||||
num_transposes = 0
|
||||
for node in metadata.cost_graph.node:
|
||||
if node.name.startswith('LayoutOptimizerTranspose'):
|
||||
if _is_transpose(node.name):
|
||||
num_transposes += 1
|
||||
nodes.append(node.name)
|
||||
|
||||
@ -278,9 +303,9 @@ class LayoutOptimizerTest(test.TestCase):
|
||||
# LayoutOptimizer; two of them are cancelled out in the Collapse phase.
|
||||
expected_num_transposes = 2
|
||||
self.assertEqual(expected_num_transposes, num_transposes)
|
||||
self.assertIn('LayoutOptimizerTransposeNHWCToNCHW-Conv2D-0', nodes)
|
||||
self.assertIn('LayoutOptimizerTransposeNCHWToNHWC-split-0-0', nodes)
|
||||
self.assertIn('LayoutOptimizerDimMapNHWCToNCHW_split_0', nodes)
|
||||
self._assert_trans_nhwc_to_nchw('Conv2D-0', nodes)
|
||||
self._assert_trans_nchw_to_nhwc('split-0-0', nodes)
|
||||
self._assert_map_nhwc_to_nchw('split-0', nodes)
|
||||
self.assertAllClose(output_val_ref, output_val, atol=1e-3)
|
||||
|
||||
def testSplitVWithNonConstAxis(self):
|
||||
@ -304,7 +329,7 @@ class LayoutOptimizerTest(test.TestCase):
|
||||
nodes = []
|
||||
num_transposes = 0
|
||||
for node in metadata.cost_graph.node:
|
||||
if node.name.startswith('LayoutOptimizerTranspose'):
|
||||
if _is_transpose(node.name):
|
||||
num_transposes += 1
|
||||
nodes.append(node.name)
|
||||
|
||||
@ -312,9 +337,9 @@ class LayoutOptimizerTest(test.TestCase):
|
||||
# LayoutOptimizer; two of them are cancelled out in the Collapse phase.
|
||||
expected_num_transposes = 2
|
||||
self.assertEqual(expected_num_transposes, num_transposes)
|
||||
self.assertIn('LayoutOptimizerTransposeNHWCToNCHW-Conv2D-0', nodes)
|
||||
self.assertIn('LayoutOptimizerTransposeNCHWToNHWC-SplitV-0-0', nodes)
|
||||
self.assertIn('LayoutOptimizerDimMapNHWCToNCHW_SplitV_2', nodes)
|
||||
self._assert_trans_nhwc_to_nchw('Conv2D-0', nodes)
|
||||
self._assert_trans_nchw_to_nhwc('SplitV-0-0', nodes)
|
||||
self._assert_map_nhwc_to_nchw('SplitV-2', nodes)
|
||||
self.assertAllClose(output_val_ref, output_val, atol=1e-3)
|
||||
|
||||
def testPadWithConstPaddings(self):
|
||||
@ -338,7 +363,7 @@ class LayoutOptimizerTest(test.TestCase):
|
||||
nodes = []
|
||||
num_transposes = 0
|
||||
for node in metadata.cost_graph.node:
|
||||
if node.name.startswith('LayoutOptimizerTranspose'):
|
||||
if _is_transpose(node.name):
|
||||
num_transposes += 1
|
||||
nodes.append(node.name)
|
||||
|
||||
@ -346,9 +371,9 @@ class LayoutOptimizerTest(test.TestCase):
|
||||
# LayoutOptimizer; two of them are cancelled out in the Collapse phase.
|
||||
expected_num_transposes = 2
|
||||
self.assertEqual(expected_num_transposes, num_transposes)
|
||||
self.assertIn('LayoutOptimizerTransposeNHWCToNCHW-Conv2D-0', nodes)
|
||||
self.assertIn('LayoutOptimizerTransposeNCHWToNHWC-Pad-0-0', nodes)
|
||||
self.assertIn('LayoutOptimizer-Pad-PaddingsConst', nodes)
|
||||
self._assert_trans_nhwc_to_nchw('Conv2D-0', nodes)
|
||||
self._assert_trans_nchw_to_nhwc('Pad-0-0', nodes)
|
||||
self.assertIn('Pad-PaddingsConst-LayoutOptimizer', nodes)
|
||||
self.assertAllClose(output_val_ref, output_val, atol=1e-3)
|
||||
|
||||
def testReduceSum(self):
|
||||
@ -369,7 +394,7 @@ class LayoutOptimizerTest(test.TestCase):
|
||||
nodes = []
|
||||
num_transposes = 0
|
||||
for node in metadata.cost_graph.node:
|
||||
if node.name.startswith('LayoutOptimizerTranspose'):
|
||||
if _is_transpose(node.name):
|
||||
num_transposes += 1
|
||||
nodes.append(node.name)
|
||||
|
||||
@ -377,7 +402,7 @@ class LayoutOptimizerTest(test.TestCase):
|
||||
# LayoutOptimizer; two of them are cancelled out in the Collapse phase.
|
||||
expected_num_transposes = 1
|
||||
self.assertEqual(expected_num_transposes, num_transposes)
|
||||
self.assertIn('LayoutOptimizerTransposeNHWCToNCHW-Conv2D-0', nodes)
|
||||
self._assert_trans_nhwc_to_nchw('Conv2D-0', nodes)
|
||||
self.assertAllClose(output_val_ref, output_val, atol=1e-3)
|
||||
|
||||
def testReduceSumAlongHWC(self):
|
||||
@ -398,7 +423,7 @@ class LayoutOptimizerTest(test.TestCase):
|
||||
nodes = []
|
||||
num_transposes = 0
|
||||
for node in metadata.cost_graph.node:
|
||||
if node.name.startswith('LayoutOptimizerTranspose'):
|
||||
if _is_transpose(node.name):
|
||||
num_transposes += 1
|
||||
nodes.append(node.name)
|
||||
|
||||
@ -406,7 +431,7 @@ class LayoutOptimizerTest(test.TestCase):
|
||||
# LayoutOptimizer; two of them are cancelled out in the Collapse phase.
|
||||
expected_num_transposes = 1
|
||||
self.assertEqual(expected_num_transposes, num_transposes)
|
||||
self.assertIn('LayoutOptimizerTransposeNHWCToNCHW-Conv2D-0', nodes)
|
||||
self._assert_trans_nhwc_to_nchw('Conv2D-0', nodes)
|
||||
self.assertAllClose(output_val_ref, output_val, atol=1e-3)
|
||||
|
||||
def testReduceSumAlongNHW(self):
|
||||
@ -427,7 +452,7 @@ class LayoutOptimizerTest(test.TestCase):
|
||||
nodes = []
|
||||
num_transposes = 0
|
||||
for node in metadata.cost_graph.node:
|
||||
if node.name.startswith('LayoutOptimizerTranspose'):
|
||||
if _is_transpose(node.name):
|
||||
num_transposes += 1
|
||||
nodes.append(node.name)
|
||||
|
||||
@ -435,7 +460,7 @@ class LayoutOptimizerTest(test.TestCase):
|
||||
# LayoutOptimizer; two of them are cancelled out in the Collapse phase.
|
||||
expected_num_transposes = 1
|
||||
self.assertEqual(expected_num_transposes, num_transposes)
|
||||
self.assertIn('LayoutOptimizerTransposeNHWCToNCHW-Conv2D-0', nodes)
|
||||
self._assert_trans_nhwc_to_nchw('Conv2D-0', nodes)
|
||||
self.assertAllClose(output_val_ref, output_val, atol=1e-3)
|
||||
|
||||
def testReduceSumAlongC(self):
|
||||
@ -456,7 +481,7 @@ class LayoutOptimizerTest(test.TestCase):
|
||||
nodes = []
|
||||
num_transposes = 0
|
||||
for node in metadata.cost_graph.node:
|
||||
if node.name.startswith('LayoutOptimizerTranspose'):
|
||||
if _is_transpose(node.name):
|
||||
num_transposes += 1
|
||||
nodes.append(node.name)
|
||||
|
||||
@ -464,7 +489,7 @@ class LayoutOptimizerTest(test.TestCase):
|
||||
# LayoutOptimizer; two of them are cancelled out in the Collapse phase.
|
||||
expected_num_transposes = 1
|
||||
self.assertEqual(expected_num_transposes, num_transposes)
|
||||
self.assertIn('LayoutOptimizerTransposeNHWCToNCHW-Conv2D-0', nodes)
|
||||
self._assert_trans_nhwc_to_nchw('Conv2D-0', nodes)
|
||||
self.assertAllClose(output_val_ref, output_val, atol=1e-3)
|
||||
|
||||
def testConcatWithControlDependency(self):
|
||||
@ -489,7 +514,7 @@ class LayoutOptimizerTest(test.TestCase):
|
||||
nodes = []
|
||||
num_transposes = 0
|
||||
for node in metadata.cost_graph.node:
|
||||
if node.name.startswith('LayoutOptimizerTranspose'):
|
||||
if _is_transpose(node.name):
|
||||
num_transposes += 1
|
||||
nodes.append(node.name)
|
||||
|
||||
@ -497,9 +522,9 @@ class LayoutOptimizerTest(test.TestCase):
|
||||
# LayoutOptimizer; two of them are cancelled out in the Collapse phase.
|
||||
expected_num_transposes = 2
|
||||
self.assertEqual(expected_num_transposes, num_transposes)
|
||||
self.assertIn('LayoutOptimizerTransposeNHWCToNCHW-Conv2D-0', nodes)
|
||||
self.assertIn('LayoutOptimizerTransposeNCHWToNHWC-concat-0-0', nodes)
|
||||
self.assertIn('LayoutOptimizer-concat-Const_2', nodes)
|
||||
self._assert_trans_nhwc_to_nchw('Conv2D-0', nodes)
|
||||
self._assert_trans_nchw_to_nhwc('concat-0-0', nodes)
|
||||
self.assertIn('concat-Const_2-LayoutOptimizer', nodes)
|
||||
self.assertAllClose(output_val_ref, output_val, atol=1e-3)
|
||||
|
||||
def testFill(self):
|
||||
@ -527,9 +552,9 @@ class LayoutOptimizerTest(test.TestCase):
|
||||
num_transposes = 0
|
||||
num_vec_permute = 0
|
||||
for node in metadata.cost_graph.node:
|
||||
if node.name.startswith('LayoutOptimizerTranspose'):
|
||||
if _is_transpose(node.name):
|
||||
num_transposes += 1
|
||||
if node.name.startswith('LayoutOptimizerVecPermute'):
|
||||
if _is_permute(node.name):
|
||||
num_vec_permute += 1
|
||||
nodes.append(node.name)
|
||||
|
||||
@ -541,8 +566,8 @@ class LayoutOptimizerTest(test.TestCase):
|
||||
# LayoutOptimizer; they cancelled out each other in the Collapse phase.
|
||||
expected_vec_permute = 0
|
||||
self.assertEqual(expected_vec_permute, num_vec_permute)
|
||||
self.assertIn('LayoutOptimizerTransposeNHWCToNCHW-Conv2D-0', nodes)
|
||||
self.assertIn('LayoutOptimizerTransposeNCHWToNHWC-Fill-0-0', nodes)
|
||||
self._assert_trans_nhwc_to_nchw('Conv2D-0', nodes)
|
||||
self._assert_trans_nchw_to_nhwc('Fill-0-0', nodes)
|
||||
self.assertAllClose(output_val_ref, output_val, atol=1e-3)
|
||||
|
||||
def testTile(self):
|
||||
@ -568,7 +593,7 @@ class LayoutOptimizerTest(test.TestCase):
|
||||
nodes = []
|
||||
num_transposes = 0
|
||||
for node in metadata.cost_graph.node:
|
||||
if node.name.startswith('LayoutOptimizerTranspose'):
|
||||
if _is_transpose(node.name):
|
||||
num_transposes += 1
|
||||
nodes.append(node.name)
|
||||
|
||||
@ -576,9 +601,9 @@ class LayoutOptimizerTest(test.TestCase):
|
||||
# LayoutOptimizer; two of them are cancelled out in the Collapse phase.
|
||||
expected_num_transposes = 2
|
||||
self.assertEqual(expected_num_transposes, num_transposes)
|
||||
self.assertIn('LayoutOptimizerTransposeNHWCToNCHW-Conv2D-0', nodes)
|
||||
self.assertIn('LayoutOptimizerTransposeNCHWToNHWC-Tile-0-0', nodes)
|
||||
self.assertIn('LayoutOptimizerVecPermuteNHWCToNCHW_Tile_1', nodes)
|
||||
self._assert_trans_nhwc_to_nchw('Conv2D-0', nodes)
|
||||
self._assert_trans_nchw_to_nhwc('Tile-0-0', nodes)
|
||||
self._assert_vec_nhwc_to_nchw('Tile-1', nodes)
|
||||
self.assertAllClose(output_val_ref, output_val, atol=1e-3)
|
||||
|
||||
def testReverseWithConstDims(self):
|
||||
@ -600,7 +625,7 @@ class LayoutOptimizerTest(test.TestCase):
|
||||
nodes = []
|
||||
num_transposes = 0
|
||||
for node in metadata.cost_graph.node:
|
||||
if node.name.startswith('LayoutOptimizerTranspose'):
|
||||
if _is_transpose(node.name):
|
||||
num_transposes += 1
|
||||
nodes.append(node.name)
|
||||
|
||||
@ -608,9 +633,9 @@ class LayoutOptimizerTest(test.TestCase):
|
||||
# LayoutOptimizer; two of them are cancelled out in the Collapse phase.
|
||||
expected_num_transposes = 2
|
||||
self.assertEqual(expected_num_transposes, num_transposes)
|
||||
self.assertIn('LayoutOptimizerTransposeNHWCToNCHW-Conv2D-0', nodes)
|
||||
self.assertIn('LayoutOptimizerTransposeNCHWToNHWC-ReverseV2-0-0', nodes)
|
||||
self.assertIn('LayoutOptimizer-ReverseV2-DimsConst', nodes)
|
||||
self._assert_trans_nhwc_to_nchw('Conv2D-0', nodes)
|
||||
self._assert_trans_nchw_to_nhwc('ReverseV2-0-0', nodes)
|
||||
self.assertIn('ReverseV2-DimsConst-LayoutOptimizer', nodes)
|
||||
self.assertAllClose(output_val_ref, output_val, atol=1e-3)
|
||||
|
||||
def testReverseWithNonConstDims(self):
|
||||
@ -636,7 +661,7 @@ class LayoutOptimizerTest(test.TestCase):
|
||||
nodes = []
|
||||
num_transposes = 0
|
||||
for node in metadata.cost_graph.node:
|
||||
if node.name.startswith('LayoutOptimizerTranspose'):
|
||||
if _is_transpose(node.name):
|
||||
num_transposes += 1
|
||||
nodes.append(node.name)
|
||||
|
||||
@ -644,9 +669,9 @@ class LayoutOptimizerTest(test.TestCase):
|
||||
# LayoutOptimizer; two of them are cancelled out in the Collapse phase.
|
||||
expected_num_transposes = 2
|
||||
self.assertEqual(expected_num_transposes, num_transposes)
|
||||
self.assertIn('LayoutOptimizerTransposeNHWCToNCHW-Conv2D-0', nodes)
|
||||
self.assertIn('LayoutOptimizerTransposeNCHWToNHWC-ReverseV2-0-0', nodes)
|
||||
self.assertIn('LayoutOptimizerDimMapNHWCToNCHW_ReverseV2_1', nodes)
|
||||
self._assert_trans_nhwc_to_nchw('Conv2D-0', nodes)
|
||||
self._assert_trans_nchw_to_nhwc('ReverseV2-0-0', nodes)
|
||||
self._assert_map_nhwc_to_nchw('ReverseV2-1', nodes)
|
||||
self.assertAllClose(output_val_ref, output_val, atol=1e-3)
|
||||
|
||||
def testSelectOp(self):
|
||||
@ -670,14 +695,14 @@ class LayoutOptimizerTest(test.TestCase):
|
||||
nodes = []
|
||||
num_transposes = 0
|
||||
for node in metadata.cost_graph.node:
|
||||
if node.name.startswith('LayoutOptimizerTranspose'):
|
||||
if _is_transpose(node.name):
|
||||
num_transposes += 1
|
||||
nodes.append(node.name)
|
||||
|
||||
expected_num_transposes = 2
|
||||
self.assertEqual(expected_num_transposes, num_transposes)
|
||||
self.assertIn('LayoutOptimizerTransposeNHWCToNCHW-Conv2D-0', nodes)
|
||||
self.assertIn('LayoutOptimizerTransposeNCHWToNHWC-Select-0-0', nodes)
|
||||
self._assert_trans_nhwc_to_nchw('Conv2D-0', nodes)
|
||||
self._assert_trans_nchw_to_nhwc('Select-0-0', nodes)
|
||||
self.assertAllClose(output_val_ref, output_val, atol=1e-3)
|
||||
|
||||
def testSelectOpScalarCondition(self):
|
||||
@ -700,14 +725,14 @@ class LayoutOptimizerTest(test.TestCase):
|
||||
nodes = []
|
||||
num_transposes = 0
|
||||
for node in metadata.cost_graph.node:
|
||||
if node.name.startswith('LayoutOptimizerTranspose'):
|
||||
if _is_transpose(node.name):
|
||||
num_transposes += 1
|
||||
nodes.append(node.name)
|
||||
|
||||
expected_num_transposes = 2
|
||||
self.assertEqual(expected_num_transposes, num_transposes)
|
||||
self.assertIn('LayoutOptimizerTransposeNHWCToNCHW-Conv2D-0', nodes)
|
||||
self.assertIn('LayoutOptimizerTransposeNCHWToNHWC-Select-0-0', nodes)
|
||||
self._assert_trans_nhwc_to_nchw('Conv2D-0', nodes)
|
||||
self._assert_trans_nchw_to_nhwc('Select-0-0', nodes)
|
||||
self.assertAllClose(output_val_ref, output_val, atol=1e-3)
|
||||
|
||||
def testPadWithNonConstPaddings(self):
|
||||
@ -733,7 +758,7 @@ class LayoutOptimizerTest(test.TestCase):
|
||||
nodes = []
|
||||
num_transposes = 0
|
||||
for node in metadata.cost_graph.node:
|
||||
if node.name.startswith('LayoutOptimizerTranspose'):
|
||||
if _is_transpose(node.name):
|
||||
num_transposes += 1
|
||||
nodes.append(node.name)
|
||||
|
||||
@ -741,9 +766,9 @@ class LayoutOptimizerTest(test.TestCase):
|
||||
# LayoutOptimizer; two of them are cancelled out in the Collapse phase.
|
||||
expected_num_transposes = 2
|
||||
self.assertEqual(expected_num_transposes, num_transposes)
|
||||
self.assertIn('LayoutOptimizerTransposeNHWCToNCHW-Conv2D-0', nodes)
|
||||
self.assertIn('LayoutOptimizerTransposeNCHWToNHWC-Pad-0-0', nodes)
|
||||
self.assertIn('LayoutOptimizerVecPermuteNHWCToNCHW_Pad_1', nodes)
|
||||
self._assert_trans_nhwc_to_nchw('Conv2D-0', nodes)
|
||||
self._assert_trans_nchw_to_nhwc('Pad-0-0', nodes)
|
||||
self._assert_vec_nhwc_to_nchw('Pad-1', nodes)
|
||||
self.assertAllClose(output_val_ref, output_val, atol=1e-3)
|
||||
|
||||
def testMaxPoolV2(self):
|
||||
@ -770,16 +795,16 @@ class LayoutOptimizerTest(test.TestCase):
|
||||
nodes = []
|
||||
num_transposes = 0
|
||||
for node in metadata.cost_graph.node:
|
||||
if node.name.startswith('LayoutOptimizerTranspose'):
|
||||
if _is_transpose(node.name):
|
||||
num_transposes += 1
|
||||
nodes.append(node.name)
|
||||
|
||||
expected_num_transposes = 2
|
||||
self.assertEqual(expected_num_transposes, num_transposes)
|
||||
self.assertIn('LayoutOptimizerTransposeNHWCToNCHW-Conv2D-0', nodes)
|
||||
self.assertIn('LayoutOptimizerTransposeNCHWToNHWC-MaxPoolV2-0-0', nodes)
|
||||
self.assertIn('LayoutOptimizerVecPermuteNHWCToNCHW_MaxPoolV2_2', nodes)
|
||||
self.assertIn('LayoutOptimizer-MaxPoolV2-Const_2', nodes)
|
||||
self._assert_trans_nhwc_to_nchw('Conv2D-0', nodes)
|
||||
self._assert_trans_nchw_to_nhwc('MaxPoolV2-0-0', nodes)
|
||||
self._assert_vec_nhwc_to_nchw('MaxPoolV2-2', nodes)
|
||||
self.assertIn('MaxPoolV2-Const_2-LayoutOptimizer', nodes)
|
||||
self.assertAllClose(output_val_ref, output_val, atol=1e-3)
|
||||
|
||||
def testMaxPoolGradV2(self):
|
||||
@ -807,18 +832,16 @@ class LayoutOptimizerTest(test.TestCase):
|
||||
nodes = []
|
||||
num_transposes = 0
|
||||
for node in metadata.cost_graph.node:
|
||||
if node.name.startswith('LayoutOptimizerTranspose'):
|
||||
if _is_transpose(node.name):
|
||||
num_transposes += 1
|
||||
nodes.append(node.name)
|
||||
|
||||
expected_num_transposes = 2
|
||||
self.assertEqual(expected_num_transposes, num_transposes)
|
||||
self.assertIn('LayoutOptimizerTransposeNHWCToNCHW-Conv2D-0', nodes)
|
||||
self.assertIn('LayoutOptimizerTransposeNCHWToNHWC-MaxPoolGradV2-0-0',
|
||||
nodes)
|
||||
self.assertIn('LayoutOptimizerVecPermuteNHWCToNCHW_MaxPoolGradV2_4',
|
||||
nodes)
|
||||
self.assertIn('LayoutOptimizer-MaxPoolGradV2-Const_2', nodes)
|
||||
self._assert_trans_nhwc_to_nchw('Conv2D-0', nodes)
|
||||
self._assert_trans_nchw_to_nhwc('MaxPoolGradV2-0-0', nodes)
|
||||
self._assert_vec_nhwc_to_nchw('MaxPoolGradV2-4', nodes)
|
||||
self.assertIn('MaxPoolGradV2-Const_2-LayoutOptimizer', nodes)
|
||||
self.assertAllClose(output_val_ref, output_val, atol=1e-3)
|
||||
|
||||
def testSliceWithNonConstAxis(self):
|
||||
@ -844,7 +867,7 @@ class LayoutOptimizerTest(test.TestCase):
|
||||
nodes = []
|
||||
num_transposes = 0
|
||||
for node in metadata.cost_graph.node:
|
||||
if node.name.startswith('LayoutOptimizerTranspose'):
|
||||
if _is_transpose(node.name):
|
||||
num_transposes += 1
|
||||
nodes.append(node.name)
|
||||
|
||||
@ -852,9 +875,9 @@ class LayoutOptimizerTest(test.TestCase):
|
||||
# LayoutOptimizer; two of them are cancelled out in the Collapse phase.
|
||||
expected_num_transposes = 2
|
||||
self.assertEqual(expected_num_transposes, num_transposes)
|
||||
self.assertIn('LayoutOptimizerTransposeNHWCToNCHW-Conv2D-0', nodes)
|
||||
self.assertIn('LayoutOptimizerTransposeNCHWToNHWC-Slice-0-0', nodes)
|
||||
self.assertIn('LayoutOptimizerVecPermuteNHWCToNCHW_Slice_2', nodes)
|
||||
self._assert_trans_nhwc_to_nchw('Conv2D-0', nodes)
|
||||
self._assert_trans_nchw_to_nhwc('Slice-0-0', nodes)
|
||||
self._assert_vec_nhwc_to_nchw('Slice-2', nodes)
|
||||
self.assertAllClose(output_val_ref, output_val, atol=1e-3)
|
||||
|
||||
def testStridedSliceWithNonConstAxis(self):
|
||||
@ -880,7 +903,7 @@ class LayoutOptimizerTest(test.TestCase):
|
||||
nodes = []
|
||||
num_transposes = 0
|
||||
for node in metadata.cost_graph.node:
|
||||
if node.name.startswith('LayoutOptimizerTranspose'):
|
||||
if _is_transpose(node.name):
|
||||
num_transposes += 1
|
||||
nodes.append(node.name)
|
||||
|
||||
@ -888,12 +911,11 @@ class LayoutOptimizerTest(test.TestCase):
|
||||
# LayoutOptimizer; two of them are cancelled out in the Collapse phase.
|
||||
expected_num_transposes = 2
|
||||
self.assertEqual(expected_num_transposes, num_transposes)
|
||||
self.assertIn('LayoutOptimizerTransposeNHWCToNCHW-Conv2D-0', nodes)
|
||||
self.assertIn('LayoutOptimizerTransposeNCHWToNHWC-StridedSlice-0-0',
|
||||
nodes)
|
||||
self.assertIn('LayoutOptimizerVecPermuteNHWCToNCHW_StridedSlice_2', nodes)
|
||||
self.assertIn('LayoutOptimizer-StridedSlice-StridedSlice/begin', nodes)
|
||||
self.assertIn('LayoutOptimizer-StridedSlice-StridedSlice/strides', nodes)
|
||||
self._assert_trans_nhwc_to_nchw('Conv2D-0', nodes)
|
||||
self._assert_trans_nchw_to_nhwc('StridedSlice-0-0', nodes)
|
||||
self._assert_vec_nhwc_to_nchw('StridedSlice-2', nodes)
|
||||
self.assertIn('StridedSlice-StridedSlice/begin-LayoutOptimizer', nodes)
|
||||
self.assertIn('StridedSlice-StridedSlice/strides-LayoutOptimizer', nodes)
|
||||
self.assertAllClose(output_val_ref, output_val, atol=1e-3)
|
||||
|
||||
def testStridedSliceWithMask(self):
|
||||
@ -915,7 +937,7 @@ class LayoutOptimizerTest(test.TestCase):
|
||||
nodes = []
|
||||
num_transposes = 0
|
||||
for node in metadata.cost_graph.node:
|
||||
if node.name.startswith('LayoutOptimizerTranspose'):
|
||||
if _is_transpose(node.name):
|
||||
num_transposes += 1
|
||||
nodes.append(node.name)
|
||||
|
||||
@ -923,13 +945,12 @@ class LayoutOptimizerTest(test.TestCase):
|
||||
# LayoutOptimizer; two of them are cancelled out in the Collapse phase.
|
||||
expected_num_transposes = 2
|
||||
self.assertEqual(expected_num_transposes, num_transposes)
|
||||
self.assertIn('LayoutOptimizerTransposeNHWCToNCHW-Conv2D-0', nodes)
|
||||
self.assertIn('LayoutOptimizerTransposeNCHWToNHWC-strided_slice-0-0',
|
||||
self._assert_trans_nhwc_to_nchw('Conv2D-0', nodes)
|
||||
self._assert_trans_nchw_to_nhwc('strided_slice-0-0', nodes)
|
||||
self.assertIn('strided_slice-strided_slice/stack-LayoutOptimizer', nodes)
|
||||
self.assertIn('strided_slice-strided_slice/stack_1-LayoutOptimizer',
|
||||
nodes)
|
||||
self.assertIn('LayoutOptimizer-strided_slice-strided_slice/stack', nodes)
|
||||
self.assertIn('LayoutOptimizer-strided_slice-strided_slice/stack_1',
|
||||
nodes)
|
||||
self.assertIn('LayoutOptimizer-strided_slice-strided_slice/stack_2',
|
||||
self.assertIn('strided_slice-strided_slice/stack_2-LayoutOptimizer',
|
||||
nodes)
|
||||
self.assertAllClose(output_val_ref, output_val, atol=1e-3)
|
||||
|
||||
@ -960,7 +981,7 @@ class LayoutOptimizerTest(test.TestCase):
|
||||
nodes = []
|
||||
num_transposes = 0
|
||||
for node in metadata.cost_graph.node:
|
||||
if node.name.startswith('LayoutOptimizerTranspose'):
|
||||
if _is_transpose(node.name):
|
||||
num_transposes += 1
|
||||
nodes.append(node.name)
|
||||
|
||||
@ -968,14 +989,12 @@ class LayoutOptimizerTest(test.TestCase):
|
||||
# LayoutOptimizer; two of them are cancelled out in the Collapse phase.
|
||||
expected_num_transposes = 2
|
||||
self.assertEqual(expected_num_transposes, num_transposes)
|
||||
self.assertIn('LayoutOptimizerTransposeNHWCToNCHW-Conv2D-0', nodes)
|
||||
self.assertIn('LayoutOptimizerTransposeNCHWToNHWC-StridedSliceGrad-0-0',
|
||||
self._assert_trans_nhwc_to_nchw('Conv2D-0', nodes)
|
||||
self._assert_trans_nchw_to_nhwc('StridedSliceGrad-0-0', nodes)
|
||||
self._assert_vec_nhwc_to_nchw('StridedSliceGrad-2', nodes)
|
||||
self.assertIn('StridedSlice-StridedSliceGrad/begin-LayoutOptimizer',
|
||||
nodes)
|
||||
self.assertIn('LayoutOptimizerVecPermuteNHWCToNCHW_StridedSliceGrad_2',
|
||||
nodes)
|
||||
self.assertIn('LayoutOptimizer-StridedSlice-StridedSliceGrad/begin',
|
||||
nodes)
|
||||
self.assertIn('LayoutOptimizer-StridedSlice-StridedSliceGrad/strides',
|
||||
self.assertIn('StridedSlice-StridedSliceGrad/strides-LayoutOptimizer',
|
||||
nodes)
|
||||
self.assertAllClose(output_val_ref, output_val, atol=1e-3)
|
||||
|
||||
@ -1000,14 +1019,14 @@ class LayoutOptimizerTest(test.TestCase):
|
||||
nodes = []
|
||||
num_transposes = 0
|
||||
for node in metadata.cost_graph.node:
|
||||
if node.name.startswith('LayoutOptimizerTranspose'):
|
||||
if _is_transpose(node.name):
|
||||
num_transposes += 1
|
||||
nodes.append(node.name)
|
||||
|
||||
expected_num_transposes = 1
|
||||
self.assertEqual(expected_num_transposes, num_transposes)
|
||||
self.assertIn('LayoutOptimizerTransposeNHWCToNCHW-Conv2D-0', nodes)
|
||||
self.assertIn('LayoutOptimizerVecPermuteNCHWToNHWC-ShapeN-0-0', nodes)
|
||||
self._assert_trans_nhwc_to_nchw('Conv2D-0', nodes)
|
||||
self._assert_vec_nchw_to_nhwc('ShapeN-0-0', nodes)
|
||||
self.assertAllEqual(output_val_ref, output_val)
|
||||
|
||||
def testLoop(self):
|
||||
@ -1024,7 +1043,7 @@ class LayoutOptimizerTest(test.TestCase):
|
||||
nodes = []
|
||||
num_transposes = 0
|
||||
for node in metadata.cost_graph.node:
|
||||
if node.name.startswith('LayoutOptimizerTranspose'):
|
||||
if _is_transpose(node.name):
|
||||
num_transposes += 1
|
||||
nodes.append(node.name)
|
||||
|
||||
@ -1033,10 +1052,8 @@ class LayoutOptimizerTest(test.TestCase):
|
||||
expected_num_transposes = 2
|
||||
self.assertEqual(expected_num_transposes, num_transposes)
|
||||
self.assertEqual(expected_num_transposes, num_transposes)
|
||||
self.assertIn('LayoutOptimizerTransposeNHWCToNCHW-map/while/Conv2D-0',
|
||||
nodes)
|
||||
self.assertIn(
|
||||
'LayoutOptimizerTransposeNCHWToNHWC-map/while/MaxPool_1-0-2', nodes)
|
||||
self._assert_trans_nhwc_to_nchw('map/while/Conv2D-0', nodes)
|
||||
self._assert_trans_nchw_to_nhwc('map/while/MaxPool_1-0-2', nodes)
|
||||
self.assertAllClose(output_val_ref, output_val, atol=1e-3)
|
||||
|
||||
def testLoopWithBranch(self):
|
||||
@ -1053,16 +1070,14 @@ class LayoutOptimizerTest(test.TestCase):
|
||||
nodes = []
|
||||
num_transposes = 0
|
||||
for node in metadata.cost_graph.node:
|
||||
if node.name.startswith('LayoutOptimizerTranspose'):
|
||||
if _is_transpose(node.name):
|
||||
num_transposes += 1
|
||||
nodes.append(node.name)
|
||||
|
||||
expected_num_transposes = 2
|
||||
self.assertEqual(expected_num_transposes, num_transposes)
|
||||
self.assertIn('LayoutOptimizerTransposeNHWCToNCHW-map/while/Conv2D-0',
|
||||
nodes)
|
||||
self.assertIn('LayoutOptimizerTransposeNCHWToNHWC-map/while/Add-0-2',
|
||||
nodes)
|
||||
self._assert_trans_nhwc_to_nchw('map/while/Conv2D-0', nodes)
|
||||
self._assert_trans_nchw_to_nhwc('map/while/Add-0-2', nodes)
|
||||
self.assertAllClose(output_val_ref, output_val, atol=1e-3)
|
||||
|
||||
def testLoopWithVecAnd4D(self):
|
||||
@ -1079,16 +1094,14 @@ class LayoutOptimizerTest(test.TestCase):
|
||||
nodes = []
|
||||
num_transposes = 0
|
||||
for node in metadata.cost_graph.node:
|
||||
if node.name.startswith('LayoutOptimizerTranspose'):
|
||||
if _is_transpose(node.name):
|
||||
num_transposes += 1
|
||||
nodes.append(node.name)
|
||||
|
||||
expected_num_transposes = 2
|
||||
self.assertEqual(expected_num_transposes, num_transposes)
|
||||
self.assertIn('LayoutOptimizerTransposeNHWCToNCHW-map/while/Conv2D-0',
|
||||
nodes)
|
||||
self.assertIn('LayoutOptimizerTransposeNCHWToNHWC-map/while/Add-0-2',
|
||||
nodes)
|
||||
self._assert_trans_nhwc_to_nchw('map/while/Conv2D-0', nodes)
|
||||
self._assert_trans_nchw_to_nhwc('map/while/Add-0-2', nodes)
|
||||
self.assertAllClose(output_val_ref, output_val, atol=1e-3)
|
||||
|
||||
def testBinaryOpSecondPort(self):
|
||||
@ -1105,15 +1118,14 @@ class LayoutOptimizerTest(test.TestCase):
|
||||
nodes = []
|
||||
num_transposes = 0
|
||||
for node in metadata.cost_graph.node:
|
||||
if node.name.startswith('LayoutOptimizerTranspose'):
|
||||
if _is_transpose(node.name):
|
||||
num_transposes += 1
|
||||
nodes.append(node.name)
|
||||
|
||||
expected_num_transposes = 2
|
||||
self.assertEqual(expected_num_transposes, num_transposes)
|
||||
self.assertIn('LayoutOptimizerTransposeNHWCToNCHW-FusedBatchNorm-0',
|
||||
nodes)
|
||||
self.assertIn('LayoutOptimizerTransposeNCHWToNHWC-Add-0-0', nodes)
|
||||
self._assert_trans_nhwc_to_nchw('FusedBatchNorm-0', nodes)
|
||||
self._assert_trans_nchw_to_nhwc('Add-0-0', nodes)
|
||||
self.assertAllClose(output_val_ref, output_val, atol=1e-3)
|
||||
|
||||
def testGradient(self):
|
||||
|
@ -74,12 +74,12 @@ def load_data(path='imdb.npz',
|
||||
f.close()
|
||||
|
||||
np.random.seed(seed)
|
||||
indices = np.arrange(len(x_train))
|
||||
indices = np.arange(len(x_train))
|
||||
np.random.shuffle(indices)
|
||||
x_train = x_train[indices]
|
||||
labels_train = labels_train[indices]
|
||||
|
||||
indices = np.arrange(len(x_test))
|
||||
indices = np.arange(len(x_test))
|
||||
np.random.shuffle(indices)
|
||||
x_test = x_test[indices]
|
||||
labels_test = labels_test[indices]
|
||||
|
@ -73,7 +73,7 @@ def load_data(path='reuters.npz',
|
||||
npzfile.close()
|
||||
|
||||
np.random.seed(seed)
|
||||
indices = np.arrange(len(xs))
|
||||
indices = np.arange(len(xs))
|
||||
np.random.shuffle(indices)
|
||||
xs = xs[indices]
|
||||
labels = labels[indices]
|
||||
|
@ -323,7 +323,7 @@ class TemplateTest(test.TestCase):
|
||||
v1 = tmpl1()
|
||||
v2 = tmpl1()
|
||||
v3 = tmpl2()
|
||||
self.assertTrue(v1, v2)
|
||||
self.assertEqual(v1, v2)
|
||||
self.assertNotEqual(v1, v3)
|
||||
self.assertEqual("s1/nested_1/dummy:0", v1.name)
|
||||
self.assertEqual("s1_1/nested_1/dummy:0", v3.name)
|
||||
|
@ -223,7 +223,7 @@ class EinsumTest(test.TestCase):
|
||||
|
||||
dim_mismatch_cases = [('ijk,jkl->il', [(2, 3, 4), (3, 5, 6)])]
|
||||
|
||||
def test_simple(self):
|
||||
def disabled_test_simple(self):
|
||||
for case in self.simple_cases:
|
||||
self.run_test(case)
|
||||
|
||||
|
@ -366,8 +366,8 @@ port::StatusOr<DriverVersion> Diagnostician::FindKernelDriverVersion() {
|
||||
contents[kContentsSize - 1] = '\0';
|
||||
|
||||
if (retcode != 0) {
|
||||
LOG(INFO) << "driver version file contents: \"\"\"" << contents.begin()
|
||||
<< "\"\"\"";
|
||||
VLOG(1) << "driver version file contents: \"\"\"" << contents.begin()
|
||||
<< "\"\"\"";
|
||||
fclose(driver_version_file);
|
||||
return FindKernelModuleVersion(contents.begin());
|
||||
}
|
||||
|
@ -169,6 +169,7 @@ sh_binary(
|
||||
"//tensorflow/contrib/ndlstm:ndlstm",
|
||||
"//tensorflow/contrib/nn:nn_py",
|
||||
"//tensorflow/contrib/predictor:predictor_pip",
|
||||
"//tensorflow/contrib/py2tf/convert:convert",
|
||||
"//tensorflow/contrib/py2tf/pyct:pyct",
|
||||
"//tensorflow/contrib/py2tf/pyct/static_analysis:static_analysis",
|
||||
"//tensorflow/contrib/receptive_field:receptive_field_pip",
|
||||
|
Loading…
Reference in New Issue
Block a user