Merge branch 'master' into google_upstream_argmax_op

This commit is contained in:
Wen-Heng (Jack) Chung 2019-03-15 14:16:57 -07:00 committed by GitHub
commit 23d21d078c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
163 changed files with 4975 additions and 4020 deletions

View File

@ -50,6 +50,7 @@ _DEFAULT_PROMPT_ASK_ATTEMPTS = 10
_TF_BAZELRC_FILENAME = '.tf_configure.bazelrc'
_TF_WORKSPACE_ROOT = ''
_TF_BAZELRC = ''
_TF_CURRENT_BAZEL_VERSION = None
NCCL_LIB_PATHS = [
'lib64/', 'lib/powerpc64le-linux-gnu/', 'lib/x86_64-linux-gnu/', ''
@ -337,8 +338,8 @@ def get_var(environ_cp,
'Environment variable %s must be set as a boolean indicator.\n'
'The following are accepted as TRUE : %s.\n'
'The following are accepted as FALSE: %s.\n'
'Current value is %s.' % (var_name, ', '.join(true_strings),
', '.join(false_strings), var))
'Current value is %s.' %
(var_name, ', '.join(true_strings), ', '.join(false_strings), var))
while var is None:
user_input_origin = get_input(question)
@ -771,11 +772,12 @@ def check_ndk_level(android_ndk_home_path):
else:
raise Exception('Unable to parse NDK revision.')
if int(ndk_api_level) not in _SUPPORTED_ANDROID_NDK_VERSIONS:
print('WARNING: The API level of the NDK in %s is %s, which is not '
'supported by Bazel (officially supported versions: %s). Please use '
'another version. Compiling Android targets may result in confusing '
'errors.\n' % (android_ndk_home_path, ndk_api_level,
_SUPPORTED_ANDROID_NDK_VERSIONS))
print(
'WARNING: The API level of the NDK in %s is %s, which is not '
'supported by Bazel (officially supported versions: %s). Please use '
'another version. Compiling Android targets may result in confusing '
'errors.\n' %
(android_ndk_home_path, ndk_api_level, _SUPPORTED_ANDROID_NDK_VERSIONS))
return ndk_api_level
@ -1230,8 +1232,8 @@ def set_tf_nccl_install_path(environ_cp):
# Reset and Retry
print(
'Invalid path to NCCL %s toolkit, %s or %s not found. Please use the '
'O/S agnostic package of NCCL 2' % (tf_nccl_version, nccl_lib_path,
nccl_hdr_path))
'O/S agnostic package of NCCL 2' %
(tf_nccl_version, nccl_lib_path, nccl_hdr_path))
environ_cp['TF_NCCL_VERSION'] = ''
else:
@ -1498,6 +1500,7 @@ def set_other_mpi_vars(environ_cp):
'Cannot find the MPI library file in %s/lib or %s/lib64 or %s/lib32' %
(mpi_home, mpi_home, mpi_home))
def system_specific_test_config(env):
"""Add default test flags required for TF tests to bazelrc."""
write_to_bazelrc('test --flaky_test_attempts=3')
@ -1593,11 +1596,15 @@ def configure_apple_bazel_rules():
existing_filepath = os.path.join(_TF_WORKSPACE_ROOT, filepath + '.apple')
renamed_filepath = os.path.join(_TF_WORKSPACE_ROOT, filepath)
os.rename(existing_filepath, renamed_filepath)
if _TF_CURRENT_BAZEL_VERSION is None or _TF_CURRENT_BAZEL_VERSION < 23000:
print(
'Building Bazel rules on Apple platforms requires Bazel 0.23 or later.')
def main():
global _TF_WORKSPACE_ROOT
global _TF_BAZELRC
global _TF_CURRENT_BAZEL_VERSION
parser = argparse.ArgumentParser()
parser.add_argument(
@ -1614,7 +1621,8 @@ def main():
# environment variables.
environ_cp = dict(os.environ)
check_bazel_version('0.19.0', '0.23.2')
current_bazel_version = check_bazel_version('0.19.0', '0.23.2')
_TF_CURRENT_BAZEL_VERSION = convert_version_to_int(current_bazel_version)
reset_tf_configure_bazelrc()

View File

@ -50,6 +50,7 @@ WARNING: The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
* https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
* https://github.com/tensorflow/addons
* https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.
"""
contrib = LazyLoader('contrib', globals(), 'tensorflow.contrib',

View File

@ -88,15 +88,19 @@ Status ScaleAndTranslateGradHelper(const Scope& scope, const Operation& op,
string kernel_type;
TF_RETURN_IF_ERROR(
GetNodeAttr(op.node()->attrs(), "kernel_type", &kernel_type));
bool antialias;
TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "antialias", &antialias));
grad_outputs->push_back(internal::ScaleAndTranslateGrad(
scope, grad_inputs[0], op.input(0), op.input(2), op.input(3),
internal::ScaleAndTranslateGrad::KernelType(kernel_type)));
internal::ScaleAndTranslateGrad::KernelType(kernel_type)
.Antialias(antialias)));
grad_outputs->push_back(NoGradient());
grad_outputs->push_back(NoGradient());
grad_outputs->push_back(NoGradient());
return scope.status();
}
REGISTER_GRADIENT_OP("ScaleAndTranslate", ScaleAndTranslateGradHelper);
Status CropAndResizeGradHelper(const Scope& scope, const Operation& op,

View File

@ -196,29 +196,106 @@ class ScaleAndTranslateGradTest : public ::testing::Test {
}
template <typename T>
void MakeOp(const Tensor& x_data, const Input& y_shape, Output* x,
Output* y) {
void MakeOp(const Tensor& x_data, const Input& y_shape, Input scale,
Input translation, const string& kernel_type, bool antialias,
Output* x, Output* y) {
*x = Const<T>(scope_, x_data);
*y = ScaleAndTranslate(scope_, *x, y_shape, {1.8f, 2.1f}, {0.5f, 0.7f});
*y = ScaleAndTranslate(scope_, *x, y_shape, scale, translation,
ScaleAndTranslate::KernelType(kernel_type)
.Antialias(antialias)
.Antialias(antialias));
TF_ASSERT_OK(scope_.status());
}
template <typename X_T, typename Y_T, typename JAC_T>
void TestResize() {
TensorShape x_shape({1, 2, 3, 1});
void TestScaleAndTranslate(const TensorShape x_shape, const int out_height,
const int out_width, Input scale,
Input translation, const string& kernel_type,
bool antialias) {
Tensor x_data = MakeData<X_T>(x_shape);
Output x, y;
MakeOp<X_T>(x_data, {4, 6}, &x, &y);
MakeOp<X_T>(x_data, {out_height, out_width}, scale, translation,
kernel_type, antialias, &x, &y);
JAC_T max_error;
TF_ASSERT_OK((ComputeGradientError<X_T, Y_T, JAC_T>(
scope_, x, x_data, y, {1, 4, 6, 1}, &max_error)));
EXPECT_LT(max_error, 1e-3);
scope_, x, x_data, y, {1, out_height, out_width, 1}, &max_error)));
EXPECT_LT(max_error, 2e-3);
}
const std::vector<Input> kScales = {Input{1.0f, 1.0f}, Input{0.37f, 0.47f},
Input{2.1f, 2.1f}};
const std::vector<Input> kTranslations = {
Input{0.0f, 0.0f}, Input{3.14f, 1.19f}, Input{2.1f, 3.1f},
Input{100.0f, 200.0f}};
Scope scope_;
};
TEST_F(ScaleAndTranslateGradTest, Works) { TestResize<float, float, float>(); }
TEST_F(ScaleAndTranslateGradTest, TestGrads) {
const std::vector<std::string> kKernelTypes = {"lanczos1", "lanczos3",
"lanczos5", "gaussian"};
constexpr int kOutHeight = 4;
constexpr int kOutWidth = 6;
const TensorShape kXShape = TensorShape({1, 2, 3, 1});
for (const Input scale : kScales) {
for (const Input translation : kTranslations) {
for (const std::string& kernel_type : kKernelTypes) {
TestScaleAndTranslate<float, float, float>(
kXShape, kOutHeight, kOutWidth, scale, translation, kernel_type,
true);
}
}
}
}
TEST_F(ScaleAndTranslateGradTest, TestGradsWithoutAntialias) {
constexpr int kOutHeight = 4;
constexpr int kOutWidth = 6;
const TensorShape kXShape = TensorShape({1, 2, 3, 1});
for (const Input scale : kScales) {
for (const Input translation : kTranslations) {
TestScaleAndTranslate<float, float, float>(kXShape, kOutHeight, kOutWidth,
scale, translation, "lanczos3",
false);
}
}
}
TEST_F(ScaleAndTranslateGradTest, TestGradsWithSameShape) {
const std::vector<std::string> kKernelTypes = {"lanczos3", "gaussian"};
constexpr int kOutHeight = 2;
constexpr int kOutWidth = 3;
const TensorShape kXShape = TensorShape({1, 2, 3, 1});
for (const Input scale : kScales) {
for (const Input translation : kTranslations) {
for (const std::string& kernel_type : kKernelTypes) {
TestScaleAndTranslate<float, float, float>(
kXShape, kOutHeight, kOutWidth, scale, translation, kernel_type,
true);
}
}
}
}
TEST_F(ScaleAndTranslateGradTest, TestGradsWithSmallerShape) {
const std::vector<std::string> kKernelTypes = {"lanczos3", "gaussian"};
constexpr int kOutHeight = 2;
constexpr int kOutWidth = 3;
const TensorShape kXShape = TensorShape({1, 4, 6, 1});
for (const Input scale : kScales) {
for (const Input translation : kTranslations) {
for (const std::string& kernel_type : kKernelTypes) {
TestScaleAndTranslate<float, float, float>(
kXShape, kOutHeight, kOutWidth, scale, translation, kernel_type,
true);
}
}
}
}
class CropAndResizeGradTest : public ::testing::Test {
protected:
@ -237,9 +314,9 @@ class CropAndResizeGradTest : public ::testing::Test {
template <typename T>
void MakeOp(const Tensor& x_data, const Input& boxes, const Input& box_ind,
const Input& crop_szie, Output* x, Output* y) {
const Input& crop_size, Output* x, Output* y) {
*x = Const<T>(scope_, x_data);
*y = CropAndResize(scope_, *x, boxes, box_ind, crop_szie,
*y = CropAndResize(scope_, *x, boxes, box_ind, crop_size,
CropAndResize::Method("bilinear"));
TF_ASSERT_OK(scope_.status());
}

View File

@ -101,6 +101,15 @@ class XlaAssignVariableOp : public OpKernel {
REGISTER_KERNEL_BUILDER( \
Name("Identity").Device(DEVICE).TypeConstraint("T", DT_STRING), \
IdentityOp); \
REGISTER_KERNEL_BUILDER( \
Name("Identity").Device(DEVICE).TypeConstraint<Variant>("T"), \
IdentityOp); \
REGISTER_KERNEL_BUILDER(Name("Identity") \
.Device(DEVICE) \
.TypeConstraint<ResourceHandle>("T") \
.HostMemory("input") \
.HostMemory("output"), \
IdentityOp); \
REGISTER_KERNEL_BUILDER(Name("IdentityN").Device(DEVICE), IdentityNOp); \
REGISTER_KERNEL_BUILDER(Name("Placeholder").Device(DEVICE), PlaceholderOp); \
REGISTER_KERNEL_BUILDER(Name("PlaceholderV2").Device(DEVICE), \
@ -199,9 +208,7 @@ class XlaAssignVariableOp : public OpKernel {
Name("FIFOQueueV2").Device(DEVICE).HostMemory("handle"), FIFOQueueOp); \
\
REGISTER_KERNEL_BUILDER( \
Name(kArgOp).Device(DEVICE).HostMemory("output").TypeConstraint("T", \
TYPES), \
ArgOp); \
Name(kArgOp).Device(DEVICE).TypeConstraint("T", TYPES), ArgOp); \
REGISTER_KERNEL_BUILDER(Name(kArgOp) \
.Device(DEVICE) \
.HostMemory("output") \
@ -210,11 +217,8 @@ class XlaAssignVariableOp : public OpKernel {
REGISTER_KERNEL_BUILDER( \
Name(kArgOp).Device(DEVICE).TypeConstraint<Variant>("T"), ArgOp); \
\
REGISTER_KERNEL_BUILDER(Name(kRetOp) \
.Device(DEVICE) \
.TypeConstraint("T", TYPES) \
.HostMemory("input"), \
RetvalOp); \
REGISTER_KERNEL_BUILDER( \
Name(kRetOp).Device(DEVICE).TypeConstraint("T", TYPES), RetvalOp); \
REGISTER_KERNEL_BUILDER(Name(kRetOp) \
.Device(DEVICE) \
.TypeConstraint<ResourceHandle>("T") \

View File

@ -482,7 +482,7 @@ tf_xla_py_test(
name = "fft_test",
size = "medium",
srcs = ["fft_test.py"],
shard_count = 3,
shard_count = 6,
tags = ["optonly"],
deps = [
":xla_test",

View File

@ -79,6 +79,24 @@ static xla::XlaOp DivNoNanImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x,
XLA_MAKE_BINARY(DivNoNan,
DivNoNanImpl(b, input_type(0), lhs, rhs, broadcast_helper));
// Implementation of MulNoNan. Pseudo-code:
// if (y == 0) {
// return 0
// } else {
// return x * y;
// }
static xla::XlaOp MulNoNanImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x,
xla::XlaOp y, const BCast& broadcast_helper) {
std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper);
auto zero = XlaHelpers::Zero(b, dtype);
auto y_equals_0 = xla::Eq(y, zero);
auto zeros = xla::ZerosLike(x);
auto result = xla::Select(y_equals_0, zeros, xla::Mul(x, y));
return result;
}
XLA_MAKE_BINARY(MulNoNan,
MulNoNanImpl(b, input_type(0), lhs, rhs, broadcast_helper));
// Implementation of FloorDiv.
//
// For floating-point values, simply returns floor(x / y). For integers, does:

View File

@ -128,6 +128,20 @@ static void AllocateFlags() {
flag_values->xla_cpu_enable_fast_math(),
"Enable unsafe fast-math optimizations in the CPU compiler; "
"this may produce faster code at the expense of some accuracy."),
tensorflow::Flag(
"xla_cpu_fast_math_honor_nans",
bool_setter_for(&DebugOptions::set_xla_cpu_fast_math_honor_nans),
flag_values->xla_cpu_fast_math_honor_nans(),
"When xla_cpu_enable_fast_math is true then this controls whether we "
"allow operations to produce NaNs. Ignored when "
"xla_cpu_enable_fast_math is false."),
tensorflow::Flag(
"xla_cpu_fast_math_honor_infs",
bool_setter_for(&DebugOptions::set_xla_cpu_fast_math_honor_infs),
flag_values->xla_cpu_fast_math_honor_infs(),
"When xla_cpu_enable_fast_math is true then this controls whether we "
"allow operations to produce infinites. Ignored when "
"xla_cpu_enable_fast_math is false."),
tensorflow::Flag(
"xla_gpu_enable_fast_min_max",
bool_setter_for(&DebugOptions::set_xla_gpu_enable_fast_min_max),

View File

@ -1663,15 +1663,18 @@ Applies a reduction function to one or more arrays in parallel.
<b> `Reduce(operands..., init_values..., computation, dimensions)` </b>
Arguments | Type | Semantics
------------- | --------------------- | ---------------------------------------
`operands` | Sequence of N `XlaOp` | N arrays of types `T_0, ..., T_N`.
`init_values` | Sequence of N `XlaOp` | N scalars of types `T_0, ..., T_N`.
`computation` | `XlaComputation` | computation of type
: : `T_0, ..., T_N, T_0, ..., T_N -> Collate(T_0, ..., T_N)`
`dimensions` | `int64` array | unordered array of dimensions to reduce
| Arguments | Type | Semantics |
| ------------- | --------------------- | ------------------------------------ |
| `operands` | Sequence of N `XlaOp` | N arrays of types `T_0, ..., T_N`. |
| `init_values` | Sequence of N `XlaOp` | N scalars of types `T_0, ..., T_N`. |
| `computation` | `XlaComputation` | computation of type `T_0, ..., T_N, |
: : : T_0, ..., T_N ->` `Collate(T_0, ..., :
: : : T_N)` :
| `dimensions` | `int64` array | unordered array of dimensions to |
: : : reduce :
Where:
* N is required to be greater or equal to 1.
* All input arrays must have the same dimensions.
* If `N = 1`, `Collate(T)` is `T`.
@ -1681,10 +1684,10 @@ The output of the op is `Collate(Q_0, ..., Q_N)` where `Q_i` is an array of type
`T_i`, the dimensions of which are described below.
This operation reduces one or more dimensions of each input array into scalars.
The rank of each returned array is `rank(operand) - len(dimensions)`.
`init_value` is the initial value used for every reduction and may be inserted
The rank of each returned array is `rank(operand) - len(dimensions)`. The
initial value used for every reduction is `init_value`, and it may be inserted
anywhere during computation by the back-end. In most cases, `init_value` is an
identity of the reduction function (for example, 0 for addition). The applied
identity of the reduction function (for example, `0` for addition). The applied
`computation` is always passed the `init_value` on the left-hand side.
The evaluation order of the reduction function is arbitrary and may be
@ -1695,10 +1698,10 @@ Some reduction functions like addition are not strictly associative for floats.
However, if the range of the data is limited, floating-point addition is close
enough to being associative for most practical uses. It is possible to conceive
of some completely non-associative reductions, however, and these will produce
incorrect or unpredictable results in XLA reductions.
incorrect or unpredictable results in XLA.
As an example, when reducing across one dimension in a single 1D array with
values [10, 11, 12, 13], with reduction function `f` (this is `computation`)
values `[10, 11, 12, 13]`, with reduction function `f` (this is `computation`)
then that could be computed as
`f(10, f(11, f(12, f(init_value, 13)))`
@ -1777,16 +1780,27 @@ preserved in the output, but some dimensions may get assigned new numbers (since
the rank changes).
We can also reduce multiple dimensions. Add-reducing dimensions 0 and 1 produces
the 1D array `| 20 28 36 |`.
the 1D array `[20, 28, 36]`.
Reducing the 3D array over all its dimensions produces the scalar `84`.
When `N > 1`, reduce function application is slightly more complex, as it is
applied simultaneously to all inputs. For example, consider the following
reduction function, which can be used to compute the max and the argmax of a a
1-D array in parallel:
### Variadic Reduce
```
When `N > 1`, reduce function application is slightly more complex, as it is
applied simultaneously to all inputs. The operands are supplied to the
computation in the following order:
* Running reduced value for the first operand
* ...
* Running reduced value for the N'th operand
* Input value for the first operand
* ...
* Input value for the N'th operand
For example, consider the following reduction function, which can be used to
compute the max and the argmax of a 1-D array in parallel:
```python
f: (Float, Int, Float, Int) -> Float, Int
f(max, argmax, value, index):
if value >= argmax:
@ -1798,6 +1812,7 @@ f(max, argmax, value, index):
For 1-D Input arrays `V = Float[N], K = Int[N]`, and init values
`I_V = Float, I_K = Int`, the result `f_(N-1)` of reducing across the only
input dimension is equivalent to the following recursive application:
```
f_0 = f(I_V, I_K, V_0, K_0)
f_1 = f(f_0.first, f_0.second, V_1, K_1)

View File

@ -339,15 +339,15 @@ cc_library(
srcs = ["ir_function.cc"],
hdrs = ["ir_function.h"],
deps = [
":cpu_runtime",
":ir_emission_utils",
":shape_partition",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla/service/cpu:cpu_runtime",
"//tensorflow/compiler/xla/service:hlo_module_config",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
"//tensorflow/core:lib",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"@llvm//:core",
@ -378,6 +378,7 @@ cc_library(
":vector_support_library",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_module_config",
"//tensorflow/compiler/xla/service/llvm_ir:kernel_support_library",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
"//tensorflow/core:lib",

View File

@ -409,10 +409,20 @@ auto memory_alignment = [](LogicalBuffer::Color) { return kMemoryAlignment; };
llvm::TargetOptions CompilerTargetOptions(
const HloModuleConfig& module_config) {
llvm::TargetOptions target_options;
llvm_ir::SetTargetOptions(
/*fast_math_enabled=*/module_config.debug_options()
.xla_cpu_enable_fast_math(),
&target_options);
// In LLVM backend flags, UnsafeFPMath does not explicitly imply NoInfs, etc.
if (module_config.debug_options().xla_cpu_enable_fast_math()) {
target_options.UnsafeFPMath = true;
target_options.NoInfsFPMath =
module_config.debug_options().xla_cpu_fast_math_honor_infs();
target_options.NoNaNsFPMath =
module_config.debug_options().xla_cpu_fast_math_honor_nans();
target_options.NoSignedZerosFPMath = true;
} else {
target_options.UnsafeFPMath = false;
target_options.NoInfsFPMath = false;
target_options.NoNaNsFPMath = false;
target_options.NoSignedZerosFPMath = false;
}
return target_options;
}

View File

@ -250,11 +250,6 @@ void DotOpEmitter::EmitTiledLlvmIrGemm() {
std::tie(tile_size_m, tile_size_k, tile_size_n_in_vector_width) =
GetGemmTileSize();
const bool enable_fast_math =
hlo_module_config_.debug_options().xla_cpu_enable_fast_math();
const bool optimize_for_size =
options::OptimizeForSizeRequested(hlo_module_config_);
EmitSmallGemm(
/*scalar_type=*/primitive_type,
/*m=*/m, /*k=*/k, /*n=*/n,
@ -262,9 +257,7 @@ void DotOpEmitter::EmitTiledLlvmIrGemm() {
/*max_vector_count=*/tile_size_n_in_vector_width,
/*min_vectorization_width=*/std::min<int64>(4, max_target_vector_width),
/*tile_size_m=*/tile_size_m, /*tile_size_k=*/tile_size_k, /*lhs=*/lhs,
/*rhs=*/rhs, /*result=*/target, b_,
/*enable_fast_math=*/enable_fast_math,
/*optimize_for_size=*/optimize_for_size);
/*rhs=*/rhs, /*result=*/target, b_, hlo_module_config_);
}
void DotOpEmitter::EmitTiledLlvmIrGemv() {
@ -323,11 +316,6 @@ void DotOpEmitter::EmitTiledLlvmIrGemv() {
llvm::Value* rhs_op =
swap_operands ? lhs_array_.GetBasePointer() : rhs_array_.GetBasePointer();
const bool enable_fast_math =
hlo_module_config_.debug_options().xla_cpu_enable_fast_math();
const bool optimize_for_size =
options::OptimizeForSizeRequested(hlo_module_config_);
const int target_vector_register_element_size =
target_machine_features_.vector_register_num_elements(
*b_->GetInsertBlock()->getParent(), primitive_type);
@ -349,9 +337,7 @@ void DotOpEmitter::EmitTiledLlvmIrGemv() {
/*tile_rows=*/vector_register_element_size, /*tile_cols=*/tiling_factor,
/*m=*/m, /*k=*/k, /*lhs=*/lhs_op, /*rhs=*/rhs_op,
/*addend=*/addend_array_ ? addend_array_->GetBasePointer() : nullptr,
/*result=*/result_op, b_,
/*enable_fast_math=*/enable_fast_math,
/*optimize_for_size=*/optimize_for_size);
/*result=*/result_op, b_, hlo_module_config_);
} else {
VLOG(2) << "Emitting row major matrix-vector multiply with m = " << m
<< " and k = " << k;
@ -361,9 +347,7 @@ void DotOpEmitter::EmitTiledLlvmIrGemv() {
/*tile_cols=*/vector_register_element_size,
/*m=*/m, /*k=*/k, /*lhs=*/lhs_op, /*rhs=*/rhs_op,
/*addend=*/addend_array_ ? addend_array_->GetBasePointer() : nullptr,
/*result=*/result_op, b_,
/*enable_fast_math=*/enable_fast_math,
/*optimize_for_size=*/optimize_for_size);
/*result=*/result_op, b_, hlo_module_config_);
}
}

View File

@ -98,9 +98,7 @@ IrEmitter::IrEmitter(
is_top_level_computation_(false),
target_machine_features_(*target_machine_features),
emit_code_for_msan_(emit_code_for_msan) {
b_.setFastMathFlags(llvm_ir::GetFastMathFlags(
/*fast_math_enabled=*/hlo_module_config_.debug_options()
.xla_cpu_enable_fast_math()));
b_.setFastMathFlags(llvm_ir::GetCpuFastMathFlags(hlo_module_config_));
Status s = GatherComputationsByAllocationType(
&hlo_module, &thread_local_computations_, &global_computations_);
absl::c_sort(thread_local_computations_);
@ -159,11 +157,9 @@ void IrEmitter::InitializeIrFunction(const string& function_name) {
is_top_level_computation_ ? llvm::GlobalValue::ExternalLinkage
: llvm::GlobalValue::InternalLinkage;
// Create and initialize new IrFunction.
compute_function_.reset(new IrFunction(
function_name, linkage,
options::OptimizeForSizeRequested(hlo_module_config_),
hlo_module_config_.debug_options().xla_cpu_enable_fast_math(), module_,
&b_, num_dynamic_loop_bounds_));
compute_function_.reset(new IrFunction(function_name, linkage,
hlo_module_config_, module_, &b_,
num_dynamic_loop_bounds_));
}
IrEmitter::~IrEmitter() {}
@ -302,7 +298,7 @@ Status IrEmitter::HandleGetTupleElement(HloInstruction* get_tuple_element) {
const Shape& shape = get_tuple_element->shape();
emitted_value_[get_tuple_element] = llvm_ir::EmitGetTupleElement(
shape, get_tuple_element->tuple_index(), MinimumAlignmentForShape(shape),
GetEmittedValueFor(operand), &b_, module_);
GetEmittedValueFor(operand), &b_);
return Status::OK();
}
@ -322,7 +318,7 @@ Status IrEmitter::HandleTupleSelect(HloInstruction* tuple_select) {
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(tuple_select));
llvm_ir::EmitTupleSelect(GetIrArrayFor(tuple_select), GetIrArrayFor(pred),
GetEmittedValueFor(on_true),
GetEmittedValueFor(on_false), &b_, module_);
GetEmittedValueFor(on_false), &b_);
return Status::OK();
}
@ -345,8 +341,7 @@ Status IrEmitter::HandleInfeed(HloInstruction* instruction) {
assignment_.GetUniqueSlice(infeed, {1}));
llvm::Value* token_address = EmitBufferPointer(
token_slice, ShapeUtil::GetTupleElementShape(infeed->shape(), 1));
llvm_ir::EmitTuple(GetIrArrayFor(infeed), {data_address, token_address}, &b_,
module_);
llvm_ir::EmitTuple(GetIrArrayFor(infeed), {data_address, token_address}, &b_);
if (data_shape.IsTuple()) {
TF_RET_CHECK(!ShapeUtil::IsNestedTuple(data_shape));
@ -377,7 +372,7 @@ Status IrEmitter::HandleInfeed(HloInstruction* instruction) {
}
llvm_ir::EmitTuple(llvm_ir::IrArray(data_address, data_shape),
tuple_element_addresses, &b_, module_);
tuple_element_addresses, &b_);
} else {
TF_RETURN_IF_ERROR(
EmitXfeedTransfer(XfeedKind::kInfeed, data_shape, data_address));
@ -498,7 +493,7 @@ Status IrEmitter::HandleOutfeed(HloInstruction* outfeed) {
ShapeUtil::GetTupleElementShape(operand_shape, i);
llvm::Value* tuple_element = llvm_ir::EmitGetTupleElement(
tuple_element_shape, i, MinimumAlignmentForShape(tuple_element_shape),
value, &b_, module_);
value, &b_);
TF_RETURN_IF_ERROR(EmitXfeedTransfer(XfeedKind::kOutfeed,
tuple_element_shape, tuple_element));
}
@ -621,8 +616,7 @@ Status IrEmitter::HandleSort(HloInstruction* hlo) {
GetProfileCountersArgument(), less_than_function});
if (sort->values_count() > 0) {
llvm_ir::EmitTuple(GetIrArrayFor(sort), destination_addresses, &b_,
module_);
llvm_ir::EmitTuple(GetIrArrayFor(sort), destination_addresses, &b_);
}
return Status::OK();
}
@ -633,7 +627,7 @@ Status IrEmitter::HandleTuple(HloInstruction* tuple) {
for (auto operand : tuple->operands()) {
base_ptrs.push_back(GetEmittedValueFor(operand));
}
llvm_ir::EmitTuple(GetIrArrayFor(tuple), base_ptrs, &b_, module_);
llvm_ir::EmitTuple(GetIrArrayFor(tuple), base_ptrs, &b_);
return Status::OK();
}
@ -1349,7 +1343,7 @@ Status IrEmitter::HandleAllReduce(HloInstruction* crs) {
MemCpy(operand_ptrs.back(), /*DstAlign=*/1, in_ptr,
/*SrcAlign=*/1, ShapeUtil::ByteSizeOf(operand_shape));
}
llvm_ir::EmitTuple(GetIrArrayFor(crs), operand_ptrs, &b_, module_);
llvm_ir::EmitTuple(GetIrArrayFor(crs), operand_ptrs, &b_);
return Status::OK();
}
@ -2289,7 +2283,7 @@ Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) {
llvm::Value* addr = EmitBufferPointer(slice, elem_shape);
base_ptrs.push_back(addr);
}
llvm_ir::EmitTuple(GetIrArrayFor(custom_call), base_ptrs, &b_, module_);
llvm_ir::EmitTuple(GetIrArrayFor(custom_call), base_ptrs, &b_);
}
auto* output_address_arg =
PointerCast(GetEmittedValueFor(custom_call), i8_ptr_type);
@ -2980,7 +2974,7 @@ Status IrEmitter::EmitTargetElementLoop(
for (int64 i = 0; i < output_arrays.size(); ++i) {
tuple_operand_ptrs.push_back(output_arrays[i].GetBasePointer());
}
llvm_ir::EmitTuple(target_array, tuple_operand_ptrs, &b_, module_);
llvm_ir::EmitTuple(target_array, tuple_operand_ptrs, &b_);
} else {
if (ShouldEmitParallelLoopFor(*target_op)) {

View File

@ -43,15 +43,14 @@ static std::vector<llvm::Type*> GetComputeFunctionParams(
IrFunction::IrFunction(const string& function_name,
llvm::Function::LinkageTypes linkage,
const bool optimize_for_size_requested,
const bool enable_fast_math, llvm::Module* llvm_module,
llvm::IRBuilder<>* b, int64 num_dynamic_loop_bounds)
const HloModuleConfig& module_config,
llvm::Module* llvm_module, llvm::IRBuilder<>* b,
int64 num_dynamic_loop_bounds)
: b_(b),
llvm_module_(llvm_module),
caller_insert_point_guard_(*b),
num_dynamic_loop_bounds_(num_dynamic_loop_bounds) {
Initialize(function_name, linkage, optimize_for_size_requested,
enable_fast_math);
Initialize(function_name, linkage, module_config);
}
IrFunction::~IrFunction() {
@ -70,8 +69,7 @@ DynamicLoopBounds IrFunction::GetDynamicLoopBounds() {
void IrFunction::Initialize(const string& function_name,
llvm::Function::LinkageTypes linkage,
const bool optimize_for_size_requested,
const bool enable_fast_math) {
const HloModuleConfig& module_config) {
// The function signature is:
// void function(i8* retval, i8* run_options, i8** params, i8**
// buffer_table,
@ -142,11 +140,8 @@ void IrFunction::Initialize(const string& function_name,
// Functions with local linkage get an inlining bonus. Because we know
// a-priori that embedded functions (non-entry functions) will not have its
// name resolved, give it local linkage.
function_ =
llvm_ir::CreateFunction(function_type, linkage,
/*enable_fast_math=*/enable_fast_math,
/*optimize_for_size=*/optimize_for_size_requested,
function_name, llvm_module_);
function_ = llvm_ir::CreateCpuFunction(function_type, linkage, module_config,
function_name, llvm_module_);
// Set meaningful names for the function's arguments: useful for debugging.
llvm::Function::arg_iterator arg_iter = function_->arg_begin();

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "llvm/IR/Module.h"
#include "llvm/IR/Value.h"
#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
@ -52,8 +53,7 @@ namespace cpu {
class IrFunction {
public:
IrFunction(const string& function_name, llvm::Function::LinkageTypes linkage,
const bool optimize_for_size_requested,
const bool enable_fast_math, llvm::Module* llvm_module,
const HloModuleConfig& module_config, llvm::Module* llvm_module,
llvm::IRBuilder<>* b, int64 num_dynamic_loop_bounds);
~IrFunction();
@ -92,7 +92,7 @@ class IrFunction {
// Initialize an llvm::Function with standard signature based on arguments.
void Initialize(const string& function_name,
llvm::Function::LinkageTypes linkage,
bool optimize_for_size_requested, bool enable_fast_math);
const HloModuleConfig& module_config);
// Emit ir to read and return the ir value for the dynamic loop bound at
// 'offset' from the "dynamic_loop_bounds" argument of this function.

View File

@ -991,7 +991,7 @@ void EmitRowMajorGemv(PrimitiveType scalar_type, int64 tile_rows,
int64 tile_cols, int64 m, int64 k, llvm::Value* lhs,
llvm::Value* rhs, llvm::Value* addend,
llvm::Value* result, llvm::IRBuilder<>* b,
bool enable_fast_math, bool optimize_for_size) {
const HloModuleConfig& module_config) {
RowMajorMatrixVectorProductEmitter::Config config(
/*scalar_type=*/scalar_type,
/*tile_rows=*/tile_rows, /*tile_cols=*/tile_cols,
@ -1001,8 +1001,7 @@ void EmitRowMajorGemv(PrimitiveType scalar_type, int64 tile_rows,
GetGemvBuffersWithCanonicalType(lhs, rhs, addend, result, b);
KernelSupportLibrary::EmitAndCallOutlinedKernel(
/*enable_fast_math=*/enable_fast_math,
/*optimize_for_size=*/optimize_for_size, b, config.GetCacheKey(),
module_config, b, config.GetCacheKey(),
canonical_inputs.lhs_canonicalized, canonical_inputs.rhs_canonicalized,
canonical_inputs.addend_canonicalized,
canonical_inputs.result_canonicalized,
@ -1019,7 +1018,7 @@ void EmitColumnMajorGemv(PrimitiveType scalar_type, int64 tile_rows,
int64 tile_cols, int64 m, int64 k, llvm::Value* lhs,
llvm::Value* rhs, llvm::Value* addend,
llvm::Value* result, llvm::IRBuilder<>* b,
bool enable_fast_math, bool optimize_for_size) {
const HloModuleConfig& module_config) {
ColumnMajorMatrixVectorProductEmitter::Config config(
/*scalar_type=*/scalar_type,
/*tile_rows=*/tile_rows, /*tile_cols=*/tile_cols,
@ -1029,8 +1028,7 @@ void EmitColumnMajorGemv(PrimitiveType scalar_type, int64 tile_rows,
GetGemvBuffersWithCanonicalType(lhs, rhs, addend, result, b);
KernelSupportLibrary::EmitAndCallOutlinedKernel(
/*enable_fast_math=*/enable_fast_math,
/*optimize_for_size=*/optimize_for_size, b, config.GetCacheKey(),
module_config, b, config.GetCacheKey(),
canonical_inputs.lhs_canonicalized, canonical_inputs.rhs_canonicalized,
canonical_inputs.addend_canonicalized,
canonical_inputs.result_canonicalized,
@ -1048,7 +1046,7 @@ void EmitSmallGemm(PrimitiveType scalar_type, int64 m, int64 k, int64 n,
int64 min_vectorization_width, int64 tile_size_m,
int64 tile_size_k, llvm::Value* lhs, llvm::Value* rhs,
llvm::Value* result, llvm::IRBuilder<>* b,
bool enable_fast_math, bool optimize_for_size) {
const HloModuleConfig& module_config) {
TiledSmallGemmEmitter::Config config(
/*scalar_type=*/scalar_type,
TiledSmallGemmEmitter::Dimensions{/*m=*/m, /*k=*/k, /*n=*/n},
@ -1058,9 +1056,7 @@ void EmitSmallGemm(PrimitiveType scalar_type, int64 m, int64 k, int64 n,
/*tile_size_m=*/tile_size_m, /*tile_size_k=*/tile_size_k);
KernelSupportLibrary::EmitAndCallOutlinedKernel(
/*enable_fast_math=*/enable_fast_math,
/*optimize_for_size=*/optimize_for_size, b, config.GetCacheKey(), lhs,
rhs, result,
module_config, b, config.GetCacheKey(), lhs, rhs, result,
[&](llvm::Value* lhs, llvm::Value* rhs, llvm::Value* result) {
TiledSmallGemmEmitter small_gemm_emitter(config, /*lhs=*/lhs,
/*rhs=*/rhs,

View File

@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TILED_DOT_EMITTER_H_
#include "llvm/IR/IRBuilder.h"
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/types.h"
@ -29,15 +30,15 @@ void EmitRowMajorGemv(PrimitiveType scalar_type, tensorflow::int64 tile_rows,
tensorflow::int64 tile_cols, tensorflow::int64 m,
tensorflow::int64 k, llvm::Value* lhs, llvm::Value* rhs,
llvm::Value* addend, llvm::Value* result,
llvm::IRBuilder<>* b, bool enable_fast_math,
bool optimize_for_size);
llvm::IRBuilder<>* b,
const HloModuleConfig& module_config);
void EmitColumnMajorGemv(PrimitiveType scalar_type, tensorflow::int64 tile_rows,
tensorflow::int64 tile_cols, tensorflow::int64 m,
tensorflow::int64 k, llvm::Value* lhs,
llvm::Value* rhs, llvm::Value* addend,
llvm::Value* result, llvm::IRBuilder<>* b,
bool enable_fast_math, bool optimize_for_size);
const HloModuleConfig& module_config);
void EmitSmallGemm(PrimitiveType scalar_type, tensorflow::int64 m,
tensorflow::int64 k, tensorflow::int64 n,
@ -46,8 +47,7 @@ void EmitSmallGemm(PrimitiveType scalar_type, tensorflow::int64 m,
tensorflow::int64 min_vectorization_width,
tensorflow::int64 tile_size_m, tensorflow::int64 tile_size_k,
llvm::Value* lhs, llvm::Value* rhs, llvm::Value* result,
llvm::IRBuilder<>* b, bool enable_fast_math,
bool optimize_for_size);
llvm::IRBuilder<>* b, const HloModuleConfig& module_config);
} // namespace cpu
} // namespace xla

View File

@ -2407,8 +2407,9 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
case HloOpcode::kCopy:
return [hlo, &operand_to_generator](
const IrArray::Index& target_index) -> StatusOr<llvm::Value*> {
IrArray::Index source_index = target_index;
source_index.ClearLinearIndex();
IrArray::Index source_index(target_index.multidim(),
hlo->operand(0)->shape(),
target_index.GetType());
TF_ASSIGN_OR_RETURN(
llvm::Value * operand_value,
operand_to_generator.at(hlo->operand(0))(source_index));

View File

@ -135,11 +135,11 @@ llvm::Value* HloToIrBindings::EmitGetTupleElement(const HloInstruction* gte,
if (gte->operand(0)->opcode() != HloOpcode::kGetTupleElement) {
return llvm_ir::EmitGetTupleElement(
gte->shape(), gte->tuple_index(), /*alignment=*/1,
GetTypedIrValue(*gte->operand(0), {}, base_ptr), b_, module_);
GetTypedIrValue(*gte->operand(0), {}, base_ptr), b_);
}
return llvm_ir::EmitGetTupleElement(
gte->shape(), gte->tuple_index(), /*alignment=*/1,
EmitGetTupleElement(gte->operand(0), base_ptr), b_, module_);
EmitGetTupleElement(gte->operand(0), base_ptr), b_);
}
// Returns true if `value` has a name that should not be changed.

View File

@ -115,7 +115,7 @@ Status IrEmitter::HandleGetTupleElement(HloInstruction* get_tuple_element) {
get_tuple_element->shape(), get_tuple_element->tuple_index(),
// TODO(b/26344050): tighten the alignment here
// based on the real element type.
/*alignment=*/1, GetBasePointer(*operand), &b_, module_));
/*alignment=*/1, GetBasePointer(*operand), &b_));
return Status::OK();
}
@ -144,7 +144,7 @@ Status IrEmitter::HandleTuple(HloInstruction* tuple) {
for (const HloInstruction* operand : tuple->operands()) {
base_ptrs.push_back(GetBasePointer(*operand));
}
llvm_ir::EmitTuple(GetIrArray(*tuple, *tuple), base_ptrs, &b_, module_);
llvm_ir::EmitTuple(GetIrArray(*tuple, *tuple), base_ptrs, &b_);
return Status::OK();
}
@ -434,7 +434,7 @@ Status IrEmitter::HandleTupleSelect(HloInstruction* tuple_select) {
llvm_ir::EmitTupleSelect(GetIrArray(*tuple_select, *tuple_select),
GetIrArray(*pred, *tuple_select),
GetBasePointer(*on_true), GetBasePointer(*on_false),
&b_, module_);
&b_);
return Status::OK();
}

View File

@ -123,7 +123,7 @@ Status IrEmitterNested::EmitTargetElementLoop(
ConstructIrArrayForOutputs(hlo);
TF_RETURN_IF_ERROR(
llvm_ir::LoopEmitter(element_generator, target_arrays, &b_).EmitLoop());
llvm_ir::EmitTuple(GetIrArray(hlo, hlo), target_arrays, &b_, module_);
llvm_ir::EmitTuple(GetIrArray(hlo, hlo), target_arrays, &b_);
return Status::OK();
}
return llvm_ir::LoopEmitter(element_generator, GetIrArray(hlo, hlo), &b_)

View File

@ -2201,7 +2201,7 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk(
// kernel *anyway*.
std::vector<IrArray> output_arrays = ConstructIrArrayForOutputs(hlo);
KernelSupportLibrary{&b_}.If("emit_mof_tuple", IsBlock0Thread0(&b_), [&] {
llvm_ir::EmitTuple(GetIrArray(hlo, hlo), output_arrays, &b_, module_);
llvm_ir::EmitTuple(GetIrArray(hlo, hlo), output_arrays, &b_);
});
// For multioutput fusion, we need to emit each operand and the root.
@ -2916,13 +2916,6 @@ void IrEmitterUnnested::EmitTileElementForReduction(
reduction_info->GetKernelMappingScheme()->GetUnnormalizedIndex(
index,
GetFirstReduceInstruction(output_instructions)->operand(0)->shape());
int num_partial_results = reduction_info->GetNumberOfPartialResults();
if (num_partial_results > 1) {
// Clear the linear index field of the IrArray::Index to enable the use of
// GetElementPointer with array types. This enables the vectorization of
// the computation for different partial results.
input_index.ClearLinearIndex();
}
absl::Span<llvm::AllocaInst* const> partial_reduction_result_addresses =
reduction_info->GetPartialResultAddresses();
absl::Span<llvm::AllocaInst* const> reduction_input_addresses =
@ -3103,8 +3096,7 @@ LaunchDimensions IrEmitterUnnested::EmitKernel(
if (!reduction_info && unnested_hlo->IsMultiOutputFusion()) {
KernelSupportLibrary{&b_}.If("emit_mof_tuple", IsBlock0Thread0(&b_), [&] {
llvm_ir::EmitTuple(GetIrArray(*unnested_hlo, *unnested_hlo),
ConstructIrArrayForOutputs(*unnested_hlo), &b_,
module_);
ConstructIrArrayForOutputs(*unnested_hlo), &b_);
});
}

View File

@ -207,6 +207,39 @@ std::vector<HloInstruction*> HloModuleGroupUtil::RootInstructions(
return roots;
}
string HloModuleGroupUtil::CycleToString(HloInstruction* init_instruction) {
std::vector<string> names;
absl::flat_hash_set<HloInstruction*> seen;
std::function<bool(HloInstruction*)> helper =
[&](HloInstruction* instruction) {
if (seen.find(instruction) != seen.end()) {
if (instruction == init_instruction) {
names.push_back(instruction->name());
return true;
}
return false;
}
seen.insert(instruction);
for (HloInstruction* predecessor : GlobalPredecessors(instruction)) {
bool init_found = helper(predecessor);
if (init_found) {
names.push_back(instruction->name());
return true;
}
}
return false;
};
helper(init_instruction);
std::vector<string> pieces;
pieces.reserve(names.size());
for (auto name : names) {
pieces.push_back(name);
}
return absl::StrJoin(pieces, " --> ");
}
Status HloModuleGroupUtil::VisitTopologicalOrder(
VisitStates* visit_state, const VisitFunction& visit_function,
HloInstruction* root) {
@ -269,22 +302,9 @@ Status HloModuleGroupUtil::VisitTopologicalOrder(
// a cycle. Generate an error with the list of instructions in the
// cycle.
if ((*visit_state)[predecessor] == VisitState::kVisiting) {
string cyclic_instructions;
for (const auto& state : *visit_state) {
if (state.second == VisitState::kVisiting) {
absl::StrAppend(&cyclic_instructions, state.first->ToString(),
"\n");
}
}
// TODO(b/64305524): Improve the error message to print out the
// instructions in a deterministic order that forms the cycle.
return FailedPrecondition(
"Cross-computation cycle detected via communicating nodes. The "
"cycle contains the node %s. The cycle is found among the "
"following nodes. Note that the order of the nodes is arbitrary "
"and that the list may include nodes that are not part of the "
"cycle.\n%s",
predecessor->ToString(), cyclic_instructions);
"Cross-computation cycle detected via communicating nodes.\n%s",
CycleToString(predecessor));
}
stack.push(predecessor);
}

View File

@ -108,6 +108,8 @@ class HloModuleGroupUtil {
HloInstruction* instruction, HloReachabilityMap* reachability_map);
private:
string CycleToString(HloInstruction* instruction);
const HloModuleGroupMetadata& metadata_;
};

View File

@ -71,6 +71,7 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_module_config",
"//tensorflow/compiler/xla/service:name_uniquer",
"//tensorflow/compiler/xla/service/cpu:cpu_options",
"//tensorflow/core:lib",
"@com_google_absl//absl/base",
"@com_google_absl//absl/strings",
@ -239,7 +240,7 @@ cc_library(
hdrs = ["kernel_support_library.h"],
deps = [
":llvm_loop",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
":llvm_util",
"@com_google_absl//absl/strings",
"@llvm//:core",
],

View File

@ -121,9 +121,9 @@ Status FusedIrEmitter::HandleGetTupleElement(
}
// Lookup tuple element pointer.
return llvm_ir::EmitGetTupleElement(
get_tuple_element->shape(), get_tuple_element->tuple_index(),
/*alignment=*/1, tuple_ptr, b_, module_);
return llvm_ir::EmitGetTupleElement(get_tuple_element->shape(),
get_tuple_element->tuple_index(),
/*alignment=*/1, tuple_ptr, b_);
};
if (!get_tuple_element->shape().IsTuple()) {

View File

@ -331,7 +331,8 @@ llvm::Value* IrArray::Index::Linearize(absl::Span<const int64> dimensions,
llvm::Value* IrArray::EmitArrayElementAddress(const IrArray::Index& index,
llvm::IRBuilder<>* b,
absl::string_view name) const {
absl::string_view name,
bool use_linear_index) const {
if (ShapeUtil::IsScalar(shape_)) {
// Special handling of scalars: a scalar pretends to have the same value for
// every index, thus effectively implementing broadcasting of its value
@ -340,7 +341,7 @@ llvm::Value* IrArray::EmitArrayElementAddress(const IrArray::Index& index,
}
CHECK_EQ(index.size(), shape_.rank());
if (index.LinearValidOnShape(shape_)) {
if (use_linear_index && index.LinearValidOnShape(shape_)) {
llvm::Module* module = b->GetInsertBlock()->getParent()->getParent();
return b->CreateInBoundsGEP(
b->CreateBitCast(base_ptr_,
@ -389,16 +390,20 @@ void IrArray::AnnotateLoadStoreInstructionWithMetadata(
llvm::Value* IrArray::EmitReadArrayElement(const Index& index,
llvm::IRBuilder<>* b,
absl::string_view name) const {
llvm::Value* element_address = EmitArrayElementAddress(index, b, name);
absl::string_view name,
bool use_linear_index) const {
llvm::Value* element_address =
EmitArrayElementAddress(index, b, name, use_linear_index);
llvm::LoadInst* load = b->CreateLoad(element_address);
AnnotateLoadStoreInstructionWithMetadata(load);
return load;
}
void IrArray::EmitWriteArrayElement(const Index& index, llvm::Value* value,
llvm::IRBuilder<>* b) const {
llvm::Value* element_address = EmitArrayElementAddress(index, b);
llvm::IRBuilder<>* b,
bool use_linear_index) const {
llvm::Value* element_address =
EmitArrayElementAddress(index, b, "", use_linear_index);
llvm::StoreInst* store = b->CreateStore(value, element_address);
AnnotateLoadStoreInstructionWithMetadata(store);
}

View File

@ -168,8 +168,6 @@ class IrArray {
return llvm::ConstantInt::get(index_type_, c);
}
void ClearLinearIndex() { linear_ = nullptr; }
private:
// Constructs an index from both a multi-dimensional index and a linear
// index. 'shape' is the shape on which the index is used. 'index_type' is
@ -227,7 +225,8 @@ class IrArray {
// The optional name is useful for debugging when looking at
// the emitted LLVM IR.
llvm::Value* EmitArrayElementAddress(const Index& index, llvm::IRBuilder<>* b,
absl::string_view name = "") const;
absl::string_view name = "",
bool use_linear_index = true) const;
// Attach metadata this IrArray instance knows about to "instruction".
void AnnotateLoadStoreInstructionWithMetadata(
@ -240,15 +239,23 @@ class IrArray {
//
// The optional name is useful for debugging when looking at
// the emitted LLVM IR.
// 'use_linear_index' can be used to specify whether the linear index (if
// available) or the multi-dimensional index should be used.
llvm::Value* EmitReadArrayElement(const Index& index, llvm::IRBuilder<>* b,
absl::string_view name = "") const;
absl::string_view name = "",
bool use_linear_index = true) const;
// Emit IR to write the given value to the array element at the given index.
// 'use_linear_index' can be used to specify whether the linear index (if
// available) or the multi-dimensional index should be used.
void EmitWriteArrayElement(const Index& index, llvm::Value* value,
llvm::IRBuilder<>* b) const;
llvm::IRBuilder<>* b,
bool use_linear_index = true) const;
// Returns a new IrArray whose shape is "new_shape" and base pointer is a
// bitcast of the base pointer of "this" IrArray.
// 'use_linear_index' can be used to specify whether the linear index (if
// available) or the multi-dimensional index should be used.
IrArray CastToShape(const Shape& new_shape, llvm::IRBuilder<>* b) const;
void AddAliasScopeMetadata(llvm::MDNode* alias_scope) {

View File

@ -70,7 +70,7 @@ Status KernelSupportLibrary::IfWithStatus(
}
void KernelSupportLibrary::EmitAndCallOutlinedKernel(
bool enable_fast_math, bool optimize_for_size, llvm::IRBuilder<>* b,
const HloModuleConfig& module_config, llvm::IRBuilder<>* b,
absl::string_view kernel_name,
KernelSupportLibrary::ArgumentVector arguments,
const std::function<void(KernelSupportLibrary::ArgumentVector)>&
@ -101,10 +101,9 @@ void KernelSupportLibrary::EmitAndCallOutlinedKernel(
auto* function_type =
llvm::FunctionType::get(b->getVoidTy(), arg_types, /*isVarArg=*/false);
function = llvm_ir::CreateFunction(
function_type, llvm::GlobalValue::InternalLinkage,
/*enable_fast_math=*/enable_fast_math,
/*optimize_for_size=*/optimize_for_size, kernel_name, module);
function = llvm_ir::CreateCpuFunction(function_type,
llvm::GlobalValue::InternalLinkage,
module_config, kernel_name, module);
llvm::IRBuilder<>::InsertPointGuard guard(*b);

View File

@ -263,33 +263,33 @@ class KernelSupportLibrary {
// in a nullptr llvm::Value* in its position to `kernel_body_generator`.
// Currently we only support at most one nullptr value in `arguments`.
static void EmitAndCallOutlinedKernel(
bool enable_fast_math, bool optimize_for_size, llvm::IRBuilder<>* b,
const HloModuleConfig& module_config, llvm::IRBuilder<>* b,
absl::string_view kernel_name, ArgumentVector arguments,
const std::function<void(ArgumentVector)>& kernel_body_generator);
// Thin wrappers around the more general EmitAndCallOutlinedKernel above.
static void EmitAndCallOutlinedKernel(
bool enable_fast_math, bool optimize_for_size, llvm::IRBuilder<>* b,
const HloModuleConfig& module_config, llvm::IRBuilder<>* b,
absl::string_view kernel_name, llvm::Value* arg0, llvm::Value* arg1,
llvm::Value* arg2,
const std::function<void(llvm::Value*, llvm::Value*, llvm::Value*)>&
kernel_body_generator) {
EmitAndCallOutlinedKernel(
enable_fast_math, optimize_for_size, b, kernel_name, {arg0, arg1, arg2},
[&](ArgumentVector args) {
kernel_body_generator(args[0], args[1], args[2]);
});
EmitAndCallOutlinedKernel(module_config, b, kernel_name, {arg0, arg1, arg2},
[&](ArgumentVector args) {
kernel_body_generator(args[0], args[1],
args[2]);
});
}
static void EmitAndCallOutlinedKernel(
bool enable_fast_math, bool optimize_for_size, llvm::IRBuilder<>* b,
const HloModuleConfig& module_config, llvm::IRBuilder<>* b,
absl::string_view kernel_name, llvm::Value* arg0, llvm::Value* arg1,
llvm::Value* arg2, llvm::Value* arg3,
const std::function<void(llvm::Value*, llvm::Value*, llvm::Value*,
llvm::Value*)>& kernel_body_generator) {
EmitAndCallOutlinedKernel(
enable_fast_math, optimize_for_size, b, kernel_name,
{arg0, arg1, arg2, arg3}, [&](ArgumentVector args) {
module_config, b, kernel_name, {arg0, arg1, arg2, arg3},
[&](ArgumentVector args) {
kernel_body_generator(args[0], args[1], args[2], args[3]);
});
}

View File

@ -31,6 +31,7 @@ limitations under the License.
#include "llvm/Transforms/Utils/Cloning.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_options.h"
#include "tensorflow/compiler/xla/service/dump.h"
#include "tensorflow/compiler/xla/service/name_uniquer.h"
#include "tensorflow/compiler/xla/shape_util.h"
@ -499,24 +500,25 @@ int64 ByteSizeOf(const Shape& shape, const llvm::DataLayout& data_layout) {
return ShapeUtil::ByteSizeOf(shape, pointer_size);
}
llvm::FastMathFlags GetFastMathFlags(bool fast_math_enabled) {
llvm::FastMathFlags GetCpuFastMathFlags(const HloModuleConfig& module_config) {
llvm::FastMathFlags flags;
if (fast_math_enabled) {
// Fast implies AllowReassoc, NoInfs, NoNaNs, NoSignedZeros,
// AllowReciprocal, AllowContract, and ApproxFunc.
flags.setFast();
if (!module_config.debug_options().xla_cpu_enable_fast_math()) {
return flags;
}
return flags;
}
void SetTargetOptions(bool fast_math_enabled,
llvm::TargetOptions* target_options) {
// In LLVM backend flags, UnsafeFPMath does not explicitly imply
// NoInfs, etc.
target_options->UnsafeFPMath = fast_math_enabled;
target_options->NoInfsFPMath = fast_math_enabled;
target_options->NoNaNsFPMath = fast_math_enabled;
target_options->NoSignedZerosFPMath = fast_math_enabled;
// Fast implies AllowReassoc, NoInfs, NoNaNs, NoSignedZeros, AllowReciprocal,
// AllowContract, and ApproxFunc.
flags.setFast();
if (module_config.debug_options().xla_cpu_fast_math_honor_nans()) {
flags.setNoNaNs(false);
}
if (module_config.debug_options().xla_cpu_fast_math_honor_infs()) {
flags.setNoInfs(false);
}
return flags;
}
std::map<int, llvm::MDNode*> MergeMetadata(
@ -603,10 +605,11 @@ void DumpIrIfEnabled(const HloModule& hlo_module,
}
}
llvm::Function* CreateFunction(llvm::FunctionType* function_type,
llvm::GlobalValue::LinkageTypes linkage,
bool enable_fast_math, bool optimize_for_size,
absl::string_view name, llvm::Module* module) {
llvm::Function* CreateCpuFunction(llvm::FunctionType* function_type,
llvm::GlobalValue::LinkageTypes linkage,
const HloModuleConfig& module_config,
absl::string_view name,
llvm::Module* module) {
llvm::Function* function =
llvm::Function::Create(function_type, linkage, AsStringRef(name), module);
function->setCallingConv(llvm::CallingConv::C);
@ -616,17 +619,23 @@ llvm::Function* CreateFunction(llvm::FunctionType* function_type,
// created by the JIT compiled code.
function->setHasUWTable();
if (enable_fast_math) {
if (module_config.debug_options().xla_cpu_enable_fast_math()) {
function->addFnAttr("unsafe-fp-math", "true");
function->addFnAttr("no-infs-fp-math", "true");
function->addFnAttr("no-nans-fp-math", "true");
function->addFnAttr("no-signed-zeros-fp-math", "true");
if (!module_config.debug_options().xla_cpu_fast_math_honor_nans()) {
function->addFnAttr("no-nans-fp-math", "true");
}
if (!module_config.debug_options().xla_cpu_fast_math_honor_infs()) {
function->addFnAttr("no-infs-fp-math", "true");
}
}
// Add the optize attribute to the function if optimizing for size. This
// controls internal behavior of some optimization passes (e.g. loop
// unrolling).
if (optimize_for_size) {
if (cpu::options::OptimizeForSizeRequested(module_config)) {
function->addFnAttr(llvm::Attribute::OptimizeForSize);
}

View File

@ -263,12 +263,7 @@ int64 ByteSizeOf(const Shape& shape, const llvm::DataLayout& data_layout);
// Gets an llvm::FastMathFlags that reflects the settings in the given
// module config.
llvm::FastMathFlags GetFastMathFlags(bool fast_math_enabled);
// Sets values in the given TargetOptions struct according to the given
// compilation options.
void SetTargetOptions(bool fast_math_enabled,
llvm::TargetOptions* target_options);
llvm::FastMathFlags GetCpuFastMathFlags(const HloModuleConfig& module_config);
// Computes a conservative union of the metadata in "a" and "b". For
// aliasing-related metadata, this means the result can be applied to
@ -287,10 +282,10 @@ std::map<int, llvm::MDNode*> MergeMetadata(
void DumpIrIfEnabled(const HloModule& hlo_module,
const llvm::Module& llvm_module, bool optimized);
llvm::Function* CreateFunction(llvm::FunctionType* function_type,
llvm::GlobalValue::LinkageTypes linkage,
bool enable_fast_math, bool optimize_for_size,
absl::string_view name, llvm::Module* module);
llvm::Function* CreateCpuFunction(llvm::FunctionType* function_type,
llvm::GlobalValue::LinkageTypes linkage,
const HloModuleConfig& module_config,
absl::string_view name, llvm::Module* module);
// Extracts the xla_backend_extra_options from `config` and passes those that
// don't start with xla_ to LLVM.

View File

@ -29,9 +29,14 @@ limitations under the License.
namespace xla {
namespace llvm_ir {
static llvm::Module* getModuleFromBuilder(llvm::IRBuilder<>* b) {
return b->GetInsertBlock()->getModule();
}
void EmitTupleSelect(const IrArray& select, const IrArray& pred,
llvm::Value* on_true, llvm::Value* on_false,
llvm::IRBuilder<>* b, llvm::Module* module) {
llvm::IRBuilder<>* b) {
llvm::Module* module = getModuleFromBuilder(b);
CHECK(ShapeUtil::IsScalar(pred.GetShape()));
llvm::LoadInst* pred_value =
@ -65,7 +70,8 @@ void EmitTupleSelect(const IrArray& select, const IrArray& pred,
}
void EmitTuple(const IrArray& tuple, absl::Span<llvm::Value* const> operands,
llvm::IRBuilder<>* b, llvm::Module* module) {
llvm::IRBuilder<>* b) {
llvm::Module* module = getModuleFromBuilder(b);
for (size_t i = 0; i < operands.size(); ++i) {
auto* store = b->CreateStore(
b->CreatePointerCast(operands[i], PrimitiveTypeToIrType(TUPLE, module)),
@ -76,18 +82,19 @@ void EmitTuple(const IrArray& tuple, absl::Span<llvm::Value* const> operands,
}
void EmitTuple(const IrArray& tuple, absl::Span<const IrArray> buffers,
llvm::IRBuilder<>* b, llvm::Module* module) {
llvm::IRBuilder<>* b) {
std::vector<llvm::Value*> buffer_ptrs;
buffer_ptrs.reserve(buffers.size());
absl::c_transform(
buffers, std::back_inserter(buffer_ptrs),
[](const llvm_ir::IrArray& buffer) { return buffer.GetBasePointer(); });
llvm_ir::EmitTuple(tuple, buffer_ptrs, b, module);
llvm_ir::EmitTuple(tuple, buffer_ptrs, b);
}
llvm::Value* EmitGetTupleElement(const Shape& target_shape, int64 index,
int alignment, llvm::Value* operand,
llvm::IRBuilder<>* b, llvm::Module* module) {
llvm::IRBuilder<>* b) {
llvm::Module* module = getModuleFromBuilder(b);
llvm::Value* element_ptr =
b->CreateInBoundsGEP(operand, {b->getInt64(0), b->getInt64(index)});
llvm::LoadInst* src_buffer = b->CreateLoad(element_ptr);

View File

@ -61,17 +61,17 @@ namespace llvm_ir {
// output[i] = pred ? tuple_on_true[i] : tuple_on_false[i]
void EmitTupleSelect(const IrArray& select, const IrArray& pred,
llvm::Value* on_true, llvm::Value* on_false,
llvm::IRBuilder<>* b, llvm::Module* module);
llvm::IRBuilder<>* b);
// A tuple is an array of pointers, one for each operand. Each pointer points to
// the output buffer of its corresponding operand.
void EmitTuple(const IrArray& tuple, absl::Span<llvm::Value* const> operands,
llvm::IRBuilder<>* b, llvm::Module* module);
llvm::IRBuilder<>* b);
// Similar to EmitTuple above, except that the output buffers are provided in
// the form of IrArray.
void EmitTuple(const IrArray& tuple, absl::Span<const IrArray> buffers,
llvm::IRBuilder<>* b, llvm::Module* module);
llvm::IRBuilder<>* b);
// A tuple is an array of pointers, one for each operand. Each pointer points to
// the output buffer of its corresponding operand. A GetTupleElement instruction
@ -79,7 +79,7 @@ void EmitTuple(const IrArray& tuple, absl::Span<const IrArray> buffers,
// Returns an llvm value representing a pointer to the tuple element buffer.
llvm::Value* EmitGetTupleElement(const Shape& target_shape, int64 index,
int alignment, llvm::Value* operand,
llvm::IRBuilder<>* b, llvm::Module* module);
llvm::IRBuilder<>* b);
} // namespace llvm_ir
} // namespace xla

View File

@ -156,10 +156,21 @@ message DebugOptions {
//
// - Reducing the precision of operations (e.g. using an approximate sin
// function, or transforming x/y into x * (1/y)).
// - Assuming that operations never produce or consume NaN or +/- Inf.
// - Assuming that operations never produce or consume NaN or +/- Inf (this
// behavior can be adjusted using xla_cpu_fast_math_allow_{nans|infs}).
// - Assuming that +0 and -0 are indistinguishable.
bool xla_cpu_enable_fast_math = 99;
// When xla_cpu_enable_fast_math is true then this controls whether we allow
// operations to produce NaNs. Ignored when xla_cpu_enable_fast_math is
// false.
bool xla_cpu_fast_math_honor_nans = 120;
// When xla_cpu_enable_fast_math is true then this controls whether we allow
// operations to produce infinites. Ignored when xla_cpu_enable_fast_math is
// false.
bool xla_cpu_fast_math_honor_infs = 121;
// When true we lower the Minimum and Maximum hlos in the GPU backend such
// that Min(NotNaN, NaN) = Min(NaN, NotNaN) = NotNaN. In other words, if flag
// this is true we don't propagate NaNs through Min and Max.
@ -250,7 +261,7 @@ message DebugOptions {
// END flags controlling dumping HLO modules.
//
// Next id: 119
// Next id: 121
// Extra options to pass to the compilation backend (e.g. LLVM); specific
// interpretation of these values is left to the backend.

View File

@ -1149,9 +1149,9 @@ class GbdtTest(test_util.TensorFlowTestCase):
expected_leaf_1 = [-3.4480, -3.4429, 13.8490, -3.45, -3.4508]
expected_leaf_2 = [-1.2547, -1.3145, 1.52, 2.3875, -1.3264]
self.assertArrayNear(expected_leaf_1,
output.trees[0].nodes[1].leaf.vector.value, 2e-3)
output.trees[0].nodes[1].leaf.vector.value, 3e-3)
self.assertArrayNear(expected_leaf_2,
output.trees[0].nodes[2].leaf.vector.value, 2e-3)
output.trees[0].nodes[2].leaf.vector.value, 3e-3)
def testTrainFnMulticlassDiagonalHessian(self):
"""Tests the GBDT train for multiclass diagonal hessian."""

View File

@ -64,9 +64,9 @@ py_library(
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework_ops",
"//tensorflow/python:platform",
"//tensorflow/python:summary_op_util",
"//tensorflow/python:util",
"//tensorflow/python:variable_scope",
"//tensorflow/python/distribute:summary_op_util",
"//tensorflow/python/estimator:estimator_py",
],
)

View File

@ -25,11 +25,11 @@ from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.compiler.jit.ops import xla_ops
from tensorflow.compiler.jit.ops import xla_ops_grad # pylint: disable=unused-import
from tensorflow.core.framework import attr_value_pb2
from tensorflow.python.distribute import summary_op_util
from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import summary_op_util
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import compat

View File

@ -266,7 +266,7 @@ class CollectiveAllReduceStrategyTestBase(
target=master_target) as sess:
with d.scope():
train_op = d.extended.call_for_each_replica(model_fn)
train_op = d.group(d.unwrap(train_op))
train_op = d.group(d.experimental_local_results(train_op))
sess.run(variables.global_variables_initializer())
sess.run(train_op)
@ -294,7 +294,7 @@ class CollectiveAllReduceStrategyTestBase(
x = distribution.extended.call_for_each_replica(model_fn)
reduced_x = distribution.reduce(reduce_util.ReduceOp.MEAN, x)
x = distribution.unwrap(x)[0]
x = distribution.experimental_local_results(x)[0]
sess.run(variables.global_variables_initializer())

View File

@ -749,7 +749,7 @@ class TestDistributionStrategyWithDatasets(test.TestCase,
model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0,
callbacks=[keras.callbacks.LearningRateScheduler(schedule)])
grouped_models = distribution.unwrap(
grouped_models = distribution.experimental_local_results(
distributed_training_utils.get_distributed_model(
model, ModeKeys.TRAIN))
with distribution.scope():

View File

@ -220,7 +220,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
def step_fn(ctx, inputs):
del ctx # Unused
fetches = distribution.unwrap(
fetches = distribution.experimental_local_results(
distribution.extended.call_for_each_replica(
model_fn, args=(inputs,)))
if update_ops_in_cross_replica_mode:
@ -419,8 +419,9 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
# values that are of the same structure as non reduced losses. In
# MirroredStrategy, this will be a list of losses, in TPUStrategy
# it will be single tensor. Using `call_for_each_replica` followed
# by `unwrap` gives us the desired initial value structure.
not_reduced = distribution.unwrap(
# by `experimental_local_results` gives us the desired initial
# value structure.
not_reduced = distribution.experimental_local_results(
distribution.extended.call_for_each_replica(initial_loss))
initial_loop_values = {
"replica_loss_reduced": initial_loss(),
@ -469,11 +470,11 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
def _verify_loss_output(self, initial_loss, loss_output, reduced,
distribution):
if not reduced:
self.assertLen(distribution.unwrap(loss_output),
self.assertLen(distribution.experimental_local_results(loss_output),
distribution.num_replicas_in_sync)
loss_tensor = distribution.reduce(reduce_util.ReduceOp.MEAN, loss_output)
else:
unwrapped_output = distribution.unwrap(loss_output)
unwrapped_output = distribution.experimental_local_results(loss_output)
self.assertLen(unwrapped_output, 1)
loss_tensor = unwrapped_output[0]
self.assertEqual(initial_loss.dtype, loss_tensor.dtype)

View File

@ -249,7 +249,7 @@ class MirroredStrategyVariableCreatorStackTest(
distribution.scope(), \
variable_scope.variable_creator_scope(main_thread_creator):
result = distribution.extended.call_for_each_replica(model_fn)
result = distribution.unwrap(result)
result = distribution.experimental_local_results(result)
expected = ("main_thread:thread_0", "main_thread:thread_1")
self.assertEqual(expected, result)
@ -269,7 +269,7 @@ class MirroredStrategyCallForEachReplicaTest(test.TestCase):
with distribution.scope():
in_scope = ops.executing_eagerly_outside_functions()
in_model_fn = distribution.extended.call_for_each_replica(model_fn)
unwrapped = distribution.unwrap(in_model_fn)
unwrapped = distribution.experimental_local_results(in_model_fn)
self.assertEqual(in_scope, unwrapped[0])
self.assertEqual(in_scope, originally)
@ -277,7 +277,7 @@ class MirroredStrategyCallForEachReplicaTest(test.TestCase):
with func_graph.FuncGraph("fg").as_default(), distribution.scope():
in_scope = ops.executing_eagerly_outside_functions()
in_model_fn = distribution.extended.call_for_each_replica(model_fn)
unwrapped = distribution.unwrap(in_model_fn)
unwrapped = distribution.experimental_local_results(in_model_fn)
self.assertEqual(in_scope, unwrapped[0])
self.assertEqual(in_scope, originally)
@ -701,7 +701,8 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
# Apply updates
self.evaluate(variables.global_variables_initializer())
self.evaluate([y for x in ret_ops for y in distribution.unwrap(x)])
self.evaluate([y for x in ret_ops # pylint: disable=g-complex-comprehension
for y in distribution.experimental_local_results(x)])
expected_sum = 0.0
expected_mean = 0.0
for i, d in enumerate(distribution.extended.worker_devices):
@ -747,7 +748,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
self.assertEqual(2, len(result))
for v in result:
self.assertIsInstance(v, values.DistributedValues)
_, v1 = distribution.unwrap(v)
_, v1 = distribution.experimental_local_results(v)
self.assertStartsWith(v1._op.name, "replica_1/")
def testSyncOnReadVariableUpdate(self, distribution):
@ -816,7 +817,7 @@ class MirroredStrategyNameScopeTest(test.TestCase):
self.assertEqual(2, len(result))
for v, name in zip(result, ["a", "b"]):
self.assertIsInstance(v, values.DistributedValues)
v0, v1 = distribution.unwrap(v)
v0, v1 = distribution.experimental_local_results(v)
self.assertEqual("main/foo/" + name + ":0", v0.name)
self.assertEqual("main/replica_1/foo/" + name + ":0", v1.name)
@ -833,7 +834,7 @@ class MirroredStrategyNameScopeTest(test.TestCase):
self.assertEqual(2, len(result))
for v, name in zip(result, ["a", "b"]):
self.assertIsInstance(v, values.DistributedValues)
v0, v1 = distribution.unwrap(v)
v0, v1 = distribution.experimental_local_results(v)
self.assertEqual("foo/" + name + ":0", v0.name)
self.assertEqual("replica_1/foo/" + name + ":0", v1.name)
@ -860,9 +861,9 @@ class MirroredStrategyNameScopeTest(test.TestCase):
result_c = result[1]
self.assertIsInstance(result_b, values.DistributedValues)
self.assertIsInstance(result_c, values.DistributedValues)
a0, a1 = distribution.unwrap(a)
b0, b1 = distribution.unwrap(result_b)
c0, c1 = distribution.unwrap(result_c)
a0, a1 = distribution.experimental_local_results(a)
b0, b1 = distribution.experimental_local_results(result_b)
c0, c1 = distribution.experimental_local_results(result_c)
self.assertEqual("main/a:0", a0.name)
self.assertEqual("main/a/replica_1:0", a1.name)
self.assertEqual("main/b:0", b0.name)
@ -889,9 +890,9 @@ class MirroredStrategyNameScopeTest(test.TestCase):
result_c = result[1]
self.assertIsInstance(result_b, values.DistributedValues)
self.assertIsInstance(result_c, values.DistributedValues)
a0, a1 = distribution.unwrap(a)
b0, b1 = distribution.unwrap(result_b)
c0, c1 = distribution.unwrap(result_c)
a0, a1 = distribution.experimental_local_results(a)
b0, b1 = distribution.experimental_local_results(result_b)
c0, c1 = distribution.experimental_local_results(result_c)
self.assertEqual("a:0", a0.name)
self.assertEqual("a/replica_1:0", a1.name)
self.assertEqual("b:0", b0.name)
@ -961,7 +962,7 @@ class MirroredVariableUpdateTest(test.TestCase):
with self.assertRaisesRegexp(
ValueError, "You must specify an aggregation method to update a "
"MirroredVariable in Replica Context."):
self.evaluate(distribution.unwrap(
self.evaluate(distribution.experimental_local_results(
distribution.extended.call_for_each_replica(model_fn)))
def testAssignMirroredVarReplicaContextWithSum(self, distribution):
@ -983,7 +984,7 @@ class MirroredVariableUpdateTest(test.TestCase):
with self.assertRaisesRegexp(
ValueError, "A non-DistributedValues value 5.0 cannot be reduced "
"with the given reduce op ReduceOp.SUM."):
self.evaluate(distribution.unwrap(
self.evaluate(distribution.experimental_local_results(
distribution.extended.call_for_each_replica(model_fn)))
def testAssignMirroredVarCrossDeviceContext(self, distribution):
@ -1015,7 +1016,7 @@ class MirroredVariableUpdateTest(test.TestCase):
mirrored_var.dtype)
return mirrored_var.assign(value)
self.evaluate(distribution.unwrap(
self.evaluate(distribution.experimental_local_results(
distribution.extended.call_for_each_replica(model_fn)))
self.assertEqual(0.5, self.evaluate(mirrored_var))
@ -1033,7 +1034,7 @@ class MirroredVariableUpdateTest(test.TestCase):
def model_fn():
return mirrored_var.assign(5.0)
self.evaluate(distribution.unwrap(
self.evaluate(distribution.experimental_local_results(
distribution.extended.call_for_each_replica(model_fn)))
self.assertEqual(5.0, self.evaluate(mirrored_var))
@ -1076,7 +1077,7 @@ class MirroredVariableUpdateTest(test.TestCase):
mirrored_var.dtype)
return mirrored_var.assign_add(value)
self.evaluate(distribution.unwrap(
self.evaluate(distribution.experimental_local_results(
distribution.extended.call_for_each_replica(model_fn)))
self.assertEqual(1.5, self.evaluate(mirrored_var))
@ -1094,7 +1095,7 @@ class MirroredVariableUpdateTest(test.TestCase):
def model_fn():
return mirrored_var.assign_add(5.0)
self.evaluate(distribution.unwrap(
self.evaluate(distribution.experimental_local_results(
distribution.extended.call_for_each_replica(model_fn)))
self.assertEqual(6.0, self.evaluate(mirrored_var))
@ -1129,7 +1130,7 @@ class MirroredVariableUpdateTest(test.TestCase):
mirrored_var.dtype)
return mirrored_var.assign_sub(value)
self.evaluate(distribution.unwrap(
self.evaluate(distribution.experimental_local_results(
distribution.extended.call_for_each_replica(model_fn)))
self.assertEqual(4.5, self.evaluate(mirrored_var))
@ -1147,7 +1148,7 @@ class MirroredVariableUpdateTest(test.TestCase):
def model_fn():
return mirrored_var.assign_sub(1.0)
self.evaluate(distribution.unwrap(
self.evaluate(distribution.experimental_local_results(
distribution.extended.call_for_each_replica(model_fn)))
self.assertEqual(4.0, self.evaluate(mirrored_var))

View File

@ -56,7 +56,7 @@ class AssignMovingAveragesTest(test.TestCase, parameterized.TestCase):
var, assign = distribution.extended.call_for_each_replica(replica_fn)
variables.global_variables_initializer().run()
self.assertAllClose([10.0, 11.0], var.eval())
sess.run(distribution.unwrap(assign))
sess.run(distribution.experimental_local_results(assign))
# Mean of val across calls to replica_fn().
average_val = [1.0 + 0.5 * (replica_id[0] - 1),
2.0 - 0.5 * (replica_id[0] - 1)]
@ -82,7 +82,7 @@ class AssignMovingAveragesTest(test.TestCase, parameterized.TestCase):
var, assign_op = distribution.extended.call_for_each_replica(replica_fn)
variables.global_variables_initializer().run()
self.assertAllClose([0.0, 0.0], var.eval())
sess.run(distribution.unwrap(assign_op))
sess.run(distribution.experimental_local_results(assign_op))
# Mean of val across calls to replica_fn().
average_val = [1.0 + 0.5 * (replica_id[0] - 1),
2.0 - 0.5 * (replica_id[0] - 1)]
@ -155,7 +155,7 @@ class AssignMovingAveragesTest(test.TestCase, parameterized.TestCase):
var, assign = distribution.extended.call_for_each_replica(replica_fn)
variables.global_variables_initializer().run()
self.assertAllClose([10.0, 11.0], var.eval())
sess.run(distribution.unwrap(assign))
sess.run(distribution.experimental_local_results(assign))
self.assertAllClose(
[10 * 0.25 + 1. * (1 - 0.25), 11 * 0.25 + 2. * (1 - 0.25)],
var.eval())

View File

@ -45,7 +45,7 @@ class MinimizeLossOptimizerV2Test(test.TestCase, parameterized.TestCase):
def run_step():
return control_flow_ops.group(
distribution.unwrap(
distribution.experimental_local_results(
distribution.extended.call_for_each_replica(
model_fn, args=(iterator.get_next(),))))

View File

@ -346,7 +346,7 @@ class DistributionTestBase(test.TestCase):
train_ops, value = strategy.extended.call_for_each_replica(model_fn)
self.evaluate(strategy.group(train_ops))
global_step_tensors = strategy.unwrap(value)
global_step_tensors = strategy.experimental_local_results(value)
global_step_values = self.evaluate(global_step_tensors)
self.assertEqual((1,) * len(global_step_tensors), global_step_values)
@ -365,7 +365,8 @@ class DistributionTestBase(test.TestCase):
def run_and_concatenate(strategy, i):
x, y = strategy.experimental_run(lambda z: z, i)
x, y = self.evaluate((strategy.unwrap(x), strategy.unwrap(y)))
x, y = self.evaluate((strategy.experimental_local_results(x),
strategy.experimental_local_results(y)))
return np.concatenate(x), np.concatenate(y)
x_1, y_1 = run_and_concatenate(strategy, i)
@ -424,7 +425,8 @@ class OneDeviceDistributionTestBase(test.TestCase):
self.evaluate(inputs.initialize())
outputs = self.evaluate(
list(map(strategy.unwrap, strategy.experimental_run(comm_fn, inputs))))
list(map(strategy.experimental_local_results,
strategy.experimental_run(comm_fn, inputs))))
self.assertAllEqual([expected[0]], outputs[0])
self.assertAllEqual([expected[1]], outputs[1])
@ -444,7 +446,8 @@ class OneDeviceDistributionTestBase(test.TestCase):
self.evaluate(inputs.initialize())
self.assertAllEqual(
expected_grads,
self.evaluate(strategy.unwrap(strategy.experimental_run(step, inputs))))
self.evaluate(strategy.experimental_local_results(
strategy.experimental_run(step, inputs))))
def _test_collective_comms_gradient_tape(
self, strategy, comm_fn, inputs, expected_grads):
@ -461,7 +464,8 @@ class OneDeviceDistributionTestBase(test.TestCase):
self.evaluate(inputs.initialize())
self.assertAllEqual(
expected_grads,
self.evaluate(strategy.unwrap(strategy.experimental_run(step, inputs))))
self.evaluate(strategy.experimental_local_results(
strategy.experimental_run(step, inputs))))
class TwoDeviceDistributionTestBase(test.TestCase):
@ -515,7 +519,8 @@ class TwoDeviceDistributionTestBase(test.TestCase):
self.evaluate(inputs.initialize())
outputs = self.evaluate(
list(map(strategy.unwrap, strategy.experimental_run(comm_fn, inputs))))
list(map(strategy.experimental_local_results,
strategy.experimental_run(comm_fn, inputs))))
self.assertAllEqual([expected[0], expected[0]], outputs[0])
self.assertAllEqual([expected[1], expected[1]], outputs[1])
@ -535,7 +540,8 @@ class TwoDeviceDistributionTestBase(test.TestCase):
self.evaluate(inputs.initialize())
self.assertAllEqual(
expected_grads,
self.evaluate(strategy.unwrap(strategy.experimental_run(step, inputs))))
self.evaluate(strategy.experimental_local_results(
strategy.experimental_run(step, inputs))))
def _test_collective_comms_gradient_tape(
self, strategy, comm_fn, inputs, expected_grads):
@ -552,7 +558,8 @@ class TwoDeviceDistributionTestBase(test.TestCase):
self.evaluate(inputs.initialize())
self.assertAllEqual(
expected_grads,
self.evaluate(strategy.unwrap(strategy.experimental_run(step, inputs))))
self.evaluate(strategy.experimental_local_results(
strategy.experimental_run(step, inputs))))
def _all_sum(value):

View File

@ -11,777 +11,17 @@
"\n",
"Licensed under the Apache License, Version 2.0 (the \"License\").\n",
"\n",
"# Pix2Pix: An example with tf.keras and eager\n",
"\n",
"\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\u003ctd\u003e\n",
"\u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/pix2pix/pix2pix_eager.ipynb\"\u003e\n",
" \u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e \n",
"\u003c/td\u003e\u003ctd\u003e\n",
"\u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/eager/python/examples/pix2pix/pix2pix_eager.ipynb\"\u003e\u003cimg width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\u003c/td\u003e\u003c/table\u003e"
"# Pix2Pix"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "ITZuApL56Mny"
"id": "c7W3j96p219v"
},
"source": [
"This notebook demonstrates image to image translation using conditional GAN's, as described in [Image-to-Image Translation with Conditional Adversarial Networks](https://arxiv.org/abs/1611.07004). Using this technique we can colorize black and white photos, convert google maps to google earth, etc. Here, we convert building facades to real buildings. We use [tf.keras](https://www.tensorflow.org/programmers_guide/keras) and [eager execution](https://www.tensorflow.org/programmers_guide/eager) to achieve this.\n",
"\n",
"In example, we will use the [CMP Facade Database](http://cmp.felk.cvut.cz/~tylecr1/facade/), helpfully provided by the [Center for Machine Perception](http://cmp.felk.cvut.cz/) at the [Czech Technical University in Prague](https://www.cvut.cz/). To keep our example short, we will use a preprocessed [copy](https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/) of this dataset, created by the authors of the [paper](https://arxiv.org/abs/1611.07004) above.\n",
"\n",
"Each epoch takes around 58 seconds on a single P100 GPU.\n",
"\n",
"Below is the output generated after training the model for 200 epochs.\n",
"\n",
"\n",
"![sample output_1](https://www.tensorflow.org/images/gan/pix2pix_1.png)\n",
"![sample output_2](https://www.tensorflow.org/images/gan/pix2pix_2.png)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "e1_Y75QXJS6h"
},
"source": [
"## Import TensorFlow and enable eager execution"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "YfIk2es3hJEd"
},
"outputs": [],
"source": [
"# Import TensorFlow \u003e= 1.10 and enable eager execution\n",
"import tensorflow as tf\n",
"tf.enable_eager_execution()\n",
"\n",
"import os\n",
"import time\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import PIL\n",
"from IPython.display import clear_output"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "iYn4MdZnKCey"
},
"source": [
"## Load the dataset\n",
"\n",
"You can download this dataset and similar datasets from [here](https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets). As mentioned in the [paper](https://arxiv.org/abs/1611.07004) we apply random jittering and mirroring to the training dataset.\n",
"* In random jittering, the image is resized to `286 x 286` and then randomly cropped to `256 x 256`\n",
"* In random mirroring, the image is randomly flipped horizontally i.e left to right."
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "Kn-k8kTXuAlv"
},
"outputs": [],
"source": [
"path_to_zip = tf.keras.utils.get_file('facades.tar.gz',\n",
" cache_subdir=os.path.abspath('.'),\n",
" origin='https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/facades.tar.gz', \n",
" extract=True)\n",
"\n",
"PATH = os.path.join(os.path.dirname(path_to_zip), 'facades/')"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "2CbTEt448b4R"
},
"outputs": [],
"source": [
"BUFFER_SIZE = 400\n",
"BATCH_SIZE = 1\n",
"IMG_WIDTH = 256\n",
"IMG_HEIGHT = 256"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "tyaP4hLJ8b4W"
},
"outputs": [],
"source": [
"def load_image(image_file, is_train):\n",
" image = tf.read_file(image_file)\n",
" image = tf.image.decode_jpeg(image)\n",
"\n",
" w = tf.shape(image)[1]\n",
"\n",
" w = w // 2\n",
" real_image = image[:, :w, :]\n",
" input_image = image[:, w:, :]\n",
"\n",
" input_image = tf.cast(input_image, tf.float32)\n",
" real_image = tf.cast(real_image, tf.float32)\n",
"\n",
" if is_train:\n",
" # random jittering\n",
" \n",
" # resizing to 286 x 286 x 3\n",
" input_image = tf.image.resize_images(input_image, [286, 286], \n",
" align_corners=True, \n",
" method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)\n",
" real_image = tf.image.resize_images(real_image, [286, 286], \n",
" align_corners=True, \n",
" method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)\n",
" \n",
" # randomly cropping to 256 x 256 x 3\n",
" stacked_image = tf.stack([input_image, real_image], axis=0)\n",
" cropped_image = tf.random_crop(stacked_image, size=[2, IMG_HEIGHT, IMG_WIDTH, 3])\n",
" input_image, real_image = cropped_image[0], cropped_image[1]\n",
"\n",
" if np.random.random() \u003e 0.5:\n",
" # random mirroring\n",
" input_image = tf.image.flip_left_right(input_image)\n",
" real_image = tf.image.flip_left_right(real_image)\n",
" else:\n",
" input_image = tf.image.resize_images(input_image, size=[IMG_HEIGHT, IMG_WIDTH], \n",
" align_corners=True, method=2)\n",
" real_image = tf.image.resize_images(real_image, size=[IMG_HEIGHT, IMG_WIDTH], \n",
" align_corners=True, method=2)\n",
" \n",
" # normalizing the images to [-1, 1]\n",
" input_image = (input_image / 127.5) - 1\n",
" real_image = (real_image / 127.5) - 1\n",
"\n",
" return input_image, real_image"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "PIGN6ouoQxt3"
},
"source": [
"## Use tf.data to create batches, map(do preprocessing) and shuffle the dataset"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "SQHmYSmk8b4b"
},
"outputs": [],
"source": [
"train_dataset = tf.data.Dataset.list_files(PATH+'train/*.jpg')\n",
"train_dataset = train_dataset.shuffle(BUFFER_SIZE)\n",
"train_dataset = train_dataset.map(lambda x: load_image(x, True))\n",
"train_dataset = train_dataset.batch(1)"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "MS9J0yA58b4g"
},
"outputs": [],
"source": [
"test_dataset = tf.data.Dataset.list_files(PATH+'test/*.jpg')\n",
"test_dataset = test_dataset.map(lambda x: load_image(x, False))\n",
"test_dataset = test_dataset.batch(1)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "THY-sZMiQ4UV"
},
"source": [
"## Write the generator and discriminator models\n",
"\n",
"* **Generator** \n",
" * The architecture of generator is a modified U-Net.\n",
" * Each block in the encoder is (Conv -\u003e Batchnorm -\u003e Leaky ReLU)\n",
" * Each block in the decoder is (Transposed Conv -\u003e Batchnorm -\u003e Dropout(applied to the first 3 blocks) -\u003e ReLU)\n",
" * There are skip connections between the encoder and decoder (as in U-Net).\n",
" \n",
"* **Discriminator**\n",
" * The Discriminator is a PatchGAN.\n",
" * Each block in the discriminator is (Conv -\u003e BatchNorm -\u003e Leaky ReLU)\n",
" * The shape of the output after the last layer is (batch_size, 30, 30, 1)\n",
" * Each 30x30 patch of the output classifies a 70x70 portion of the input image (such an architecture is called a PatchGAN).\n",
" * Discriminator receives 2 inputs.\n",
" * Input image and the target image, which it should classify as real.\n",
" * Input image and the generated image (output of generator), which it should classify as fake. \n",
" * We concatenate these 2 inputs together in the code (`tf.concat([inp, tar], axis=-1)`)\n",
"\n",
"* Shape of the input travelling through the generator and the discriminator is in the comments in the code.\n",
"\n",
"To learn more about the architecture and the hyperparameters you can refer the [paper](https://arxiv.org/abs/1611.07004).\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "tqqvWxlw8b4l"
},
"outputs": [],
"source": [
"OUTPUT_CHANNELS = 3"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "lFPI4Nu-8b4q"
},
"outputs": [],
"source": [
"class Downsample(tf.keras.Model):\n",
" \n",
" def __init__(self, filters, size, apply_batchnorm=True):\n",
" super(Downsample, self).__init__()\n",
" self.apply_batchnorm = apply_batchnorm\n",
" initializer = tf.random_normal_initializer(0., 0.02)\n",
"\n",
" self.conv1 = tf.keras.layers.Conv2D(filters, \n",
" (size, size), \n",
" strides=2, \n",
" padding='same',\n",
" kernel_initializer=initializer,\n",
" use_bias=False)\n",
" if self.apply_batchnorm:\n",
" self.batchnorm = tf.keras.layers.BatchNormalization()\n",
" \n",
" def call(self, x, training):\n",
" x = self.conv1(x)\n",
" if self.apply_batchnorm:\n",
" x = self.batchnorm(x, training=training)\n",
" x = tf.nn.leaky_relu(x)\n",
" return x \n",
"\n",
"\n",
"class Upsample(tf.keras.Model):\n",
" \n",
" def __init__(self, filters, size, apply_dropout=False):\n",
" super(Upsample, self).__init__()\n",
" self.apply_dropout = apply_dropout\n",
" initializer = tf.random_normal_initializer(0., 0.02)\n",
"\n",
" self.up_conv = tf.keras.layers.Conv2DTranspose(filters, \n",
" (size, size), \n",
" strides=2, \n",
" padding='same',\n",
" kernel_initializer=initializer,\n",
" use_bias=False)\n",
" self.batchnorm = tf.keras.layers.BatchNormalization()\n",
" if self.apply_dropout:\n",
" self.dropout = tf.keras.layers.Dropout(0.5)\n",
"\n",
" def call(self, x1, x2, training):\n",
" x = self.up_conv(x1)\n",
" x = self.batchnorm(x, training=training)\n",
" if self.apply_dropout:\n",
" x = self.dropout(x, training=training)\n",
" x = tf.nn.relu(x)\n",
" x = tf.concat([x, x2], axis=-1)\n",
" return x\n",
"\n",
"\n",
"class Generator(tf.keras.Model):\n",
" \n",
" def __init__(self):\n",
" super(Generator, self).__init__()\n",
" initializer = tf.random_normal_initializer(0., 0.02)\n",
" \n",
" self.down1 = Downsample(64, 4, apply_batchnorm=False)\n",
" self.down2 = Downsample(128, 4)\n",
" self.down3 = Downsample(256, 4)\n",
" self.down4 = Downsample(512, 4)\n",
" self.down5 = Downsample(512, 4)\n",
" self.down6 = Downsample(512, 4)\n",
" self.down7 = Downsample(512, 4)\n",
" self.down8 = Downsample(512, 4)\n",
"\n",
" self.up1 = Upsample(512, 4, apply_dropout=True)\n",
" self.up2 = Upsample(512, 4, apply_dropout=True)\n",
" self.up3 = Upsample(512, 4, apply_dropout=True)\n",
" self.up4 = Upsample(512, 4)\n",
" self.up5 = Upsample(256, 4)\n",
" self.up6 = Upsample(128, 4)\n",
" self.up7 = Upsample(64, 4)\n",
"\n",
" self.last = tf.keras.layers.Conv2DTranspose(OUTPUT_CHANNELS, \n",
" (4, 4), \n",
" strides=2, \n",
" padding='same',\n",
" kernel_initializer=initializer)\n",
" \n",
" @tf.contrib.eager.defun\n",
" def call(self, x, training):\n",
" # x shape == (bs, 256, 256, 3) \n",
" x1 = self.down1(x, training=training) # (bs, 128, 128, 64)\n",
" x2 = self.down2(x1, training=training) # (bs, 64, 64, 128)\n",
" x3 = self.down3(x2, training=training) # (bs, 32, 32, 256)\n",
" x4 = self.down4(x3, training=training) # (bs, 16, 16, 512)\n",
" x5 = self.down5(x4, training=training) # (bs, 8, 8, 512)\n",
" x6 = self.down6(x5, training=training) # (bs, 4, 4, 512)\n",
" x7 = self.down7(x6, training=training) # (bs, 2, 2, 512)\n",
" x8 = self.down8(x7, training=training) # (bs, 1, 1, 512)\n",
"\n",
" x9 = self.up1(x8, x7, training=training) # (bs, 2, 2, 1024)\n",
" x10 = self.up2(x9, x6, training=training) # (bs, 4, 4, 1024)\n",
" x11 = self.up3(x10, x5, training=training) # (bs, 8, 8, 1024)\n",
" x12 = self.up4(x11, x4, training=training) # (bs, 16, 16, 1024)\n",
" x13 = self.up5(x12, x3, training=training) # (bs, 32, 32, 512)\n",
" x14 = self.up6(x13, x2, training=training) # (bs, 64, 64, 256)\n",
" x15 = self.up7(x14, x1, training=training) # (bs, 128, 128, 128)\n",
"\n",
" x16 = self.last(x15) # (bs, 256, 256, 3)\n",
" x16 = tf.nn.tanh(x16)\n",
"\n",
" return x16"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "ll6aNeQx8b4v"
},
"outputs": [],
"source": [
"class DiscDownsample(tf.keras.Model):\n",
" \n",
" def __init__(self, filters, size, apply_batchnorm=True):\n",
" super(DiscDownsample, self).__init__()\n",
" self.apply_batchnorm = apply_batchnorm\n",
" initializer = tf.random_normal_initializer(0., 0.02)\n",
"\n",
" self.conv1 = tf.keras.layers.Conv2D(filters, \n",
" (size, size), \n",
" strides=2, \n",
" padding='same',\n",
" kernel_initializer=initializer,\n",
" use_bias=False)\n",
" if self.apply_batchnorm:\n",
" self.batchnorm = tf.keras.layers.BatchNormalization()\n",
" \n",
" def call(self, x, training):\n",
" x = self.conv1(x)\n",
" if self.apply_batchnorm:\n",
" x = self.batchnorm(x, training=training)\n",
" x = tf.nn.leaky_relu(x)\n",
" return x \n",
"\n",
"class Discriminator(tf.keras.Model):\n",
" \n",
" def __init__(self):\n",
" super(Discriminator, self).__init__()\n",
" initializer = tf.random_normal_initializer(0., 0.02)\n",
" \n",
" self.down1 = DiscDownsample(64, 4, False)\n",
" self.down2 = DiscDownsample(128, 4)\n",
" self.down3 = DiscDownsample(256, 4)\n",
" \n",
" # we are zero padding here with 1 because we need our shape to \n",
" # go from (batch_size, 32, 32, 256) to (batch_size, 31, 31, 512)\n",
" self.zero_pad1 = tf.keras.layers.ZeroPadding2D()\n",
" self.conv = tf.keras.layers.Conv2D(512, \n",
" (4, 4), \n",
" strides=1, \n",
" kernel_initializer=initializer, \n",
" use_bias=False)\n",
" self.batchnorm1 = tf.keras.layers.BatchNormalization()\n",
" \n",
" # shape change from (batch_size, 31, 31, 512) to (batch_size, 30, 30, 1)\n",
" self.zero_pad2 = tf.keras.layers.ZeroPadding2D()\n",
" self.last = tf.keras.layers.Conv2D(1, \n",
" (4, 4), \n",
" strides=1,\n",
" kernel_initializer=initializer)\n",
" \n",
" @tf.contrib.eager.defun\n",
" def call(self, inp, tar, training):\n",
" # concatenating the input and the target\n",
" x = tf.concat([inp, tar], axis=-1) # (bs, 256, 256, channels*2)\n",
" x = self.down1(x, training=training) # (bs, 128, 128, 64)\n",
" x = self.down2(x, training=training) # (bs, 64, 64, 128)\n",
" x = self.down3(x, training=training) # (bs, 32, 32, 256)\n",
"\n",
" x = self.zero_pad1(x) # (bs, 34, 34, 256)\n",
" x = self.conv(x) # (bs, 31, 31, 512)\n",
" x = self.batchnorm1(x, training=training)\n",
" x = tf.nn.leaky_relu(x)\n",
" \n",
" x = self.zero_pad2(x) # (bs, 33, 33, 512)\n",
" # don't add a sigmoid activation here since\n",
" # the loss function expects raw logits.\n",
" x = self.last(x) # (bs, 30, 30, 1)\n",
"\n",
" return x"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "gDkA05NE6QMs"
},
"outputs": [],
"source": [
"# The call function of Generator and Discriminator have been decorated\n",
"# with tf.contrib.eager.defun()\n",
"# We get a performance speedup if defun is used (~25 seconds per epoch)\n",
"generator = Generator()\n",
"discriminator = Discriminator()"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "0FMYgY_mPfTi"
},
"source": [
"## Define the loss functions and the optimizer\n",
"\n",
"* **Discriminator loss**\n",
" * The discriminator loss function takes 2 inputs; **real images, generated images**\n",
" * real_loss is a sigmoid cross entropy loss of the **real images** and an **array of ones(since these are the real images)**\n",
" * generated_loss is a sigmoid cross entropy loss of the **generated images** and an **array of zeros(since these are the fake images)**\n",
" * Then the total_loss is the sum of real_loss and the generated_loss\n",
" \n",
"* **Generator loss**\n",
" * It is a sigmoid cross entropy loss of the generated images and an **array of ones**.\n",
" * The [paper](https://arxiv.org/abs/1611.07004) also includes L1 loss which is MAE (mean absolute error) between the generated image and the target image.\n",
" * This allows the generated image to become structurally similar to the target image.\n",
" * The formula to calculate the total generator loss = gan_loss + LAMBDA * l1_loss, where LAMBDA = 100. This value was decided by the authors of the [paper](https://arxiv.org/abs/1611.07004)."
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "cyhxTuvJyIHV"
},
"outputs": [],
"source": [
"LAMBDA = 100"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "wkMNfBWlT-PV"
},
"outputs": [],
"source": [
"def discriminator_loss(disc_real_output, disc_generated_output):\n",
" real_loss = tf.losses.sigmoid_cross_entropy(multi_class_labels = tf.ones_like(disc_real_output), \n",
" logits = disc_real_output)\n",
" generated_loss = tf.losses.sigmoid_cross_entropy(multi_class_labels = tf.zeros_like(disc_generated_output), \n",
" logits = disc_generated_output)\n",
"\n",
" total_disc_loss = real_loss + generated_loss\n",
"\n",
" return total_disc_loss"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "90BIcCKcDMxz"
},
"outputs": [],
"source": [
"def generator_loss(disc_generated_output, gen_output, target):\n",
" gan_loss = tf.losses.sigmoid_cross_entropy(multi_class_labels = tf.ones_like(disc_generated_output),\n",
" logits = disc_generated_output) \n",
" # mean absolute error\n",
" l1_loss = tf.reduce_mean(tf.abs(target - gen_output))\n",
"\n",
" total_gen_loss = gan_loss + (LAMBDA * l1_loss)\n",
"\n",
" return total_gen_loss"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "iWCn_PVdEJZ7"
},
"outputs": [],
"source": [
"generator_optimizer = tf.train.AdamOptimizer(2e-4, beta1=0.5)\n",
"discriminator_optimizer = tf.train.AdamOptimizer(2e-4, beta1=0.5)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "aKUZnDiqQrAh"
},
"source": [
"## Checkpoints (Object-based saving)"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "WJnftd5sQsv6"
},
"outputs": [],
"source": [
"checkpoint_dir = './training_checkpoints'\n",
"checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt\")\n",
"checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,\n",
" discriminator_optimizer=discriminator_optimizer,\n",
" generator=generator,\n",
" discriminator=discriminator)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "Rw1fkAczTQYh"
},
"source": [
"## Training\n",
"\n",
"* We start by iterating over the dataset\n",
"* The generator gets the input image and we get a generated output.\n",
"* The discriminator receives the input_image and the generated image as the first input. The second input is the input_image and the target_image.\n",
"* Next, we calculate the generator and the discriminator loss.\n",
"* Then, we calculate the gradients of loss with respect to both the generator and the discriminator variables(inputs) and apply those to the optimizer.\n",
"\n",
"## Generate Images\n",
"\n",
"* After training, its time to generate some images!\n",
"* We pass images from the test dataset to the generator.\n",
"* The generator will then translate the input image into the output we expect.\n",
"* Last step is to plot the predictions and **voila!**"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "NS2GWywBbAWo"
},
"outputs": [],
"source": [
"EPOCHS = 200"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "RmdVsmvhPxyy"
},
"outputs": [],
"source": [
"def generate_images(model, test_input, tar):\n",
" # the training=True is intentional here since\n",
" # we want the batch statistics while running the model\n",
" # on the test dataset. If we use training=False, we will get \n",
" # the accumulated statistics learned from the training dataset\n",
" # (which we don't want)\n",
" prediction = model(test_input, training=True)\n",
" plt.figure(figsize=(15,15))\n",
"\n",
" display_list = [test_input[0], tar[0], prediction[0]]\n",
" title = ['Input Image', 'Ground Truth', 'Predicted Image']\n",
"\n",
" for i in range(3):\n",
" plt.subplot(1, 3, i+1)\n",
" plt.title(title[i])\n",
" # getting the pixel values between [0, 1] to plot it.\n",
" plt.imshow(display_list[i] * 0.5 + 0.5)\n",
" plt.axis('off')\n",
" plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "2M7LmLtGEMQJ"
},
"outputs": [],
"source": [
"def train(dataset, epochs): \n",
" for epoch in range(epochs):\n",
" start = time.time()\n",
"\n",
" for input_image, target in dataset:\n",
"\n",
" with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:\n",
" gen_output = generator(input_image, training=True)\n",
"\n",
" disc_real_output = discriminator(input_image, target, training=True)\n",
" disc_generated_output = discriminator(input_image, gen_output, training=True)\n",
"\n",
" gen_loss = generator_loss(disc_generated_output, gen_output, target)\n",
" disc_loss = discriminator_loss(disc_real_output, disc_generated_output)\n",
"\n",
" generator_gradients = gen_tape.gradient(gen_loss, \n",
" generator.variables)\n",
" discriminator_gradients = disc_tape.gradient(disc_loss, \n",
" discriminator.variables)\n",
"\n",
" generator_optimizer.apply_gradients(zip(generator_gradients, \n",
" generator.variables))\n",
" discriminator_optimizer.apply_gradients(zip(discriminator_gradients, \n",
" discriminator.variables))\n",
"\n",
" if epoch % 1 == 0:\n",
" clear_output(wait=True)\n",
" for inp, tar in test_dataset.take(1):\n",
" generate_images(generator, inp, tar)\n",
" \n",
" # saving (checkpoint) the model every 20 epochs\n",
" if (epoch + 1) % 20 == 0:\n",
" checkpoint.save(file_prefix = checkpoint_prefix)\n",
"\n",
" print ('Time taken for epoch {} is {} sec\\n'.format(epoch + 1,\n",
" time.time()-start))"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "a1zZmKmvOH85"
},
"outputs": [],
"source": [
"train(train_dataset, EPOCHS)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "kz80bY3aQ1VZ"
},
"source": [
"## Restore the latest checkpoint and test"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "4t4x69adQ5xb"
},
"outputs": [],
"source": [
"# restoring the latest checkpoint in checkpoint_dir\n",
"checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "1RGysMU_BZhx"
},
"source": [
"## Testing on the entire test dataset"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "KUgSnmy2nqSP"
},
"outputs": [],
"source": [
"# Run the trained model on the entire test dataset\n",
"for inp, tar in test_dataset:\n",
" generate_images(generator, inp, tar)"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "3AJXOByaZVOf"
},
"outputs": [],
"source": [
""
"This notebook has been moved to [https://github.com/tensorflow/docs/blob/master/site/en/r2/tutorials/generative/pix2pix.ipynb](https://github.com/tensorflow/docs/blob/master/site/en/r2/tutorials/generative/pix2pix.ipynb)"
]
}
],

View File

@ -30,7 +30,10 @@ _ffmpeg_so = loader.load_op_library(
resource_loader.get_path_to_datafile('ffmpeg.so'))
@deprecated('2018-09-04', 'This will be deleted and should not be used.')
@deprecated('2018-09-04',
'tf.contrib.ffmpeg will be removed in 2.0, the support for video '
'and audio will continue to be provided in tensorflow-io: '
'https://github.com/tensorflow/io')
def decode_audio(contents, file_format=None, samples_per_second=None,
channel_count=None, stream=None):
"""Create an op that decodes the contents of an audio file.
@ -71,7 +74,10 @@ def decode_audio(contents, file_format=None, samples_per_second=None,
ops.NotDifferentiable('DecodeAudio')
@deprecated('2018-09-04', 'This will be deleted and should not be used.')
@deprecated('2018-09-04',
'tf.contrib.ffmpeg will be removed in 2.0, the support for video '
'and audio will continue to be provided in tensorflow-io: '
'https://github.com/tensorflow/io')
def encode_audio(audio, file_format=None, samples_per_second=None):
"""Creates an op that encodes an audio file using sampled audio from a tensor.
@ -98,13 +104,15 @@ def encode_audio(audio, file_format=None, samples_per_second=None):
ops.NotDifferentiable('EncodeAudio')
@deprecated('2018-09-04', 'This will be deleted and should not be used.')
@deprecated('2018-09-04',
'tf.contrib.ffmpeg will be removed in 2.0, the support for video '
'and audio will continue to be provided in tensorflow-io: '
'https://github.com/tensorflow/io')
def decode_video(contents):
"""Create an op that decodes the contents of a video file.
Args:
contents: The binary contents of the video file to decode. This is a
scalar.
contents: The binary contents of the video file to decode. This is a scalar.
Returns:
A rank-4 `Tensor` that has `[frames, height, width, 3]` RGB as output.

View File

@ -37,6 +37,7 @@ limitations under the License.
#include "tensorflow/core/graph/control_flow.h"
#include "tensorflow/core/graph/gradients.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/graph/optimizer_cse.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/gtl/map_util.h"
@ -1734,6 +1735,9 @@ Status InlineFunctionBody(const FunctionLibraryDefinition& flib_def, Graph* g,
// added above or on all control return nodes (controlled by
// `options.output_control_src` value). And nodes previously depend on
// "callee" is changed to depend on "output_control_node".
//
// If `keep_node_fetchable` is `true` we always add an output control node, to
// guarantee that executing a fetchable node will execute all side-effects.
std::vector<Node*> outputs(caller->num_outputs());
for (std::size_t i = 0; i < fbody->ret_nodes.size(); ++i) {
Node* ret = node_map[fbody->ret_nodes[i]->id()];
@ -1754,29 +1758,53 @@ Status InlineFunctionBody(const FunctionLibraryDefinition& flib_def, Graph* g,
}
g->RemoveNode(ret); // 'ret' is disconnected.
}
Node* output_control_node = nullptr;
bool has_control_outputs = absl::c_any_of(
caller->out_edges(), [](const Edge* e) { return e->IsControlEdge(); });
if (has_control_outputs || options.keep_caller_fetchable) {
output_control_node = no_op("output_control_node");
if (options.output_control_src == OutputControlSrc::kDataOutputs) {
for (Node* n : outputs) {
g->AddControlEdge(n, output_control_node);
}
} else {
for (Node* fbody_node : fbody->control_ret_nodes) {
Node* n = node_map[fbody_node->id()];
g->AddControlEdge(n, output_control_node);
}
}
}
for (const Edge* e : caller->out_edges()) {
if (e->IsControlEdge()) {
if (output_control_node == nullptr) {
output_control_node = no_op("output_control_node");
if (options.output_control_src ==
InlineFunctionBodyOptions::OutputControlSource::kDataOutputs) {
for (Node* n : outputs) {
g->AddControlEdge(n, output_control_node);
}
} else {
for (Node* fbody_node : fbody->control_ret_nodes) {
Node* n = node_map[fbody_node->id()];
g->AddControlEdge(n, output_control_node);
}
}
}
g->AddControlEdge(output_control_node, e->dst());
} else {
g->AddEdge(outputs[e->src_output()], 0, e->dst(), e->dst_input());
}
}
g->RemoveNode(caller); // 'caller' is replaced with inlined nodes.
// ------------------------------------------------------------------------ //
// Add an IdentityN node in-place of caller node to keep `caller` fetchable.
if (options.keep_caller_fetchable) {
std::vector<NodeBuilder::NodeOut> output_tensors;
absl::c_transform(outputs, std::back_inserter(output_tensors),
[](Node* n) { return NodeBuilder::NodeOut(n, 0); });
Node* fetchable_node;
TF_CHECK_OK(NodeBuilder(caller->name(), "IdentityN")
.Device(caller->requested_device())
.Input(output_tensors)
.ControlInput(output_control_node)
.Finalize(g, &fetchable_node));
}
// ------------------------------------------------------------------------ //
// 'caller' is replaced with inlined function body nodes and maybe IdentityN
// to keep it fetchable.
g->RemoveNode(caller);
return Status::OK();
}

View File

@ -173,6 +173,16 @@ struct InlineFunctionBodyOptions {
// If 'true' function inlining will override explicitly specified devices
// inside function body with the caller node device.
bool override_device = false;
// If 'true' function inlining will add an IdentityN node to the graph with
// the same name as the caller node. It will have a control edge from inlined
// 'output_control_node' and data edges from function output nodes. IdentityN
// node will be placed on the same device as the caller node.
// This is mostly for compatibility with Tensorflow v1 and sessions. When we
// prepare a graph for execution in GraphExecutionState::MakeForBaseGraph we
// don't know what nodes will be fetched, so we can't safely remove any of
// them. When graph executed as a function it has 'Retval' nodes for each
// fetched tensor, and we can safely inline function calls.
bool keep_caller_fetchable = false;
// For compatibility with Tensorflow v1 by default we will use data outputs.
// Control returns were added to Tensorflow v2 with automatic control
// dependencies tracking in Eager mode.

View File

@ -54,7 +54,9 @@ limitations under the License.
namespace tensorflow {
namespace {
using FDH = FunctionDefHelper;
using FDH = ::tensorflow::FunctionDefHelper;
using OutputControlSrc = InlineFunctionBodyOptions::OutputControlSource;
Status GetOpSig(const string& op, const OpDef** sig) {
return OpRegistry::Global()->LookUpOpDef(op, sig);
@ -888,17 +890,13 @@ TEST_F(FunctionLibraryRuntimeTest, ExpandInlineFunctionsWithInputControlEdges) {
TEST_F(FunctionLibraryRuntimeTest,
ExpandInlineFunctionsWithOutputControlEdges) {
using test::function::NDef;
using FDH = FunctionDefHelper;
using OutputControlSrc = InlineFunctionBodyOptions::OutputControlSource;
// `add` node is not required to compute regular output `o`, but it must
// execute because it is in `control_ret`.
const FunctionDef func =
FDH::Create("FunctionWithControlOutputs", {"i: float"}, {"o: float"}, {},
{
{{"add"}, "Add", {"i", "i"}, {{"T", DT_FLOAT}}},
{{"ret"}, "Mul", {"i", "i"}, {{"T", DT_FLOAT}}},
},
FDH::Create("AddAndMul", {"i: float"}, {"o: float"}, {},
{{{"add"}, "Add", {"i", "i"}, {{"T", DT_FLOAT}}},
{{"ret"}, "Mul", {"i", "i"}, {{"T", DT_FLOAT}}}},
/*ret_def=*/{{"o", "ret:z:0"}},
/*control_ret_def=*/{{"must_execute", "add"}});
@ -907,16 +905,16 @@ TEST_F(FunctionLibraryRuntimeTest,
// Construct a graph for the function call:
//
// a = Arg[dtype=DT_FLOAT]
// b = FunctionWithControlOutputs(a)
// b = AddAndMul(a)
// c = NoOp(^b)
// ret = RetVal(b, ^c)
const auto init_graph = [this](std::unique_ptr<Graph>* g) -> void {
g->reset(new Graph(OpRegistry::Global()));
*g = absl::make_unique<Graph>(OpRegistry::Global());
Scope s = Scope::NewRootScope();
TF_ASSERT_OK(s.graph()->AddFunctionLibrary(fdef_lib_));
auto a = ops::_Arg(s.WithOpName("a"), DT_FLOAT, 0);
auto b = test::function::Call(&s, "b", "FunctionWithControlOutputs", {a});
auto b = test::function::Call(&s, "b", "AddAndMul", {a});
auto c = ops::NoOp(s.WithOpName("c"));
auto ret = ops::_Retval(s.WithOpName("ret"), b, 0);
s.graph()->AddControlEdge(b.node(), c.operation.node());
@ -978,6 +976,55 @@ TEST_F(FunctionLibraryRuntimeTest,
}
}
TEST_F(FunctionLibraryRuntimeTest, ExpandInlineFunctionsAndKeepThemFetchable) {
using test::function::NDef;
const FunctionDef func =
FDH::Create("AddAndMul", {"i: float"}, {"o: float"}, {},
{{{"add"}, "Add", {"i", "i"}, {{"T", DT_FLOAT}}},
{{"ret"}, "Mul", {"i", "i"}, {{"T", DT_FLOAT}}}},
/*ret_def=*/{{"o", "ret:z:0"}},
/*control_ret_def=*/{{"must_execute", "add"}});
Init({func});
// Construct a graph:
// a = Arg[dtype=DT_FLOAT]
// b = FunctionWithControlOutputs(a)
std::unique_ptr<Graph> g = absl::make_unique<Graph>(OpRegistry::Global());
Scope s = Scope::NewRootScope();
TF_ASSERT_OK(s.graph()->AddFunctionLibrary(fdef_lib_));
auto a = ops::_Arg(s.WithOpName("a"), DT_FLOAT, 0);
auto b = test::function::Call(&s, "b", "AddAndMul", {a});
TF_ASSERT_OK(s.ToGraph(g.get()));
ExpandInlineFunctionsOptions opts;
opts.native_options.keep_caller_fetchable = true;
opts.native_options.output_control_src = OutputControlSrc::kControlOutputs;
const string input_node = "Func/b/input/_0";
const string output_node = "Func/b/output/_1";
const string output_control_node = "Func/b/output_control_node/_2";
ExpandInlineFunctions(flr0_, g.get(), opts);
{
GraphDef expected = test::function::GDef(
{NDef("a", "_Arg", {}, {{"T", DT_FLOAT}, {"index", 0}}),
NDef(input_node, "Identity", {"a"}, {{"T", DT_FLOAT}}),
NDef("b/add", "Add", {input_node, input_node}, {{"T", DT_FLOAT}}),
NDef("b/ret", "Mul", {input_node, input_node}, {{"T", DT_FLOAT}}),
NDef(output_node, "Identity", {"b/ret"}, {{"T", DT_FLOAT}}),
NDef(output_control_node, "NoOp", {"^b/add"}, {}),
NDef("b", "IdentityN", {output_node, "^" + output_control_node},
{{"T", DataTypeSlice{DT_FLOAT}}})},
{func});
GraphDef actual;
g->ToGraphDef(&actual);
TF_EXPECT_GRAPH_EQ(expected, actual);
}
}
TEST_F(FunctionLibraryRuntimeTest, PruneBody) {
auto T = DT_INT32;
FunctionDef stateful_func = FDH::Define(

View File

@ -147,11 +147,21 @@ string CondBuilder::NewName(const string& infix) {
Status CondBuilder::AddInput(Node* src, int src_output) {
Node* input;
NodeDebugInfo debug_info(*src);
// Colocate the Switch node with the `src` node.
//
// This is to avoid unnecessary Host<->Device copies between src and the
// Switch node. This aligns with the implementation of legacy tf.cond in
// control_flow_ops.py. The legacy impl colocates the Switch with the
// input tensor which resets the device stack and forces the Switch to have
// the same device as the input node (if set) and sets the colocation _class
// attr. It also ignores the existing colocation constraints on the input node
// using colocate_with(ignore_existing=True).
TF_RETURN_IF_ERROR(NodeBuilder(NewName(src->name()), "Switch",
graph_->op_registry(), &debug_info)
.Input(src, src_output)
.Input(pred_)
.Device(if_op_->requested_device())
.Device(src->requested_device())
.Attr("_class", {src->name()})
.Finalize(graph_, &input));
then_call_builder_.Input(input, kThenBranch);
else_call_builder_.Input(input, kElseBranch);

View File

@ -102,7 +102,7 @@ Status Placer::Run() {
}
if (VLOG_IS_ON(3)) {
DumpGraphToFile("placer_input", *graph_, nullptr, "/tmp");
DumpGraphToFile("placer_input", *graph_, nullptr);
for (const Node* node : graph_->op_nodes()) {
VLOG(3) << " " << node->name() << ": requested: '"
<< node->requested_device() << "' assigned: '"
@ -226,7 +226,7 @@ Status Placer::Run() {
}
if (VLOG_IS_ON(3)) {
DumpGraphToFile("placer_output", *graph_, nullptr, "/tmp");
DumpGraphToFile("placer_output", *graph_, nullptr);
}
return Status::OK();
}

View File

@ -686,7 +686,7 @@ class SymbolicShapeRefiner {
// Perform inference on function body.
GraphProperties gp(grappler_function_item);
TF_RETURN_IF_ERROR(gp.InferStatically(true));
TF_RETURN_IF_ERROR(gp.InferStatically(true, aggressive_shape_inference_));
// Add return nodes for output shapes.
int output = 0;

View File

@ -1204,6 +1204,61 @@ TEST_F(GraphPropertiesTest, FunctionReturnTensorValue) {
properties.GetInputProperties("MyFunc")[0].value());
}
TEST_F(GraphPropertiesTest, ArithmeticFunctionReturnTensorValue) {
FunctionDefLibrary library;
// Function that adds two input values.
*library.add_function() = FunctionDefHelper::Create(
"MyFunc", // Name
{"x: int32", "y: int32"}, // Inputs
{"out: int32"}, // Outputs
{}, // Attrs
{{{"a"}, "Add", {"x", "y"}, {{"T", DataType::DT_INT32}}}}, // Nodes
{{"out", "a:z:0"}}); // Returns
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
TF_CHECK_OK(s.graph()->AddFunctionLibrary(library));
Output shape = ops::Const(s.WithOpName("shape"), {5, 7}, {2});
auto _shape = tensorflow::ops::AsNodeOut(s, shape);
auto builder =
tensorflow::NodeBuilder("MyFunc", "MyFunc", s.graph()->op_registry());
tensorflow::Node* func_op;
TF_CHECK_OK(
builder.Input(_shape).Input(_shape).Finalize(s.graph(), &func_op));
GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph));
{
GraphProperties properties(item);
// Without aggressive_shape_inference, the internal function does not
// evaluate output value.
TF_CHECK_OK(properties.InferStatically(
/*assume_valid_feeds=*/true,
/*aggressive_shape_inference=*/false));
const auto out_props = properties.GetOutputProperties("MyFunc");
const OpInfo::TensorProperties out_prop0 = out_props[0];
EXPECT_EQ("int32: [2]", PropToString(out_prop0));
EXPECT_FALSE(out_prop0.has_value());
}
{
GraphProperties properties(item);
// With aggressive_shape_inference, output value is evaluated.
TF_CHECK_OK(properties.InferStatically(
/*assume_valid_feeds=*/true,
/*aggressive_shape_inference=*/true));
const auto out_props = properties.GetOutputProperties("MyFunc");
const OpInfo::TensorProperties out_prop0 = out_props[0];
EXPECT_EQ("int32: [2]", PropToString(out_prop0));
EXPECT_TRUE(out_prop0.has_value());
ExpectTensorValues({10, 14}, out_prop0.value());
ExpectTensorValues({5, 7},
properties.GetInputProperties("MyFunc")[0].value());
ExpectTensorValues({5, 7},
properties.GetInputProperties("MyFunc")[1].value());
}
}
TEST_F(GraphPropertiesTest, FunctionWithScalarInput) {
// Create graph with a function that takes a scalar value so that we use
// Placeholder with scalar as for input to the function shape inference.

View File

@ -39,6 +39,7 @@ constexpr char kDepthwiseConv2dNativeBackpropInput[] =
"DepthwiseConv2dNativeBackpropInput";
constexpr char kMatMul[] = "MatMul";
constexpr char kSparseMatMul[] = "SparseMatMul";
constexpr char kSparseTensorDenseMatMul[] = "SparseTensorDenseMatMul";
constexpr char kPlaceholder[] = "Placeholder";
constexpr char kIdentity[] = "Identity";
constexpr char kIdentityN[] = "IdentityN";
@ -243,6 +244,8 @@ OpLevelCostEstimator::OpLevelCostEstimator() {
wrap(&OpLevelCostEstimator::PredictConv2DBackpropInput)},
{kMatMul, wrap(&OpLevelCostEstimator::PredictMatMul)},
{kSparseMatMul, wrap(&OpLevelCostEstimator::PredictMatMul)},
{kSparseTensorDenseMatMul,
wrap(&OpLevelCostEstimator::PredictSparseTensorDenseMatMul)},
{kBatchMatMul, wrap(&OpLevelCostEstimator::PredictBatchMatMul)},
{kQuantizedMatMul, wrap(&OpLevelCostEstimator::PredictMatMul)},
{kQuantizedMatMulV2, wrap(&OpLevelCostEstimator::PredictMatMul)},
@ -1228,6 +1231,49 @@ Costs OpLevelCostEstimator::PredictMatMul(const OpContext& op_context) const {
return costs;
}
Costs OpLevelCostEstimator::PredictSparseTensorDenseMatMul(
const OpContext& op_context) const {
const auto& op_info = op_context.op_info;
bool found_unknown_shapes = false;
// input[0]: indices in sparse matrix a
// input[1]: values in sparse matrix a
// input[2]: shape of matrix a
// input[3]: matrix b
// See
// https://github.com/tensorflow/tensorflow/blob/9a43dfeac5/tensorflow/core/ops/sparse_ops.cc#L85
int64 num_elems_in_a =
CalculateTensorElementCount(op_info.inputs(1), &found_unknown_shapes);
auto b_matrix = op_info.inputs(3);
auto b_matrix_shape =
MaybeGetMinimumShape(b_matrix.shape(), 2, &found_unknown_shapes);
int64 n_dim = b_matrix_shape.dim(1).size();
// Each element in A is multiplied and added with an element from each column
// in b.
const int64 op_count = kOpsPerMac * num_elems_in_a * n_dim;
int64 a_indices_input_size =
CalculateTensorSize(op_info.inputs(0), &found_unknown_shapes);
int64 a_values_input_size =
CalculateTensorSize(op_info.inputs(1), &found_unknown_shapes);
int64 a_shape_input_size =
CalculateTensorSize(op_info.inputs(2), &found_unknown_shapes);
int64 b_input_size =
num_elems_in_a * n_dim * DataTypeSize(BaseType(b_matrix.dtype()));
double input_size = a_indices_input_size + a_values_input_size +
a_shape_input_size + b_input_size;
double output_size = CalculateOutputSize(op_info, &found_unknown_shapes);
auto costs =
PredictOpCountBasedCost(op_count, input_size, output_size, op_info);
costs.inaccurate = found_unknown_shapes;
costs.num_ops_with_unknown_shapes = found_unknown_shapes;
costs.max_memory = output_size;
return costs;
}
Costs OpLevelCostEstimator::PredictNoOp(const OpContext& op_context) const {
const auto& op_info = op_context.op_info;
VLOG(1) << "Op:" << op_info.op() << " Execution Time 0 (ns)";

View File

@ -132,6 +132,7 @@ class OpLevelCostEstimator {
Costs PredictConv2DBackpropFilter(const OpContext& op_context) const;
Costs PredictFusedConv2DBiasActivation(const OpContext& op_context) const;
Costs PredictMatMul(const OpContext& op_context) const;
Costs PredictSparseTensorDenseMatMul(const OpContext& op_context) const;
Costs PredictNoOp(const OpContext& op_context) const;
Costs PredictIdentity(const OpContext& op_context) const;
Costs PredictVariable(const OpContext& op_context) const;

View File

@ -119,6 +119,22 @@ OpContext DescribeBatchMatMul(const std::vector<int>& dims_a,
return op_context;
}
// Returns an OpInfo for a SparseTensorDenseMatMul
OpContext DescribeSparseTensorDenseMatMul(const int nnz_a,
const std::vector<int>& dims_b,
const std::vector<int>& dims_out) {
OpContext op_context;
SetCpuDevice(&op_context.op_info);
op_context.op_info.set_op("SparseTensorDenseMatMul");
DescribeArbitraryRankInput({nnz_a, 2}, DT_INT64, &op_context.op_info);
DescribeArbitraryRankInput({nnz_a}, DT_FLOAT, &op_context.op_info);
DescribeArbitraryRankInput({2}, DT_INT64, &op_context.op_info);
DescribeArbitraryRankInput(dims_b, DT_FLOAT, &op_context.op_info);
DescribeArbitraryRankOutput(dims_out, DT_FLOAT, &op_context.op_info);
return op_context;
}
// Wrangles the minimum number of proto fields to set up a 1D Tensor for cost
// estimation purposes.
void DescribeTensor1D(int dim0, OpInfo::TensorProperties* tensor) {
@ -854,6 +870,58 @@ TEST_F(OpLevelCostEstimatorTest, BatchMatMul) {
EXPECT_NE(matmul_inaccurate, batch_matmul_inaccurate);
}
TEST_F(OpLevelCostEstimatorTest, SparseTensorDenseMatMul) {
// Unknown shape cases
{
auto cost =
PredictCosts(DescribeSparseTensorDenseMatMul(-1, {1, 1}, {1, 1}));
EXPECT_EQ(1, cost.num_ops_total);
EXPECT_TRUE(cost.inaccurate);
EXPECT_EQ(1, cost.num_ops_with_unknown_shapes);
}
{
auto cost =
PredictCosts(DescribeSparseTensorDenseMatMul(1, {-1, 1}, {1, 1}));
EXPECT_EQ(1, cost.num_ops_total);
EXPECT_TRUE(cost.inaccurate);
EXPECT_EQ(1, cost.num_ops_with_unknown_shapes);
}
{
auto cost =
PredictCosts(DescribeSparseTensorDenseMatMul(1, {1, -1}, {1, -1}));
EXPECT_EQ(1, cost.num_ops_total);
EXPECT_TRUE(cost.inaccurate);
EXPECT_EQ(1, cost.num_ops_with_unknown_shapes);
}
{
auto cost =
PredictCosts(DescribeSparseTensorDenseMatMul(1, {1, 1}, {-1, 1}));
EXPECT_EQ(1, cost.num_ops_total);
EXPECT_TRUE(cost.inaccurate);
EXPECT_EQ(1, cost.num_ops_with_unknown_shapes);
}
// Known shape cases
{
auto cost = PredictCosts(
DescribeSparseTensorDenseMatMul(10, {1000, 100}, {50, 100}));
EXPECT_EQ(1, cost.num_ops_total);
EXPECT_FALSE(cost.inaccurate);
EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
EXPECT_EQ(Costs::Duration(200), cost.compute_time);
EXPECT_EQ(Costs::Duration(2422), cost.memory_time);
}
{
// Same cost as above case because cost does not depend on k_dim
auto cost = PredictCosts(
DescribeSparseTensorDenseMatMul(10, {100000, 100}, {50, 100}));
EXPECT_EQ(1, cost.num_ops_total);
EXPECT_FALSE(cost.inaccurate);
EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
EXPECT_EQ(Costs::Duration(200), cost.compute_time);
EXPECT_EQ(Costs::Duration(2422), cost.memory_time);
}
}
void ExpectTensorShape(const std::vector<int64>& expected,
const TensorShapeProto& tensor_shape_proto) {
TensorShape tensor_shape_expected(expected);

View File

@ -38,7 +38,7 @@ namespace grappler {
// b = Placeholder(..)
// c = AddN([a, a, b])
//
// GraphView edges: [a:0 -> c:0, a:0 -> c:1, b:0 -> c:3]
// GraphView edges: [a:0 -> c:0, a:0 -> c:1, b:0 -> c:2]
// GraphTopologyView edges: [a -> c, b -> c]
//
// GraphView is used for exploring single node fanins and fanouts, and

View File

@ -499,12 +499,12 @@ inline int64 ConvolveScratchSize() {
// convolution on the stream) and parameters, by running all possible
// algorithms and measuring execution time.
// TODO(ezhulenev): Move it to conv_ops_gpu.h and share with conv_ops.cc.
template <typename T, typename ConvLaunch>
Status FindBestConvolveAlgorithm(
const FusedConvParameters& params, const ConvLaunch launch,
OpKernelContext* context, se::Stream* stream,
se::dnn::AlgorithmConfig* algorithm_config,
std::vector<tensorflow::AutotuneResult>* results) {
template <typename T, typename ConvLaunch, typename LogFunc>
Status FindBestConvolveAlgorithm(const FusedConvParameters& params,
const ConvLaunch launch,
OpKernelContext* context, se::Stream* stream,
const LogFunc& log,
se::dnn::AlgorithmConfig* algorithm_config) {
// Check if we already have an algorithm selected for the given parameters.
if (AutoTuneFusedConv::GetInstance()->Find(params, algorithm_config)) {
return Status::OK();
@ -521,6 +521,7 @@ Status FindBestConvolveAlgorithm(
"see if a warning log message was printed above.");
}
std::vector<tensorflow::AutotuneResult> results;
for (auto profile_algorithm : algorithms) {
DnnScratchAllocator scratch_allocator(ConvolveScratchSize(), context);
se::dnn::ProfileResult profile_result;
@ -530,8 +531,8 @@ Status FindBestConvolveAlgorithm(
&profile_result);
if (cudnn_launch_status && profile_result.is_valid()) {
results->emplace_back();
auto& result = results->back();
results.emplace_back();
auto& result = results.back();
result.mutable_conv()->set_algorithm(profile_algorithm.algo_id());
result.mutable_conv()->set_tensor_ops_enabled(
profile_algorithm.tensor_ops_enabled());
@ -542,7 +543,9 @@ Status FindBestConvolveAlgorithm(
absl::Milliseconds(profile_result.elapsed_time_in_ms()));
}
}
TF_RETURN_IF_ERROR(BestCudnnConvAlgorithm(*results, algorithm_config));
// Only log on an AutoTuneFusedConv cache miss.
log(results);
TF_RETURN_IF_ERROR(BestCudnnConvAlgorithm(results, algorithm_config));
AutoTuneFusedConv::GetInstance()->Insert(params, *algorithm_config);
return Status::OK();
}
@ -789,13 +792,14 @@ struct LaunchFusedConv2DOp<GPUDevice, T> {
se::dnn::AlgorithmConfig algorithm_config;
if (cudnn_use_autotune) {
std::vector<tensorflow::AutotuneResult> results;
auto status =
FindBestConvolveAlgorithm<T>(conv_parameters, launch, context, stream,
&algorithm_config, &results);
LogFusedConvAutotuneResults(context->op_kernel().def(), input,
transformed_filter, transformed_output, bias,
nullptr, stream->parent(), results);
auto status = FindBestConvolveAlgorithm<T>(
conv_parameters, launch, context, stream,
[&](absl::Span<const tensorflow::AutotuneResult> results) {
LogFusedConvAutotuneResults(
context->op_kernel().def(), input, transformed_filter,
transformed_output, bias, nullptr, stream->parent(), results);
},
&algorithm_config);
OP_REQUIRES_OK(context, status);
}

View File

@ -24,7 +24,8 @@ REGISTER5(BinaryOp, CPU, "TruncateDiv", functor::safe_div, uint8, uint16, int16,
int32, int64);
REGISTER6(BinaryOp, CPU, "RealDiv", functor::div, float, Eigen::half, double,
bfloat16, complex64, complex128);
REGISTER2(BinaryOp, CPU, "DivNoNan", functor::div_no_nan, float, double);
REGISTER5(BinaryOp, CPU, "DivNoNan", functor::div_no_nan, Eigen::half, float,
double, complex64, complex128);
#if GOOGLE_CUDA
REGISTER9(BinaryOp, GPU, "Div", functor::div, float, Eigen::half, double, uint8,

View File

@ -19,7 +19,7 @@ limitations under the License.
namespace tensorflow {
namespace functor {
DEFINE_UNARY2(tan, float, double);
DEFINE_UNARY3(tan, Eigen::half, float, double);
} // namespace functor
} // namespace tensorflow

View File

@ -19,7 +19,8 @@ namespace tensorflow {
REGISTER6(BinaryOp, CPU, "Mul", functor::mul, float, Eigen::half, double, uint8,
int32, bfloat16);
REGISTER2(BinaryOp, CPU, "MulNoNan", functor::mul_no_nan, float, double);
REGISTER5(BinaryOp, CPU, "MulNoNan", functor::mul_no_nan, Eigen::half, float,
double, complex64, complex128);
#if defined(__ANDROID_TYPES_SLIM__)
// We only register the first type when we have multi-argument calls in the

View File

@ -16,11 +16,11 @@ limitations under the License.
#include "tensorflow/core/kernels/cwise_ops_common.h"
namespace tensorflow {
REGISTER4(UnaryOp, CPU, "Tan", functor::tan, float, double, complex64,
complex128);
REGISTER5(UnaryOp, CPU, "Tan", functor::tan, Eigen::half, float, double,
complex64, complex128);
#if GOOGLE_CUDA
REGISTER2(UnaryOp, GPU, "Tan", functor::tan, float, double);
REGISTER3(UnaryOp, GPU, "Tan", functor::tan, Eigen::half, float, double);
#endif
#ifdef TENSORFLOW_USE_SYCL

View File

@ -540,12 +540,32 @@ tf_kernel_library(
name = "tensor_dataset_op",
srcs = ["tensor_dataset_op.cc"],
deps = [
":dataset_utils",
"//tensorflow/core:dataset_ops_op_lib",
"//tensorflow/core:framework",
"//tensorflow/core:graph",
],
)
tf_cc_test(
name = "tensor_dataset_op_test",
size = "small",
srcs = ["tensor_dataset_op_test.cc"],
deps = [
":dataset_test_base",
":dataset_utils",
":iterator_ops",
":tensor_dataset_op",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:dataset_ops_op_lib",
"//tensorflow/core:framework",
"//tensorflow/core:lib_internal",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
],
)
tf_kernel_library(
name = "tensor_slice_dataset_op",
srcs = ["tensor_slice_dataset_op.cc"],

View File

@ -446,18 +446,26 @@ class ChooseFastestBranchDatasetOp : public UnaryDatasetOpKernel {
return s;
}
// Select the fastest input to use based on the histograms of timings
// of the completed iterations. The input with the best 90th percentile
// iteration time is selected.
void SelectFastestInputIndex() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
fastest_index_ = 0;
VLOG(2) << "90.0 percentile iteration time:";
double best_percentile = histograms_[0].Percentile(kPercentile);
VLOG(2) << "Branch 0: " << best_percentile;
for (size_t i = 1, num_inputs = histograms_.size(); i < num_inputs;
++i) {
double percentile = histograms_[i].Percentile(kPercentile);
VLOG(2) << "Branch " << i << ": " << percentile;
if (percentile <= best_percentile) {
best_percentile = percentile;
fastest_index_ = i;
}
}
VLOG(1) << "Selecting index " << fastest_index_
<< " as the fastest index.";
}
Status MakeCurrentIterator(IteratorContext* ctx, int64 branch_index,

View File

@ -325,15 +325,20 @@ class ChooseFastestDatasetOp : public DatasetOpKernel {
void SelectFastestInputIndex() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
fastest_index_ = 0;
VLOG(2) << "90.0 percentile iteration time:";
double best_percentile = histograms_[0].Percentile(kPercentile);
VLOG(2) << "Branch 0: " << best_percentile;
for (size_t i = 1, num_inputs = histograms_.size(); i < num_inputs;
++i) {
double percentile = histograms_[i].Percentile(kPercentile);
VLOG(2) << "Branch " << i << ": " << percentile;
if (percentile <= best_percentile) {
best_percentile = percentile;
fastest_index_ = i;
}
}
VLOG(1) << "Selecting index " << fastest_index_
<< " as the fastest index.";
fastest_input_impl_ = std::move(input_impls_[fastest_index_]);
input_impls_.clear(); // Delete the unused iterators.

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/kernels/data/dataset_utils.h"
namespace tensorflow {
namespace data {
@ -26,15 +27,20 @@ namespace {
class TensorDatasetOp : public DatasetOpKernel {
public:
explicit TensorDatasetOp(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) {}
explicit TensorDatasetOp(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("Toutput_types", &output_types_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
}
void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
OpInputList inputs;
OP_REQUIRES_OK(ctx, ctx->input_list("components", &inputs));
// TODO(mrry): Validate that the shapes of the "components" tensors match
// the "shapes" attr.;
std::vector<Tensor> components(inputs.begin(), inputs.end());
*output = new Dataset(ctx, std::move(components));
OP_REQUIRES_OK(ctx,
VerifyTypesMatch((*output)->output_dtypes(), output_types_));
OP_REQUIRES_OK(ctx, VerifyShapesCompatible((*output)->output_shapes(),
output_shapes_));
}
private:
@ -137,6 +143,9 @@ class TensorDatasetOp : public DatasetOpKernel {
DataTypeVector dtypes_;
std::vector<PartialTensorShape> shapes_;
};
DataTypeVector output_types_;
std::vector<PartialTensorShape> output_shapes_;
};
REGISTER_KERNEL_BUILDER(Name("TensorDataset").Device(DEVICE_CPU),

View File

@ -0,0 +1,535 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/fake_input.h"
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/variant.h"
#include "tensorflow/core/framework/variant_tensor_data.h"
#include "tensorflow/core/kernels/data/dataset_test_base.h"
#include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/kernels/data/iterator_ops.h"
#include "tensorflow/core/kernels/ops_testutil.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
namespace data {
namespace {
constexpr char kNodeName[] = "tensor_dataset";
constexpr char kOpName[] = "TensorDataset";
class TensorDatasetOpTest : public DatasetOpsTestBase {
protected:
// Creates a new TensorDataset op kernel.
Status CreateTensorDatasetKernel(
DataTypeVector dtypes, std::vector<PartialTensorShape> shapes,
std::unique_ptr<OpKernel> *tensor_dataset_kernel) {
std::vector<string> components;
components.reserve(dtypes.size());
for (int i = 0; i < dtypes.size(); i++) {
components.emplace_back(strings::StrCat("component_", i));
}
node_def_ = test::function::NDef(
kNodeName, kOpName, components,
{{"Toutput_types", dtypes}, {"output_shapes", shapes}});
TF_RETURN_IF_ERROR(CreateOpKernel(node_def_, tensor_dataset_kernel));
return Status::OK();
}
// Creates a new TensorDataset op kernel context.
Status CreateTensorDatasetContext(OpKernel *const tensor_dataset_kernel,
gtl::InlinedVector<TensorValue, 4> *inputs,
std::unique_ptr<OpKernelContext> *context) {
TF_RETURN_IF_ERROR(CheckOpKernelInput(*tensor_dataset_kernel, *inputs));
TF_RETURN_IF_ERROR(
CreateOpKernelContext(tensor_dataset_kernel, inputs, context));
return Status::OK();
}
private:
NodeDef node_def_;
};
struct TestCase {
std::vector<Tensor> components;
std::vector<Tensor> expected_outputs;
DataTypeVector expected_output_dtypes;
std::vector<PartialTensorShape> expected_output_shapes;
int64 expected_cardinality;
std::vector<int> breakpoints;
};
// Test case 1: test a dataset that represents a single tuple of plain tensors.
TestCase PlainTensorsTestCase() {
return {
/*components*/
{DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({1, 3}), {1, 2, 3}),
DatasetOpsTestBase::CreateTensor<double>(TensorShape({}), {37.0}),
DatasetOpsTestBase::CreateTensor<string>(TensorShape({1, 2}),
{"a", "b"})},
/*expected_outputs*/
{DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({1, 3}), {1, 2, 3}),
DatasetOpsTestBase::CreateTensor<double>(TensorShape({}), {37.0}),
DatasetOpsTestBase::CreateTensor<string>(TensorShape({1, 2}),
{"a", "b"})},
/*expected_output_dtypes*/
{DT_INT64, DT_INT64, DT_DOUBLE, DT_STRING},
/*expected_output_shapes*/
{PartialTensorShape({}), PartialTensorShape({1, 3}),
PartialTensorShape({}), PartialTensorShape({1, 2})},
/*expected_cardinality*/ 1,
/*breakpoints*/ {0, 1, 2}};
}
// Test case 2: test a dataset that represents a tuple of nested tensors.
TestCase NestedTensorsTestCase() {
return {
/*components*/
{DatasetOpsTestBase::CreateTensor<Variant>(
TensorShape({}), {DatasetOpsTestBase::CreateTensor<double>(
TensorShape({2, 2}), {1.0, 2.0, 3.0, 4.0})}),
DatasetOpsTestBase::CreateTensor<Variant>(
TensorShape({}), {DatasetOpsTestBase::CreateTensor<string>(
TensorShape({1, 2}), {"a", "b"})}),
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({1, 3}), {1, 2, 3})},
/*expected_outputs*/
{DatasetOpsTestBase::CreateTensor<Variant>(
TensorShape({}), {DatasetOpsTestBase::CreateTensor<double>(
TensorShape({2, 2}), {1.0, 2.0, 3.0, 4.0})}),
DatasetOpsTestBase::CreateTensor<Variant>(
TensorShape({}), {DatasetOpsTestBase::CreateTensor<string>(
TensorShape({1, 2}), {"a", "b"})}),
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({1, 3}), {1, 2, 3})},
/*expected_output_dtypes*/
{DT_VARIANT, DT_VARIANT, DT_INT64},
/*expected_output_shapes*/
{PartialTensorShape({}), PartialTensorShape({}),
PartialTensorShape({1, 3})},
/*expected_cardinality*/ 1,
/*breakpoints*/ {0, 1, 2}};
}
class ParametrizedTensorDatasetOpTest
: public TensorDatasetOpTest,
public ::testing::WithParamInterface<TestCase> {};
TEST_P(ParametrizedTensorDatasetOpTest, GetNext) {
int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
const TestCase &test_case = GetParam();
std::vector<Tensor> components = test_case.components;
gtl::InlinedVector<TensorValue, 4> inputs;
for (auto &component : components) {
inputs.push_back(&component);
}
std::unique_ptr<OpKernel> tensor_dataset_kernel;
TF_ASSERT_OK(CreateTensorDatasetKernel(test_case.expected_output_dtypes,
test_case.expected_output_shapes,
&tensor_dataset_kernel));
std::unique_ptr<OpKernelContext> tensor_dataset_context;
TF_ASSERT_OK(CreateTensorDatasetContext(tensor_dataset_kernel.get(), &inputs,
&tensor_dataset_context));
DatasetBase *tensor_dataset;
TF_ASSERT_OK(CreateDataset(tensor_dataset_kernel.get(),
tensor_dataset_context.get(), &tensor_dataset));
core::ScopedUnref scoped_unref(tensor_dataset);
std::unique_ptr<IteratorContext> iterator_context;
TF_ASSERT_OK(
CreateIteratorContext(tensor_dataset_context.get(), &iterator_context));
std::unique_ptr<IteratorBase> iterator;
TF_ASSERT_OK(tensor_dataset->MakeIterator(iterator_context.get(), "Iterator",
&iterator));
bool end_of_sequence = false;
std::vector<Tensor> out_tensors;
while (!end_of_sequence) {
TF_EXPECT_OK(iterator->GetNext(iterator_context.get(), &out_tensors,
&end_of_sequence));
}
EXPECT_EQ(out_tensors.size(), test_case.expected_outputs.size());
for (int i = 0; i < out_tensors.size(); ++i) {
if (out_tensors[i].dtype() == DT_VARIANT) {
// Currently `ExpectEqual()` does not support the variant tensor
// yet, so we manually cast the variant to numeric/string tensor.
const Tensor *output = out_tensors[i].scalar<Variant>()().get<Tensor>();
const Tensor *expected_output =
test_case.expected_outputs[i].scalar<Variant>()().get<Tensor>();
TF_EXPECT_OK(ExpectEqual(*output, *expected_output));
} else {
TF_EXPECT_OK(ExpectEqual(out_tensors[i], test_case.expected_outputs[i]));
}
}
}
TEST_F(TensorDatasetOpTest, DatasetTypeString) {
int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
const TestCase &test_case = PlainTensorsTestCase();
std::vector<Tensor> components = test_case.components;
gtl::InlinedVector<TensorValue, 4> inputs;
for (auto &component : components) {
inputs.push_back(&component);
}
std::unique_ptr<OpKernel> tensor_dataset_kernel;
TF_ASSERT_OK(CreateTensorDatasetKernel(test_case.expected_output_dtypes,
test_case.expected_output_shapes,
&tensor_dataset_kernel));
std::unique_ptr<OpKernelContext> tensor_dataset_context;
TF_ASSERT_OK(CreateTensorDatasetContext(tensor_dataset_kernel.get(), &inputs,
&tensor_dataset_context));
DatasetBase *tensor_dataset;
TF_ASSERT_OK(CreateDataset(tensor_dataset_kernel.get(),
tensor_dataset_context.get(), &tensor_dataset));
core::ScopedUnref scoped_unref(tensor_dataset);
EXPECT_EQ(tensor_dataset->type_string(), kOpName);
}
TEST_F(TensorDatasetOpTest, DatasetNodeName) {
int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
const TestCase &test_case = PlainTensorsTestCase();
std::vector<Tensor> components = test_case.components;
gtl::InlinedVector<TensorValue, 4> inputs;
for (auto &component : components) {
inputs.push_back(&component);
}
std::unique_ptr<OpKernel> tensor_dataset_kernel;
TF_ASSERT_OK(CreateTensorDatasetKernel(test_case.expected_output_dtypes,
test_case.expected_output_shapes,
&tensor_dataset_kernel));
std::unique_ptr<OpKernelContext> tensor_dataset_context;
TF_ASSERT_OK(CreateTensorDatasetContext(tensor_dataset_kernel.get(), &inputs,
&tensor_dataset_context));
DatasetBase *tensor_dataset;
TF_ASSERT_OK(CreateDataset(tensor_dataset_kernel.get(),
tensor_dataset_context.get(), &tensor_dataset));
core::ScopedUnref scoped_unref(tensor_dataset);
EXPECT_EQ(tensor_dataset->node_name(), kNodeName);
}
TEST_F(TensorDatasetOpTest, DatasetOutputDtypes) {
int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
const TestCase &test_case = PlainTensorsTestCase();
std::vector<Tensor> components = test_case.components;
gtl::InlinedVector<TensorValue, 4> inputs;
for (auto &component : components) {
inputs.push_back(&component);
}
std::unique_ptr<OpKernel> tensor_dataset_kernel;
TF_ASSERT_OK(CreateTensorDatasetKernel(test_case.expected_output_dtypes,
test_case.expected_output_shapes,
&tensor_dataset_kernel));
std::unique_ptr<OpKernelContext> tensor_dataset_context;
TF_ASSERT_OK(CreateTensorDatasetContext(tensor_dataset_kernel.get(), &inputs,
&tensor_dataset_context));
DatasetBase *tensor_dataset;
TF_ASSERT_OK(CreateDataset(tensor_dataset_kernel.get(),
tensor_dataset_context.get(), &tensor_dataset));
core::ScopedUnref scoped_unref(tensor_dataset);
EXPECT_EQ(tensor_dataset->output_dtypes(), test_case.expected_output_dtypes);
}
TEST_F(TensorDatasetOpTest, DatasetOutputShapes) {
int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
const TestCase &test_case = PlainTensorsTestCase();
std::vector<Tensor> components = test_case.components;
gtl::InlinedVector<TensorValue, 4> inputs;
for (auto &component : components) {
inputs.push_back(&component);
}
std::unique_ptr<OpKernel> tensor_dataset_kernel;
TF_ASSERT_OK(CreateTensorDatasetKernel(test_case.expected_output_dtypes,
test_case.expected_output_shapes,
&tensor_dataset_kernel));
std::unique_ptr<OpKernelContext> tensor_dataset_context;
TF_ASSERT_OK(CreateTensorDatasetContext(tensor_dataset_kernel.get(), &inputs,
&tensor_dataset_context));
DatasetBase *tensor_dataset;
TF_ASSERT_OK(CreateDataset(tensor_dataset_kernel.get(),
tensor_dataset_context.get(), &tensor_dataset));
core::ScopedUnref scoped_unref(tensor_dataset);
EXPECT_EQ(tensor_dataset->output_shapes().size(),
test_case.expected_output_shapes.size());
for (int i = 0; i < test_case.expected_output_shapes.size(); i++) {
EXPECT_TRUE(test_case.expected_output_shapes[i].IsIdenticalTo(
tensor_dataset->output_shapes()[i]));
}
}
TEST_F(TensorDatasetOpTest, Cardinality) {
int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
const TestCase &test_case = PlainTensorsTestCase();
std::vector<Tensor> components = test_case.components;
gtl::InlinedVector<TensorValue, 4> inputs;
for (auto &component : components) {
inputs.push_back(&component);
}
std::unique_ptr<OpKernel> tensor_dataset_kernel;
TF_ASSERT_OK(CreateTensorDatasetKernel(test_case.expected_output_dtypes,
test_case.expected_output_shapes,
&tensor_dataset_kernel));
std::unique_ptr<OpKernelContext> tensor_dataset_context;
TF_ASSERT_OK(CreateTensorDatasetContext(tensor_dataset_kernel.get(), &inputs,
&tensor_dataset_context));
DatasetBase *tensor_dataset;
TF_ASSERT_OK(CreateDataset(tensor_dataset_kernel.get(),
tensor_dataset_context.get(), &tensor_dataset));
core::ScopedUnref scoped_unref(tensor_dataset);
EXPECT_EQ(tensor_dataset->Cardinality(), test_case.expected_cardinality);
}
TEST_P(ParametrizedTensorDatasetOpTest, DatasetSave) {
int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
const TestCase &test_case = GetParam();
std::vector<Tensor> components = test_case.components;
gtl::InlinedVector<TensorValue, 4> inputs;
for (auto &component : components) {
inputs.push_back(&component);
}
std::unique_ptr<OpKernel> tensor_dataset_kernel;
TF_ASSERT_OK(CreateTensorDatasetKernel(test_case.expected_output_dtypes,
test_case.expected_output_shapes,
&tensor_dataset_kernel));
std::unique_ptr<OpKernelContext> tensor_dataset_context;
TF_ASSERT_OK(CreateTensorDatasetContext(tensor_dataset_kernel.get(), &inputs,
&tensor_dataset_context));
DatasetBase *tensor_dataset;
TF_ASSERT_OK(CreateDataset(tensor_dataset_kernel.get(),
tensor_dataset_context.get(), &tensor_dataset));
core::ScopedUnref scoped_unref(tensor_dataset);
std::unique_ptr<SerializationContext> serialization_context;
TF_ASSERT_OK(CreateSerializationContext(&serialization_context));
VariantTensorData data;
VariantTensorDataWriter writer(&data);
TF_ASSERT_OK(tensor_dataset->Save(serialization_context.get(), &writer));
TF_ASSERT_OK(writer.Flush());
}
TEST_P(ParametrizedTensorDatasetOpTest, IteratorOutputDtypes) {
int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
const TestCase &test_case = GetParam();
std::vector<Tensor> components = test_case.components;
gtl::InlinedVector<TensorValue, 4> inputs;
for (auto &component : components) {
inputs.push_back(&component);
}
std::unique_ptr<OpKernel> tensor_dataset_kernel;
TF_ASSERT_OK(CreateTensorDatasetKernel(test_case.expected_output_dtypes,
test_case.expected_output_shapes,
&tensor_dataset_kernel));
std::unique_ptr<OpKernelContext> tensor_dataset_context;
TF_ASSERT_OK(CreateTensorDatasetContext(tensor_dataset_kernel.get(), &inputs,
&tensor_dataset_context));
DatasetBase *tensor_dataset;
TF_ASSERT_OK(CreateDataset(tensor_dataset_kernel.get(),
tensor_dataset_context.get(), &tensor_dataset));
core::ScopedUnref scoped_unref(tensor_dataset);
std::unique_ptr<IteratorContext> iterator_context;
TF_ASSERT_OK(
CreateIteratorContext(tensor_dataset_context.get(), &iterator_context));
std::unique_ptr<IteratorBase> iterator;
TF_ASSERT_OK(tensor_dataset->MakeIterator(iterator_context.get(), "Iterator",
&iterator));
EXPECT_EQ(iterator->output_dtypes(), test_case.expected_output_dtypes);
}
TEST_P(ParametrizedTensorDatasetOpTest, IteratorOutputShapes) {
int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
const TestCase &test_case = GetParam();
std::vector<Tensor> components = test_case.components;
gtl::InlinedVector<TensorValue, 4> inputs;
for (auto &component : components) {
inputs.push_back(&component);
}
std::unique_ptr<OpKernel> tensor_dataset_kernel;
TF_ASSERT_OK(CreateTensorDatasetKernel(test_case.expected_output_dtypes,
test_case.expected_output_shapes,
&tensor_dataset_kernel));
std::unique_ptr<OpKernelContext> tensor_dataset_context;
TF_ASSERT_OK(CreateTensorDatasetContext(tensor_dataset_kernel.get(), &inputs,
&tensor_dataset_context));
DatasetBase *tensor_dataset;
TF_ASSERT_OK(CreateDataset(tensor_dataset_kernel.get(),
tensor_dataset_context.get(), &tensor_dataset));
core::ScopedUnref scoped_unref(tensor_dataset);
std::unique_ptr<IteratorContext> iterator_context;
TF_ASSERT_OK(
CreateIteratorContext(tensor_dataset_context.get(), &iterator_context));
std::unique_ptr<IteratorBase> iterator;
TF_ASSERT_OK(tensor_dataset->MakeIterator(iterator_context.get(), "Iterator",
&iterator));
EXPECT_EQ(iterator->output_shapes().size(),
test_case.expected_output_shapes.size());
for (int i = 0; i < test_case.expected_output_shapes.size(); ++i) {
EXPECT_TRUE(test_case.expected_output_shapes[i].IsIdenticalTo(
iterator->output_shapes()[i]));
}
}
TEST_F(TensorDatasetOpTest, IteratorOutputPrefix) {
int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
const TestCase &test_case = PlainTensorsTestCase();
std::vector<Tensor> components = test_case.components;
gtl::InlinedVector<TensorValue, 4> inputs;
for (auto &component : components) {
inputs.push_back(&component);
}
std::unique_ptr<OpKernel> tensor_dataset_kernel;
TF_ASSERT_OK(CreateTensorDatasetKernel(test_case.expected_output_dtypes,
test_case.expected_output_shapes,
&tensor_dataset_kernel));
std::unique_ptr<OpKernelContext> tensor_dataset_context;
TF_ASSERT_OK(CreateTensorDatasetContext(tensor_dataset_kernel.get(), &inputs,
&tensor_dataset_context));
DatasetBase *tensor_dataset;
TF_ASSERT_OK(CreateDataset(tensor_dataset_kernel.get(),
tensor_dataset_context.get(), &tensor_dataset));
core::ScopedUnref scoped_unref(tensor_dataset);
std::unique_ptr<IteratorContext> iterator_context;
TF_ASSERT_OK(
CreateIteratorContext(tensor_dataset_context.get(), &iterator_context));
std::unique_ptr<IteratorBase> iterator;
TF_ASSERT_OK(tensor_dataset->MakeIterator(iterator_context.get(), "Iterator",
&iterator));
EXPECT_EQ(iterator->prefix(), "Iterator::FromTensor");
}
TEST_P(ParametrizedTensorDatasetOpTest, Roundtrip) {
int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
const TestCase &test_case = GetParam();
std::vector<Tensor> components = test_case.components;
gtl::InlinedVector<TensorValue, 4> inputs;
for (auto &component : components) {
inputs.push_back(&component);
}
std::unique_ptr<OpKernel> tensor_dataset_kernel;
TF_ASSERT_OK(CreateTensorDatasetKernel(test_case.expected_output_dtypes,
test_case.expected_output_shapes,
&tensor_dataset_kernel));
std::unique_ptr<OpKernelContext> tensor_dataset_context;
TF_ASSERT_OK(CreateTensorDatasetContext(tensor_dataset_kernel.get(), &inputs,
&tensor_dataset_context));
DatasetBase *tensor_dataset;
TF_ASSERT_OK(CreateDataset(tensor_dataset_kernel.get(),
tensor_dataset_context.get(), &tensor_dataset));
core::ScopedUnref scoped_unref(tensor_dataset);
std::unique_ptr<IteratorContext> iterator_ctx;
TF_ASSERT_OK(
CreateIteratorContext(tensor_dataset_context.get(), &iterator_ctx));
std::unique_ptr<IteratorBase> iterator;
TF_ASSERT_OK(
tensor_dataset->MakeIterator(iterator_ctx.get(), "Iterator", &iterator));
std::unique_ptr<SerializationContext> serialization_ctx;
TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx));
bool end_of_sequence = false;
std::vector<Tensor> out_tensors;
int cur_iteration = 0;
const std::vector<int> &breakpoints = test_case.breakpoints;
for (int breakpoint : breakpoints) {
VariantTensorData data;
VariantTensorDataWriter writer(&data);
TF_EXPECT_OK(iterator->Save(serialization_ctx.get(), &writer));
TF_EXPECT_OK(writer.Flush());
VariantTensorDataReader reader(&data);
TF_EXPECT_OK(iterator->Restore(iterator_ctx.get(), &reader));
while (cur_iteration <= breakpoint) {
TF_EXPECT_OK(iterator->GetNext(iterator_ctx.get(), &out_tensors,
&end_of_sequence));
if (!end_of_sequence) {
EXPECT_EQ(out_tensors.size(), test_case.expected_outputs.size());
for (int i = 0; i < out_tensors.size(); ++i) {
if (out_tensors[i].dtype() == DT_VARIANT) {
// Currently `ExpectEqual()` does not support the variant tensor
// yet, so we manually cast the variant to numeric/string tensor.
const Tensor *output =
out_tensors[i].scalar<Variant>()().get<Tensor>();
const Tensor *expected_output =
test_case.expected_outputs[i].scalar<Variant>()().get<Tensor>();
TF_EXPECT_OK(ExpectEqual(*output, *expected_output));
} else {
TF_EXPECT_OK(
ExpectEqual(out_tensors[i], test_case.expected_outputs[i]));
}
}
}
cur_iteration++;
}
if (breakpoint >= test_case.expected_cardinality) {
EXPECT_TRUE(end_of_sequence);
} else {
EXPECT_FALSE(end_of_sequence);
}
}
}
INSTANTIATE_TEST_CASE_P(
TensorDatasetOpTest, ParametrizedTensorDatasetOpTest,
::testing::ValuesIn(std::vector<TestCase>({PlainTensorsTestCase(),
NestedTensorsTestCase()})));
} // namespace
} // namespace data
} // namespace tensorflow

View File

@ -82,10 +82,8 @@ Status ComputeSpansCore(OpKernelContext* context, const Kernel& kernel,
const float col_f = x + 0.5f;
const float sample_f = col_f * inv_scale + inv_translate;
// Don't sample when the sampling *kernel* is completely outside the
// source image.
if (sample_f < 0 - kernel.Radius() * kernel_scale ||
sample_f > input_size + kernel.Radius() * kernel_scale) {
// Don't sample when the sampling location is outside the source image.
if (sample_f < 0 || sample_f > input_size) {
// Add an empty span.
starts_vec(x) = 0;
continue;
@ -169,11 +167,15 @@ Status ComputeGradSpansCore(OpKernelContext* context, const Spans& spans,
auto grad_weights_vec = grad_spans->weights.vec<float>();
grad_weights_vec.setZero();
for (int input_index = 0; input_index < forward_input_size; ++input_index) {
const int start_span = grad_components[input_index].front().index;
grad_starts_vec(input_index) = start_span;
for (const GradComponent& gc : grad_components[input_index]) {
grad_weights_vec(input_index * grad_spans->span_size + gc.index -
start_span) += gc.weight;
if (!grad_components[input_index].empty()) {
const int start_span = grad_components[input_index].front().index;
grad_starts_vec(input_index) = start_span;
for (const GradComponent& gc : grad_components[input_index]) {
grad_weights_vec(input_index * grad_spans->span_size + gc.index -
start_span) += gc.weight;
}
} else {
grad_starts_vec(input_index) = 0;
}
}
return Status::OK();

View File

@ -120,7 +120,8 @@ void Sample(const DynamicKernel& kernel, const bool antialias,
1;
std::fill(dest, dest + channels, 0.0f);
if (y_span_end <= y_span_start || x_span_end <= x_span_start) {
if (sample_f.x() < 0.0f || sample_f.y() < 0.0f || sample_f.x() > in_width ||
sample_f.y() > in_height) {
return;
}
const Vector2f one_over_kernel_scale(1.0f / kernel_scale.x(),
@ -170,6 +171,8 @@ void ScaleAndTranslateBaseline(const DynamicKernel& kernel,
const int64 out_height = output.dimension(1);
const int64 out_width = output.dimension(2);
const int64 in_height = images.dimension(1);
const int64 in_width = images.dimension(2);
for (int b = 0; b < batch; ++b) {
for (int64 y = 0; y < out_height; ++y) {
@ -178,8 +181,13 @@ void ScaleAndTranslateBaseline(const DynamicKernel& kernel,
for (int64 x = 0; x < out_width; ++x) {
const float out_x_f = static_cast<float>(x) + 0.5;
const float in_x_f = out_x_f * scale.x() + translate.x();
Sample(kernel, antialias, images, b, scale, Vector2f(in_x_f, in_y_f),
&output(b, y, x, 0));
if (in_x_f < 0.0f || in_y_f < 0.0f || in_x_f > in_width ||
in_y_f > in_height) {
std::fill(&output(b, y, x, 0), &output(b, y, x + 1, 0), 0.0f);
} else {
Sample(kernel, antialias, images, b, scale, Vector2f(in_x_f, in_y_f),
&output(b, y, x, 0));
}
}
}
}

View File

@ -216,7 +216,7 @@ class StageOp : public OpKernel {
};
REGISTER_KERNEL_BUILDER(Name("Stage").Device(DEVICE_CPU), StageOp);
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
REGISTER_KERNEL_BUILDER(Name("Stage").Device(DEVICE_GPU), StageOp);
#endif
#ifdef TENSORFLOW_USE_SYCL
@ -249,7 +249,7 @@ class UnstageOp : public OpKernel {
};
REGISTER_KERNEL_BUILDER(Name("Unstage").Device(DEVICE_CPU), UnstageOp);
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
REGISTER_KERNEL_BUILDER(Name("Unstage").Device(DEVICE_GPU), UnstageOp);
#endif
#ifdef TENSORFLOW_USE_SYCL
@ -284,7 +284,7 @@ class StagePeekOp : public OpKernel {
};
REGISTER_KERNEL_BUILDER(Name("StagePeek").Device(DEVICE_CPU), StagePeekOp);
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
REGISTER_KERNEL_BUILDER(
Name("StagePeek").HostMemory("index").Device(DEVICE_GPU), StagePeekOp);
#endif
@ -314,7 +314,7 @@ class StageSizeOp : public OpKernel {
};
REGISTER_KERNEL_BUILDER(Name("StageSize").Device(DEVICE_CPU), StageSizeOp);
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
REGISTER_KERNEL_BUILDER(Name("StageSize").HostMemory("size").Device(DEVICE_GPU),
StageSizeOp);
#endif
@ -339,7 +339,7 @@ class StageClearOp : public OpKernel {
};
REGISTER_KERNEL_BUILDER(Name("StageClear").Device(DEVICE_CPU), StageClearOp);
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
REGISTER_KERNEL_BUILDER(Name("StageClear").Device(DEVICE_GPU), StageClearOp);
#endif
#ifdef TENSORFLOW_USE_SYCL

View File

@ -134,7 +134,7 @@ TF_CALL_half(REGISTER_CPU);
TF_CALL_float(REGISTER_CPU);
TF_CALL_double(REGISTER_CPU);
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
REGISTER_KERNEL_BUILDER(Name("SoftmaxCrossEntropyWithLogits")
.Device(DEVICE_GPU)
.TypeConstraint<Eigen::half>("T"),
@ -147,7 +147,7 @@ REGISTER_KERNEL_BUILDER(Name("SoftmaxCrossEntropyWithLogits")
.Device(DEVICE_GPU)
.TypeConstraint<double>("T"),
SoftmaxXentWithLogitsOp<GPUDevice, double>);
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#ifdef TENSORFLOW_USE_SYCL
REGISTER_KERNEL_BUILDER(Name("SoftmaxCrossEntropyWithLogits")

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU
@ -54,4 +54,4 @@ template struct functor::XentFunctor<GPUDevice, double>;
} // end namespace tensorflow
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

View File

@ -21990,6 +21990,34 @@ op {
}
}
}
op {
name: "DivNoNan"
input_arg {
name: "x"
type_attr: "T"
}
input_arg {
name: "y"
type_attr: "T"
}
output_arg {
name: "z"
type_attr: "T"
}
attr {
name: "T"
type: "type"
allowed_values {
list {
type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_COMPLEX64
type: DT_COMPLEX128
}
}
}
}
op {
name: "DrawBoundingBoxes"
input_arg {
@ -40851,6 +40879,35 @@ op {
}
is_commutative: true
}
op {
name: "MulNoNan"
input_arg {
name: "x"
type_attr: "T"
}
input_arg {
name: "y"
type_attr: "T"
}
output_arg {
name: "z"
type_attr: "T"
}
attr {
name: "T"
type: "type"
allowed_values {
list {
type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_COMPLEX64
type: DT_COMPLEX128
}
}
}
is_commutative: true
}
op {
name: "MultiDeviceIterator"
output_arg {

View File

@ -435,7 +435,7 @@ REGISTER_OP("MulNoNan")
.Input("x: T")
.Input("y: T")
.Output("z: T")
.Attr("T: {float, double}")
.Attr("T: {half, float, double, complex64, complex128}")
.SetIsCommutative()
.SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
@ -460,7 +460,7 @@ REGISTER_OP("DivNoNan")
.Input("x: T")
.Input("y: T")
.Output("z: T")
.Attr("T: {float, double}")
.Attr("T: {half, float, double, complex64, complex128}")
.SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
REGISTER_OP("FloorDiv")

View File

@ -10019,8 +10019,11 @@ op {
type: "type"
allowed_values {
list {
type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_COMPLEX64
type: DT_COMPLEX128
}
}
}
@ -20017,8 +20020,11 @@ op {
type: "type"
allowed_values {
list {
type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_COMPLEX64
type: DT_COMPLEX128
}
}
}

View File

@ -284,7 +284,7 @@ Status OAuthClient::ParseOAuthResponse(StringPiece response,
return errors::FailedPrecondition("Unexpected Oauth token type: " +
token_type);
}
int64 expires_in;
int64 expires_in = 0;
TF_RETURN_IF_ERROR(ReadJsonInt(root, "expires_in", &expires_in));
*expiration_timestamp_sec = request_timestamp_sec + expires_in;
TF_RETURN_IF_ERROR(ReadJsonString(root, "access_token", token));

View File

@ -16,6 +16,7 @@ limitations under the License.
#include <setjmp.h>
#include <stdio.h>
#include <string.h>
#include <cmath>
#include <fstream>
#include <vector>
@ -228,7 +229,9 @@ void DecodeLocation(const float* encoded_location, const float* box_priors,
}
}
float DecodeScore(float encoded_score) { return 1 / (1 + exp(-encoded_score)); }
float DecodeScore(float encoded_score) {
return 1 / (1 + std::exp(-encoded_score));
}
void DrawBox(const int image_width, const int image_height, int left, int top,
int right, int bottom, tensorflow::TTypes<uint8>::Flat* image) {

View File

@ -72,7 +72,7 @@ class TextEmbeddingModel(tf.train.Checkpoint):
normalized_sentences = tf.strings.regex_replace(
input=sentences, pattern=r"\pP", rewrite="")
normalized_sentences = tf.reshape(normalized_sentences, [-1])
sparse_tokens = tf.string_split(normalized_sentences, " ")
sparse_tokens = tf.strings.split(normalized_sentences, " ")
# Deal with a corner case: there is one empty sentence.
sparse_tokens, _ = tf.sparse.fill_empty_rows(sparse_tokens, tf.constant(""))

View File

@ -50,7 +50,7 @@ class TextRnnModel(tf.train.Checkpoint):
# splitting on spaces.
normalized_sentences = tf.strings.regex_replace(
input=sentences, pattern=r"\pP", rewrite="")
sparse_tokens = tf.string_split(normalized_sentences, " ")
sparse_tokens = tf.strings.split(normalized_sentences, " ")
# Deal with a corner case: there is one empty sentence.
sparse_tokens, _ = tf.sparse.fill_empty_rows(sparse_tokens, tf.constant(""))

View File

@ -53,7 +53,7 @@ type Graph struct {
c *C.TF_Graph
}
// Graph execution options
// The GraphImportOptions struct holds parameters for the ImportWithOptions function.
type GraphImportOptions struct {
// Node prefix
Prefix string
@ -170,7 +170,7 @@ func (g *Graph) Operation(name string) *Operation {
// Operations returns a list of all operations in the graph
func (g *Graph) Operations() []Operation {
var pos C.size_t = 0
var pos C.size_t
ops := []Operation{}
for {
cop := C.TF_GraphNextOperation(g.c, &pos)

View File

@ -5342,7 +5342,6 @@ py_library(
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [
":distribute",
":framework",
":framework_for_generated_wrappers",
":platform",
@ -5371,6 +5370,7 @@ py_library(
":summary_ops_gen",
":summary_ops_v2",
":util",
"//tensorflow/python/distribute:summary_op_util",
"//tensorflow/python/eager:context",
"@six_archive//:six",
],

View File

@ -88,20 +88,6 @@ from tensorflow.python.util.tf_export import tf_export
# TODO(mdan): Add a test specific to this converter.
# TODO(mdan): Remove when updating the API.
@tf_export('autograph.experimental.Verbosity')
class Verbosity(enum.IntEnum):
"""Represents conversion verbosity levels.
Attributes:
BRIEF: No logging, minimal error messages.
VERBOSE: Detailed logging of generated code, detailed error messages.
"""
BRIEF = 0
VERBOSE = 1
@tf_export('autograph.experimental.Feature')
class Feature(enum.Enum):
"""Represents conversion options that can be toggled on or off.

View File

@ -377,16 +377,12 @@ def _is_not_callable(obj):
return False
# TODO(mdan): Remove obsolete args.
@tf_export('autograph.to_graph')
def to_graph(entity,
recursive=True,
arg_values=None,
arg_types=None,
experimental_optional_features=converter.Feature.ALL,
experimental_strip_decorators=None,
experimental_verbose=converter.Verbosity.BRIEF,
experimental_partial_types=None):
experimental_optional_features=converter.Feature.ALL):
"""Converts a Python entity into a TensorFlow graph.
Also see: `tf.autograph.to_code`, `tf.function`.
@ -442,9 +438,6 @@ def to_graph(entity,
experimental_optional_features: `None`, a tuple of, or a single
`tf.autograph.experimental.Feature` value. Controls the use of
optional features in the conversion process.
experimental_strip_decorators: Deprecated, unused.
experimental_verbose: Deprecated, unused.
experimental_partial_types: Deprecated, unused.
Returns:
Same as `entity`, the converted Python function or class.
@ -452,10 +445,6 @@ def to_graph(entity,
Raises:
ValueError: If the entity could not be converted.
"""
del experimental_strip_decorators
del experimental_verbose
del experimental_partial_types
try:
program_ctx = converter.ProgramContext(
options=converter.ConversionOptions(
@ -520,8 +509,7 @@ def to_code(entity,
arg_values=None,
arg_types=None,
indentation=' ',
experimental_optional_features=converter.Feature.ALL,
experimental_partial_types=None):
experimental_optional_features=converter.Feature.ALL):
"""Similar to `to_graph`, but returns Python source code as a string.
Also see: `tf.autograph.to_graph`.
@ -544,13 +532,10 @@ def to_code(entity,
experimental_optional_features: `None`, a tuple of, or a single
`tf.autograph.experimental.Feature` value. Controls the use of
optional features in the conversion process.
experimental_partial_types: Deprecated, unused.
Returns:
The converted code as string.
"""
del experimental_partial_types
program_ctx = converter.ProgramContext(
options=converter.ConversionOptions(
recursive=recursive,

View File

@ -27,7 +27,7 @@ import datetime
from tensorflow.python.util import tf_contextlib
from tensorflow.python.util.tf_export import tf_export
_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2019, 3, 14)
_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2019, 3, 15)
@tf_export("compat.forward_compatible")

View File

@ -19,70 +19,19 @@ from __future__ import print_function
from tensorflow.python.data.experimental.ops import random_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import convert
from tensorflow.python.data.ops import readers
from tensorflow.python.data.util import nest
from tensorflow.python.data.util import structure
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
from tensorflow.python.ops import gen_experimental_dataset_ops
from tensorflow.python.ops import gen_stateless_random_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
class _ParallelInterleaveDataset(dataset_ops.UnaryDataset):
"""A `Dataset` that maps a function over its input and flattens the result."""
def __init__(self, input_dataset, map_func, cycle_length, block_length,
sloppy, buffer_output_elements, prefetch_input_elements):
"""See `tf.data.experimental.parallel_interleave()` for details."""
self._input_dataset = input_dataset
self._map_func = dataset_ops.StructuredFunctionWrapper(
map_func, self._transformation_name(), dataset=input_dataset)
if not isinstance(self._map_func.output_structure,
dataset_ops.DatasetStructure):
raise TypeError("`map_func` must return a `Dataset` object.")
self._structure = self._map_func.output_structure._element_structure # pylint: disable=protected-access
self._cycle_length = ops.convert_to_tensor(
cycle_length, dtype=dtypes.int64, name="cycle_length")
self._block_length = ops.convert_to_tensor(
block_length, dtype=dtypes.int64, name="block_length")
self._sloppy = ops.convert_to_tensor(
sloppy, dtype=dtypes.bool, name="sloppy")
self._buffer_output_elements = convert.optional_param_to_tensor(
"buffer_output_elements",
buffer_output_elements,
argument_default=2 * block_length)
self._prefetch_input_elements = convert.optional_param_to_tensor(
"prefetch_input_elements",
prefetch_input_elements,
argument_default=2 * cycle_length)
variant_tensor = ged_ops.experimental_parallel_interleave_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
self._map_func.function.captured_inputs,
self._cycle_length,
self._block_length,
self._sloppy,
self._buffer_output_elements,
self._prefetch_input_elements,
f=self._map_func.function,
**dataset_ops.flat_structure(self))
super(_ParallelInterleaveDataset, self).__init__(input_dataset,
variant_tensor)
def _functions(self):
return [self._map_func]
@property
def _element_structure(self):
return self._structure
def _transformation_name(self):
return "tf.data.experimental.parallel_interleave()"
@deprecation.deprecated(
None,
"Use `tf.data.Dataset.interleave(map_func, cycle_length, block_length, "
@ -139,7 +88,7 @@ def parallel_interleave(map_func,
`tf.data.Dataset.apply`.
"""
def _apply_fn(dataset):
return _ParallelInterleaveDataset(
return readers.ParallelInterleaveDataset(
dataset, map_func, cycle_length, block_length, sloppy,
buffer_output_elements, prefetch_input_elements)
@ -176,10 +125,11 @@ class _DirectedInterleaveDataset(dataset_ops.Dataset):
def _as_variant_tensor(self):
# pylint: disable=protected-access
return (ged_ops.experimental_directed_interleave_dataset(
self._selector_input._variant_tensor,
[data_input._variant_tensor for data_input in self._data_inputs],
**dataset_ops.flat_structure(self)))
return (
gen_experimental_dataset_ops.experimental_directed_interleave_dataset(
self._selector_input._variant_tensor,
[data_input._variant_tensor for data_input in self._data_inputs],
**dataset_ops.flat_structure(self)))
# pylint: enable=protected-access
def _inputs(self):

View File

@ -25,6 +25,7 @@ import numpy as np
from tensorflow.python.data.experimental.ops import batching
from tensorflow.python.data.experimental.ops import error_ops
from tensorflow.python.data.experimental.ops import interleave_ops
from tensorflow.python.data.experimental.ops import optimization
from tensorflow.python.data.experimental.ops import parsing_ops
from tensorflow.python.data.experimental.ops import shuffle_ops
@ -493,15 +494,9 @@ def make_csv_dataset_v2(
return features
# Read files sequentially (if num_parallel_reads=1) or in parallel
dataset = dataset.interleave(
filename_to_dataset,
cycle_length=num_parallel_reads,
num_parallel_calls=num_parallel_reads)
if sloppy:
options = dataset_ops.Options()
options.experimental_deterministic = False
dataset = dataset.with_options(options)
dataset = dataset.apply(
interleave_ops.parallel_interleave(
filename_to_dataset, cycle_length=num_parallel_reads, sloppy=sloppy))
dataset = _maybe_shuffle_and_repeat(
dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed)
@ -833,15 +828,11 @@ def make_batched_features_dataset_v2(file_pattern,
reader_args = []
# Read files sequentially (if reader_num_threads=1) or in parallel
dataset = dataset.interleave(
lambda filename: reader(filename, *reader_args),
cycle_length=reader_num_threads,
num_parallel_calls=reader_num_threads)
if sloppy_ordering:
options = dataset_ops.Options()
options.experimental_deterministic = False
dataset = dataset.with_options(options)
dataset = dataset.apply(
interleave_ops.parallel_interleave(
lambda filename: reader(filename, *reader_args),
cycle_length=reader_num_threads,
sloppy=sloppy_ordering))
# Extract values if the `Example` tensors are stored as key-value tuples.
if dataset_ops.get_legacy_output_types(dataset) == (

View File

@ -1188,7 +1188,9 @@ class DatasetV2(object):
"""
dataset = transformation_func(self)
if not isinstance(dataset, DatasetV2):
raise TypeError("`transformation_func` must return a Dataset.")
raise TypeError(
"`transformation_func` must return a Dataset. Got {}.".format(
dataset))
dataset._input_datasets = [self] # pylint: disable=protected-access
return dataset
@ -2121,7 +2123,9 @@ class SparseTensorSliceDataset(DatasetSource):
def __init__(self, sparse_tensor):
"""See `Dataset.from_sparse_tensor_slices()` for details."""
if not isinstance(sparse_tensor, sparse_tensor_lib.SparseTensor):
raise TypeError("`sparse_tensor` must be a `tf.SparseTensor` object.")
raise TypeError(
"`sparse_tensor` must be a `tf.SparseTensor` object. Was {}.".format(
sparse_tensor))
self._sparse_tensor = sparse_tensor
indices_shape = self._sparse_tensor.indices.get_shape()
@ -2885,7 +2889,11 @@ def _default_padding(input_dataset):
if t.base_dtype == dtypes.string:
return ""
elif t.base_dtype == dtypes.variant:
raise TypeError("Unable to create padding for field of type 'variant'")
error_msg = ("Unable to create padding for field of type 'variant' "
"because t.base_type == dtypes.variant == "
"{}.".format(
t.base_dtype))
raise TypeError(error_msg)
else:
return np.zeros_like(t.as_numpy_dtype())
@ -3066,7 +3074,9 @@ class FlatMapDataset(UnaryDataset):
self._map_func = StructuredFunctionWrapper(
map_func, self._transformation_name(), dataset=input_dataset)
if not isinstance(self._map_func.output_structure, DatasetStructure):
raise TypeError("`map_func` must return a `Dataset` object.")
raise TypeError(
"`map_func` must return a `Dataset` object. Got {}".format(
type(self._map_func.output_structure)))
self._structure = self._map_func.output_structure._element_structure # pylint: disable=protected-access
variant_tensor = gen_dataset_ops.flat_map_dataset(
input_dataset._variant_tensor, # pylint: disable=protected-access
@ -3096,7 +3106,9 @@ class InterleaveDataset(UnaryDataset):
self._map_func = StructuredFunctionWrapper(
map_func, self._transformation_name(), dataset=input_dataset)
if not isinstance(self._map_func.output_structure, DatasetStructure):
raise TypeError("`map_func` must return a `Dataset` object.")
raise TypeError(
"`map_func` must return a `Dataset` object. Got {}".format(
type(self._map_func.output_structure)))
self._structure = self._map_func.output_structure._element_structure # pylint: disable=protected-access
self._cycle_length = ops.convert_to_tensor(
cycle_length, dtype=dtypes.int64, name="cycle_length")
@ -3124,8 +3136,7 @@ class InterleaveDataset(UnaryDataset):
class ParallelInterleaveDataset(UnaryDataset):
"""A `Dataset` that maps a function over its input and interleaves the result.
"""
"""A `Dataset` that maps a function over its input and interleaves the result."""
def __init__(self, input_dataset, map_func, cycle_length, block_length,
num_parallel_calls):
@ -3134,7 +3145,9 @@ class ParallelInterleaveDataset(UnaryDataset):
self._map_func = StructuredFunctionWrapper(
map_func, self._transformation_name(), dataset=input_dataset)
if not isinstance(self._map_func.output_structure, DatasetStructure):
raise TypeError("`map_func` must return a `Dataset` object.")
raise TypeError(
"`map_func` must return a `Dataset` object. Got {}".format(
type(self._map_func.output_structure)))
self._structure = self._map_func.output_structure._element_structure # pylint: disable=protected-access
self._cycle_length = ops.convert_to_tensor(
cycle_length, dtype=dtypes.int64, name="cycle_length")
@ -3177,7 +3190,10 @@ class FilterDataset(UnaryUnchangedStructureDataset):
use_legacy_function=use_legacy_function)
if not wrapped_func.output_structure.is_compatible_with(
structure_lib.TensorStructure(dtypes.bool, [])):
raise ValueError("`predicate` must return a scalar boolean tensor.")
error_msg = ("`predicate` return type must be convertible to a scalar "
"boolean tensor. Was {}.").format(
wrapped_func.output_structure)
raise ValueError(error_msg)
self._predicate = wrapped_func
variant_tensor = gen_dataset_ops.filter_dataset(
input_dataset._variant_tensor, # pylint: disable=protected-access

View File

@ -26,6 +26,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
from tensorflow.python.util.tf_export import tf_export
@ -118,6 +119,57 @@ class _TFRecordDataset(dataset_ops.DatasetSource):
return structure.TensorStructure(dtypes.string, [])
class ParallelInterleaveDataset(dataset_ops.UnaryDataset):
"""A `Dataset` that maps a function over its input and flattens the result."""
def __init__(self, input_dataset, map_func, cycle_length, block_length,
sloppy, buffer_output_elements, prefetch_input_elements):
"""See `tf.data.experimental.parallel_interleave()` for details."""
self._input_dataset = input_dataset
self._map_func = dataset_ops.StructuredFunctionWrapper(
map_func, self._transformation_name(), dataset=input_dataset)
if not isinstance(self._map_func.output_structure,
dataset_ops.DatasetStructure):
raise TypeError("`map_func` must return a `Dataset` object.")
self._structure = self._map_func.output_structure._element_structure # pylint: disable=protected-access
self._cycle_length = ops.convert_to_tensor(
cycle_length, dtype=dtypes.int64, name="cycle_length")
self._block_length = ops.convert_to_tensor(
block_length, dtype=dtypes.int64, name="block_length")
self._sloppy = ops.convert_to_tensor(
sloppy, dtype=dtypes.bool, name="sloppy")
self._buffer_output_elements = convert.optional_param_to_tensor(
"buffer_output_elements",
buffer_output_elements,
argument_default=2 * block_length)
self._prefetch_input_elements = convert.optional_param_to_tensor(
"prefetch_input_elements",
prefetch_input_elements,
argument_default=2 * cycle_length)
variant_tensor = ged_ops.experimental_parallel_interleave_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
self._map_func.function.captured_inputs,
self._cycle_length,
self._block_length,
self._sloppy,
self._buffer_output_elements,
self._prefetch_input_elements,
f=self._map_func.function,
**dataset_ops.flat_structure(self))
super(ParallelInterleaveDataset, self).__init__(input_dataset,
variant_tensor)
def _functions(self):
return [self._map_func]
@property
def _element_structure(self):
return self._structure
def _transformation_name(self):
return "tf.data.experimental.parallel_interleave()"
@tf_export("data.TFRecordDataset", v1=[])
class TFRecordDatasetV2(dataset_ops.DatasetV2):
"""A `Dataset` comprising records from one or more TFRecord files."""
@ -169,10 +221,10 @@ class TFRecordDatasetV2(dataset_ops.DatasetV2):
if num_parallel_reads is None:
self._impl = filenames.flat_map(read_one_file)
else:
self._impl = filenames.interleave(
read_one_file,
cycle_length=num_parallel_reads,
num_parallel_calls=num_parallel_reads)
self._impl = ParallelInterleaveDataset(
filenames, read_one_file, cycle_length=num_parallel_reads,
block_length=1, sloppy=False, buffer_output_elements=None,
prefetch_input_elements=None)
variant_tensor = self._impl._variant_tensor # pylint: disable=protected-access
super(TFRecordDatasetV2, self).__init__(variant_tensor)

View File

@ -469,6 +469,16 @@ py_test(
],
)
py_library(
name = "summary_op_util",
srcs = ["summary_op_util.py"],
deps = [
":distribute_lib",
"//tensorflow/python:framework_ops",
"//tensorflow/python:tensor_util",
],
)
py_library(
name = "values",
srcs = ["values.py"],

View File

@ -512,23 +512,50 @@ class DistributionStrategy(object):
_require_cross_replica_or_default_context_extended(self._extended)
return self._extended._reduce(reduce_op, value) # pylint: disable=protected-access
@doc_controls.do_not_generate_docs # DEPRECATED, -> `DistributedValues`
@doc_controls.do_not_generate_docs # DEPRECATED
def unwrap(self, value):
"""Returns the list of all per-replica values contained in `value`.
"""Returns the list of all local per-replica values contained in `value`.
DEPRECATED: Please use `experimental_local_results` instead.
Note: This only returns values on the workers initiated by this client.
When using a `Strategy` like
`tf.distribute.experimental.MultiWorkerMirroredStrategy`, each worker
will be its own client, and this function will only return values
computed on that worker.
Args:
value: A value returned by `extended.call_for_each_replica()` or a
variable created in `scope`.
value: A value returned by `experimental_run()`,
`extended.call_for_each_replica()`, or a variable created in `scope`.
Returns:
A tuple of values contained in `value`. If `value` represents a single
value, this returns `(value,).`
"""
return self._extended._unwrap(value) # pylint: disable=protected-access
return self._extended._local_results(value) # pylint: disable=protected-access
@doc_controls.do_not_generate_docs # DEPRECATED, -> `DistributedValues`
def experimental_local_results(self, value):
"""Returns the list of all local per-replica values contained in `value`.
Note: This only returns values on the workers initiated by this client.
When using a `Strategy` like
`tf.distribute.experimental.MultiWorkerMirroredStrategy`, each worker
will be its own client, and this function will only return values
computed on that worker.
Args:
value: A value returned by `experimental_run()`, `experimental_run_v2()`,
`extended.call_for_each_replica()`, or a variable created in `scope`.
Returns:
A tuple of values contained in `value`. If `value` represents a single
value, this returns `(value,).`
"""
return self._extended._local_results(value) # pylint: disable=protected-access
@doc_controls.do_not_generate_docs # DEPRECATED: TF v1.x only
def group(self, value, name=None):
"""Shortcut for `tf.group(self.unwrap(value))`."""
"""Shortcut for `tf.group(self.experimental_local_results(value))`."""
return self._extended._group(value, name) # pylint: disable=protected-access
@property
@ -1067,7 +1094,7 @@ class DistributionStrategyExtended(object):
# Called once in "cross-replica" context.
def merge_fn(distribution, three_plus_replica_id):
# sum the values across replicas
return sum(distribution.unwrap(three_plus_replica_id))
return sum(distribution.experimental_local_results(three_plus_replica_id))
# Called once per replica in `distribution`, in a "replica" context.
def fn(three):
@ -1082,7 +1109,8 @@ class DistributionStrategyExtended(object):
...
merged_results = distribution.call_for_each_replica(fn, args=[3])
# merged_results has the values from every replica execution of `fn`.
print(distribution.unwrap(merged_results)) # Prints a list
# This statement prints a list:
print(distribution.experimental_local_results(merged_results))
```
Args:
@ -1104,8 +1132,9 @@ class DistributionStrategyExtended(object):
def _reduce(self, reduce_op, value):
# Default implementation until we have an implementation for each strategy.
return self._unwrap(self._reduce_to(
reduce_op, value, device_util.current() or "/device:CPU:0"))[0]
return self._local_results(
self._reduce_to(reduce_op, value,
device_util.current() or "/device:CPU:0"))[0]
def reduce_to(self, reduce_op, value, destinations):
"""Combine (via e.g. sum or mean) values across replicas.
@ -1224,7 +1253,7 @@ class DistributionStrategyExtended(object):
def _update_non_slot(self, colocate_with, fn, args, kwargs, group):
raise NotImplementedError("must be implemented in descendants")
def _unwrap(self, distributed_value):
def _local_results(self, distributed_value):
raise NotImplementedError("must be implemented in descendants")
def value_container(self, value):
@ -1238,13 +1267,14 @@ class DistributionStrategyExtended(object):
A container that `value` belongs to.
If value does not belong to any container (including the case of
container having been destroyed), returns the value itself.
`value in unwrap(value_container(value))` will always be true.
`value in experimental_local_results(value_container(value))` will
always be true.
"""
raise NotImplementedError("must be implemented in descendants")
def _group(self, value, name=None):
"""Shortcut for `tf.group(distribution.unwrap(value))`."""
value = nest.flatten(self._unwrap(value))
"""Implementation of `group`."""
value = nest.flatten(self._local_results(value))
if len(value) != 1 or name is not None:
return control_flow_ops.group(value, name=name)
@ -1590,12 +1620,12 @@ class _DefaultDistributionExtended(DistributionStrategyExtended):
if should_group:
return result
else:
return nest.map_structure(self._unwrap, result)
return nest.map_structure(self._local_results, result)
def read_var(self, replica_local_var):
return array_ops.identity(replica_local_var)
def _unwrap(self, distributed_value):
def _local_results(self, distributed_value):
return (distributed_value,)
def value_container(self, value):

View File

@ -648,6 +648,7 @@ class MultiStepContext(object):
def merge_fn(distribution, value):
# NOTE(priyag): For non tensor outputs, we simply return all the values
# in a list as reduction doesn't make sense on non tensors.
self._non_tensor_outputs[name] = distribution.unwrap(value)
self._non_tensor_outputs[name] = (
distribution.experimental_local_results(value))
distribution_strategy_context.get_replica_context().merge_call(
merge_fn, args=(output,))

View File

@ -602,7 +602,7 @@ class MirroredExtended(distribute_lib.DistributionStrategyExtended):
fn_result = fn(ctx, iterator.get_next())
for (name, output) in ctx.last_step_outputs.items():
# Convert all outputs to tensors, potentially from `DistributedValues`.
ctx.last_step_outputs[name] = self._unwrap(output)
ctx.last_step_outputs[name] = self._local_results(output)
flat_last_step_outputs = nest.flatten(ctx.last_step_outputs)
with ops.control_dependencies([fn_result]):
return [i + 1] + flat_last_step_outputs
@ -741,7 +741,7 @@ class MirroredExtended(distribute_lib.DistributionStrategyExtended):
assert isinstance(replica_local_var, values.Mirrored)
return array_ops.identity(replica_local_var.get())
def _unwrap(self, val):
def _local_results(self, val):
if isinstance(val, values.DistributedValues):
return val.values
return (val,)

View File

@ -159,13 +159,13 @@ class OneDeviceExtended(distribute_lib.DistributionStrategyExtended):
if group:
return result
else:
return nest.map_structure(self._unwrap, result)
return nest.map_structure(self._local_results, result)
def read_var(self, replica_local_var):
"""Read the aggregate value of a replica-local variable."""
return array_ops.identity(replica_local_var)
def _unwrap(self, value):
def _local_results(self, value):
return (value,)
def value_container(self, value):

View File

@ -426,7 +426,7 @@ class ParameterServerStrategyExtended(
if group:
return result
else:
return nest.map_structure(self._unwrap, result)
return nest.map_structure(self._local_results, result)
# TODO(yuefengz): does it need to call _select_single_value?
def _update_non_slot(self, colocate_with, fn, args, kwargs, group):
@ -436,9 +436,9 @@ class ParameterServerStrategyExtended(
if group:
return result
else:
return nest.map_structure(self._unwrap, result)
return nest.map_structure(self._local_results, result)
def _unwrap(self, val):
def _local_results(self, val):
if isinstance(val, values.DistributedValues):
return val.values
return (val,)

View File

@ -0,0 +1,48 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#==============================================================================
"""Contains utility functions used by summary ops in distribution strategy."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
def skip_summary():
"""Determines if summary should be skipped.
If using multiple replicas in distributed strategy, skip summaries on all
replicas except the first one (replica_id=0).
Returns:
True if the summary is skipped; False otherwise.
"""
# TODO(priyag): Add a new optional argument that will provide multiple
# alternatives to override default behavior. (e.g. run on last replica,
# compute sum or mean across replicas).
replica_context = distribution_strategy_context.get_replica_context()
if not replica_context:
return False
# TODO(b/118385803): when replica_id of _TPUReplicaContext is properly
# initialized, remember to change here as well.
replica_id = replica_context.replica_id_in_sync_group
if isinstance(replica_id, ops.Tensor):
replica_id = tensor_util.constant_value(replica_id)
return replica_id and replica_id > 0

View File

@ -500,7 +500,7 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended):
assert isinstance(var, values.TPUMirroredVariable)
return var.read_value()
def _unwrap(self, val):
def _local_results(self, val):
if isinstance(val, values.DistributedValues):
# Return in a deterministic order.
return tuple(val.get(device=d) for d in sorted(val.devices))
@ -589,7 +589,7 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended):
if group:
return result
else:
return nest.map_structure(self._unwrap, result)
return nest.map_structure(self._local_results, result)
def _configure(self,
session_config=None,

View File

@ -272,7 +272,7 @@ class DistributedValues(object):
def device_map(self):
return self._device_map
# TODO(josh11b): Replace unwrap with this?
# TODO(josh11b): Replace experimental_local_results with this?
@property
def values(self):
return self._values
@ -622,8 +622,9 @@ def validate_colocate(v, extended):
def _apply_aggregation(strategy, value, aggregation, destinations):
if aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA:
return strategy.extended.broadcast_to(strategy.unwrap(value)[0],
destinations=destinations)
return strategy.extended.broadcast_to(
strategy.experimental_local_results(value)[0],
destinations=destinations)
reduce_op = reduce_util.ReduceOp.from_variable_aggregation(aggregation)
return strategy.extended.reduce_to(reduce_op, value, destinations)
@ -824,7 +825,7 @@ class TPUMirroredVariable(trackable.Trackable):
def device_map(self):
return self._device_map
# TODO(josh11b): Replace unwrap with this?
# TODO(josh11b): Replace experimental_local_results with this?
@property
def values(self):
return self._values
@ -1410,7 +1411,7 @@ def update_regroup(extended, device_map, updates, group):
# so we can avoid all these nest operations.
regrouped = regroup(device_map, updates, Mirrored)
if not group:
return nest.map_structure(extended._unwrap, regrouped) # pylint: disable=protected-access
return nest.map_structure(extended._local_results, regrouped) # pylint: disable=protected-access
grouped_flat = []
for u in nest.flatten(regrouped):
if isinstance(u, DistributedValues):

View File

@ -735,7 +735,7 @@ class ConcreteFunction(object):
# In case of eager execution, function definition gets added to context
# during construction itself.
# TODO(allel/shivaniagrawal): rename this to register to reflect the
# TODO(allenl/shivaniagrawal): rename this to register to reflect the
# method's functionality better. Remove register_gradient_functions argument
# and figure out if these needs to be registered.

View File

@ -73,6 +73,10 @@ def function_def_to_graph(fdef, input_shapes=None):
func_graph.outputs = [
func_graph.get_tensor_by_name(name) for name in output_tensor_names
]
func_graph.control_outputs = [
func_graph.get_operation_by_name(fdef.control_ret[ret_name])
for ret_name in fdef.signature.control_output
]
return func_graph

Some files were not shown because too many files have changed in this diff Show More