Merge branch 'master' into google_upstream_argmax_op
This commit is contained in:
commit
23d21d078c
28
configure.py
28
configure.py
@ -50,6 +50,7 @@ _DEFAULT_PROMPT_ASK_ATTEMPTS = 10
|
||||
_TF_BAZELRC_FILENAME = '.tf_configure.bazelrc'
|
||||
_TF_WORKSPACE_ROOT = ''
|
||||
_TF_BAZELRC = ''
|
||||
_TF_CURRENT_BAZEL_VERSION = None
|
||||
|
||||
NCCL_LIB_PATHS = [
|
||||
'lib64/', 'lib/powerpc64le-linux-gnu/', 'lib/x86_64-linux-gnu/', ''
|
||||
@ -337,8 +338,8 @@ def get_var(environ_cp,
|
||||
'Environment variable %s must be set as a boolean indicator.\n'
|
||||
'The following are accepted as TRUE : %s.\n'
|
||||
'The following are accepted as FALSE: %s.\n'
|
||||
'Current value is %s.' % (var_name, ', '.join(true_strings),
|
||||
', '.join(false_strings), var))
|
||||
'Current value is %s.' %
|
||||
(var_name, ', '.join(true_strings), ', '.join(false_strings), var))
|
||||
|
||||
while var is None:
|
||||
user_input_origin = get_input(question)
|
||||
@ -771,11 +772,12 @@ def check_ndk_level(android_ndk_home_path):
|
||||
else:
|
||||
raise Exception('Unable to parse NDK revision.')
|
||||
if int(ndk_api_level) not in _SUPPORTED_ANDROID_NDK_VERSIONS:
|
||||
print('WARNING: The API level of the NDK in %s is %s, which is not '
|
||||
'supported by Bazel (officially supported versions: %s). Please use '
|
||||
'another version. Compiling Android targets may result in confusing '
|
||||
'errors.\n' % (android_ndk_home_path, ndk_api_level,
|
||||
_SUPPORTED_ANDROID_NDK_VERSIONS))
|
||||
print(
|
||||
'WARNING: The API level of the NDK in %s is %s, which is not '
|
||||
'supported by Bazel (officially supported versions: %s). Please use '
|
||||
'another version. Compiling Android targets may result in confusing '
|
||||
'errors.\n' %
|
||||
(android_ndk_home_path, ndk_api_level, _SUPPORTED_ANDROID_NDK_VERSIONS))
|
||||
return ndk_api_level
|
||||
|
||||
|
||||
@ -1230,8 +1232,8 @@ def set_tf_nccl_install_path(environ_cp):
|
||||
# Reset and Retry
|
||||
print(
|
||||
'Invalid path to NCCL %s toolkit, %s or %s not found. Please use the '
|
||||
'O/S agnostic package of NCCL 2' % (tf_nccl_version, nccl_lib_path,
|
||||
nccl_hdr_path))
|
||||
'O/S agnostic package of NCCL 2' %
|
||||
(tf_nccl_version, nccl_lib_path, nccl_hdr_path))
|
||||
|
||||
environ_cp['TF_NCCL_VERSION'] = ''
|
||||
else:
|
||||
@ -1498,6 +1500,7 @@ def set_other_mpi_vars(environ_cp):
|
||||
'Cannot find the MPI library file in %s/lib or %s/lib64 or %s/lib32' %
|
||||
(mpi_home, mpi_home, mpi_home))
|
||||
|
||||
|
||||
def system_specific_test_config(env):
|
||||
"""Add default test flags required for TF tests to bazelrc."""
|
||||
write_to_bazelrc('test --flaky_test_attempts=3')
|
||||
@ -1593,11 +1596,15 @@ def configure_apple_bazel_rules():
|
||||
existing_filepath = os.path.join(_TF_WORKSPACE_ROOT, filepath + '.apple')
|
||||
renamed_filepath = os.path.join(_TF_WORKSPACE_ROOT, filepath)
|
||||
os.rename(existing_filepath, renamed_filepath)
|
||||
if _TF_CURRENT_BAZEL_VERSION is None or _TF_CURRENT_BAZEL_VERSION < 23000:
|
||||
print(
|
||||
'Building Bazel rules on Apple platforms requires Bazel 0.23 or later.')
|
||||
|
||||
|
||||
def main():
|
||||
global _TF_WORKSPACE_ROOT
|
||||
global _TF_BAZELRC
|
||||
global _TF_CURRENT_BAZEL_VERSION
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
@ -1614,7 +1621,8 @@ def main():
|
||||
# environment variables.
|
||||
environ_cp = dict(os.environ)
|
||||
|
||||
check_bazel_version('0.19.0', '0.23.2')
|
||||
current_bazel_version = check_bazel_version('0.19.0', '0.23.2')
|
||||
_TF_CURRENT_BAZEL_VERSION = convert_version_to_int(current_bazel_version)
|
||||
|
||||
reset_tf_configure_bazelrc()
|
||||
|
||||
|
@ -50,6 +50,7 @@ WARNING: The TensorFlow contrib module will not be included in TensorFlow 2.0.
|
||||
For more information, please see:
|
||||
* https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
|
||||
* https://github.com/tensorflow/addons
|
||||
* https://github.com/tensorflow/io (for I/O related ops)
|
||||
If you depend on functionality not listed there, please file an issue.
|
||||
"""
|
||||
contrib = LazyLoader('contrib', globals(), 'tensorflow.contrib',
|
||||
|
@ -88,15 +88,19 @@ Status ScaleAndTranslateGradHelper(const Scope& scope, const Operation& op,
|
||||
string kernel_type;
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetNodeAttr(op.node()->attrs(), "kernel_type", &kernel_type));
|
||||
bool antialias;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "antialias", &antialias));
|
||||
grad_outputs->push_back(internal::ScaleAndTranslateGrad(
|
||||
scope, grad_inputs[0], op.input(0), op.input(2), op.input(3),
|
||||
internal::ScaleAndTranslateGrad::KernelType(kernel_type)));
|
||||
internal::ScaleAndTranslateGrad::KernelType(kernel_type)
|
||||
.Antialias(antialias)));
|
||||
|
||||
grad_outputs->push_back(NoGradient());
|
||||
grad_outputs->push_back(NoGradient());
|
||||
grad_outputs->push_back(NoGradient());
|
||||
return scope.status();
|
||||
}
|
||||
|
||||
REGISTER_GRADIENT_OP("ScaleAndTranslate", ScaleAndTranslateGradHelper);
|
||||
|
||||
Status CropAndResizeGradHelper(const Scope& scope, const Operation& op,
|
||||
|
@ -196,29 +196,106 @@ class ScaleAndTranslateGradTest : public ::testing::Test {
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void MakeOp(const Tensor& x_data, const Input& y_shape, Output* x,
|
||||
Output* y) {
|
||||
void MakeOp(const Tensor& x_data, const Input& y_shape, Input scale,
|
||||
Input translation, const string& kernel_type, bool antialias,
|
||||
Output* x, Output* y) {
|
||||
*x = Const<T>(scope_, x_data);
|
||||
*y = ScaleAndTranslate(scope_, *x, y_shape, {1.8f, 2.1f}, {0.5f, 0.7f});
|
||||
*y = ScaleAndTranslate(scope_, *x, y_shape, scale, translation,
|
||||
ScaleAndTranslate::KernelType(kernel_type)
|
||||
.Antialias(antialias)
|
||||
.Antialias(antialias));
|
||||
TF_ASSERT_OK(scope_.status());
|
||||
}
|
||||
|
||||
template <typename X_T, typename Y_T, typename JAC_T>
|
||||
void TestResize() {
|
||||
TensorShape x_shape({1, 2, 3, 1});
|
||||
void TestScaleAndTranslate(const TensorShape x_shape, const int out_height,
|
||||
const int out_width, Input scale,
|
||||
Input translation, const string& kernel_type,
|
||||
bool antialias) {
|
||||
Tensor x_data = MakeData<X_T>(x_shape);
|
||||
Output x, y;
|
||||
MakeOp<X_T>(x_data, {4, 6}, &x, &y);
|
||||
MakeOp<X_T>(x_data, {out_height, out_width}, scale, translation,
|
||||
kernel_type, antialias, &x, &y);
|
||||
JAC_T max_error;
|
||||
TF_ASSERT_OK((ComputeGradientError<X_T, Y_T, JAC_T>(
|
||||
scope_, x, x_data, y, {1, 4, 6, 1}, &max_error)));
|
||||
EXPECT_LT(max_error, 1e-3);
|
||||
scope_, x, x_data, y, {1, out_height, out_width, 1}, &max_error)));
|
||||
EXPECT_LT(max_error, 2e-3);
|
||||
}
|
||||
|
||||
const std::vector<Input> kScales = {Input{1.0f, 1.0f}, Input{0.37f, 0.47f},
|
||||
Input{2.1f, 2.1f}};
|
||||
const std::vector<Input> kTranslations = {
|
||||
Input{0.0f, 0.0f}, Input{3.14f, 1.19f}, Input{2.1f, 3.1f},
|
||||
Input{100.0f, 200.0f}};
|
||||
Scope scope_;
|
||||
};
|
||||
|
||||
TEST_F(ScaleAndTranslateGradTest, Works) { TestResize<float, float, float>(); }
|
||||
TEST_F(ScaleAndTranslateGradTest, TestGrads) {
|
||||
const std::vector<std::string> kKernelTypes = {"lanczos1", "lanczos3",
|
||||
"lanczos5", "gaussian"};
|
||||
constexpr int kOutHeight = 4;
|
||||
constexpr int kOutWidth = 6;
|
||||
|
||||
const TensorShape kXShape = TensorShape({1, 2, 3, 1});
|
||||
for (const Input scale : kScales) {
|
||||
for (const Input translation : kTranslations) {
|
||||
for (const std::string& kernel_type : kKernelTypes) {
|
||||
TestScaleAndTranslate<float, float, float>(
|
||||
kXShape, kOutHeight, kOutWidth, scale, translation, kernel_type,
|
||||
true);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(ScaleAndTranslateGradTest, TestGradsWithoutAntialias) {
|
||||
constexpr int kOutHeight = 4;
|
||||
constexpr int kOutWidth = 6;
|
||||
|
||||
const TensorShape kXShape = TensorShape({1, 2, 3, 1});
|
||||
for (const Input scale : kScales) {
|
||||
for (const Input translation : kTranslations) {
|
||||
TestScaleAndTranslate<float, float, float>(kXShape, kOutHeight, kOutWidth,
|
||||
scale, translation, "lanczos3",
|
||||
false);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(ScaleAndTranslateGradTest, TestGradsWithSameShape) {
|
||||
const std::vector<std::string> kKernelTypes = {"lanczos3", "gaussian"};
|
||||
|
||||
constexpr int kOutHeight = 2;
|
||||
constexpr int kOutWidth = 3;
|
||||
|
||||
const TensorShape kXShape = TensorShape({1, 2, 3, 1});
|
||||
for (const Input scale : kScales) {
|
||||
for (const Input translation : kTranslations) {
|
||||
for (const std::string& kernel_type : kKernelTypes) {
|
||||
TestScaleAndTranslate<float, float, float>(
|
||||
kXShape, kOutHeight, kOutWidth, scale, translation, kernel_type,
|
||||
true);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(ScaleAndTranslateGradTest, TestGradsWithSmallerShape) {
|
||||
const std::vector<std::string> kKernelTypes = {"lanczos3", "gaussian"};
|
||||
constexpr int kOutHeight = 2;
|
||||
constexpr int kOutWidth = 3;
|
||||
|
||||
const TensorShape kXShape = TensorShape({1, 4, 6, 1});
|
||||
for (const Input scale : kScales) {
|
||||
for (const Input translation : kTranslations) {
|
||||
for (const std::string& kernel_type : kKernelTypes) {
|
||||
TestScaleAndTranslate<float, float, float>(
|
||||
kXShape, kOutHeight, kOutWidth, scale, translation, kernel_type,
|
||||
true);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class CropAndResizeGradTest : public ::testing::Test {
|
||||
protected:
|
||||
@ -237,9 +314,9 @@ class CropAndResizeGradTest : public ::testing::Test {
|
||||
|
||||
template <typename T>
|
||||
void MakeOp(const Tensor& x_data, const Input& boxes, const Input& box_ind,
|
||||
const Input& crop_szie, Output* x, Output* y) {
|
||||
const Input& crop_size, Output* x, Output* y) {
|
||||
*x = Const<T>(scope_, x_data);
|
||||
*y = CropAndResize(scope_, *x, boxes, box_ind, crop_szie,
|
||||
*y = CropAndResize(scope_, *x, boxes, box_ind, crop_size,
|
||||
CropAndResize::Method("bilinear"));
|
||||
TF_ASSERT_OK(scope_.status());
|
||||
}
|
||||
|
@ -101,6 +101,15 @@ class XlaAssignVariableOp : public OpKernel {
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("Identity").Device(DEVICE).TypeConstraint("T", DT_STRING), \
|
||||
IdentityOp); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("Identity").Device(DEVICE).TypeConstraint<Variant>("T"), \
|
||||
IdentityOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("Identity") \
|
||||
.Device(DEVICE) \
|
||||
.TypeConstraint<ResourceHandle>("T") \
|
||||
.HostMemory("input") \
|
||||
.HostMemory("output"), \
|
||||
IdentityOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("IdentityN").Device(DEVICE), IdentityNOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("Placeholder").Device(DEVICE), PlaceholderOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("PlaceholderV2").Device(DEVICE), \
|
||||
@ -199,9 +208,7 @@ class XlaAssignVariableOp : public OpKernel {
|
||||
Name("FIFOQueueV2").Device(DEVICE).HostMemory("handle"), FIFOQueueOp); \
|
||||
\
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name(kArgOp).Device(DEVICE).HostMemory("output").TypeConstraint("T", \
|
||||
TYPES), \
|
||||
ArgOp); \
|
||||
Name(kArgOp).Device(DEVICE).TypeConstraint("T", TYPES), ArgOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name(kArgOp) \
|
||||
.Device(DEVICE) \
|
||||
.HostMemory("output") \
|
||||
@ -210,11 +217,8 @@ class XlaAssignVariableOp : public OpKernel {
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name(kArgOp).Device(DEVICE).TypeConstraint<Variant>("T"), ArgOp); \
|
||||
\
|
||||
REGISTER_KERNEL_BUILDER(Name(kRetOp) \
|
||||
.Device(DEVICE) \
|
||||
.TypeConstraint("T", TYPES) \
|
||||
.HostMemory("input"), \
|
||||
RetvalOp); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name(kRetOp).Device(DEVICE).TypeConstraint("T", TYPES), RetvalOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name(kRetOp) \
|
||||
.Device(DEVICE) \
|
||||
.TypeConstraint<ResourceHandle>("T") \
|
||||
|
@ -482,7 +482,7 @@ tf_xla_py_test(
|
||||
name = "fft_test",
|
||||
size = "medium",
|
||||
srcs = ["fft_test.py"],
|
||||
shard_count = 3,
|
||||
shard_count = 6,
|
||||
tags = ["optonly"],
|
||||
deps = [
|
||||
":xla_test",
|
||||
|
@ -79,6 +79,24 @@ static xla::XlaOp DivNoNanImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x,
|
||||
XLA_MAKE_BINARY(DivNoNan,
|
||||
DivNoNanImpl(b, input_type(0), lhs, rhs, broadcast_helper));
|
||||
|
||||
// Implementation of MulNoNan. Pseudo-code:
|
||||
// if (y == 0) {
|
||||
// return 0
|
||||
// } else {
|
||||
// return x * y;
|
||||
// }
|
||||
static xla::XlaOp MulNoNanImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x,
|
||||
xla::XlaOp y, const BCast& broadcast_helper) {
|
||||
std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper);
|
||||
auto zero = XlaHelpers::Zero(b, dtype);
|
||||
auto y_equals_0 = xla::Eq(y, zero);
|
||||
auto zeros = xla::ZerosLike(x);
|
||||
auto result = xla::Select(y_equals_0, zeros, xla::Mul(x, y));
|
||||
return result;
|
||||
}
|
||||
XLA_MAKE_BINARY(MulNoNan,
|
||||
MulNoNanImpl(b, input_type(0), lhs, rhs, broadcast_helper));
|
||||
|
||||
// Implementation of FloorDiv.
|
||||
//
|
||||
// For floating-point values, simply returns floor(x / y). For integers, does:
|
||||
|
@ -128,6 +128,20 @@ static void AllocateFlags() {
|
||||
flag_values->xla_cpu_enable_fast_math(),
|
||||
"Enable unsafe fast-math optimizations in the CPU compiler; "
|
||||
"this may produce faster code at the expense of some accuracy."),
|
||||
tensorflow::Flag(
|
||||
"xla_cpu_fast_math_honor_nans",
|
||||
bool_setter_for(&DebugOptions::set_xla_cpu_fast_math_honor_nans),
|
||||
flag_values->xla_cpu_fast_math_honor_nans(),
|
||||
"When xla_cpu_enable_fast_math is true then this controls whether we "
|
||||
"allow operations to produce NaNs. Ignored when "
|
||||
"xla_cpu_enable_fast_math is false."),
|
||||
tensorflow::Flag(
|
||||
"xla_cpu_fast_math_honor_infs",
|
||||
bool_setter_for(&DebugOptions::set_xla_cpu_fast_math_honor_infs),
|
||||
flag_values->xla_cpu_fast_math_honor_infs(),
|
||||
"When xla_cpu_enable_fast_math is true then this controls whether we "
|
||||
"allow operations to produce infinites. Ignored when "
|
||||
"xla_cpu_enable_fast_math is false."),
|
||||
tensorflow::Flag(
|
||||
"xla_gpu_enable_fast_min_max",
|
||||
bool_setter_for(&DebugOptions::set_xla_gpu_enable_fast_min_max),
|
||||
|
@ -1663,15 +1663,18 @@ Applies a reduction function to one or more arrays in parallel.
|
||||
|
||||
<b> `Reduce(operands..., init_values..., computation, dimensions)` </b>
|
||||
|
||||
Arguments | Type | Semantics
|
||||
------------- | --------------------- | ---------------------------------------
|
||||
`operands` | Sequence of N `XlaOp` | N arrays of types `T_0, ..., T_N`.
|
||||
`init_values` | Sequence of N `XlaOp` | N scalars of types `T_0, ..., T_N`.
|
||||
`computation` | `XlaComputation` | computation of type
|
||||
: : `T_0, ..., T_N, T_0, ..., T_N -> Collate(T_0, ..., T_N)`
|
||||
`dimensions` | `int64` array | unordered array of dimensions to reduce
|
||||
| Arguments | Type | Semantics |
|
||||
| ------------- | --------------------- | ------------------------------------ |
|
||||
| `operands` | Sequence of N `XlaOp` | N arrays of types `T_0, ..., T_N`. |
|
||||
| `init_values` | Sequence of N `XlaOp` | N scalars of types `T_0, ..., T_N`. |
|
||||
| `computation` | `XlaComputation` | computation of type `T_0, ..., T_N, |
|
||||
: : : T_0, ..., T_N ->` `Collate(T_0, ..., :
|
||||
: : : T_N)` :
|
||||
| `dimensions` | `int64` array | unordered array of dimensions to |
|
||||
: : : reduce :
|
||||
|
||||
Where:
|
||||
|
||||
* N is required to be greater or equal to 1.
|
||||
* All input arrays must have the same dimensions.
|
||||
* If `N = 1`, `Collate(T)` is `T`.
|
||||
@ -1681,10 +1684,10 @@ The output of the op is `Collate(Q_0, ..., Q_N)` where `Q_i` is an array of type
|
||||
`T_i`, the dimensions of which are described below.
|
||||
|
||||
This operation reduces one or more dimensions of each input array into scalars.
|
||||
The rank of each returned array is `rank(operand) - len(dimensions)`.
|
||||
`init_value` is the initial value used for every reduction and may be inserted
|
||||
The rank of each returned array is `rank(operand) - len(dimensions)`. The
|
||||
initial value used for every reduction is `init_value`, and it may be inserted
|
||||
anywhere during computation by the back-end. In most cases, `init_value` is an
|
||||
identity of the reduction function (for example, 0 for addition). The applied
|
||||
identity of the reduction function (for example, `0` for addition). The applied
|
||||
`computation` is always passed the `init_value` on the left-hand side.
|
||||
|
||||
The evaluation order of the reduction function is arbitrary and may be
|
||||
@ -1695,10 +1698,10 @@ Some reduction functions like addition are not strictly associative for floats.
|
||||
However, if the range of the data is limited, floating-point addition is close
|
||||
enough to being associative for most practical uses. It is possible to conceive
|
||||
of some completely non-associative reductions, however, and these will produce
|
||||
incorrect or unpredictable results in XLA reductions.
|
||||
incorrect or unpredictable results in XLA.
|
||||
|
||||
As an example, when reducing across one dimension in a single 1D array with
|
||||
values [10, 11, 12, 13], with reduction function `f` (this is `computation`)
|
||||
values `[10, 11, 12, 13]`, with reduction function `f` (this is `computation`)
|
||||
then that could be computed as
|
||||
|
||||
`f(10, f(11, f(12, f(init_value, 13)))`
|
||||
@ -1777,16 +1780,27 @@ preserved in the output, but some dimensions may get assigned new numbers (since
|
||||
the rank changes).
|
||||
|
||||
We can also reduce multiple dimensions. Add-reducing dimensions 0 and 1 produces
|
||||
the 1D array `| 20 28 36 |`.
|
||||
the 1D array `[20, 28, 36]`.
|
||||
|
||||
Reducing the 3D array over all its dimensions produces the scalar `84`.
|
||||
|
||||
When `N > 1`, reduce function application is slightly more complex, as it is
|
||||
applied simultaneously to all inputs. For example, consider the following
|
||||
reduction function, which can be used to compute the max and the argmax of a a
|
||||
1-D array in parallel:
|
||||
### Variadic Reduce
|
||||
|
||||
```
|
||||
When `N > 1`, reduce function application is slightly more complex, as it is
|
||||
applied simultaneously to all inputs. The operands are supplied to the
|
||||
computation in the following order:
|
||||
|
||||
* Running reduced value for the first operand
|
||||
* ...
|
||||
* Running reduced value for the N'th operand
|
||||
* Input value for the first operand
|
||||
* ...
|
||||
* Input value for the N'th operand
|
||||
|
||||
For example, consider the following reduction function, which can be used to
|
||||
compute the max and the argmax of a 1-D array in parallel:
|
||||
|
||||
```python
|
||||
f: (Float, Int, Float, Int) -> Float, Int
|
||||
f(max, argmax, value, index):
|
||||
if value >= argmax:
|
||||
@ -1798,6 +1812,7 @@ f(max, argmax, value, index):
|
||||
For 1-D Input arrays `V = Float[N], K = Int[N]`, and init values
|
||||
`I_V = Float, I_K = Int`, the result `f_(N-1)` of reducing across the only
|
||||
input dimension is equivalent to the following recursive application:
|
||||
|
||||
```
|
||||
f_0 = f(I_V, I_K, V_0, K_0)
|
||||
f_1 = f(f_0.first, f_0.second, V_1, K_1)
|
||||
|
@ -339,15 +339,15 @@ cc_library(
|
||||
srcs = ["ir_function.cc"],
|
||||
hdrs = ["ir_function.h"],
|
||||
deps = [
|
||||
":cpu_runtime",
|
||||
":ir_emission_utils",
|
||||
":shape_partition",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla/service/cpu:cpu_runtime",
|
||||
"//tensorflow/compiler/xla/service:hlo_module_config",
|
||||
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:span",
|
||||
"@llvm//:core",
|
||||
@ -378,6 +378,7 @@ cc_library(
|
||||
":vector_support_library",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service:hlo_module_config",
|
||||
"//tensorflow/compiler/xla/service/llvm_ir:kernel_support_library",
|
||||
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
|
||||
"//tensorflow/core:lib",
|
||||
|
@ -409,10 +409,20 @@ auto memory_alignment = [](LogicalBuffer::Color) { return kMemoryAlignment; };
|
||||
llvm::TargetOptions CompilerTargetOptions(
|
||||
const HloModuleConfig& module_config) {
|
||||
llvm::TargetOptions target_options;
|
||||
llvm_ir::SetTargetOptions(
|
||||
/*fast_math_enabled=*/module_config.debug_options()
|
||||
.xla_cpu_enable_fast_math(),
|
||||
&target_options);
|
||||
// In LLVM backend flags, UnsafeFPMath does not explicitly imply NoInfs, etc.
|
||||
if (module_config.debug_options().xla_cpu_enable_fast_math()) {
|
||||
target_options.UnsafeFPMath = true;
|
||||
target_options.NoInfsFPMath =
|
||||
module_config.debug_options().xla_cpu_fast_math_honor_infs();
|
||||
target_options.NoNaNsFPMath =
|
||||
module_config.debug_options().xla_cpu_fast_math_honor_nans();
|
||||
target_options.NoSignedZerosFPMath = true;
|
||||
} else {
|
||||
target_options.UnsafeFPMath = false;
|
||||
target_options.NoInfsFPMath = false;
|
||||
target_options.NoNaNsFPMath = false;
|
||||
target_options.NoSignedZerosFPMath = false;
|
||||
}
|
||||
return target_options;
|
||||
}
|
||||
|
||||
|
@ -250,11 +250,6 @@ void DotOpEmitter::EmitTiledLlvmIrGemm() {
|
||||
std::tie(tile_size_m, tile_size_k, tile_size_n_in_vector_width) =
|
||||
GetGemmTileSize();
|
||||
|
||||
const bool enable_fast_math =
|
||||
hlo_module_config_.debug_options().xla_cpu_enable_fast_math();
|
||||
const bool optimize_for_size =
|
||||
options::OptimizeForSizeRequested(hlo_module_config_);
|
||||
|
||||
EmitSmallGemm(
|
||||
/*scalar_type=*/primitive_type,
|
||||
/*m=*/m, /*k=*/k, /*n=*/n,
|
||||
@ -262,9 +257,7 @@ void DotOpEmitter::EmitTiledLlvmIrGemm() {
|
||||
/*max_vector_count=*/tile_size_n_in_vector_width,
|
||||
/*min_vectorization_width=*/std::min<int64>(4, max_target_vector_width),
|
||||
/*tile_size_m=*/tile_size_m, /*tile_size_k=*/tile_size_k, /*lhs=*/lhs,
|
||||
/*rhs=*/rhs, /*result=*/target, b_,
|
||||
/*enable_fast_math=*/enable_fast_math,
|
||||
/*optimize_for_size=*/optimize_for_size);
|
||||
/*rhs=*/rhs, /*result=*/target, b_, hlo_module_config_);
|
||||
}
|
||||
|
||||
void DotOpEmitter::EmitTiledLlvmIrGemv() {
|
||||
@ -323,11 +316,6 @@ void DotOpEmitter::EmitTiledLlvmIrGemv() {
|
||||
llvm::Value* rhs_op =
|
||||
swap_operands ? lhs_array_.GetBasePointer() : rhs_array_.GetBasePointer();
|
||||
|
||||
const bool enable_fast_math =
|
||||
hlo_module_config_.debug_options().xla_cpu_enable_fast_math();
|
||||
const bool optimize_for_size =
|
||||
options::OptimizeForSizeRequested(hlo_module_config_);
|
||||
|
||||
const int target_vector_register_element_size =
|
||||
target_machine_features_.vector_register_num_elements(
|
||||
*b_->GetInsertBlock()->getParent(), primitive_type);
|
||||
@ -349,9 +337,7 @@ void DotOpEmitter::EmitTiledLlvmIrGemv() {
|
||||
/*tile_rows=*/vector_register_element_size, /*tile_cols=*/tiling_factor,
|
||||
/*m=*/m, /*k=*/k, /*lhs=*/lhs_op, /*rhs=*/rhs_op,
|
||||
/*addend=*/addend_array_ ? addend_array_->GetBasePointer() : nullptr,
|
||||
/*result=*/result_op, b_,
|
||||
/*enable_fast_math=*/enable_fast_math,
|
||||
/*optimize_for_size=*/optimize_for_size);
|
||||
/*result=*/result_op, b_, hlo_module_config_);
|
||||
} else {
|
||||
VLOG(2) << "Emitting row major matrix-vector multiply with m = " << m
|
||||
<< " and k = " << k;
|
||||
@ -361,9 +347,7 @@ void DotOpEmitter::EmitTiledLlvmIrGemv() {
|
||||
/*tile_cols=*/vector_register_element_size,
|
||||
/*m=*/m, /*k=*/k, /*lhs=*/lhs_op, /*rhs=*/rhs_op,
|
||||
/*addend=*/addend_array_ ? addend_array_->GetBasePointer() : nullptr,
|
||||
/*result=*/result_op, b_,
|
||||
/*enable_fast_math=*/enable_fast_math,
|
||||
/*optimize_for_size=*/optimize_for_size);
|
||||
/*result=*/result_op, b_, hlo_module_config_);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -98,9 +98,7 @@ IrEmitter::IrEmitter(
|
||||
is_top_level_computation_(false),
|
||||
target_machine_features_(*target_machine_features),
|
||||
emit_code_for_msan_(emit_code_for_msan) {
|
||||
b_.setFastMathFlags(llvm_ir::GetFastMathFlags(
|
||||
/*fast_math_enabled=*/hlo_module_config_.debug_options()
|
||||
.xla_cpu_enable_fast_math()));
|
||||
b_.setFastMathFlags(llvm_ir::GetCpuFastMathFlags(hlo_module_config_));
|
||||
Status s = GatherComputationsByAllocationType(
|
||||
&hlo_module, &thread_local_computations_, &global_computations_);
|
||||
absl::c_sort(thread_local_computations_);
|
||||
@ -159,11 +157,9 @@ void IrEmitter::InitializeIrFunction(const string& function_name) {
|
||||
is_top_level_computation_ ? llvm::GlobalValue::ExternalLinkage
|
||||
: llvm::GlobalValue::InternalLinkage;
|
||||
// Create and initialize new IrFunction.
|
||||
compute_function_.reset(new IrFunction(
|
||||
function_name, linkage,
|
||||
options::OptimizeForSizeRequested(hlo_module_config_),
|
||||
hlo_module_config_.debug_options().xla_cpu_enable_fast_math(), module_,
|
||||
&b_, num_dynamic_loop_bounds_));
|
||||
compute_function_.reset(new IrFunction(function_name, linkage,
|
||||
hlo_module_config_, module_, &b_,
|
||||
num_dynamic_loop_bounds_));
|
||||
}
|
||||
|
||||
IrEmitter::~IrEmitter() {}
|
||||
@ -302,7 +298,7 @@ Status IrEmitter::HandleGetTupleElement(HloInstruction* get_tuple_element) {
|
||||
const Shape& shape = get_tuple_element->shape();
|
||||
emitted_value_[get_tuple_element] = llvm_ir::EmitGetTupleElement(
|
||||
shape, get_tuple_element->tuple_index(), MinimumAlignmentForShape(shape),
|
||||
GetEmittedValueFor(operand), &b_, module_);
|
||||
GetEmittedValueFor(operand), &b_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -322,7 +318,7 @@ Status IrEmitter::HandleTupleSelect(HloInstruction* tuple_select) {
|
||||
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(tuple_select));
|
||||
llvm_ir::EmitTupleSelect(GetIrArrayFor(tuple_select), GetIrArrayFor(pred),
|
||||
GetEmittedValueFor(on_true),
|
||||
GetEmittedValueFor(on_false), &b_, module_);
|
||||
GetEmittedValueFor(on_false), &b_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -345,8 +341,7 @@ Status IrEmitter::HandleInfeed(HloInstruction* instruction) {
|
||||
assignment_.GetUniqueSlice(infeed, {1}));
|
||||
llvm::Value* token_address = EmitBufferPointer(
|
||||
token_slice, ShapeUtil::GetTupleElementShape(infeed->shape(), 1));
|
||||
llvm_ir::EmitTuple(GetIrArrayFor(infeed), {data_address, token_address}, &b_,
|
||||
module_);
|
||||
llvm_ir::EmitTuple(GetIrArrayFor(infeed), {data_address, token_address}, &b_);
|
||||
|
||||
if (data_shape.IsTuple()) {
|
||||
TF_RET_CHECK(!ShapeUtil::IsNestedTuple(data_shape));
|
||||
@ -377,7 +372,7 @@ Status IrEmitter::HandleInfeed(HloInstruction* instruction) {
|
||||
}
|
||||
|
||||
llvm_ir::EmitTuple(llvm_ir::IrArray(data_address, data_shape),
|
||||
tuple_element_addresses, &b_, module_);
|
||||
tuple_element_addresses, &b_);
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(
|
||||
EmitXfeedTransfer(XfeedKind::kInfeed, data_shape, data_address));
|
||||
@ -498,7 +493,7 @@ Status IrEmitter::HandleOutfeed(HloInstruction* outfeed) {
|
||||
ShapeUtil::GetTupleElementShape(operand_shape, i);
|
||||
llvm::Value* tuple_element = llvm_ir::EmitGetTupleElement(
|
||||
tuple_element_shape, i, MinimumAlignmentForShape(tuple_element_shape),
|
||||
value, &b_, module_);
|
||||
value, &b_);
|
||||
TF_RETURN_IF_ERROR(EmitXfeedTransfer(XfeedKind::kOutfeed,
|
||||
tuple_element_shape, tuple_element));
|
||||
}
|
||||
@ -621,8 +616,7 @@ Status IrEmitter::HandleSort(HloInstruction* hlo) {
|
||||
GetProfileCountersArgument(), less_than_function});
|
||||
|
||||
if (sort->values_count() > 0) {
|
||||
llvm_ir::EmitTuple(GetIrArrayFor(sort), destination_addresses, &b_,
|
||||
module_);
|
||||
llvm_ir::EmitTuple(GetIrArrayFor(sort), destination_addresses, &b_);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
@ -633,7 +627,7 @@ Status IrEmitter::HandleTuple(HloInstruction* tuple) {
|
||||
for (auto operand : tuple->operands()) {
|
||||
base_ptrs.push_back(GetEmittedValueFor(operand));
|
||||
}
|
||||
llvm_ir::EmitTuple(GetIrArrayFor(tuple), base_ptrs, &b_, module_);
|
||||
llvm_ir::EmitTuple(GetIrArrayFor(tuple), base_ptrs, &b_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -1349,7 +1343,7 @@ Status IrEmitter::HandleAllReduce(HloInstruction* crs) {
|
||||
MemCpy(operand_ptrs.back(), /*DstAlign=*/1, in_ptr,
|
||||
/*SrcAlign=*/1, ShapeUtil::ByteSizeOf(operand_shape));
|
||||
}
|
||||
llvm_ir::EmitTuple(GetIrArrayFor(crs), operand_ptrs, &b_, module_);
|
||||
llvm_ir::EmitTuple(GetIrArrayFor(crs), operand_ptrs, &b_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -2289,7 +2283,7 @@ Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) {
|
||||
llvm::Value* addr = EmitBufferPointer(slice, elem_shape);
|
||||
base_ptrs.push_back(addr);
|
||||
}
|
||||
llvm_ir::EmitTuple(GetIrArrayFor(custom_call), base_ptrs, &b_, module_);
|
||||
llvm_ir::EmitTuple(GetIrArrayFor(custom_call), base_ptrs, &b_);
|
||||
}
|
||||
auto* output_address_arg =
|
||||
PointerCast(GetEmittedValueFor(custom_call), i8_ptr_type);
|
||||
@ -2980,7 +2974,7 @@ Status IrEmitter::EmitTargetElementLoop(
|
||||
for (int64 i = 0; i < output_arrays.size(); ++i) {
|
||||
tuple_operand_ptrs.push_back(output_arrays[i].GetBasePointer());
|
||||
}
|
||||
llvm_ir::EmitTuple(target_array, tuple_operand_ptrs, &b_, module_);
|
||||
llvm_ir::EmitTuple(target_array, tuple_operand_ptrs, &b_);
|
||||
|
||||
} else {
|
||||
if (ShouldEmitParallelLoopFor(*target_op)) {
|
||||
|
@ -43,15 +43,14 @@ static std::vector<llvm::Type*> GetComputeFunctionParams(
|
||||
|
||||
IrFunction::IrFunction(const string& function_name,
|
||||
llvm::Function::LinkageTypes linkage,
|
||||
const bool optimize_for_size_requested,
|
||||
const bool enable_fast_math, llvm::Module* llvm_module,
|
||||
llvm::IRBuilder<>* b, int64 num_dynamic_loop_bounds)
|
||||
const HloModuleConfig& module_config,
|
||||
llvm::Module* llvm_module, llvm::IRBuilder<>* b,
|
||||
int64 num_dynamic_loop_bounds)
|
||||
: b_(b),
|
||||
llvm_module_(llvm_module),
|
||||
caller_insert_point_guard_(*b),
|
||||
num_dynamic_loop_bounds_(num_dynamic_loop_bounds) {
|
||||
Initialize(function_name, linkage, optimize_for_size_requested,
|
||||
enable_fast_math);
|
||||
Initialize(function_name, linkage, module_config);
|
||||
}
|
||||
|
||||
IrFunction::~IrFunction() {
|
||||
@ -70,8 +69,7 @@ DynamicLoopBounds IrFunction::GetDynamicLoopBounds() {
|
||||
|
||||
void IrFunction::Initialize(const string& function_name,
|
||||
llvm::Function::LinkageTypes linkage,
|
||||
const bool optimize_for_size_requested,
|
||||
const bool enable_fast_math) {
|
||||
const HloModuleConfig& module_config) {
|
||||
// The function signature is:
|
||||
// void function(i8* retval, i8* run_options, i8** params, i8**
|
||||
// buffer_table,
|
||||
@ -142,11 +140,8 @@ void IrFunction::Initialize(const string& function_name,
|
||||
// Functions with local linkage get an inlining bonus. Because we know
|
||||
// a-priori that embedded functions (non-entry functions) will not have its
|
||||
// name resolved, give it local linkage.
|
||||
function_ =
|
||||
llvm_ir::CreateFunction(function_type, linkage,
|
||||
/*enable_fast_math=*/enable_fast_math,
|
||||
/*optimize_for_size=*/optimize_for_size_requested,
|
||||
function_name, llvm_module_);
|
||||
function_ = llvm_ir::CreateCpuFunction(function_type, linkage, module_config,
|
||||
function_name, llvm_module_);
|
||||
|
||||
// Set meaningful names for the function's arguments: useful for debugging.
|
||||
llvm::Function::arg_iterator arg_iter = function_->arg_begin();
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include "llvm/IR/Module.h"
|
||||
#include "llvm/IR/Value.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
@ -52,8 +53,7 @@ namespace cpu {
|
||||
class IrFunction {
|
||||
public:
|
||||
IrFunction(const string& function_name, llvm::Function::LinkageTypes linkage,
|
||||
const bool optimize_for_size_requested,
|
||||
const bool enable_fast_math, llvm::Module* llvm_module,
|
||||
const HloModuleConfig& module_config, llvm::Module* llvm_module,
|
||||
llvm::IRBuilder<>* b, int64 num_dynamic_loop_bounds);
|
||||
~IrFunction();
|
||||
|
||||
@ -92,7 +92,7 @@ class IrFunction {
|
||||
// Initialize an llvm::Function with standard signature based on arguments.
|
||||
void Initialize(const string& function_name,
|
||||
llvm::Function::LinkageTypes linkage,
|
||||
bool optimize_for_size_requested, bool enable_fast_math);
|
||||
const HloModuleConfig& module_config);
|
||||
|
||||
// Emit ir to read and return the ir value for the dynamic loop bound at
|
||||
// 'offset' from the "dynamic_loop_bounds" argument of this function.
|
||||
|
@ -991,7 +991,7 @@ void EmitRowMajorGemv(PrimitiveType scalar_type, int64 tile_rows,
|
||||
int64 tile_cols, int64 m, int64 k, llvm::Value* lhs,
|
||||
llvm::Value* rhs, llvm::Value* addend,
|
||||
llvm::Value* result, llvm::IRBuilder<>* b,
|
||||
bool enable_fast_math, bool optimize_for_size) {
|
||||
const HloModuleConfig& module_config) {
|
||||
RowMajorMatrixVectorProductEmitter::Config config(
|
||||
/*scalar_type=*/scalar_type,
|
||||
/*tile_rows=*/tile_rows, /*tile_cols=*/tile_cols,
|
||||
@ -1001,8 +1001,7 @@ void EmitRowMajorGemv(PrimitiveType scalar_type, int64 tile_rows,
|
||||
GetGemvBuffersWithCanonicalType(lhs, rhs, addend, result, b);
|
||||
|
||||
KernelSupportLibrary::EmitAndCallOutlinedKernel(
|
||||
/*enable_fast_math=*/enable_fast_math,
|
||||
/*optimize_for_size=*/optimize_for_size, b, config.GetCacheKey(),
|
||||
module_config, b, config.GetCacheKey(),
|
||||
canonical_inputs.lhs_canonicalized, canonical_inputs.rhs_canonicalized,
|
||||
canonical_inputs.addend_canonicalized,
|
||||
canonical_inputs.result_canonicalized,
|
||||
@ -1019,7 +1018,7 @@ void EmitColumnMajorGemv(PrimitiveType scalar_type, int64 tile_rows,
|
||||
int64 tile_cols, int64 m, int64 k, llvm::Value* lhs,
|
||||
llvm::Value* rhs, llvm::Value* addend,
|
||||
llvm::Value* result, llvm::IRBuilder<>* b,
|
||||
bool enable_fast_math, bool optimize_for_size) {
|
||||
const HloModuleConfig& module_config) {
|
||||
ColumnMajorMatrixVectorProductEmitter::Config config(
|
||||
/*scalar_type=*/scalar_type,
|
||||
/*tile_rows=*/tile_rows, /*tile_cols=*/tile_cols,
|
||||
@ -1029,8 +1028,7 @@ void EmitColumnMajorGemv(PrimitiveType scalar_type, int64 tile_rows,
|
||||
GetGemvBuffersWithCanonicalType(lhs, rhs, addend, result, b);
|
||||
|
||||
KernelSupportLibrary::EmitAndCallOutlinedKernel(
|
||||
/*enable_fast_math=*/enable_fast_math,
|
||||
/*optimize_for_size=*/optimize_for_size, b, config.GetCacheKey(),
|
||||
module_config, b, config.GetCacheKey(),
|
||||
canonical_inputs.lhs_canonicalized, canonical_inputs.rhs_canonicalized,
|
||||
canonical_inputs.addend_canonicalized,
|
||||
canonical_inputs.result_canonicalized,
|
||||
@ -1048,7 +1046,7 @@ void EmitSmallGemm(PrimitiveType scalar_type, int64 m, int64 k, int64 n,
|
||||
int64 min_vectorization_width, int64 tile_size_m,
|
||||
int64 tile_size_k, llvm::Value* lhs, llvm::Value* rhs,
|
||||
llvm::Value* result, llvm::IRBuilder<>* b,
|
||||
bool enable_fast_math, bool optimize_for_size) {
|
||||
const HloModuleConfig& module_config) {
|
||||
TiledSmallGemmEmitter::Config config(
|
||||
/*scalar_type=*/scalar_type,
|
||||
TiledSmallGemmEmitter::Dimensions{/*m=*/m, /*k=*/k, /*n=*/n},
|
||||
@ -1058,9 +1056,7 @@ void EmitSmallGemm(PrimitiveType scalar_type, int64 m, int64 k, int64 n,
|
||||
/*tile_size_m=*/tile_size_m, /*tile_size_k=*/tile_size_k);
|
||||
|
||||
KernelSupportLibrary::EmitAndCallOutlinedKernel(
|
||||
/*enable_fast_math=*/enable_fast_math,
|
||||
/*optimize_for_size=*/optimize_for_size, b, config.GetCacheKey(), lhs,
|
||||
rhs, result,
|
||||
module_config, b, config.GetCacheKey(), lhs, rhs, result,
|
||||
[&](llvm::Value* lhs, llvm::Value* rhs, llvm::Value* result) {
|
||||
TiledSmallGemmEmitter small_gemm_emitter(config, /*lhs=*/lhs,
|
||||
/*rhs=*/rhs,
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TILED_DOT_EMITTER_H_
|
||||
|
||||
#include "llvm/IR/IRBuilder.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
@ -29,15 +30,15 @@ void EmitRowMajorGemv(PrimitiveType scalar_type, tensorflow::int64 tile_rows,
|
||||
tensorflow::int64 tile_cols, tensorflow::int64 m,
|
||||
tensorflow::int64 k, llvm::Value* lhs, llvm::Value* rhs,
|
||||
llvm::Value* addend, llvm::Value* result,
|
||||
llvm::IRBuilder<>* b, bool enable_fast_math,
|
||||
bool optimize_for_size);
|
||||
llvm::IRBuilder<>* b,
|
||||
const HloModuleConfig& module_config);
|
||||
|
||||
void EmitColumnMajorGemv(PrimitiveType scalar_type, tensorflow::int64 tile_rows,
|
||||
tensorflow::int64 tile_cols, tensorflow::int64 m,
|
||||
tensorflow::int64 k, llvm::Value* lhs,
|
||||
llvm::Value* rhs, llvm::Value* addend,
|
||||
llvm::Value* result, llvm::IRBuilder<>* b,
|
||||
bool enable_fast_math, bool optimize_for_size);
|
||||
const HloModuleConfig& module_config);
|
||||
|
||||
void EmitSmallGemm(PrimitiveType scalar_type, tensorflow::int64 m,
|
||||
tensorflow::int64 k, tensorflow::int64 n,
|
||||
@ -46,8 +47,7 @@ void EmitSmallGemm(PrimitiveType scalar_type, tensorflow::int64 m,
|
||||
tensorflow::int64 min_vectorization_width,
|
||||
tensorflow::int64 tile_size_m, tensorflow::int64 tile_size_k,
|
||||
llvm::Value* lhs, llvm::Value* rhs, llvm::Value* result,
|
||||
llvm::IRBuilder<>* b, bool enable_fast_math,
|
||||
bool optimize_for_size);
|
||||
llvm::IRBuilder<>* b, const HloModuleConfig& module_config);
|
||||
|
||||
} // namespace cpu
|
||||
} // namespace xla
|
||||
|
@ -2407,8 +2407,9 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
|
||||
case HloOpcode::kCopy:
|
||||
return [hlo, &operand_to_generator](
|
||||
const IrArray::Index& target_index) -> StatusOr<llvm::Value*> {
|
||||
IrArray::Index source_index = target_index;
|
||||
source_index.ClearLinearIndex();
|
||||
IrArray::Index source_index(target_index.multidim(),
|
||||
hlo->operand(0)->shape(),
|
||||
target_index.GetType());
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
llvm::Value * operand_value,
|
||||
operand_to_generator.at(hlo->operand(0))(source_index));
|
||||
|
@ -135,11 +135,11 @@ llvm::Value* HloToIrBindings::EmitGetTupleElement(const HloInstruction* gte,
|
||||
if (gte->operand(0)->opcode() != HloOpcode::kGetTupleElement) {
|
||||
return llvm_ir::EmitGetTupleElement(
|
||||
gte->shape(), gte->tuple_index(), /*alignment=*/1,
|
||||
GetTypedIrValue(*gte->operand(0), {}, base_ptr), b_, module_);
|
||||
GetTypedIrValue(*gte->operand(0), {}, base_ptr), b_);
|
||||
}
|
||||
return llvm_ir::EmitGetTupleElement(
|
||||
gte->shape(), gte->tuple_index(), /*alignment=*/1,
|
||||
EmitGetTupleElement(gte->operand(0), base_ptr), b_, module_);
|
||||
EmitGetTupleElement(gte->operand(0), base_ptr), b_);
|
||||
}
|
||||
|
||||
// Returns true if `value` has a name that should not be changed.
|
||||
|
@ -115,7 +115,7 @@ Status IrEmitter::HandleGetTupleElement(HloInstruction* get_tuple_element) {
|
||||
get_tuple_element->shape(), get_tuple_element->tuple_index(),
|
||||
// TODO(b/26344050): tighten the alignment here
|
||||
// based on the real element type.
|
||||
/*alignment=*/1, GetBasePointer(*operand), &b_, module_));
|
||||
/*alignment=*/1, GetBasePointer(*operand), &b_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -144,7 +144,7 @@ Status IrEmitter::HandleTuple(HloInstruction* tuple) {
|
||||
for (const HloInstruction* operand : tuple->operands()) {
|
||||
base_ptrs.push_back(GetBasePointer(*operand));
|
||||
}
|
||||
llvm_ir::EmitTuple(GetIrArray(*tuple, *tuple), base_ptrs, &b_, module_);
|
||||
llvm_ir::EmitTuple(GetIrArray(*tuple, *tuple), base_ptrs, &b_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -434,7 +434,7 @@ Status IrEmitter::HandleTupleSelect(HloInstruction* tuple_select) {
|
||||
llvm_ir::EmitTupleSelect(GetIrArray(*tuple_select, *tuple_select),
|
||||
GetIrArray(*pred, *tuple_select),
|
||||
GetBasePointer(*on_true), GetBasePointer(*on_false),
|
||||
&b_, module_);
|
||||
&b_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -123,7 +123,7 @@ Status IrEmitterNested::EmitTargetElementLoop(
|
||||
ConstructIrArrayForOutputs(hlo);
|
||||
TF_RETURN_IF_ERROR(
|
||||
llvm_ir::LoopEmitter(element_generator, target_arrays, &b_).EmitLoop());
|
||||
llvm_ir::EmitTuple(GetIrArray(hlo, hlo), target_arrays, &b_, module_);
|
||||
llvm_ir::EmitTuple(GetIrArray(hlo, hlo), target_arrays, &b_);
|
||||
return Status::OK();
|
||||
}
|
||||
return llvm_ir::LoopEmitter(element_generator, GetIrArray(hlo, hlo), &b_)
|
||||
|
@ -2201,7 +2201,7 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk(
|
||||
// kernel *anyway*.
|
||||
std::vector<IrArray> output_arrays = ConstructIrArrayForOutputs(hlo);
|
||||
KernelSupportLibrary{&b_}.If("emit_mof_tuple", IsBlock0Thread0(&b_), [&] {
|
||||
llvm_ir::EmitTuple(GetIrArray(hlo, hlo), output_arrays, &b_, module_);
|
||||
llvm_ir::EmitTuple(GetIrArray(hlo, hlo), output_arrays, &b_);
|
||||
});
|
||||
|
||||
// For multioutput fusion, we need to emit each operand and the root.
|
||||
@ -2916,13 +2916,6 @@ void IrEmitterUnnested::EmitTileElementForReduction(
|
||||
reduction_info->GetKernelMappingScheme()->GetUnnormalizedIndex(
|
||||
index,
|
||||
GetFirstReduceInstruction(output_instructions)->operand(0)->shape());
|
||||
int num_partial_results = reduction_info->GetNumberOfPartialResults();
|
||||
if (num_partial_results > 1) {
|
||||
// Clear the linear index field of the IrArray::Index to enable the use of
|
||||
// GetElementPointer with array types. This enables the vectorization of
|
||||
// the computation for different partial results.
|
||||
input_index.ClearLinearIndex();
|
||||
}
|
||||
absl::Span<llvm::AllocaInst* const> partial_reduction_result_addresses =
|
||||
reduction_info->GetPartialResultAddresses();
|
||||
absl::Span<llvm::AllocaInst* const> reduction_input_addresses =
|
||||
@ -3103,8 +3096,7 @@ LaunchDimensions IrEmitterUnnested::EmitKernel(
|
||||
if (!reduction_info && unnested_hlo->IsMultiOutputFusion()) {
|
||||
KernelSupportLibrary{&b_}.If("emit_mof_tuple", IsBlock0Thread0(&b_), [&] {
|
||||
llvm_ir::EmitTuple(GetIrArray(*unnested_hlo, *unnested_hlo),
|
||||
ConstructIrArrayForOutputs(*unnested_hlo), &b_,
|
||||
module_);
|
||||
ConstructIrArrayForOutputs(*unnested_hlo), &b_);
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -207,6 +207,39 @@ std::vector<HloInstruction*> HloModuleGroupUtil::RootInstructions(
|
||||
return roots;
|
||||
}
|
||||
|
||||
string HloModuleGroupUtil::CycleToString(HloInstruction* init_instruction) {
|
||||
std::vector<string> names;
|
||||
absl::flat_hash_set<HloInstruction*> seen;
|
||||
|
||||
std::function<bool(HloInstruction*)> helper =
|
||||
[&](HloInstruction* instruction) {
|
||||
if (seen.find(instruction) != seen.end()) {
|
||||
if (instruction == init_instruction) {
|
||||
names.push_back(instruction->name());
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
seen.insert(instruction);
|
||||
for (HloInstruction* predecessor : GlobalPredecessors(instruction)) {
|
||||
bool init_found = helper(predecessor);
|
||||
if (init_found) {
|
||||
names.push_back(instruction->name());
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
helper(init_instruction);
|
||||
std::vector<string> pieces;
|
||||
pieces.reserve(names.size());
|
||||
for (auto name : names) {
|
||||
pieces.push_back(name);
|
||||
}
|
||||
return absl::StrJoin(pieces, " --> ");
|
||||
}
|
||||
|
||||
Status HloModuleGroupUtil::VisitTopologicalOrder(
|
||||
VisitStates* visit_state, const VisitFunction& visit_function,
|
||||
HloInstruction* root) {
|
||||
@ -269,22 +302,9 @@ Status HloModuleGroupUtil::VisitTopologicalOrder(
|
||||
// a cycle. Generate an error with the list of instructions in the
|
||||
// cycle.
|
||||
if ((*visit_state)[predecessor] == VisitState::kVisiting) {
|
||||
string cyclic_instructions;
|
||||
for (const auto& state : *visit_state) {
|
||||
if (state.second == VisitState::kVisiting) {
|
||||
absl::StrAppend(&cyclic_instructions, state.first->ToString(),
|
||||
"\n");
|
||||
}
|
||||
}
|
||||
// TODO(b/64305524): Improve the error message to print out the
|
||||
// instructions in a deterministic order that forms the cycle.
|
||||
return FailedPrecondition(
|
||||
"Cross-computation cycle detected via communicating nodes. The "
|
||||
"cycle contains the node %s. The cycle is found among the "
|
||||
"following nodes. Note that the order of the nodes is arbitrary "
|
||||
"and that the list may include nodes that are not part of the "
|
||||
"cycle.\n%s",
|
||||
predecessor->ToString(), cyclic_instructions);
|
||||
"Cross-computation cycle detected via communicating nodes.\n%s",
|
||||
CycleToString(predecessor));
|
||||
}
|
||||
stack.push(predecessor);
|
||||
}
|
||||
|
@ -108,6 +108,8 @@ class HloModuleGroupUtil {
|
||||
HloInstruction* instruction, HloReachabilityMap* reachability_map);
|
||||
|
||||
private:
|
||||
string CycleToString(HloInstruction* instruction);
|
||||
|
||||
const HloModuleGroupMetadata& metadata_;
|
||||
};
|
||||
|
||||
|
@ -71,6 +71,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service:hlo_module_config",
|
||||
"//tensorflow/compiler/xla/service:name_uniquer",
|
||||
"//tensorflow/compiler/xla/service/cpu:cpu_options",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/base",
|
||||
"@com_google_absl//absl/strings",
|
||||
@ -239,7 +240,7 @@ cc_library(
|
||||
hdrs = ["kernel_support_library.h"],
|
||||
deps = [
|
||||
":llvm_loop",
|
||||
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
|
||||
":llvm_util",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm//:core",
|
||||
],
|
||||
|
@ -121,9 +121,9 @@ Status FusedIrEmitter::HandleGetTupleElement(
|
||||
}
|
||||
|
||||
// Lookup tuple element pointer.
|
||||
return llvm_ir::EmitGetTupleElement(
|
||||
get_tuple_element->shape(), get_tuple_element->tuple_index(),
|
||||
/*alignment=*/1, tuple_ptr, b_, module_);
|
||||
return llvm_ir::EmitGetTupleElement(get_tuple_element->shape(),
|
||||
get_tuple_element->tuple_index(),
|
||||
/*alignment=*/1, tuple_ptr, b_);
|
||||
};
|
||||
|
||||
if (!get_tuple_element->shape().IsTuple()) {
|
||||
|
@ -331,7 +331,8 @@ llvm::Value* IrArray::Index::Linearize(absl::Span<const int64> dimensions,
|
||||
|
||||
llvm::Value* IrArray::EmitArrayElementAddress(const IrArray::Index& index,
|
||||
llvm::IRBuilder<>* b,
|
||||
absl::string_view name) const {
|
||||
absl::string_view name,
|
||||
bool use_linear_index) const {
|
||||
if (ShapeUtil::IsScalar(shape_)) {
|
||||
// Special handling of scalars: a scalar pretends to have the same value for
|
||||
// every index, thus effectively implementing broadcasting of its value
|
||||
@ -340,7 +341,7 @@ llvm::Value* IrArray::EmitArrayElementAddress(const IrArray::Index& index,
|
||||
}
|
||||
CHECK_EQ(index.size(), shape_.rank());
|
||||
|
||||
if (index.LinearValidOnShape(shape_)) {
|
||||
if (use_linear_index && index.LinearValidOnShape(shape_)) {
|
||||
llvm::Module* module = b->GetInsertBlock()->getParent()->getParent();
|
||||
return b->CreateInBoundsGEP(
|
||||
b->CreateBitCast(base_ptr_,
|
||||
@ -389,16 +390,20 @@ void IrArray::AnnotateLoadStoreInstructionWithMetadata(
|
||||
|
||||
llvm::Value* IrArray::EmitReadArrayElement(const Index& index,
|
||||
llvm::IRBuilder<>* b,
|
||||
absl::string_view name) const {
|
||||
llvm::Value* element_address = EmitArrayElementAddress(index, b, name);
|
||||
absl::string_view name,
|
||||
bool use_linear_index) const {
|
||||
llvm::Value* element_address =
|
||||
EmitArrayElementAddress(index, b, name, use_linear_index);
|
||||
llvm::LoadInst* load = b->CreateLoad(element_address);
|
||||
AnnotateLoadStoreInstructionWithMetadata(load);
|
||||
return load;
|
||||
}
|
||||
|
||||
void IrArray::EmitWriteArrayElement(const Index& index, llvm::Value* value,
|
||||
llvm::IRBuilder<>* b) const {
|
||||
llvm::Value* element_address = EmitArrayElementAddress(index, b);
|
||||
llvm::IRBuilder<>* b,
|
||||
bool use_linear_index) const {
|
||||
llvm::Value* element_address =
|
||||
EmitArrayElementAddress(index, b, "", use_linear_index);
|
||||
llvm::StoreInst* store = b->CreateStore(value, element_address);
|
||||
AnnotateLoadStoreInstructionWithMetadata(store);
|
||||
}
|
||||
|
@ -168,8 +168,6 @@ class IrArray {
|
||||
return llvm::ConstantInt::get(index_type_, c);
|
||||
}
|
||||
|
||||
void ClearLinearIndex() { linear_ = nullptr; }
|
||||
|
||||
private:
|
||||
// Constructs an index from both a multi-dimensional index and a linear
|
||||
// index. 'shape' is the shape on which the index is used. 'index_type' is
|
||||
@ -227,7 +225,8 @@ class IrArray {
|
||||
// The optional name is useful for debugging when looking at
|
||||
// the emitted LLVM IR.
|
||||
llvm::Value* EmitArrayElementAddress(const Index& index, llvm::IRBuilder<>* b,
|
||||
absl::string_view name = "") const;
|
||||
absl::string_view name = "",
|
||||
bool use_linear_index = true) const;
|
||||
|
||||
// Attach metadata this IrArray instance knows about to "instruction".
|
||||
void AnnotateLoadStoreInstructionWithMetadata(
|
||||
@ -240,15 +239,23 @@ class IrArray {
|
||||
//
|
||||
// The optional name is useful for debugging when looking at
|
||||
// the emitted LLVM IR.
|
||||
// 'use_linear_index' can be used to specify whether the linear index (if
|
||||
// available) or the multi-dimensional index should be used.
|
||||
llvm::Value* EmitReadArrayElement(const Index& index, llvm::IRBuilder<>* b,
|
||||
absl::string_view name = "") const;
|
||||
absl::string_view name = "",
|
||||
bool use_linear_index = true) const;
|
||||
|
||||
// Emit IR to write the given value to the array element at the given index.
|
||||
// 'use_linear_index' can be used to specify whether the linear index (if
|
||||
// available) or the multi-dimensional index should be used.
|
||||
void EmitWriteArrayElement(const Index& index, llvm::Value* value,
|
||||
llvm::IRBuilder<>* b) const;
|
||||
llvm::IRBuilder<>* b,
|
||||
bool use_linear_index = true) const;
|
||||
|
||||
// Returns a new IrArray whose shape is "new_shape" and base pointer is a
|
||||
// bitcast of the base pointer of "this" IrArray.
|
||||
// 'use_linear_index' can be used to specify whether the linear index (if
|
||||
// available) or the multi-dimensional index should be used.
|
||||
IrArray CastToShape(const Shape& new_shape, llvm::IRBuilder<>* b) const;
|
||||
|
||||
void AddAliasScopeMetadata(llvm::MDNode* alias_scope) {
|
||||
|
@ -70,7 +70,7 @@ Status KernelSupportLibrary::IfWithStatus(
|
||||
}
|
||||
|
||||
void KernelSupportLibrary::EmitAndCallOutlinedKernel(
|
||||
bool enable_fast_math, bool optimize_for_size, llvm::IRBuilder<>* b,
|
||||
const HloModuleConfig& module_config, llvm::IRBuilder<>* b,
|
||||
absl::string_view kernel_name,
|
||||
KernelSupportLibrary::ArgumentVector arguments,
|
||||
const std::function<void(KernelSupportLibrary::ArgumentVector)>&
|
||||
@ -101,10 +101,9 @@ void KernelSupportLibrary::EmitAndCallOutlinedKernel(
|
||||
auto* function_type =
|
||||
llvm::FunctionType::get(b->getVoidTy(), arg_types, /*isVarArg=*/false);
|
||||
|
||||
function = llvm_ir::CreateFunction(
|
||||
function_type, llvm::GlobalValue::InternalLinkage,
|
||||
/*enable_fast_math=*/enable_fast_math,
|
||||
/*optimize_for_size=*/optimize_for_size, kernel_name, module);
|
||||
function = llvm_ir::CreateCpuFunction(function_type,
|
||||
llvm::GlobalValue::InternalLinkage,
|
||||
module_config, kernel_name, module);
|
||||
|
||||
llvm::IRBuilder<>::InsertPointGuard guard(*b);
|
||||
|
||||
|
@ -263,33 +263,33 @@ class KernelSupportLibrary {
|
||||
// in a nullptr llvm::Value* in its position to `kernel_body_generator`.
|
||||
// Currently we only support at most one nullptr value in `arguments`.
|
||||
static void EmitAndCallOutlinedKernel(
|
||||
bool enable_fast_math, bool optimize_for_size, llvm::IRBuilder<>* b,
|
||||
const HloModuleConfig& module_config, llvm::IRBuilder<>* b,
|
||||
absl::string_view kernel_name, ArgumentVector arguments,
|
||||
const std::function<void(ArgumentVector)>& kernel_body_generator);
|
||||
|
||||
// Thin wrappers around the more general EmitAndCallOutlinedKernel above.
|
||||
static void EmitAndCallOutlinedKernel(
|
||||
bool enable_fast_math, bool optimize_for_size, llvm::IRBuilder<>* b,
|
||||
const HloModuleConfig& module_config, llvm::IRBuilder<>* b,
|
||||
absl::string_view kernel_name, llvm::Value* arg0, llvm::Value* arg1,
|
||||
llvm::Value* arg2,
|
||||
const std::function<void(llvm::Value*, llvm::Value*, llvm::Value*)>&
|
||||
kernel_body_generator) {
|
||||
EmitAndCallOutlinedKernel(
|
||||
enable_fast_math, optimize_for_size, b, kernel_name, {arg0, arg1, arg2},
|
||||
[&](ArgumentVector args) {
|
||||
kernel_body_generator(args[0], args[1], args[2]);
|
||||
});
|
||||
EmitAndCallOutlinedKernel(module_config, b, kernel_name, {arg0, arg1, arg2},
|
||||
[&](ArgumentVector args) {
|
||||
kernel_body_generator(args[0], args[1],
|
||||
args[2]);
|
||||
});
|
||||
}
|
||||
|
||||
static void EmitAndCallOutlinedKernel(
|
||||
bool enable_fast_math, bool optimize_for_size, llvm::IRBuilder<>* b,
|
||||
const HloModuleConfig& module_config, llvm::IRBuilder<>* b,
|
||||
absl::string_view kernel_name, llvm::Value* arg0, llvm::Value* arg1,
|
||||
llvm::Value* arg2, llvm::Value* arg3,
|
||||
const std::function<void(llvm::Value*, llvm::Value*, llvm::Value*,
|
||||
llvm::Value*)>& kernel_body_generator) {
|
||||
EmitAndCallOutlinedKernel(
|
||||
enable_fast_math, optimize_for_size, b, kernel_name,
|
||||
{arg0, arg1, arg2, arg3}, [&](ArgumentVector args) {
|
||||
module_config, b, kernel_name, {arg0, arg1, arg2, arg3},
|
||||
[&](ArgumentVector args) {
|
||||
kernel_body_generator(args[0], args[1], args[2], args[3]);
|
||||
});
|
||||
}
|
||||
|
@ -31,6 +31,7 @@ limitations under the License.
|
||||
#include "llvm/Transforms/Utils/Cloning.h"
|
||||
#include "tensorflow/compiler/xla/layout_util.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/cpu_options.h"
|
||||
#include "tensorflow/compiler/xla/service/dump.h"
|
||||
#include "tensorflow/compiler/xla/service/name_uniquer.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
@ -499,24 +500,25 @@ int64 ByteSizeOf(const Shape& shape, const llvm::DataLayout& data_layout) {
|
||||
return ShapeUtil::ByteSizeOf(shape, pointer_size);
|
||||
}
|
||||
|
||||
llvm::FastMathFlags GetFastMathFlags(bool fast_math_enabled) {
|
||||
llvm::FastMathFlags GetCpuFastMathFlags(const HloModuleConfig& module_config) {
|
||||
llvm::FastMathFlags flags;
|
||||
if (fast_math_enabled) {
|
||||
// Fast implies AllowReassoc, NoInfs, NoNaNs, NoSignedZeros,
|
||||
// AllowReciprocal, AllowContract, and ApproxFunc.
|
||||
flags.setFast();
|
||||
if (!module_config.debug_options().xla_cpu_enable_fast_math()) {
|
||||
return flags;
|
||||
}
|
||||
return flags;
|
||||
}
|
||||
|
||||
void SetTargetOptions(bool fast_math_enabled,
|
||||
llvm::TargetOptions* target_options) {
|
||||
// In LLVM backend flags, UnsafeFPMath does not explicitly imply
|
||||
// NoInfs, etc.
|
||||
target_options->UnsafeFPMath = fast_math_enabled;
|
||||
target_options->NoInfsFPMath = fast_math_enabled;
|
||||
target_options->NoNaNsFPMath = fast_math_enabled;
|
||||
target_options->NoSignedZerosFPMath = fast_math_enabled;
|
||||
// Fast implies AllowReassoc, NoInfs, NoNaNs, NoSignedZeros, AllowReciprocal,
|
||||
// AllowContract, and ApproxFunc.
|
||||
flags.setFast();
|
||||
|
||||
if (module_config.debug_options().xla_cpu_fast_math_honor_nans()) {
|
||||
flags.setNoNaNs(false);
|
||||
}
|
||||
|
||||
if (module_config.debug_options().xla_cpu_fast_math_honor_infs()) {
|
||||
flags.setNoInfs(false);
|
||||
}
|
||||
|
||||
return flags;
|
||||
}
|
||||
|
||||
std::map<int, llvm::MDNode*> MergeMetadata(
|
||||
@ -603,10 +605,11 @@ void DumpIrIfEnabled(const HloModule& hlo_module,
|
||||
}
|
||||
}
|
||||
|
||||
llvm::Function* CreateFunction(llvm::FunctionType* function_type,
|
||||
llvm::GlobalValue::LinkageTypes linkage,
|
||||
bool enable_fast_math, bool optimize_for_size,
|
||||
absl::string_view name, llvm::Module* module) {
|
||||
llvm::Function* CreateCpuFunction(llvm::FunctionType* function_type,
|
||||
llvm::GlobalValue::LinkageTypes linkage,
|
||||
const HloModuleConfig& module_config,
|
||||
absl::string_view name,
|
||||
llvm::Module* module) {
|
||||
llvm::Function* function =
|
||||
llvm::Function::Create(function_type, linkage, AsStringRef(name), module);
|
||||
function->setCallingConv(llvm::CallingConv::C);
|
||||
@ -616,17 +619,23 @@ llvm::Function* CreateFunction(llvm::FunctionType* function_type,
|
||||
// created by the JIT compiled code.
|
||||
function->setHasUWTable();
|
||||
|
||||
if (enable_fast_math) {
|
||||
if (module_config.debug_options().xla_cpu_enable_fast_math()) {
|
||||
function->addFnAttr("unsafe-fp-math", "true");
|
||||
function->addFnAttr("no-infs-fp-math", "true");
|
||||
function->addFnAttr("no-nans-fp-math", "true");
|
||||
function->addFnAttr("no-signed-zeros-fp-math", "true");
|
||||
|
||||
if (!module_config.debug_options().xla_cpu_fast_math_honor_nans()) {
|
||||
function->addFnAttr("no-nans-fp-math", "true");
|
||||
}
|
||||
|
||||
if (!module_config.debug_options().xla_cpu_fast_math_honor_infs()) {
|
||||
function->addFnAttr("no-infs-fp-math", "true");
|
||||
}
|
||||
}
|
||||
|
||||
// Add the optize attribute to the function if optimizing for size. This
|
||||
// controls internal behavior of some optimization passes (e.g. loop
|
||||
// unrolling).
|
||||
if (optimize_for_size) {
|
||||
if (cpu::options::OptimizeForSizeRequested(module_config)) {
|
||||
function->addFnAttr(llvm::Attribute::OptimizeForSize);
|
||||
}
|
||||
|
||||
|
@ -263,12 +263,7 @@ int64 ByteSizeOf(const Shape& shape, const llvm::DataLayout& data_layout);
|
||||
|
||||
// Gets an llvm::FastMathFlags that reflects the settings in the given
|
||||
// module config.
|
||||
llvm::FastMathFlags GetFastMathFlags(bool fast_math_enabled);
|
||||
|
||||
// Sets values in the given TargetOptions struct according to the given
|
||||
// compilation options.
|
||||
void SetTargetOptions(bool fast_math_enabled,
|
||||
llvm::TargetOptions* target_options);
|
||||
llvm::FastMathFlags GetCpuFastMathFlags(const HloModuleConfig& module_config);
|
||||
|
||||
// Computes a conservative union of the metadata in "a" and "b". For
|
||||
// aliasing-related metadata, this means the result can be applied to
|
||||
@ -287,10 +282,10 @@ std::map<int, llvm::MDNode*> MergeMetadata(
|
||||
void DumpIrIfEnabled(const HloModule& hlo_module,
|
||||
const llvm::Module& llvm_module, bool optimized);
|
||||
|
||||
llvm::Function* CreateFunction(llvm::FunctionType* function_type,
|
||||
llvm::GlobalValue::LinkageTypes linkage,
|
||||
bool enable_fast_math, bool optimize_for_size,
|
||||
absl::string_view name, llvm::Module* module);
|
||||
llvm::Function* CreateCpuFunction(llvm::FunctionType* function_type,
|
||||
llvm::GlobalValue::LinkageTypes linkage,
|
||||
const HloModuleConfig& module_config,
|
||||
absl::string_view name, llvm::Module* module);
|
||||
|
||||
// Extracts the xla_backend_extra_options from `config` and passes those that
|
||||
// don't start with xla_ to LLVM.
|
||||
|
@ -29,9 +29,14 @@ limitations under the License.
|
||||
namespace xla {
|
||||
namespace llvm_ir {
|
||||
|
||||
static llvm::Module* getModuleFromBuilder(llvm::IRBuilder<>* b) {
|
||||
return b->GetInsertBlock()->getModule();
|
||||
}
|
||||
|
||||
void EmitTupleSelect(const IrArray& select, const IrArray& pred,
|
||||
llvm::Value* on_true, llvm::Value* on_false,
|
||||
llvm::IRBuilder<>* b, llvm::Module* module) {
|
||||
llvm::IRBuilder<>* b) {
|
||||
llvm::Module* module = getModuleFromBuilder(b);
|
||||
CHECK(ShapeUtil::IsScalar(pred.GetShape()));
|
||||
|
||||
llvm::LoadInst* pred_value =
|
||||
@ -65,7 +70,8 @@ void EmitTupleSelect(const IrArray& select, const IrArray& pred,
|
||||
}
|
||||
|
||||
void EmitTuple(const IrArray& tuple, absl::Span<llvm::Value* const> operands,
|
||||
llvm::IRBuilder<>* b, llvm::Module* module) {
|
||||
llvm::IRBuilder<>* b) {
|
||||
llvm::Module* module = getModuleFromBuilder(b);
|
||||
for (size_t i = 0; i < operands.size(); ++i) {
|
||||
auto* store = b->CreateStore(
|
||||
b->CreatePointerCast(operands[i], PrimitiveTypeToIrType(TUPLE, module)),
|
||||
@ -76,18 +82,19 @@ void EmitTuple(const IrArray& tuple, absl::Span<llvm::Value* const> operands,
|
||||
}
|
||||
|
||||
void EmitTuple(const IrArray& tuple, absl::Span<const IrArray> buffers,
|
||||
llvm::IRBuilder<>* b, llvm::Module* module) {
|
||||
llvm::IRBuilder<>* b) {
|
||||
std::vector<llvm::Value*> buffer_ptrs;
|
||||
buffer_ptrs.reserve(buffers.size());
|
||||
absl::c_transform(
|
||||
buffers, std::back_inserter(buffer_ptrs),
|
||||
[](const llvm_ir::IrArray& buffer) { return buffer.GetBasePointer(); });
|
||||
llvm_ir::EmitTuple(tuple, buffer_ptrs, b, module);
|
||||
llvm_ir::EmitTuple(tuple, buffer_ptrs, b);
|
||||
}
|
||||
|
||||
llvm::Value* EmitGetTupleElement(const Shape& target_shape, int64 index,
|
||||
int alignment, llvm::Value* operand,
|
||||
llvm::IRBuilder<>* b, llvm::Module* module) {
|
||||
llvm::IRBuilder<>* b) {
|
||||
llvm::Module* module = getModuleFromBuilder(b);
|
||||
llvm::Value* element_ptr =
|
||||
b->CreateInBoundsGEP(operand, {b->getInt64(0), b->getInt64(index)});
|
||||
llvm::LoadInst* src_buffer = b->CreateLoad(element_ptr);
|
||||
|
@ -61,17 +61,17 @@ namespace llvm_ir {
|
||||
// output[i] = pred ? tuple_on_true[i] : tuple_on_false[i]
|
||||
void EmitTupleSelect(const IrArray& select, const IrArray& pred,
|
||||
llvm::Value* on_true, llvm::Value* on_false,
|
||||
llvm::IRBuilder<>* b, llvm::Module* module);
|
||||
llvm::IRBuilder<>* b);
|
||||
|
||||
// A tuple is an array of pointers, one for each operand. Each pointer points to
|
||||
// the output buffer of its corresponding operand.
|
||||
void EmitTuple(const IrArray& tuple, absl::Span<llvm::Value* const> operands,
|
||||
llvm::IRBuilder<>* b, llvm::Module* module);
|
||||
llvm::IRBuilder<>* b);
|
||||
|
||||
// Similar to EmitTuple above, except that the output buffers are provided in
|
||||
// the form of IrArray.
|
||||
void EmitTuple(const IrArray& tuple, absl::Span<const IrArray> buffers,
|
||||
llvm::IRBuilder<>* b, llvm::Module* module);
|
||||
llvm::IRBuilder<>* b);
|
||||
|
||||
// A tuple is an array of pointers, one for each operand. Each pointer points to
|
||||
// the output buffer of its corresponding operand. A GetTupleElement instruction
|
||||
@ -79,7 +79,7 @@ void EmitTuple(const IrArray& tuple, absl::Span<const IrArray> buffers,
|
||||
// Returns an llvm value representing a pointer to the tuple element buffer.
|
||||
llvm::Value* EmitGetTupleElement(const Shape& target_shape, int64 index,
|
||||
int alignment, llvm::Value* operand,
|
||||
llvm::IRBuilder<>* b, llvm::Module* module);
|
||||
llvm::IRBuilder<>* b);
|
||||
} // namespace llvm_ir
|
||||
} // namespace xla
|
||||
|
||||
|
@ -156,10 +156,21 @@ message DebugOptions {
|
||||
//
|
||||
// - Reducing the precision of operations (e.g. using an approximate sin
|
||||
// function, or transforming x/y into x * (1/y)).
|
||||
// - Assuming that operations never produce or consume NaN or +/- Inf.
|
||||
// - Assuming that operations never produce or consume NaN or +/- Inf (this
|
||||
// behavior can be adjusted using xla_cpu_fast_math_allow_{nans|infs}).
|
||||
// - Assuming that +0 and -0 are indistinguishable.
|
||||
bool xla_cpu_enable_fast_math = 99;
|
||||
|
||||
// When xla_cpu_enable_fast_math is true then this controls whether we allow
|
||||
// operations to produce NaNs. Ignored when xla_cpu_enable_fast_math is
|
||||
// false.
|
||||
bool xla_cpu_fast_math_honor_nans = 120;
|
||||
|
||||
// When xla_cpu_enable_fast_math is true then this controls whether we allow
|
||||
// operations to produce infinites. Ignored when xla_cpu_enable_fast_math is
|
||||
// false.
|
||||
bool xla_cpu_fast_math_honor_infs = 121;
|
||||
|
||||
// When true we lower the Minimum and Maximum hlos in the GPU backend such
|
||||
// that Min(NotNaN, NaN) = Min(NaN, NotNaN) = NotNaN. In other words, if flag
|
||||
// this is true we don't propagate NaNs through Min and Max.
|
||||
@ -250,7 +261,7 @@ message DebugOptions {
|
||||
// END flags controlling dumping HLO modules.
|
||||
//
|
||||
|
||||
// Next id: 119
|
||||
// Next id: 121
|
||||
|
||||
// Extra options to pass to the compilation backend (e.g. LLVM); specific
|
||||
// interpretation of these values is left to the backend.
|
||||
|
@ -1149,9 +1149,9 @@ class GbdtTest(test_util.TensorFlowTestCase):
|
||||
expected_leaf_1 = [-3.4480, -3.4429, 13.8490, -3.45, -3.4508]
|
||||
expected_leaf_2 = [-1.2547, -1.3145, 1.52, 2.3875, -1.3264]
|
||||
self.assertArrayNear(expected_leaf_1,
|
||||
output.trees[0].nodes[1].leaf.vector.value, 2e-3)
|
||||
output.trees[0].nodes[1].leaf.vector.value, 3e-3)
|
||||
self.assertArrayNear(expected_leaf_2,
|
||||
output.trees[0].nodes[2].leaf.vector.value, 2e-3)
|
||||
output.trees[0].nodes[2].leaf.vector.value, 3e-3)
|
||||
|
||||
def testTrainFnMulticlassDiagonalHessian(self):
|
||||
"""Tests the GBDT train for multiclass diagonal hessian."""
|
||||
|
@ -64,9 +64,9 @@ py_library(
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:summary_op_util",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python/distribute:summary_op_util",
|
||||
"//tensorflow/python/estimator:estimator_py",
|
||||
],
|
||||
)
|
||||
|
@ -25,11 +25,11 @@ from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
from tensorflow.compiler.jit.ops import xla_ops
|
||||
from tensorflow.compiler.jit.ops import xla_ops_grad # pylint: disable=unused-import
|
||||
from tensorflow.core.framework import attr_value_pb2
|
||||
from tensorflow.python.distribute import summary_op_util
|
||||
from tensorflow.python.estimator import model_fn as model_fn_lib
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import summary_op_util
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.util import compat
|
||||
|
@ -266,7 +266,7 @@ class CollectiveAllReduceStrategyTestBase(
|
||||
target=master_target) as sess:
|
||||
with d.scope():
|
||||
train_op = d.extended.call_for_each_replica(model_fn)
|
||||
train_op = d.group(d.unwrap(train_op))
|
||||
train_op = d.group(d.experimental_local_results(train_op))
|
||||
|
||||
sess.run(variables.global_variables_initializer())
|
||||
sess.run(train_op)
|
||||
@ -294,7 +294,7 @@ class CollectiveAllReduceStrategyTestBase(
|
||||
|
||||
x = distribution.extended.call_for_each_replica(model_fn)
|
||||
reduced_x = distribution.reduce(reduce_util.ReduceOp.MEAN, x)
|
||||
x = distribution.unwrap(x)[0]
|
||||
x = distribution.experimental_local_results(x)[0]
|
||||
|
||||
sess.run(variables.global_variables_initializer())
|
||||
|
||||
|
@ -749,7 +749,7 @@ class TestDistributionStrategyWithDatasets(test.TestCase,
|
||||
|
||||
model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0,
|
||||
callbacks=[keras.callbacks.LearningRateScheduler(schedule)])
|
||||
grouped_models = distribution.unwrap(
|
||||
grouped_models = distribution.experimental_local_results(
|
||||
distributed_training_utils.get_distributed_model(
|
||||
model, ModeKeys.TRAIN))
|
||||
with distribution.scope():
|
||||
|
@ -220,7 +220,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def step_fn(ctx, inputs):
|
||||
del ctx # Unused
|
||||
fetches = distribution.unwrap(
|
||||
fetches = distribution.experimental_local_results(
|
||||
distribution.extended.call_for_each_replica(
|
||||
model_fn, args=(inputs,)))
|
||||
if update_ops_in_cross_replica_mode:
|
||||
@ -419,8 +419,9 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
|
||||
# values that are of the same structure as non reduced losses. In
|
||||
# MirroredStrategy, this will be a list of losses, in TPUStrategy
|
||||
# it will be single tensor. Using `call_for_each_replica` followed
|
||||
# by `unwrap` gives us the desired initial value structure.
|
||||
not_reduced = distribution.unwrap(
|
||||
# by `experimental_local_results` gives us the desired initial
|
||||
# value structure.
|
||||
not_reduced = distribution.experimental_local_results(
|
||||
distribution.extended.call_for_each_replica(initial_loss))
|
||||
initial_loop_values = {
|
||||
"replica_loss_reduced": initial_loss(),
|
||||
@ -469,11 +470,11 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
|
||||
def _verify_loss_output(self, initial_loss, loss_output, reduced,
|
||||
distribution):
|
||||
if not reduced:
|
||||
self.assertLen(distribution.unwrap(loss_output),
|
||||
self.assertLen(distribution.experimental_local_results(loss_output),
|
||||
distribution.num_replicas_in_sync)
|
||||
loss_tensor = distribution.reduce(reduce_util.ReduceOp.MEAN, loss_output)
|
||||
else:
|
||||
unwrapped_output = distribution.unwrap(loss_output)
|
||||
unwrapped_output = distribution.experimental_local_results(loss_output)
|
||||
self.assertLen(unwrapped_output, 1)
|
||||
loss_tensor = unwrapped_output[0]
|
||||
self.assertEqual(initial_loss.dtype, loss_tensor.dtype)
|
||||
|
@ -249,7 +249,7 @@ class MirroredStrategyVariableCreatorStackTest(
|
||||
distribution.scope(), \
|
||||
variable_scope.variable_creator_scope(main_thread_creator):
|
||||
result = distribution.extended.call_for_each_replica(model_fn)
|
||||
result = distribution.unwrap(result)
|
||||
result = distribution.experimental_local_results(result)
|
||||
expected = ("main_thread:thread_0", "main_thread:thread_1")
|
||||
self.assertEqual(expected, result)
|
||||
|
||||
@ -269,7 +269,7 @@ class MirroredStrategyCallForEachReplicaTest(test.TestCase):
|
||||
with distribution.scope():
|
||||
in_scope = ops.executing_eagerly_outside_functions()
|
||||
in_model_fn = distribution.extended.call_for_each_replica(model_fn)
|
||||
unwrapped = distribution.unwrap(in_model_fn)
|
||||
unwrapped = distribution.experimental_local_results(in_model_fn)
|
||||
self.assertEqual(in_scope, unwrapped[0])
|
||||
self.assertEqual(in_scope, originally)
|
||||
|
||||
@ -277,7 +277,7 @@ class MirroredStrategyCallForEachReplicaTest(test.TestCase):
|
||||
with func_graph.FuncGraph("fg").as_default(), distribution.scope():
|
||||
in_scope = ops.executing_eagerly_outside_functions()
|
||||
in_model_fn = distribution.extended.call_for_each_replica(model_fn)
|
||||
unwrapped = distribution.unwrap(in_model_fn)
|
||||
unwrapped = distribution.experimental_local_results(in_model_fn)
|
||||
self.assertEqual(in_scope, unwrapped[0])
|
||||
self.assertEqual(in_scope, originally)
|
||||
|
||||
@ -701,7 +701,8 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
|
||||
|
||||
# Apply updates
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.evaluate([y for x in ret_ops for y in distribution.unwrap(x)])
|
||||
self.evaluate([y for x in ret_ops # pylint: disable=g-complex-comprehension
|
||||
for y in distribution.experimental_local_results(x)])
|
||||
expected_sum = 0.0
|
||||
expected_mean = 0.0
|
||||
for i, d in enumerate(distribution.extended.worker_devices):
|
||||
@ -747,7 +748,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
|
||||
self.assertEqual(2, len(result))
|
||||
for v in result:
|
||||
self.assertIsInstance(v, values.DistributedValues)
|
||||
_, v1 = distribution.unwrap(v)
|
||||
_, v1 = distribution.experimental_local_results(v)
|
||||
self.assertStartsWith(v1._op.name, "replica_1/")
|
||||
|
||||
def testSyncOnReadVariableUpdate(self, distribution):
|
||||
@ -816,7 +817,7 @@ class MirroredStrategyNameScopeTest(test.TestCase):
|
||||
self.assertEqual(2, len(result))
|
||||
for v, name in zip(result, ["a", "b"]):
|
||||
self.assertIsInstance(v, values.DistributedValues)
|
||||
v0, v1 = distribution.unwrap(v)
|
||||
v0, v1 = distribution.experimental_local_results(v)
|
||||
self.assertEqual("main/foo/" + name + ":0", v0.name)
|
||||
self.assertEqual("main/replica_1/foo/" + name + ":0", v1.name)
|
||||
|
||||
@ -833,7 +834,7 @@ class MirroredStrategyNameScopeTest(test.TestCase):
|
||||
self.assertEqual(2, len(result))
|
||||
for v, name in zip(result, ["a", "b"]):
|
||||
self.assertIsInstance(v, values.DistributedValues)
|
||||
v0, v1 = distribution.unwrap(v)
|
||||
v0, v1 = distribution.experimental_local_results(v)
|
||||
self.assertEqual("foo/" + name + ":0", v0.name)
|
||||
self.assertEqual("replica_1/foo/" + name + ":0", v1.name)
|
||||
|
||||
@ -860,9 +861,9 @@ class MirroredStrategyNameScopeTest(test.TestCase):
|
||||
result_c = result[1]
|
||||
self.assertIsInstance(result_b, values.DistributedValues)
|
||||
self.assertIsInstance(result_c, values.DistributedValues)
|
||||
a0, a1 = distribution.unwrap(a)
|
||||
b0, b1 = distribution.unwrap(result_b)
|
||||
c0, c1 = distribution.unwrap(result_c)
|
||||
a0, a1 = distribution.experimental_local_results(a)
|
||||
b0, b1 = distribution.experimental_local_results(result_b)
|
||||
c0, c1 = distribution.experimental_local_results(result_c)
|
||||
self.assertEqual("main/a:0", a0.name)
|
||||
self.assertEqual("main/a/replica_1:0", a1.name)
|
||||
self.assertEqual("main/b:0", b0.name)
|
||||
@ -889,9 +890,9 @@ class MirroredStrategyNameScopeTest(test.TestCase):
|
||||
result_c = result[1]
|
||||
self.assertIsInstance(result_b, values.DistributedValues)
|
||||
self.assertIsInstance(result_c, values.DistributedValues)
|
||||
a0, a1 = distribution.unwrap(a)
|
||||
b0, b1 = distribution.unwrap(result_b)
|
||||
c0, c1 = distribution.unwrap(result_c)
|
||||
a0, a1 = distribution.experimental_local_results(a)
|
||||
b0, b1 = distribution.experimental_local_results(result_b)
|
||||
c0, c1 = distribution.experimental_local_results(result_c)
|
||||
self.assertEqual("a:0", a0.name)
|
||||
self.assertEqual("a/replica_1:0", a1.name)
|
||||
self.assertEqual("b:0", b0.name)
|
||||
@ -961,7 +962,7 @@ class MirroredVariableUpdateTest(test.TestCase):
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, "You must specify an aggregation method to update a "
|
||||
"MirroredVariable in Replica Context."):
|
||||
self.evaluate(distribution.unwrap(
|
||||
self.evaluate(distribution.experimental_local_results(
|
||||
distribution.extended.call_for_each_replica(model_fn)))
|
||||
|
||||
def testAssignMirroredVarReplicaContextWithSum(self, distribution):
|
||||
@ -983,7 +984,7 @@ class MirroredVariableUpdateTest(test.TestCase):
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, "A non-DistributedValues value 5.0 cannot be reduced "
|
||||
"with the given reduce op ReduceOp.SUM."):
|
||||
self.evaluate(distribution.unwrap(
|
||||
self.evaluate(distribution.experimental_local_results(
|
||||
distribution.extended.call_for_each_replica(model_fn)))
|
||||
|
||||
def testAssignMirroredVarCrossDeviceContext(self, distribution):
|
||||
@ -1015,7 +1016,7 @@ class MirroredVariableUpdateTest(test.TestCase):
|
||||
mirrored_var.dtype)
|
||||
return mirrored_var.assign(value)
|
||||
|
||||
self.evaluate(distribution.unwrap(
|
||||
self.evaluate(distribution.experimental_local_results(
|
||||
distribution.extended.call_for_each_replica(model_fn)))
|
||||
self.assertEqual(0.5, self.evaluate(mirrored_var))
|
||||
|
||||
@ -1033,7 +1034,7 @@ class MirroredVariableUpdateTest(test.TestCase):
|
||||
def model_fn():
|
||||
return mirrored_var.assign(5.0)
|
||||
|
||||
self.evaluate(distribution.unwrap(
|
||||
self.evaluate(distribution.experimental_local_results(
|
||||
distribution.extended.call_for_each_replica(model_fn)))
|
||||
self.assertEqual(5.0, self.evaluate(mirrored_var))
|
||||
|
||||
@ -1076,7 +1077,7 @@ class MirroredVariableUpdateTest(test.TestCase):
|
||||
mirrored_var.dtype)
|
||||
return mirrored_var.assign_add(value)
|
||||
|
||||
self.evaluate(distribution.unwrap(
|
||||
self.evaluate(distribution.experimental_local_results(
|
||||
distribution.extended.call_for_each_replica(model_fn)))
|
||||
self.assertEqual(1.5, self.evaluate(mirrored_var))
|
||||
|
||||
@ -1094,7 +1095,7 @@ class MirroredVariableUpdateTest(test.TestCase):
|
||||
def model_fn():
|
||||
return mirrored_var.assign_add(5.0)
|
||||
|
||||
self.evaluate(distribution.unwrap(
|
||||
self.evaluate(distribution.experimental_local_results(
|
||||
distribution.extended.call_for_each_replica(model_fn)))
|
||||
self.assertEqual(6.0, self.evaluate(mirrored_var))
|
||||
|
||||
@ -1129,7 +1130,7 @@ class MirroredVariableUpdateTest(test.TestCase):
|
||||
mirrored_var.dtype)
|
||||
return mirrored_var.assign_sub(value)
|
||||
|
||||
self.evaluate(distribution.unwrap(
|
||||
self.evaluate(distribution.experimental_local_results(
|
||||
distribution.extended.call_for_each_replica(model_fn)))
|
||||
self.assertEqual(4.5, self.evaluate(mirrored_var))
|
||||
|
||||
@ -1147,7 +1148,7 @@ class MirroredVariableUpdateTest(test.TestCase):
|
||||
def model_fn():
|
||||
return mirrored_var.assign_sub(1.0)
|
||||
|
||||
self.evaluate(distribution.unwrap(
|
||||
self.evaluate(distribution.experimental_local_results(
|
||||
distribution.extended.call_for_each_replica(model_fn)))
|
||||
self.assertEqual(4.0, self.evaluate(mirrored_var))
|
||||
|
||||
|
@ -56,7 +56,7 @@ class AssignMovingAveragesTest(test.TestCase, parameterized.TestCase):
|
||||
var, assign = distribution.extended.call_for_each_replica(replica_fn)
|
||||
variables.global_variables_initializer().run()
|
||||
self.assertAllClose([10.0, 11.0], var.eval())
|
||||
sess.run(distribution.unwrap(assign))
|
||||
sess.run(distribution.experimental_local_results(assign))
|
||||
# Mean of val across calls to replica_fn().
|
||||
average_val = [1.0 + 0.5 * (replica_id[0] - 1),
|
||||
2.0 - 0.5 * (replica_id[0] - 1)]
|
||||
@ -82,7 +82,7 @@ class AssignMovingAveragesTest(test.TestCase, parameterized.TestCase):
|
||||
var, assign_op = distribution.extended.call_for_each_replica(replica_fn)
|
||||
variables.global_variables_initializer().run()
|
||||
self.assertAllClose([0.0, 0.0], var.eval())
|
||||
sess.run(distribution.unwrap(assign_op))
|
||||
sess.run(distribution.experimental_local_results(assign_op))
|
||||
# Mean of val across calls to replica_fn().
|
||||
average_val = [1.0 + 0.5 * (replica_id[0] - 1),
|
||||
2.0 - 0.5 * (replica_id[0] - 1)]
|
||||
@ -155,7 +155,7 @@ class AssignMovingAveragesTest(test.TestCase, parameterized.TestCase):
|
||||
var, assign = distribution.extended.call_for_each_replica(replica_fn)
|
||||
variables.global_variables_initializer().run()
|
||||
self.assertAllClose([10.0, 11.0], var.eval())
|
||||
sess.run(distribution.unwrap(assign))
|
||||
sess.run(distribution.experimental_local_results(assign))
|
||||
self.assertAllClose(
|
||||
[10 * 0.25 + 1. * (1 - 0.25), 11 * 0.25 + 2. * (1 - 0.25)],
|
||||
var.eval())
|
||||
|
@ -45,7 +45,7 @@ class MinimizeLossOptimizerV2Test(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def run_step():
|
||||
return control_flow_ops.group(
|
||||
distribution.unwrap(
|
||||
distribution.experimental_local_results(
|
||||
distribution.extended.call_for_each_replica(
|
||||
model_fn, args=(iterator.get_next(),))))
|
||||
|
||||
|
@ -346,7 +346,7 @@ class DistributionTestBase(test.TestCase):
|
||||
|
||||
train_ops, value = strategy.extended.call_for_each_replica(model_fn)
|
||||
self.evaluate(strategy.group(train_ops))
|
||||
global_step_tensors = strategy.unwrap(value)
|
||||
global_step_tensors = strategy.experimental_local_results(value)
|
||||
global_step_values = self.evaluate(global_step_tensors)
|
||||
self.assertEqual((1,) * len(global_step_tensors), global_step_values)
|
||||
|
||||
@ -365,7 +365,8 @@ class DistributionTestBase(test.TestCase):
|
||||
|
||||
def run_and_concatenate(strategy, i):
|
||||
x, y = strategy.experimental_run(lambda z: z, i)
|
||||
x, y = self.evaluate((strategy.unwrap(x), strategy.unwrap(y)))
|
||||
x, y = self.evaluate((strategy.experimental_local_results(x),
|
||||
strategy.experimental_local_results(y)))
|
||||
return np.concatenate(x), np.concatenate(y)
|
||||
|
||||
x_1, y_1 = run_and_concatenate(strategy, i)
|
||||
@ -424,7 +425,8 @@ class OneDeviceDistributionTestBase(test.TestCase):
|
||||
|
||||
self.evaluate(inputs.initialize())
|
||||
outputs = self.evaluate(
|
||||
list(map(strategy.unwrap, strategy.experimental_run(comm_fn, inputs))))
|
||||
list(map(strategy.experimental_local_results,
|
||||
strategy.experimental_run(comm_fn, inputs))))
|
||||
self.assertAllEqual([expected[0]], outputs[0])
|
||||
self.assertAllEqual([expected[1]], outputs[1])
|
||||
|
||||
@ -444,7 +446,8 @@ class OneDeviceDistributionTestBase(test.TestCase):
|
||||
self.evaluate(inputs.initialize())
|
||||
self.assertAllEqual(
|
||||
expected_grads,
|
||||
self.evaluate(strategy.unwrap(strategy.experimental_run(step, inputs))))
|
||||
self.evaluate(strategy.experimental_local_results(
|
||||
strategy.experimental_run(step, inputs))))
|
||||
|
||||
def _test_collective_comms_gradient_tape(
|
||||
self, strategy, comm_fn, inputs, expected_grads):
|
||||
@ -461,7 +464,8 @@ class OneDeviceDistributionTestBase(test.TestCase):
|
||||
self.evaluate(inputs.initialize())
|
||||
self.assertAllEqual(
|
||||
expected_grads,
|
||||
self.evaluate(strategy.unwrap(strategy.experimental_run(step, inputs))))
|
||||
self.evaluate(strategy.experimental_local_results(
|
||||
strategy.experimental_run(step, inputs))))
|
||||
|
||||
|
||||
class TwoDeviceDistributionTestBase(test.TestCase):
|
||||
@ -515,7 +519,8 @@ class TwoDeviceDistributionTestBase(test.TestCase):
|
||||
|
||||
self.evaluate(inputs.initialize())
|
||||
outputs = self.evaluate(
|
||||
list(map(strategy.unwrap, strategy.experimental_run(comm_fn, inputs))))
|
||||
list(map(strategy.experimental_local_results,
|
||||
strategy.experimental_run(comm_fn, inputs))))
|
||||
self.assertAllEqual([expected[0], expected[0]], outputs[0])
|
||||
self.assertAllEqual([expected[1], expected[1]], outputs[1])
|
||||
|
||||
@ -535,7 +540,8 @@ class TwoDeviceDistributionTestBase(test.TestCase):
|
||||
self.evaluate(inputs.initialize())
|
||||
self.assertAllEqual(
|
||||
expected_grads,
|
||||
self.evaluate(strategy.unwrap(strategy.experimental_run(step, inputs))))
|
||||
self.evaluate(strategy.experimental_local_results(
|
||||
strategy.experimental_run(step, inputs))))
|
||||
|
||||
def _test_collective_comms_gradient_tape(
|
||||
self, strategy, comm_fn, inputs, expected_grads):
|
||||
@ -552,7 +558,8 @@ class TwoDeviceDistributionTestBase(test.TestCase):
|
||||
self.evaluate(inputs.initialize())
|
||||
self.assertAllEqual(
|
||||
expected_grads,
|
||||
self.evaluate(strategy.unwrap(strategy.experimental_run(step, inputs))))
|
||||
self.evaluate(strategy.experimental_local_results(
|
||||
strategy.experimental_run(step, inputs))))
|
||||
|
||||
|
||||
def _all_sum(value):
|
||||
|
@ -11,777 +11,17 @@
|
||||
"\n",
|
||||
"Licensed under the Apache License, Version 2.0 (the \"License\").\n",
|
||||
"\n",
|
||||
"# Pix2Pix: An example with tf.keras and eager\n",
|
||||
"\n",
|
||||
"\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\u003ctd\u003e\n",
|
||||
"\u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/pix2pix/pix2pix_eager.ipynb\"\u003e\n",
|
||||
" \u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e \n",
|
||||
"\u003c/td\u003e\u003ctd\u003e\n",
|
||||
"\u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/eager/python/examples/pix2pix/pix2pix_eager.ipynb\"\u003e\u003cimg width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\u003c/td\u003e\u003c/table\u003e"
|
||||
"# Pix2Pix"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "ITZuApL56Mny"
|
||||
"id": "c7W3j96p219v"
|
||||
},
|
||||
"source": [
|
||||
"This notebook demonstrates image to image translation using conditional GAN's, as described in [Image-to-Image Translation with Conditional Adversarial Networks](https://arxiv.org/abs/1611.07004). Using this technique we can colorize black and white photos, convert google maps to google earth, etc. Here, we convert building facades to real buildings. We use [tf.keras](https://www.tensorflow.org/programmers_guide/keras) and [eager execution](https://www.tensorflow.org/programmers_guide/eager) to achieve this.\n",
|
||||
"\n",
|
||||
"In example, we will use the [CMP Facade Database](http://cmp.felk.cvut.cz/~tylecr1/facade/), helpfully provided by the [Center for Machine Perception](http://cmp.felk.cvut.cz/) at the [Czech Technical University in Prague](https://www.cvut.cz/). To keep our example short, we will use a preprocessed [copy](https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/) of this dataset, created by the authors of the [paper](https://arxiv.org/abs/1611.07004) above.\n",
|
||||
"\n",
|
||||
"Each epoch takes around 58 seconds on a single P100 GPU.\n",
|
||||
"\n",
|
||||
"Below is the output generated after training the model for 200 epochs.\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "e1_Y75QXJS6h"
|
||||
},
|
||||
"source": [
|
||||
"## Import TensorFlow and enable eager execution"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "YfIk2es3hJEd"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Import TensorFlow \u003e= 1.10 and enable eager execution\n",
|
||||
"import tensorflow as tf\n",
|
||||
"tf.enable_eager_execution()\n",
|
||||
"\n",
|
||||
"import os\n",
|
||||
"import time\n",
|
||||
"import numpy as np\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"import PIL\n",
|
||||
"from IPython.display import clear_output"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "iYn4MdZnKCey"
|
||||
},
|
||||
"source": [
|
||||
"## Load the dataset\n",
|
||||
"\n",
|
||||
"You can download this dataset and similar datasets from [here](https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets). As mentioned in the [paper](https://arxiv.org/abs/1611.07004) we apply random jittering and mirroring to the training dataset.\n",
|
||||
"* In random jittering, the image is resized to `286 x 286` and then randomly cropped to `256 x 256`\n",
|
||||
"* In random mirroring, the image is randomly flipped horizontally i.e left to right."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "Kn-k8kTXuAlv"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"path_to_zip = tf.keras.utils.get_file('facades.tar.gz',\n",
|
||||
" cache_subdir=os.path.abspath('.'),\n",
|
||||
" origin='https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/facades.tar.gz', \n",
|
||||
" extract=True)\n",
|
||||
"\n",
|
||||
"PATH = os.path.join(os.path.dirname(path_to_zip), 'facades/')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "2CbTEt448b4R"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"BUFFER_SIZE = 400\n",
|
||||
"BATCH_SIZE = 1\n",
|
||||
"IMG_WIDTH = 256\n",
|
||||
"IMG_HEIGHT = 256"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "tyaP4hLJ8b4W"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def load_image(image_file, is_train):\n",
|
||||
" image = tf.read_file(image_file)\n",
|
||||
" image = tf.image.decode_jpeg(image)\n",
|
||||
"\n",
|
||||
" w = tf.shape(image)[1]\n",
|
||||
"\n",
|
||||
" w = w // 2\n",
|
||||
" real_image = image[:, :w, :]\n",
|
||||
" input_image = image[:, w:, :]\n",
|
||||
"\n",
|
||||
" input_image = tf.cast(input_image, tf.float32)\n",
|
||||
" real_image = tf.cast(real_image, tf.float32)\n",
|
||||
"\n",
|
||||
" if is_train:\n",
|
||||
" # random jittering\n",
|
||||
" \n",
|
||||
" # resizing to 286 x 286 x 3\n",
|
||||
" input_image = tf.image.resize_images(input_image, [286, 286], \n",
|
||||
" align_corners=True, \n",
|
||||
" method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)\n",
|
||||
" real_image = tf.image.resize_images(real_image, [286, 286], \n",
|
||||
" align_corners=True, \n",
|
||||
" method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)\n",
|
||||
" \n",
|
||||
" # randomly cropping to 256 x 256 x 3\n",
|
||||
" stacked_image = tf.stack([input_image, real_image], axis=0)\n",
|
||||
" cropped_image = tf.random_crop(stacked_image, size=[2, IMG_HEIGHT, IMG_WIDTH, 3])\n",
|
||||
" input_image, real_image = cropped_image[0], cropped_image[1]\n",
|
||||
"\n",
|
||||
" if np.random.random() \u003e 0.5:\n",
|
||||
" # random mirroring\n",
|
||||
" input_image = tf.image.flip_left_right(input_image)\n",
|
||||
" real_image = tf.image.flip_left_right(real_image)\n",
|
||||
" else:\n",
|
||||
" input_image = tf.image.resize_images(input_image, size=[IMG_HEIGHT, IMG_WIDTH], \n",
|
||||
" align_corners=True, method=2)\n",
|
||||
" real_image = tf.image.resize_images(real_image, size=[IMG_HEIGHT, IMG_WIDTH], \n",
|
||||
" align_corners=True, method=2)\n",
|
||||
" \n",
|
||||
" # normalizing the images to [-1, 1]\n",
|
||||
" input_image = (input_image / 127.5) - 1\n",
|
||||
" real_image = (real_image / 127.5) - 1\n",
|
||||
"\n",
|
||||
" return input_image, real_image"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "PIGN6ouoQxt3"
|
||||
},
|
||||
"source": [
|
||||
"## Use tf.data to create batches, map(do preprocessing) and shuffle the dataset"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "SQHmYSmk8b4b"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_dataset = tf.data.Dataset.list_files(PATH+'train/*.jpg')\n",
|
||||
"train_dataset = train_dataset.shuffle(BUFFER_SIZE)\n",
|
||||
"train_dataset = train_dataset.map(lambda x: load_image(x, True))\n",
|
||||
"train_dataset = train_dataset.batch(1)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "MS9J0yA58b4g"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"test_dataset = tf.data.Dataset.list_files(PATH+'test/*.jpg')\n",
|
||||
"test_dataset = test_dataset.map(lambda x: load_image(x, False))\n",
|
||||
"test_dataset = test_dataset.batch(1)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "THY-sZMiQ4UV"
|
||||
},
|
||||
"source": [
|
||||
"## Write the generator and discriminator models\n",
|
||||
"\n",
|
||||
"* **Generator** \n",
|
||||
" * The architecture of generator is a modified U-Net.\n",
|
||||
" * Each block in the encoder is (Conv -\u003e Batchnorm -\u003e Leaky ReLU)\n",
|
||||
" * Each block in the decoder is (Transposed Conv -\u003e Batchnorm -\u003e Dropout(applied to the first 3 blocks) -\u003e ReLU)\n",
|
||||
" * There are skip connections between the encoder and decoder (as in U-Net).\n",
|
||||
" \n",
|
||||
"* **Discriminator**\n",
|
||||
" * The Discriminator is a PatchGAN.\n",
|
||||
" * Each block in the discriminator is (Conv -\u003e BatchNorm -\u003e Leaky ReLU)\n",
|
||||
" * The shape of the output after the last layer is (batch_size, 30, 30, 1)\n",
|
||||
" * Each 30x30 patch of the output classifies a 70x70 portion of the input image (such an architecture is called a PatchGAN).\n",
|
||||
" * Discriminator receives 2 inputs.\n",
|
||||
" * Input image and the target image, which it should classify as real.\n",
|
||||
" * Input image and the generated image (output of generator), which it should classify as fake. \n",
|
||||
" * We concatenate these 2 inputs together in the code (`tf.concat([inp, tar], axis=-1)`)\n",
|
||||
"\n",
|
||||
"* Shape of the input travelling through the generator and the discriminator is in the comments in the code.\n",
|
||||
"\n",
|
||||
"To learn more about the architecture and the hyperparameters you can refer the [paper](https://arxiv.org/abs/1611.07004).\n",
|
||||
" "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "tqqvWxlw8b4l"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"OUTPUT_CHANNELS = 3"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "lFPI4Nu-8b4q"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class Downsample(tf.keras.Model):\n",
|
||||
" \n",
|
||||
" def __init__(self, filters, size, apply_batchnorm=True):\n",
|
||||
" super(Downsample, self).__init__()\n",
|
||||
" self.apply_batchnorm = apply_batchnorm\n",
|
||||
" initializer = tf.random_normal_initializer(0., 0.02)\n",
|
||||
"\n",
|
||||
" self.conv1 = tf.keras.layers.Conv2D(filters, \n",
|
||||
" (size, size), \n",
|
||||
" strides=2, \n",
|
||||
" padding='same',\n",
|
||||
" kernel_initializer=initializer,\n",
|
||||
" use_bias=False)\n",
|
||||
" if self.apply_batchnorm:\n",
|
||||
" self.batchnorm = tf.keras.layers.BatchNormalization()\n",
|
||||
" \n",
|
||||
" def call(self, x, training):\n",
|
||||
" x = self.conv1(x)\n",
|
||||
" if self.apply_batchnorm:\n",
|
||||
" x = self.batchnorm(x, training=training)\n",
|
||||
" x = tf.nn.leaky_relu(x)\n",
|
||||
" return x \n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class Upsample(tf.keras.Model):\n",
|
||||
" \n",
|
||||
" def __init__(self, filters, size, apply_dropout=False):\n",
|
||||
" super(Upsample, self).__init__()\n",
|
||||
" self.apply_dropout = apply_dropout\n",
|
||||
" initializer = tf.random_normal_initializer(0., 0.02)\n",
|
||||
"\n",
|
||||
" self.up_conv = tf.keras.layers.Conv2DTranspose(filters, \n",
|
||||
" (size, size), \n",
|
||||
" strides=2, \n",
|
||||
" padding='same',\n",
|
||||
" kernel_initializer=initializer,\n",
|
||||
" use_bias=False)\n",
|
||||
" self.batchnorm = tf.keras.layers.BatchNormalization()\n",
|
||||
" if self.apply_dropout:\n",
|
||||
" self.dropout = tf.keras.layers.Dropout(0.5)\n",
|
||||
"\n",
|
||||
" def call(self, x1, x2, training):\n",
|
||||
" x = self.up_conv(x1)\n",
|
||||
" x = self.batchnorm(x, training=training)\n",
|
||||
" if self.apply_dropout:\n",
|
||||
" x = self.dropout(x, training=training)\n",
|
||||
" x = tf.nn.relu(x)\n",
|
||||
" x = tf.concat([x, x2], axis=-1)\n",
|
||||
" return x\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class Generator(tf.keras.Model):\n",
|
||||
" \n",
|
||||
" def __init__(self):\n",
|
||||
" super(Generator, self).__init__()\n",
|
||||
" initializer = tf.random_normal_initializer(0., 0.02)\n",
|
||||
" \n",
|
||||
" self.down1 = Downsample(64, 4, apply_batchnorm=False)\n",
|
||||
" self.down2 = Downsample(128, 4)\n",
|
||||
" self.down3 = Downsample(256, 4)\n",
|
||||
" self.down4 = Downsample(512, 4)\n",
|
||||
" self.down5 = Downsample(512, 4)\n",
|
||||
" self.down6 = Downsample(512, 4)\n",
|
||||
" self.down7 = Downsample(512, 4)\n",
|
||||
" self.down8 = Downsample(512, 4)\n",
|
||||
"\n",
|
||||
" self.up1 = Upsample(512, 4, apply_dropout=True)\n",
|
||||
" self.up2 = Upsample(512, 4, apply_dropout=True)\n",
|
||||
" self.up3 = Upsample(512, 4, apply_dropout=True)\n",
|
||||
" self.up4 = Upsample(512, 4)\n",
|
||||
" self.up5 = Upsample(256, 4)\n",
|
||||
" self.up6 = Upsample(128, 4)\n",
|
||||
" self.up7 = Upsample(64, 4)\n",
|
||||
"\n",
|
||||
" self.last = tf.keras.layers.Conv2DTranspose(OUTPUT_CHANNELS, \n",
|
||||
" (4, 4), \n",
|
||||
" strides=2, \n",
|
||||
" padding='same',\n",
|
||||
" kernel_initializer=initializer)\n",
|
||||
" \n",
|
||||
" @tf.contrib.eager.defun\n",
|
||||
" def call(self, x, training):\n",
|
||||
" # x shape == (bs, 256, 256, 3) \n",
|
||||
" x1 = self.down1(x, training=training) # (bs, 128, 128, 64)\n",
|
||||
" x2 = self.down2(x1, training=training) # (bs, 64, 64, 128)\n",
|
||||
" x3 = self.down3(x2, training=training) # (bs, 32, 32, 256)\n",
|
||||
" x4 = self.down4(x3, training=training) # (bs, 16, 16, 512)\n",
|
||||
" x5 = self.down5(x4, training=training) # (bs, 8, 8, 512)\n",
|
||||
" x6 = self.down6(x5, training=training) # (bs, 4, 4, 512)\n",
|
||||
" x7 = self.down7(x6, training=training) # (bs, 2, 2, 512)\n",
|
||||
" x8 = self.down8(x7, training=training) # (bs, 1, 1, 512)\n",
|
||||
"\n",
|
||||
" x9 = self.up1(x8, x7, training=training) # (bs, 2, 2, 1024)\n",
|
||||
" x10 = self.up2(x9, x6, training=training) # (bs, 4, 4, 1024)\n",
|
||||
" x11 = self.up3(x10, x5, training=training) # (bs, 8, 8, 1024)\n",
|
||||
" x12 = self.up4(x11, x4, training=training) # (bs, 16, 16, 1024)\n",
|
||||
" x13 = self.up5(x12, x3, training=training) # (bs, 32, 32, 512)\n",
|
||||
" x14 = self.up6(x13, x2, training=training) # (bs, 64, 64, 256)\n",
|
||||
" x15 = self.up7(x14, x1, training=training) # (bs, 128, 128, 128)\n",
|
||||
"\n",
|
||||
" x16 = self.last(x15) # (bs, 256, 256, 3)\n",
|
||||
" x16 = tf.nn.tanh(x16)\n",
|
||||
"\n",
|
||||
" return x16"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "ll6aNeQx8b4v"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class DiscDownsample(tf.keras.Model):\n",
|
||||
" \n",
|
||||
" def __init__(self, filters, size, apply_batchnorm=True):\n",
|
||||
" super(DiscDownsample, self).__init__()\n",
|
||||
" self.apply_batchnorm = apply_batchnorm\n",
|
||||
" initializer = tf.random_normal_initializer(0., 0.02)\n",
|
||||
"\n",
|
||||
" self.conv1 = tf.keras.layers.Conv2D(filters, \n",
|
||||
" (size, size), \n",
|
||||
" strides=2, \n",
|
||||
" padding='same',\n",
|
||||
" kernel_initializer=initializer,\n",
|
||||
" use_bias=False)\n",
|
||||
" if self.apply_batchnorm:\n",
|
||||
" self.batchnorm = tf.keras.layers.BatchNormalization()\n",
|
||||
" \n",
|
||||
" def call(self, x, training):\n",
|
||||
" x = self.conv1(x)\n",
|
||||
" if self.apply_batchnorm:\n",
|
||||
" x = self.batchnorm(x, training=training)\n",
|
||||
" x = tf.nn.leaky_relu(x)\n",
|
||||
" return x \n",
|
||||
"\n",
|
||||
"class Discriminator(tf.keras.Model):\n",
|
||||
" \n",
|
||||
" def __init__(self):\n",
|
||||
" super(Discriminator, self).__init__()\n",
|
||||
" initializer = tf.random_normal_initializer(0., 0.02)\n",
|
||||
" \n",
|
||||
" self.down1 = DiscDownsample(64, 4, False)\n",
|
||||
" self.down2 = DiscDownsample(128, 4)\n",
|
||||
" self.down3 = DiscDownsample(256, 4)\n",
|
||||
" \n",
|
||||
" # we are zero padding here with 1 because we need our shape to \n",
|
||||
" # go from (batch_size, 32, 32, 256) to (batch_size, 31, 31, 512)\n",
|
||||
" self.zero_pad1 = tf.keras.layers.ZeroPadding2D()\n",
|
||||
" self.conv = tf.keras.layers.Conv2D(512, \n",
|
||||
" (4, 4), \n",
|
||||
" strides=1, \n",
|
||||
" kernel_initializer=initializer, \n",
|
||||
" use_bias=False)\n",
|
||||
" self.batchnorm1 = tf.keras.layers.BatchNormalization()\n",
|
||||
" \n",
|
||||
" # shape change from (batch_size, 31, 31, 512) to (batch_size, 30, 30, 1)\n",
|
||||
" self.zero_pad2 = tf.keras.layers.ZeroPadding2D()\n",
|
||||
" self.last = tf.keras.layers.Conv2D(1, \n",
|
||||
" (4, 4), \n",
|
||||
" strides=1,\n",
|
||||
" kernel_initializer=initializer)\n",
|
||||
" \n",
|
||||
" @tf.contrib.eager.defun\n",
|
||||
" def call(self, inp, tar, training):\n",
|
||||
" # concatenating the input and the target\n",
|
||||
" x = tf.concat([inp, tar], axis=-1) # (bs, 256, 256, channels*2)\n",
|
||||
" x = self.down1(x, training=training) # (bs, 128, 128, 64)\n",
|
||||
" x = self.down2(x, training=training) # (bs, 64, 64, 128)\n",
|
||||
" x = self.down3(x, training=training) # (bs, 32, 32, 256)\n",
|
||||
"\n",
|
||||
" x = self.zero_pad1(x) # (bs, 34, 34, 256)\n",
|
||||
" x = self.conv(x) # (bs, 31, 31, 512)\n",
|
||||
" x = self.batchnorm1(x, training=training)\n",
|
||||
" x = tf.nn.leaky_relu(x)\n",
|
||||
" \n",
|
||||
" x = self.zero_pad2(x) # (bs, 33, 33, 512)\n",
|
||||
" # don't add a sigmoid activation here since\n",
|
||||
" # the loss function expects raw logits.\n",
|
||||
" x = self.last(x) # (bs, 30, 30, 1)\n",
|
||||
"\n",
|
||||
" return x"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "gDkA05NE6QMs"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# The call function of Generator and Discriminator have been decorated\n",
|
||||
"# with tf.contrib.eager.defun()\n",
|
||||
"# We get a performance speedup if defun is used (~25 seconds per epoch)\n",
|
||||
"generator = Generator()\n",
|
||||
"discriminator = Discriminator()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "0FMYgY_mPfTi"
|
||||
},
|
||||
"source": [
|
||||
"## Define the loss functions and the optimizer\n",
|
||||
"\n",
|
||||
"* **Discriminator loss**\n",
|
||||
" * The discriminator loss function takes 2 inputs; **real images, generated images**\n",
|
||||
" * real_loss is a sigmoid cross entropy loss of the **real images** and an **array of ones(since these are the real images)**\n",
|
||||
" * generated_loss is a sigmoid cross entropy loss of the **generated images** and an **array of zeros(since these are the fake images)**\n",
|
||||
" * Then the total_loss is the sum of real_loss and the generated_loss\n",
|
||||
" \n",
|
||||
"* **Generator loss**\n",
|
||||
" * It is a sigmoid cross entropy loss of the generated images and an **array of ones**.\n",
|
||||
" * The [paper](https://arxiv.org/abs/1611.07004) also includes L1 loss which is MAE (mean absolute error) between the generated image and the target image.\n",
|
||||
" * This allows the generated image to become structurally similar to the target image.\n",
|
||||
" * The formula to calculate the total generator loss = gan_loss + LAMBDA * l1_loss, where LAMBDA = 100. This value was decided by the authors of the [paper](https://arxiv.org/abs/1611.07004)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "cyhxTuvJyIHV"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"LAMBDA = 100"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "wkMNfBWlT-PV"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def discriminator_loss(disc_real_output, disc_generated_output):\n",
|
||||
" real_loss = tf.losses.sigmoid_cross_entropy(multi_class_labels = tf.ones_like(disc_real_output), \n",
|
||||
" logits = disc_real_output)\n",
|
||||
" generated_loss = tf.losses.sigmoid_cross_entropy(multi_class_labels = tf.zeros_like(disc_generated_output), \n",
|
||||
" logits = disc_generated_output)\n",
|
||||
"\n",
|
||||
" total_disc_loss = real_loss + generated_loss\n",
|
||||
"\n",
|
||||
" return total_disc_loss"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "90BIcCKcDMxz"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def generator_loss(disc_generated_output, gen_output, target):\n",
|
||||
" gan_loss = tf.losses.sigmoid_cross_entropy(multi_class_labels = tf.ones_like(disc_generated_output),\n",
|
||||
" logits = disc_generated_output) \n",
|
||||
" # mean absolute error\n",
|
||||
" l1_loss = tf.reduce_mean(tf.abs(target - gen_output))\n",
|
||||
"\n",
|
||||
" total_gen_loss = gan_loss + (LAMBDA * l1_loss)\n",
|
||||
"\n",
|
||||
" return total_gen_loss"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "iWCn_PVdEJZ7"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"generator_optimizer = tf.train.AdamOptimizer(2e-4, beta1=0.5)\n",
|
||||
"discriminator_optimizer = tf.train.AdamOptimizer(2e-4, beta1=0.5)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "aKUZnDiqQrAh"
|
||||
},
|
||||
"source": [
|
||||
"## Checkpoints (Object-based saving)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "WJnftd5sQsv6"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"checkpoint_dir = './training_checkpoints'\n",
|
||||
"checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt\")\n",
|
||||
"checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,\n",
|
||||
" discriminator_optimizer=discriminator_optimizer,\n",
|
||||
" generator=generator,\n",
|
||||
" discriminator=discriminator)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "Rw1fkAczTQYh"
|
||||
},
|
||||
"source": [
|
||||
"## Training\n",
|
||||
"\n",
|
||||
"* We start by iterating over the dataset\n",
|
||||
"* The generator gets the input image and we get a generated output.\n",
|
||||
"* The discriminator receives the input_image and the generated image as the first input. The second input is the input_image and the target_image.\n",
|
||||
"* Next, we calculate the generator and the discriminator loss.\n",
|
||||
"* Then, we calculate the gradients of loss with respect to both the generator and the discriminator variables(inputs) and apply those to the optimizer.\n",
|
||||
"\n",
|
||||
"## Generate Images\n",
|
||||
"\n",
|
||||
"* After training, its time to generate some images!\n",
|
||||
"* We pass images from the test dataset to the generator.\n",
|
||||
"* The generator will then translate the input image into the output we expect.\n",
|
||||
"* Last step is to plot the predictions and **voila!**"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "NS2GWywBbAWo"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"EPOCHS = 200"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "RmdVsmvhPxyy"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def generate_images(model, test_input, tar):\n",
|
||||
" # the training=True is intentional here since\n",
|
||||
" # we want the batch statistics while running the model\n",
|
||||
" # on the test dataset. If we use training=False, we will get \n",
|
||||
" # the accumulated statistics learned from the training dataset\n",
|
||||
" # (which we don't want)\n",
|
||||
" prediction = model(test_input, training=True)\n",
|
||||
" plt.figure(figsize=(15,15))\n",
|
||||
"\n",
|
||||
" display_list = [test_input[0], tar[0], prediction[0]]\n",
|
||||
" title = ['Input Image', 'Ground Truth', 'Predicted Image']\n",
|
||||
"\n",
|
||||
" for i in range(3):\n",
|
||||
" plt.subplot(1, 3, i+1)\n",
|
||||
" plt.title(title[i])\n",
|
||||
" # getting the pixel values between [0, 1] to plot it.\n",
|
||||
" plt.imshow(display_list[i] * 0.5 + 0.5)\n",
|
||||
" plt.axis('off')\n",
|
||||
" plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "2M7LmLtGEMQJ"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def train(dataset, epochs): \n",
|
||||
" for epoch in range(epochs):\n",
|
||||
" start = time.time()\n",
|
||||
"\n",
|
||||
" for input_image, target in dataset:\n",
|
||||
"\n",
|
||||
" with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:\n",
|
||||
" gen_output = generator(input_image, training=True)\n",
|
||||
"\n",
|
||||
" disc_real_output = discriminator(input_image, target, training=True)\n",
|
||||
" disc_generated_output = discriminator(input_image, gen_output, training=True)\n",
|
||||
"\n",
|
||||
" gen_loss = generator_loss(disc_generated_output, gen_output, target)\n",
|
||||
" disc_loss = discriminator_loss(disc_real_output, disc_generated_output)\n",
|
||||
"\n",
|
||||
" generator_gradients = gen_tape.gradient(gen_loss, \n",
|
||||
" generator.variables)\n",
|
||||
" discriminator_gradients = disc_tape.gradient(disc_loss, \n",
|
||||
" discriminator.variables)\n",
|
||||
"\n",
|
||||
" generator_optimizer.apply_gradients(zip(generator_gradients, \n",
|
||||
" generator.variables))\n",
|
||||
" discriminator_optimizer.apply_gradients(zip(discriminator_gradients, \n",
|
||||
" discriminator.variables))\n",
|
||||
"\n",
|
||||
" if epoch % 1 == 0:\n",
|
||||
" clear_output(wait=True)\n",
|
||||
" for inp, tar in test_dataset.take(1):\n",
|
||||
" generate_images(generator, inp, tar)\n",
|
||||
" \n",
|
||||
" # saving (checkpoint) the model every 20 epochs\n",
|
||||
" if (epoch + 1) % 20 == 0:\n",
|
||||
" checkpoint.save(file_prefix = checkpoint_prefix)\n",
|
||||
"\n",
|
||||
" print ('Time taken for epoch {} is {} sec\\n'.format(epoch + 1,\n",
|
||||
" time.time()-start))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "a1zZmKmvOH85"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train(train_dataset, EPOCHS)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "kz80bY3aQ1VZ"
|
||||
},
|
||||
"source": [
|
||||
"## Restore the latest checkpoint and test"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "4t4x69adQ5xb"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# restoring the latest checkpoint in checkpoint_dir\n",
|
||||
"checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "1RGysMU_BZhx"
|
||||
},
|
||||
"source": [
|
||||
"## Testing on the entire test dataset"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "KUgSnmy2nqSP"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Run the trained model on the entire test dataset\n",
|
||||
"for inp, tar in test_dataset:\n",
|
||||
" generate_images(generator, inp, tar)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "3AJXOByaZVOf"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
""
|
||||
"This notebook has been moved to [https://github.com/tensorflow/docs/blob/master/site/en/r2/tutorials/generative/pix2pix.ipynb](https://github.com/tensorflow/docs/blob/master/site/en/r2/tutorials/generative/pix2pix.ipynb)"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
@ -30,7 +30,10 @@ _ffmpeg_so = loader.load_op_library(
|
||||
resource_loader.get_path_to_datafile('ffmpeg.so'))
|
||||
|
||||
|
||||
@deprecated('2018-09-04', 'This will be deleted and should not be used.')
|
||||
@deprecated('2018-09-04',
|
||||
'tf.contrib.ffmpeg will be removed in 2.0, the support for video '
|
||||
'and audio will continue to be provided in tensorflow-io: '
|
||||
'https://github.com/tensorflow/io')
|
||||
def decode_audio(contents, file_format=None, samples_per_second=None,
|
||||
channel_count=None, stream=None):
|
||||
"""Create an op that decodes the contents of an audio file.
|
||||
@ -71,7 +74,10 @@ def decode_audio(contents, file_format=None, samples_per_second=None,
|
||||
ops.NotDifferentiable('DecodeAudio')
|
||||
|
||||
|
||||
@deprecated('2018-09-04', 'This will be deleted and should not be used.')
|
||||
@deprecated('2018-09-04',
|
||||
'tf.contrib.ffmpeg will be removed in 2.0, the support for video '
|
||||
'and audio will continue to be provided in tensorflow-io: '
|
||||
'https://github.com/tensorflow/io')
|
||||
def encode_audio(audio, file_format=None, samples_per_second=None):
|
||||
"""Creates an op that encodes an audio file using sampled audio from a tensor.
|
||||
|
||||
@ -98,13 +104,15 @@ def encode_audio(audio, file_format=None, samples_per_second=None):
|
||||
ops.NotDifferentiable('EncodeAudio')
|
||||
|
||||
|
||||
@deprecated('2018-09-04', 'This will be deleted and should not be used.')
|
||||
@deprecated('2018-09-04',
|
||||
'tf.contrib.ffmpeg will be removed in 2.0, the support for video '
|
||||
'and audio will continue to be provided in tensorflow-io: '
|
||||
'https://github.com/tensorflow/io')
|
||||
def decode_video(contents):
|
||||
"""Create an op that decodes the contents of a video file.
|
||||
|
||||
Args:
|
||||
contents: The binary contents of the video file to decode. This is a
|
||||
scalar.
|
||||
contents: The binary contents of the video file to decode. This is a scalar.
|
||||
|
||||
Returns:
|
||||
A rank-4 `Tensor` that has `[frames, height, width, 3]` RGB as output.
|
||||
|
@ -37,6 +37,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/graph/control_flow.h"
|
||||
#include "tensorflow/core/graph/gradients.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/graph/node_builder.h"
|
||||
#include "tensorflow/core/graph/optimizer_cse.h"
|
||||
#include "tensorflow/core/lib/core/threadpool.h"
|
||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||
@ -1734,6 +1735,9 @@ Status InlineFunctionBody(const FunctionLibraryDefinition& flib_def, Graph* g,
|
||||
// added above or on all control return nodes (controlled by
|
||||
// `options.output_control_src` value). And nodes previously depend on
|
||||
// "callee" is changed to depend on "output_control_node".
|
||||
//
|
||||
// If `keep_node_fetchable` is `true` we always add an output control node, to
|
||||
// guarantee that executing a fetchable node will execute all side-effects.
|
||||
std::vector<Node*> outputs(caller->num_outputs());
|
||||
for (std::size_t i = 0; i < fbody->ret_nodes.size(); ++i) {
|
||||
Node* ret = node_map[fbody->ret_nodes[i]->id()];
|
||||
@ -1754,29 +1758,53 @@ Status InlineFunctionBody(const FunctionLibraryDefinition& flib_def, Graph* g,
|
||||
}
|
||||
g->RemoveNode(ret); // 'ret' is disconnected.
|
||||
}
|
||||
|
||||
Node* output_control_node = nullptr;
|
||||
bool has_control_outputs = absl::c_any_of(
|
||||
caller->out_edges(), [](const Edge* e) { return e->IsControlEdge(); });
|
||||
|
||||
if (has_control_outputs || options.keep_caller_fetchable) {
|
||||
output_control_node = no_op("output_control_node");
|
||||
if (options.output_control_src == OutputControlSrc::kDataOutputs) {
|
||||
for (Node* n : outputs) {
|
||||
g->AddControlEdge(n, output_control_node);
|
||||
}
|
||||
} else {
|
||||
for (Node* fbody_node : fbody->control_ret_nodes) {
|
||||
Node* n = node_map[fbody_node->id()];
|
||||
g->AddControlEdge(n, output_control_node);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (const Edge* e : caller->out_edges()) {
|
||||
if (e->IsControlEdge()) {
|
||||
if (output_control_node == nullptr) {
|
||||
output_control_node = no_op("output_control_node");
|
||||
if (options.output_control_src ==
|
||||
InlineFunctionBodyOptions::OutputControlSource::kDataOutputs) {
|
||||
for (Node* n : outputs) {
|
||||
g->AddControlEdge(n, output_control_node);
|
||||
}
|
||||
} else {
|
||||
for (Node* fbody_node : fbody->control_ret_nodes) {
|
||||
Node* n = node_map[fbody_node->id()];
|
||||
g->AddControlEdge(n, output_control_node);
|
||||
}
|
||||
}
|
||||
}
|
||||
g->AddControlEdge(output_control_node, e->dst());
|
||||
} else {
|
||||
g->AddEdge(outputs[e->src_output()], 0, e->dst(), e->dst_input());
|
||||
}
|
||||
}
|
||||
g->RemoveNode(caller); // 'caller' is replaced with inlined nodes.
|
||||
|
||||
// ------------------------------------------------------------------------ //
|
||||
// Add an IdentityN node in-place of caller node to keep `caller` fetchable.
|
||||
|
||||
if (options.keep_caller_fetchable) {
|
||||
std::vector<NodeBuilder::NodeOut> output_tensors;
|
||||
absl::c_transform(outputs, std::back_inserter(output_tensors),
|
||||
[](Node* n) { return NodeBuilder::NodeOut(n, 0); });
|
||||
|
||||
Node* fetchable_node;
|
||||
TF_CHECK_OK(NodeBuilder(caller->name(), "IdentityN")
|
||||
.Device(caller->requested_device())
|
||||
.Input(output_tensors)
|
||||
.ControlInput(output_control_node)
|
||||
.Finalize(g, &fetchable_node));
|
||||
}
|
||||
|
||||
// ------------------------------------------------------------------------ //
|
||||
// 'caller' is replaced with inlined function body nodes and maybe IdentityN
|
||||
// to keep it fetchable.
|
||||
g->RemoveNode(caller);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -173,6 +173,16 @@ struct InlineFunctionBodyOptions {
|
||||
// If 'true' function inlining will override explicitly specified devices
|
||||
// inside function body with the caller node device.
|
||||
bool override_device = false;
|
||||
// If 'true' function inlining will add an IdentityN node to the graph with
|
||||
// the same name as the caller node. It will have a control edge from inlined
|
||||
// 'output_control_node' and data edges from function output nodes. IdentityN
|
||||
// node will be placed on the same device as the caller node.
|
||||
// This is mostly for compatibility with Tensorflow v1 and sessions. When we
|
||||
// prepare a graph for execution in GraphExecutionState::MakeForBaseGraph we
|
||||
// don't know what nodes will be fetched, so we can't safely remove any of
|
||||
// them. When graph executed as a function it has 'Retval' nodes for each
|
||||
// fetched tensor, and we can safely inline function calls.
|
||||
bool keep_caller_fetchable = false;
|
||||
// For compatibility with Tensorflow v1 by default we will use data outputs.
|
||||
// Control returns were added to Tensorflow v2 with automatic control
|
||||
// dependencies tracking in Eager mode.
|
||||
|
@ -54,7 +54,9 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
using FDH = FunctionDefHelper;
|
||||
using FDH = ::tensorflow::FunctionDefHelper;
|
||||
|
||||
using OutputControlSrc = InlineFunctionBodyOptions::OutputControlSource;
|
||||
|
||||
Status GetOpSig(const string& op, const OpDef** sig) {
|
||||
return OpRegistry::Global()->LookUpOpDef(op, sig);
|
||||
@ -888,17 +890,13 @@ TEST_F(FunctionLibraryRuntimeTest, ExpandInlineFunctionsWithInputControlEdges) {
|
||||
TEST_F(FunctionLibraryRuntimeTest,
|
||||
ExpandInlineFunctionsWithOutputControlEdges) {
|
||||
using test::function::NDef;
|
||||
using FDH = FunctionDefHelper;
|
||||
using OutputControlSrc = InlineFunctionBodyOptions::OutputControlSource;
|
||||
|
||||
// `add` node is not required to compute regular output `o`, but it must
|
||||
// execute because it is in `control_ret`.
|
||||
const FunctionDef func =
|
||||
FDH::Create("FunctionWithControlOutputs", {"i: float"}, {"o: float"}, {},
|
||||
{
|
||||
{{"add"}, "Add", {"i", "i"}, {{"T", DT_FLOAT}}},
|
||||
{{"ret"}, "Mul", {"i", "i"}, {{"T", DT_FLOAT}}},
|
||||
},
|
||||
FDH::Create("AddAndMul", {"i: float"}, {"o: float"}, {},
|
||||
{{{"add"}, "Add", {"i", "i"}, {{"T", DT_FLOAT}}},
|
||||
{{"ret"}, "Mul", {"i", "i"}, {{"T", DT_FLOAT}}}},
|
||||
/*ret_def=*/{{"o", "ret:z:0"}},
|
||||
/*control_ret_def=*/{{"must_execute", "add"}});
|
||||
|
||||
@ -907,16 +905,16 @@ TEST_F(FunctionLibraryRuntimeTest,
|
||||
// Construct a graph for the function call:
|
||||
//
|
||||
// a = Arg[dtype=DT_FLOAT]
|
||||
// b = FunctionWithControlOutputs(a)
|
||||
// b = AddAndMul(a)
|
||||
// c = NoOp(^b)
|
||||
// ret = RetVal(b, ^c)
|
||||
const auto init_graph = [this](std::unique_ptr<Graph>* g) -> void {
|
||||
g->reset(new Graph(OpRegistry::Global()));
|
||||
*g = absl::make_unique<Graph>(OpRegistry::Global());
|
||||
|
||||
Scope s = Scope::NewRootScope();
|
||||
TF_ASSERT_OK(s.graph()->AddFunctionLibrary(fdef_lib_));
|
||||
auto a = ops::_Arg(s.WithOpName("a"), DT_FLOAT, 0);
|
||||
auto b = test::function::Call(&s, "b", "FunctionWithControlOutputs", {a});
|
||||
auto b = test::function::Call(&s, "b", "AddAndMul", {a});
|
||||
auto c = ops::NoOp(s.WithOpName("c"));
|
||||
auto ret = ops::_Retval(s.WithOpName("ret"), b, 0);
|
||||
s.graph()->AddControlEdge(b.node(), c.operation.node());
|
||||
@ -978,6 +976,55 @@ TEST_F(FunctionLibraryRuntimeTest,
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(FunctionLibraryRuntimeTest, ExpandInlineFunctionsAndKeepThemFetchable) {
|
||||
using test::function::NDef;
|
||||
|
||||
const FunctionDef func =
|
||||
FDH::Create("AddAndMul", {"i: float"}, {"o: float"}, {},
|
||||
{{{"add"}, "Add", {"i", "i"}, {{"T", DT_FLOAT}}},
|
||||
{{"ret"}, "Mul", {"i", "i"}, {{"T", DT_FLOAT}}}},
|
||||
/*ret_def=*/{{"o", "ret:z:0"}},
|
||||
/*control_ret_def=*/{{"must_execute", "add"}});
|
||||
Init({func});
|
||||
|
||||
// Construct a graph:
|
||||
// a = Arg[dtype=DT_FLOAT]
|
||||
// b = FunctionWithControlOutputs(a)
|
||||
std::unique_ptr<Graph> g = absl::make_unique<Graph>(OpRegistry::Global());
|
||||
|
||||
Scope s = Scope::NewRootScope();
|
||||
TF_ASSERT_OK(s.graph()->AddFunctionLibrary(fdef_lib_));
|
||||
auto a = ops::_Arg(s.WithOpName("a"), DT_FLOAT, 0);
|
||||
auto b = test::function::Call(&s, "b", "AddAndMul", {a});
|
||||
TF_ASSERT_OK(s.ToGraph(g.get()));
|
||||
|
||||
ExpandInlineFunctionsOptions opts;
|
||||
opts.native_options.keep_caller_fetchable = true;
|
||||
opts.native_options.output_control_src = OutputControlSrc::kControlOutputs;
|
||||
|
||||
const string input_node = "Func/b/input/_0";
|
||||
const string output_node = "Func/b/output/_1";
|
||||
const string output_control_node = "Func/b/output_control_node/_2";
|
||||
|
||||
ExpandInlineFunctions(flr0_, g.get(), opts);
|
||||
{
|
||||
GraphDef expected = test::function::GDef(
|
||||
{NDef("a", "_Arg", {}, {{"T", DT_FLOAT}, {"index", 0}}),
|
||||
NDef(input_node, "Identity", {"a"}, {{"T", DT_FLOAT}}),
|
||||
NDef("b/add", "Add", {input_node, input_node}, {{"T", DT_FLOAT}}),
|
||||
NDef("b/ret", "Mul", {input_node, input_node}, {{"T", DT_FLOAT}}),
|
||||
NDef(output_node, "Identity", {"b/ret"}, {{"T", DT_FLOAT}}),
|
||||
NDef(output_control_node, "NoOp", {"^b/add"}, {}),
|
||||
NDef("b", "IdentityN", {output_node, "^" + output_control_node},
|
||||
{{"T", DataTypeSlice{DT_FLOAT}}})},
|
||||
{func});
|
||||
|
||||
GraphDef actual;
|
||||
g->ToGraphDef(&actual);
|
||||
TF_EXPECT_GRAPH_EQ(expected, actual);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(FunctionLibraryRuntimeTest, PruneBody) {
|
||||
auto T = DT_INT32;
|
||||
FunctionDef stateful_func = FDH::Define(
|
||||
|
@ -147,11 +147,21 @@ string CondBuilder::NewName(const string& infix) {
|
||||
Status CondBuilder::AddInput(Node* src, int src_output) {
|
||||
Node* input;
|
||||
NodeDebugInfo debug_info(*src);
|
||||
// Colocate the Switch node with the `src` node.
|
||||
//
|
||||
// This is to avoid unnecessary Host<->Device copies between src and the
|
||||
// Switch node. This aligns with the implementation of legacy tf.cond in
|
||||
// control_flow_ops.py. The legacy impl colocates the Switch with the
|
||||
// input tensor which resets the device stack and forces the Switch to have
|
||||
// the same device as the input node (if set) and sets the colocation _class
|
||||
// attr. It also ignores the existing colocation constraints on the input node
|
||||
// using colocate_with(ignore_existing=True).
|
||||
TF_RETURN_IF_ERROR(NodeBuilder(NewName(src->name()), "Switch",
|
||||
graph_->op_registry(), &debug_info)
|
||||
.Input(src, src_output)
|
||||
.Input(pred_)
|
||||
.Device(if_op_->requested_device())
|
||||
.Device(src->requested_device())
|
||||
.Attr("_class", {src->name()})
|
||||
.Finalize(graph_, &input));
|
||||
then_call_builder_.Input(input, kThenBranch);
|
||||
else_call_builder_.Input(input, kElseBranch);
|
||||
|
@ -102,7 +102,7 @@ Status Placer::Run() {
|
||||
}
|
||||
|
||||
if (VLOG_IS_ON(3)) {
|
||||
DumpGraphToFile("placer_input", *graph_, nullptr, "/tmp");
|
||||
DumpGraphToFile("placer_input", *graph_, nullptr);
|
||||
for (const Node* node : graph_->op_nodes()) {
|
||||
VLOG(3) << " " << node->name() << ": requested: '"
|
||||
<< node->requested_device() << "' assigned: '"
|
||||
@ -226,7 +226,7 @@ Status Placer::Run() {
|
||||
}
|
||||
|
||||
if (VLOG_IS_ON(3)) {
|
||||
DumpGraphToFile("placer_output", *graph_, nullptr, "/tmp");
|
||||
DumpGraphToFile("placer_output", *graph_, nullptr);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -686,7 +686,7 @@ class SymbolicShapeRefiner {
|
||||
|
||||
// Perform inference on function body.
|
||||
GraphProperties gp(grappler_function_item);
|
||||
TF_RETURN_IF_ERROR(gp.InferStatically(true));
|
||||
TF_RETURN_IF_ERROR(gp.InferStatically(true, aggressive_shape_inference_));
|
||||
|
||||
// Add return nodes for output shapes.
|
||||
int output = 0;
|
||||
|
@ -1204,6 +1204,61 @@ TEST_F(GraphPropertiesTest, FunctionReturnTensorValue) {
|
||||
properties.GetInputProperties("MyFunc")[0].value());
|
||||
}
|
||||
|
||||
TEST_F(GraphPropertiesTest, ArithmeticFunctionReturnTensorValue) {
|
||||
FunctionDefLibrary library;
|
||||
// Function that adds two input values.
|
||||
*library.add_function() = FunctionDefHelper::Create(
|
||||
"MyFunc", // Name
|
||||
{"x: int32", "y: int32"}, // Inputs
|
||||
{"out: int32"}, // Outputs
|
||||
{}, // Attrs
|
||||
{{{"a"}, "Add", {"x", "y"}, {{"T", DataType::DT_INT32}}}}, // Nodes
|
||||
{{"out", "a:z:0"}}); // Returns
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
TF_CHECK_OK(s.graph()->AddFunctionLibrary(library));
|
||||
|
||||
Output shape = ops::Const(s.WithOpName("shape"), {5, 7}, {2});
|
||||
auto _shape = tensorflow::ops::AsNodeOut(s, shape);
|
||||
auto builder =
|
||||
tensorflow::NodeBuilder("MyFunc", "MyFunc", s.graph()->op_registry());
|
||||
tensorflow::Node* func_op;
|
||||
TF_CHECK_OK(
|
||||
builder.Input(_shape).Input(_shape).Finalize(s.graph(), &func_op));
|
||||
|
||||
GrapplerItem item;
|
||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||
{
|
||||
GraphProperties properties(item);
|
||||
// Without aggressive_shape_inference, the internal function does not
|
||||
// evaluate output value.
|
||||
TF_CHECK_OK(properties.InferStatically(
|
||||
/*assume_valid_feeds=*/true,
|
||||
/*aggressive_shape_inference=*/false));
|
||||
const auto out_props = properties.GetOutputProperties("MyFunc");
|
||||
const OpInfo::TensorProperties out_prop0 = out_props[0];
|
||||
EXPECT_EQ("int32: [2]", PropToString(out_prop0));
|
||||
EXPECT_FALSE(out_prop0.has_value());
|
||||
}
|
||||
|
||||
{
|
||||
GraphProperties properties(item);
|
||||
// With aggressive_shape_inference, output value is evaluated.
|
||||
TF_CHECK_OK(properties.InferStatically(
|
||||
/*assume_valid_feeds=*/true,
|
||||
/*aggressive_shape_inference=*/true));
|
||||
const auto out_props = properties.GetOutputProperties("MyFunc");
|
||||
const OpInfo::TensorProperties out_prop0 = out_props[0];
|
||||
EXPECT_EQ("int32: [2]", PropToString(out_prop0));
|
||||
EXPECT_TRUE(out_prop0.has_value());
|
||||
|
||||
ExpectTensorValues({10, 14}, out_prop0.value());
|
||||
ExpectTensorValues({5, 7},
|
||||
properties.GetInputProperties("MyFunc")[0].value());
|
||||
ExpectTensorValues({5, 7},
|
||||
properties.GetInputProperties("MyFunc")[1].value());
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(GraphPropertiesTest, FunctionWithScalarInput) {
|
||||
// Create graph with a function that takes a scalar value so that we use
|
||||
// Placeholder with scalar as for input to the function shape inference.
|
||||
|
@ -39,6 +39,7 @@ constexpr char kDepthwiseConv2dNativeBackpropInput[] =
|
||||
"DepthwiseConv2dNativeBackpropInput";
|
||||
constexpr char kMatMul[] = "MatMul";
|
||||
constexpr char kSparseMatMul[] = "SparseMatMul";
|
||||
constexpr char kSparseTensorDenseMatMul[] = "SparseTensorDenseMatMul";
|
||||
constexpr char kPlaceholder[] = "Placeholder";
|
||||
constexpr char kIdentity[] = "Identity";
|
||||
constexpr char kIdentityN[] = "IdentityN";
|
||||
@ -243,6 +244,8 @@ OpLevelCostEstimator::OpLevelCostEstimator() {
|
||||
wrap(&OpLevelCostEstimator::PredictConv2DBackpropInput)},
|
||||
{kMatMul, wrap(&OpLevelCostEstimator::PredictMatMul)},
|
||||
{kSparseMatMul, wrap(&OpLevelCostEstimator::PredictMatMul)},
|
||||
{kSparseTensorDenseMatMul,
|
||||
wrap(&OpLevelCostEstimator::PredictSparseTensorDenseMatMul)},
|
||||
{kBatchMatMul, wrap(&OpLevelCostEstimator::PredictBatchMatMul)},
|
||||
{kQuantizedMatMul, wrap(&OpLevelCostEstimator::PredictMatMul)},
|
||||
{kQuantizedMatMulV2, wrap(&OpLevelCostEstimator::PredictMatMul)},
|
||||
@ -1228,6 +1231,49 @@ Costs OpLevelCostEstimator::PredictMatMul(const OpContext& op_context) const {
|
||||
return costs;
|
||||
}
|
||||
|
||||
Costs OpLevelCostEstimator::PredictSparseTensorDenseMatMul(
|
||||
const OpContext& op_context) const {
|
||||
const auto& op_info = op_context.op_info;
|
||||
bool found_unknown_shapes = false;
|
||||
// input[0]: indices in sparse matrix a
|
||||
// input[1]: values in sparse matrix a
|
||||
// input[2]: shape of matrix a
|
||||
// input[3]: matrix b
|
||||
// See
|
||||
// https://github.com/tensorflow/tensorflow/blob/9a43dfeac5/tensorflow/core/ops/sparse_ops.cc#L85
|
||||
int64 num_elems_in_a =
|
||||
CalculateTensorElementCount(op_info.inputs(1), &found_unknown_shapes);
|
||||
auto b_matrix = op_info.inputs(3);
|
||||
auto b_matrix_shape =
|
||||
MaybeGetMinimumShape(b_matrix.shape(), 2, &found_unknown_shapes);
|
||||
int64 n_dim = b_matrix_shape.dim(1).size();
|
||||
|
||||
// Each element in A is multiplied and added with an element from each column
|
||||
// in b.
|
||||
const int64 op_count = kOpsPerMac * num_elems_in_a * n_dim;
|
||||
|
||||
int64 a_indices_input_size =
|
||||
CalculateTensorSize(op_info.inputs(0), &found_unknown_shapes);
|
||||
int64 a_values_input_size =
|
||||
CalculateTensorSize(op_info.inputs(1), &found_unknown_shapes);
|
||||
int64 a_shape_input_size =
|
||||
CalculateTensorSize(op_info.inputs(2), &found_unknown_shapes);
|
||||
int64 b_input_size =
|
||||
num_elems_in_a * n_dim * DataTypeSize(BaseType(b_matrix.dtype()));
|
||||
double input_size = a_indices_input_size + a_values_input_size +
|
||||
a_shape_input_size + b_input_size;
|
||||
|
||||
double output_size = CalculateOutputSize(op_info, &found_unknown_shapes);
|
||||
|
||||
auto costs =
|
||||
PredictOpCountBasedCost(op_count, input_size, output_size, op_info);
|
||||
costs.inaccurate = found_unknown_shapes;
|
||||
costs.num_ops_with_unknown_shapes = found_unknown_shapes;
|
||||
costs.max_memory = output_size;
|
||||
|
||||
return costs;
|
||||
}
|
||||
|
||||
Costs OpLevelCostEstimator::PredictNoOp(const OpContext& op_context) const {
|
||||
const auto& op_info = op_context.op_info;
|
||||
VLOG(1) << "Op:" << op_info.op() << " Execution Time 0 (ns)";
|
||||
|
@ -132,6 +132,7 @@ class OpLevelCostEstimator {
|
||||
Costs PredictConv2DBackpropFilter(const OpContext& op_context) const;
|
||||
Costs PredictFusedConv2DBiasActivation(const OpContext& op_context) const;
|
||||
Costs PredictMatMul(const OpContext& op_context) const;
|
||||
Costs PredictSparseTensorDenseMatMul(const OpContext& op_context) const;
|
||||
Costs PredictNoOp(const OpContext& op_context) const;
|
||||
Costs PredictIdentity(const OpContext& op_context) const;
|
||||
Costs PredictVariable(const OpContext& op_context) const;
|
||||
|
@ -119,6 +119,22 @@ OpContext DescribeBatchMatMul(const std::vector<int>& dims_a,
|
||||
return op_context;
|
||||
}
|
||||
|
||||
// Returns an OpInfo for a SparseTensorDenseMatMul
|
||||
OpContext DescribeSparseTensorDenseMatMul(const int nnz_a,
|
||||
const std::vector<int>& dims_b,
|
||||
const std::vector<int>& dims_out) {
|
||||
OpContext op_context;
|
||||
SetCpuDevice(&op_context.op_info);
|
||||
op_context.op_info.set_op("SparseTensorDenseMatMul");
|
||||
|
||||
DescribeArbitraryRankInput({nnz_a, 2}, DT_INT64, &op_context.op_info);
|
||||
DescribeArbitraryRankInput({nnz_a}, DT_FLOAT, &op_context.op_info);
|
||||
DescribeArbitraryRankInput({2}, DT_INT64, &op_context.op_info);
|
||||
DescribeArbitraryRankInput(dims_b, DT_FLOAT, &op_context.op_info);
|
||||
DescribeArbitraryRankOutput(dims_out, DT_FLOAT, &op_context.op_info);
|
||||
return op_context;
|
||||
}
|
||||
|
||||
// Wrangles the minimum number of proto fields to set up a 1D Tensor for cost
|
||||
// estimation purposes.
|
||||
void DescribeTensor1D(int dim0, OpInfo::TensorProperties* tensor) {
|
||||
@ -854,6 +870,58 @@ TEST_F(OpLevelCostEstimatorTest, BatchMatMul) {
|
||||
EXPECT_NE(matmul_inaccurate, batch_matmul_inaccurate);
|
||||
}
|
||||
|
||||
TEST_F(OpLevelCostEstimatorTest, SparseTensorDenseMatMul) {
|
||||
// Unknown shape cases
|
||||
{
|
||||
auto cost =
|
||||
PredictCosts(DescribeSparseTensorDenseMatMul(-1, {1, 1}, {1, 1}));
|
||||
EXPECT_EQ(1, cost.num_ops_total);
|
||||
EXPECT_TRUE(cost.inaccurate);
|
||||
EXPECT_EQ(1, cost.num_ops_with_unknown_shapes);
|
||||
}
|
||||
{
|
||||
auto cost =
|
||||
PredictCosts(DescribeSparseTensorDenseMatMul(1, {-1, 1}, {1, 1}));
|
||||
EXPECT_EQ(1, cost.num_ops_total);
|
||||
EXPECT_TRUE(cost.inaccurate);
|
||||
EXPECT_EQ(1, cost.num_ops_with_unknown_shapes);
|
||||
}
|
||||
{
|
||||
auto cost =
|
||||
PredictCosts(DescribeSparseTensorDenseMatMul(1, {1, -1}, {1, -1}));
|
||||
EXPECT_EQ(1, cost.num_ops_total);
|
||||
EXPECT_TRUE(cost.inaccurate);
|
||||
EXPECT_EQ(1, cost.num_ops_with_unknown_shapes);
|
||||
}
|
||||
{
|
||||
auto cost =
|
||||
PredictCosts(DescribeSparseTensorDenseMatMul(1, {1, 1}, {-1, 1}));
|
||||
EXPECT_EQ(1, cost.num_ops_total);
|
||||
EXPECT_TRUE(cost.inaccurate);
|
||||
EXPECT_EQ(1, cost.num_ops_with_unknown_shapes);
|
||||
}
|
||||
// Known shape cases
|
||||
{
|
||||
auto cost = PredictCosts(
|
||||
DescribeSparseTensorDenseMatMul(10, {1000, 100}, {50, 100}));
|
||||
EXPECT_EQ(1, cost.num_ops_total);
|
||||
EXPECT_FALSE(cost.inaccurate);
|
||||
EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
|
||||
EXPECT_EQ(Costs::Duration(200), cost.compute_time);
|
||||
EXPECT_EQ(Costs::Duration(2422), cost.memory_time);
|
||||
}
|
||||
{
|
||||
// Same cost as above case because cost does not depend on k_dim
|
||||
auto cost = PredictCosts(
|
||||
DescribeSparseTensorDenseMatMul(10, {100000, 100}, {50, 100}));
|
||||
EXPECT_EQ(1, cost.num_ops_total);
|
||||
EXPECT_FALSE(cost.inaccurate);
|
||||
EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
|
||||
EXPECT_EQ(Costs::Duration(200), cost.compute_time);
|
||||
EXPECT_EQ(Costs::Duration(2422), cost.memory_time);
|
||||
}
|
||||
}
|
||||
|
||||
void ExpectTensorShape(const std::vector<int64>& expected,
|
||||
const TensorShapeProto& tensor_shape_proto) {
|
||||
TensorShape tensor_shape_expected(expected);
|
||||
|
@ -38,7 +38,7 @@ namespace grappler {
|
||||
// b = Placeholder(..)
|
||||
// c = AddN([a, a, b])
|
||||
//
|
||||
// GraphView edges: [a:0 -> c:0, a:0 -> c:1, b:0 -> c:3]
|
||||
// GraphView edges: [a:0 -> c:0, a:0 -> c:1, b:0 -> c:2]
|
||||
// GraphTopologyView edges: [a -> c, b -> c]
|
||||
//
|
||||
// GraphView is used for exploring single node fanins and fanouts, and
|
||||
|
@ -499,12 +499,12 @@ inline int64 ConvolveScratchSize() {
|
||||
// convolution on the stream) and parameters, by running all possible
|
||||
// algorithms and measuring execution time.
|
||||
// TODO(ezhulenev): Move it to conv_ops_gpu.h and share with conv_ops.cc.
|
||||
template <typename T, typename ConvLaunch>
|
||||
Status FindBestConvolveAlgorithm(
|
||||
const FusedConvParameters& params, const ConvLaunch launch,
|
||||
OpKernelContext* context, se::Stream* stream,
|
||||
se::dnn::AlgorithmConfig* algorithm_config,
|
||||
std::vector<tensorflow::AutotuneResult>* results) {
|
||||
template <typename T, typename ConvLaunch, typename LogFunc>
|
||||
Status FindBestConvolveAlgorithm(const FusedConvParameters& params,
|
||||
const ConvLaunch launch,
|
||||
OpKernelContext* context, se::Stream* stream,
|
||||
const LogFunc& log,
|
||||
se::dnn::AlgorithmConfig* algorithm_config) {
|
||||
// Check if we already have an algorithm selected for the given parameters.
|
||||
if (AutoTuneFusedConv::GetInstance()->Find(params, algorithm_config)) {
|
||||
return Status::OK();
|
||||
@ -521,6 +521,7 @@ Status FindBestConvolveAlgorithm(
|
||||
"see if a warning log message was printed above.");
|
||||
}
|
||||
|
||||
std::vector<tensorflow::AutotuneResult> results;
|
||||
for (auto profile_algorithm : algorithms) {
|
||||
DnnScratchAllocator scratch_allocator(ConvolveScratchSize(), context);
|
||||
se::dnn::ProfileResult profile_result;
|
||||
@ -530,8 +531,8 @@ Status FindBestConvolveAlgorithm(
|
||||
&profile_result);
|
||||
|
||||
if (cudnn_launch_status && profile_result.is_valid()) {
|
||||
results->emplace_back();
|
||||
auto& result = results->back();
|
||||
results.emplace_back();
|
||||
auto& result = results.back();
|
||||
result.mutable_conv()->set_algorithm(profile_algorithm.algo_id());
|
||||
result.mutable_conv()->set_tensor_ops_enabled(
|
||||
profile_algorithm.tensor_ops_enabled());
|
||||
@ -542,7 +543,9 @@ Status FindBestConvolveAlgorithm(
|
||||
absl::Milliseconds(profile_result.elapsed_time_in_ms()));
|
||||
}
|
||||
}
|
||||
TF_RETURN_IF_ERROR(BestCudnnConvAlgorithm(*results, algorithm_config));
|
||||
// Only log on an AutoTuneFusedConv cache miss.
|
||||
log(results);
|
||||
TF_RETURN_IF_ERROR(BestCudnnConvAlgorithm(results, algorithm_config));
|
||||
AutoTuneFusedConv::GetInstance()->Insert(params, *algorithm_config);
|
||||
return Status::OK();
|
||||
}
|
||||
@ -789,13 +792,14 @@ struct LaunchFusedConv2DOp<GPUDevice, T> {
|
||||
|
||||
se::dnn::AlgorithmConfig algorithm_config;
|
||||
if (cudnn_use_autotune) {
|
||||
std::vector<tensorflow::AutotuneResult> results;
|
||||
auto status =
|
||||
FindBestConvolveAlgorithm<T>(conv_parameters, launch, context, stream,
|
||||
&algorithm_config, &results);
|
||||
LogFusedConvAutotuneResults(context->op_kernel().def(), input,
|
||||
transformed_filter, transformed_output, bias,
|
||||
nullptr, stream->parent(), results);
|
||||
auto status = FindBestConvolveAlgorithm<T>(
|
||||
conv_parameters, launch, context, stream,
|
||||
[&](absl::Span<const tensorflow::AutotuneResult> results) {
|
||||
LogFusedConvAutotuneResults(
|
||||
context->op_kernel().def(), input, transformed_filter,
|
||||
transformed_output, bias, nullptr, stream->parent(), results);
|
||||
},
|
||||
&algorithm_config);
|
||||
OP_REQUIRES_OK(context, status);
|
||||
}
|
||||
|
||||
|
@ -24,7 +24,8 @@ REGISTER5(BinaryOp, CPU, "TruncateDiv", functor::safe_div, uint8, uint16, int16,
|
||||
int32, int64);
|
||||
REGISTER6(BinaryOp, CPU, "RealDiv", functor::div, float, Eigen::half, double,
|
||||
bfloat16, complex64, complex128);
|
||||
REGISTER2(BinaryOp, CPU, "DivNoNan", functor::div_no_nan, float, double);
|
||||
REGISTER5(BinaryOp, CPU, "DivNoNan", functor::div_no_nan, Eigen::half, float,
|
||||
double, complex64, complex128);
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
REGISTER9(BinaryOp, GPU, "Div", functor::div, float, Eigen::half, double, uint8,
|
||||
|
@ -19,7 +19,7 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
namespace functor {
|
||||
DEFINE_UNARY2(tan, float, double);
|
||||
DEFINE_UNARY3(tan, Eigen::half, float, double);
|
||||
} // namespace functor
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -19,7 +19,8 @@ namespace tensorflow {
|
||||
|
||||
REGISTER6(BinaryOp, CPU, "Mul", functor::mul, float, Eigen::half, double, uint8,
|
||||
int32, bfloat16);
|
||||
REGISTER2(BinaryOp, CPU, "MulNoNan", functor::mul_no_nan, float, double);
|
||||
REGISTER5(BinaryOp, CPU, "MulNoNan", functor::mul_no_nan, Eigen::half, float,
|
||||
double, complex64, complex128);
|
||||
|
||||
#if defined(__ANDROID_TYPES_SLIM__)
|
||||
// We only register the first type when we have multi-argument calls in the
|
||||
|
@ -16,11 +16,11 @@ limitations under the License.
|
||||
#include "tensorflow/core/kernels/cwise_ops_common.h"
|
||||
|
||||
namespace tensorflow {
|
||||
REGISTER4(UnaryOp, CPU, "Tan", functor::tan, float, double, complex64,
|
||||
complex128);
|
||||
REGISTER5(UnaryOp, CPU, "Tan", functor::tan, Eigen::half, float, double,
|
||||
complex64, complex128);
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
REGISTER2(UnaryOp, GPU, "Tan", functor::tan, float, double);
|
||||
REGISTER3(UnaryOp, GPU, "Tan", functor::tan, Eigen::half, float, double);
|
||||
#endif
|
||||
|
||||
#ifdef TENSORFLOW_USE_SYCL
|
||||
|
@ -540,12 +540,32 @@ tf_kernel_library(
|
||||
name = "tensor_dataset_op",
|
||||
srcs = ["tensor_dataset_op.cc"],
|
||||
deps = [
|
||||
":dataset_utils",
|
||||
"//tensorflow/core:dataset_ops_op_lib",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:graph",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "tensor_dataset_op_test",
|
||||
size = "small",
|
||||
srcs = ["tensor_dataset_op_test.cc"],
|
||||
deps = [
|
||||
":dataset_test_base",
|
||||
":dataset_utils",
|
||||
":iterator_ops",
|
||||
":tensor_dataset_op",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:dataset_ops_op_lib",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "tensor_slice_dataset_op",
|
||||
srcs = ["tensor_slice_dataset_op.cc"],
|
||||
|
@ -446,18 +446,26 @@ class ChooseFastestBranchDatasetOp : public UnaryDatasetOpKernel {
|
||||
return s;
|
||||
}
|
||||
|
||||
// Select the fastest input to use based on the histograms of timings
|
||||
// of the completed iterations. The input with the best 90th percentile
|
||||
// iteration time is selected.
|
||||
void SelectFastestInputIndex() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
fastest_index_ = 0;
|
||||
|
||||
VLOG(2) << "90.0 percentile iteration time:";
|
||||
double best_percentile = histograms_[0].Percentile(kPercentile);
|
||||
VLOG(2) << "Branch 0: " << best_percentile;
|
||||
for (size_t i = 1, num_inputs = histograms_.size(); i < num_inputs;
|
||||
++i) {
|
||||
double percentile = histograms_[i].Percentile(kPercentile);
|
||||
VLOG(2) << "Branch " << i << ": " << percentile;
|
||||
if (percentile <= best_percentile) {
|
||||
best_percentile = percentile;
|
||||
fastest_index_ = i;
|
||||
}
|
||||
}
|
||||
VLOG(1) << "Selecting index " << fastest_index_
|
||||
<< " as the fastest index.";
|
||||
}
|
||||
|
||||
Status MakeCurrentIterator(IteratorContext* ctx, int64 branch_index,
|
||||
|
@ -325,15 +325,20 @@ class ChooseFastestDatasetOp : public DatasetOpKernel {
|
||||
void SelectFastestInputIndex() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
fastest_index_ = 0;
|
||||
|
||||
VLOG(2) << "90.0 percentile iteration time:";
|
||||
double best_percentile = histograms_[0].Percentile(kPercentile);
|
||||
VLOG(2) << "Branch 0: " << best_percentile;
|
||||
for (size_t i = 1, num_inputs = histograms_.size(); i < num_inputs;
|
||||
++i) {
|
||||
double percentile = histograms_[i].Percentile(kPercentile);
|
||||
VLOG(2) << "Branch " << i << ": " << percentile;
|
||||
if (percentile <= best_percentile) {
|
||||
best_percentile = percentile;
|
||||
fastest_index_ = i;
|
||||
}
|
||||
}
|
||||
VLOG(1) << "Selecting index " << fastest_index_
|
||||
<< " as the fastest index.";
|
||||
|
||||
fastest_input_impl_ = std::move(input_impls_[fastest_index_]);
|
||||
input_impls_.clear(); // Delete the unused iterators.
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/kernels/data/dataset_utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
@ -26,15 +27,20 @@ namespace {
|
||||
|
||||
class TensorDatasetOp : public DatasetOpKernel {
|
||||
public:
|
||||
explicit TensorDatasetOp(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) {}
|
||||
explicit TensorDatasetOp(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("Toutput_types", &output_types_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
|
||||
}
|
||||
|
||||
void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
|
||||
OpInputList inputs;
|
||||
OP_REQUIRES_OK(ctx, ctx->input_list("components", &inputs));
|
||||
// TODO(mrry): Validate that the shapes of the "components" tensors match
|
||||
// the "shapes" attr.;
|
||||
std::vector<Tensor> components(inputs.begin(), inputs.end());
|
||||
*output = new Dataset(ctx, std::move(components));
|
||||
OP_REQUIRES_OK(ctx,
|
||||
VerifyTypesMatch((*output)->output_dtypes(), output_types_));
|
||||
OP_REQUIRES_OK(ctx, VerifyShapesCompatible((*output)->output_shapes(),
|
||||
output_shapes_));
|
||||
}
|
||||
|
||||
private:
|
||||
@ -137,6 +143,9 @@ class TensorDatasetOp : public DatasetOpKernel {
|
||||
DataTypeVector dtypes_;
|
||||
std::vector<PartialTensorShape> shapes_;
|
||||
};
|
||||
|
||||
DataTypeVector output_types_;
|
||||
std::vector<PartialTensorShape> output_shapes_;
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("TensorDataset").Device(DEVICE_CPU),
|
||||
|
535
tensorflow/core/kernels/data/tensor_dataset_op_test.cc
Normal file
535
tensorflow/core/kernels/data/tensor_dataset_op_test.cc
Normal file
@ -0,0 +1,535 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/dataset.h"
|
||||
#include "tensorflow/core/framework/fake_input.h"
|
||||
#include "tensorflow/core/framework/function_testlib.h"
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
||||
#include "tensorflow/core/framework/variant.h"
|
||||
#include "tensorflow/core/framework/variant_tensor_data.h"
|
||||
#include "tensorflow/core/kernels/data/dataset_test_base.h"
|
||||
#include "tensorflow/core/kernels/data/dataset_utils.h"
|
||||
#include "tensorflow/core/kernels/data/iterator_ops.h"
|
||||
#include "tensorflow/core/kernels/ops_testutil.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/util/ptr_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
namespace {
|
||||
|
||||
constexpr char kNodeName[] = "tensor_dataset";
|
||||
constexpr char kOpName[] = "TensorDataset";
|
||||
|
||||
class TensorDatasetOpTest : public DatasetOpsTestBase {
|
||||
protected:
|
||||
// Creates a new TensorDataset op kernel.
|
||||
Status CreateTensorDatasetKernel(
|
||||
DataTypeVector dtypes, std::vector<PartialTensorShape> shapes,
|
||||
std::unique_ptr<OpKernel> *tensor_dataset_kernel) {
|
||||
std::vector<string> components;
|
||||
components.reserve(dtypes.size());
|
||||
for (int i = 0; i < dtypes.size(); i++) {
|
||||
components.emplace_back(strings::StrCat("component_", i));
|
||||
}
|
||||
node_def_ = test::function::NDef(
|
||||
kNodeName, kOpName, components,
|
||||
{{"Toutput_types", dtypes}, {"output_shapes", shapes}});
|
||||
TF_RETURN_IF_ERROR(CreateOpKernel(node_def_, tensor_dataset_kernel));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Creates a new TensorDataset op kernel context.
|
||||
Status CreateTensorDatasetContext(OpKernel *const tensor_dataset_kernel,
|
||||
gtl::InlinedVector<TensorValue, 4> *inputs,
|
||||
std::unique_ptr<OpKernelContext> *context) {
|
||||
TF_RETURN_IF_ERROR(CheckOpKernelInput(*tensor_dataset_kernel, *inputs));
|
||||
TF_RETURN_IF_ERROR(
|
||||
CreateOpKernelContext(tensor_dataset_kernel, inputs, context));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
NodeDef node_def_;
|
||||
};
|
||||
|
||||
struct TestCase {
|
||||
std::vector<Tensor> components;
|
||||
std::vector<Tensor> expected_outputs;
|
||||
DataTypeVector expected_output_dtypes;
|
||||
std::vector<PartialTensorShape> expected_output_shapes;
|
||||
int64 expected_cardinality;
|
||||
std::vector<int> breakpoints;
|
||||
};
|
||||
|
||||
// Test case 1: test a dataset that represents a single tuple of plain tensors.
|
||||
TestCase PlainTensorsTestCase() {
|
||||
return {
|
||||
/*components*/
|
||||
{DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
|
||||
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({1, 3}), {1, 2, 3}),
|
||||
DatasetOpsTestBase::CreateTensor<double>(TensorShape({}), {37.0}),
|
||||
DatasetOpsTestBase::CreateTensor<string>(TensorShape({1, 2}),
|
||||
{"a", "b"})},
|
||||
/*expected_outputs*/
|
||||
{DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
|
||||
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({1, 3}), {1, 2, 3}),
|
||||
DatasetOpsTestBase::CreateTensor<double>(TensorShape({}), {37.0}),
|
||||
DatasetOpsTestBase::CreateTensor<string>(TensorShape({1, 2}),
|
||||
{"a", "b"})},
|
||||
/*expected_output_dtypes*/
|
||||
{DT_INT64, DT_INT64, DT_DOUBLE, DT_STRING},
|
||||
/*expected_output_shapes*/
|
||||
{PartialTensorShape({}), PartialTensorShape({1, 3}),
|
||||
PartialTensorShape({}), PartialTensorShape({1, 2})},
|
||||
/*expected_cardinality*/ 1,
|
||||
/*breakpoints*/ {0, 1, 2}};
|
||||
}
|
||||
|
||||
// Test case 2: test a dataset that represents a tuple of nested tensors.
|
||||
TestCase NestedTensorsTestCase() {
|
||||
return {
|
||||
/*components*/
|
||||
{DatasetOpsTestBase::CreateTensor<Variant>(
|
||||
TensorShape({}), {DatasetOpsTestBase::CreateTensor<double>(
|
||||
TensorShape({2, 2}), {1.0, 2.0, 3.0, 4.0})}),
|
||||
DatasetOpsTestBase::CreateTensor<Variant>(
|
||||
TensorShape({}), {DatasetOpsTestBase::CreateTensor<string>(
|
||||
TensorShape({1, 2}), {"a", "b"})}),
|
||||
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({1, 3}), {1, 2, 3})},
|
||||
/*expected_outputs*/
|
||||
{DatasetOpsTestBase::CreateTensor<Variant>(
|
||||
TensorShape({}), {DatasetOpsTestBase::CreateTensor<double>(
|
||||
TensorShape({2, 2}), {1.0, 2.0, 3.0, 4.0})}),
|
||||
DatasetOpsTestBase::CreateTensor<Variant>(
|
||||
TensorShape({}), {DatasetOpsTestBase::CreateTensor<string>(
|
||||
TensorShape({1, 2}), {"a", "b"})}),
|
||||
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({1, 3}), {1, 2, 3})},
|
||||
/*expected_output_dtypes*/
|
||||
{DT_VARIANT, DT_VARIANT, DT_INT64},
|
||||
/*expected_output_shapes*/
|
||||
{PartialTensorShape({}), PartialTensorShape({}),
|
||||
PartialTensorShape({1, 3})},
|
||||
/*expected_cardinality*/ 1,
|
||||
/*breakpoints*/ {0, 1, 2}};
|
||||
}
|
||||
|
||||
class ParametrizedTensorDatasetOpTest
|
||||
: public TensorDatasetOpTest,
|
||||
public ::testing::WithParamInterface<TestCase> {};
|
||||
|
||||
TEST_P(ParametrizedTensorDatasetOpTest, GetNext) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||
|
||||
const TestCase &test_case = GetParam();
|
||||
std::vector<Tensor> components = test_case.components;
|
||||
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||
for (auto &component : components) {
|
||||
inputs.push_back(&component);
|
||||
}
|
||||
std::unique_ptr<OpKernel> tensor_dataset_kernel;
|
||||
TF_ASSERT_OK(CreateTensorDatasetKernel(test_case.expected_output_dtypes,
|
||||
test_case.expected_output_shapes,
|
||||
&tensor_dataset_kernel));
|
||||
std::unique_ptr<OpKernelContext> tensor_dataset_context;
|
||||
TF_ASSERT_OK(CreateTensorDatasetContext(tensor_dataset_kernel.get(), &inputs,
|
||||
&tensor_dataset_context));
|
||||
DatasetBase *tensor_dataset;
|
||||
TF_ASSERT_OK(CreateDataset(tensor_dataset_kernel.get(),
|
||||
tensor_dataset_context.get(), &tensor_dataset));
|
||||
core::ScopedUnref scoped_unref(tensor_dataset);
|
||||
|
||||
std::unique_ptr<IteratorContext> iterator_context;
|
||||
TF_ASSERT_OK(
|
||||
CreateIteratorContext(tensor_dataset_context.get(), &iterator_context));
|
||||
std::unique_ptr<IteratorBase> iterator;
|
||||
TF_ASSERT_OK(tensor_dataset->MakeIterator(iterator_context.get(), "Iterator",
|
||||
&iterator));
|
||||
bool end_of_sequence = false;
|
||||
std::vector<Tensor> out_tensors;
|
||||
while (!end_of_sequence) {
|
||||
TF_EXPECT_OK(iterator->GetNext(iterator_context.get(), &out_tensors,
|
||||
&end_of_sequence));
|
||||
}
|
||||
EXPECT_EQ(out_tensors.size(), test_case.expected_outputs.size());
|
||||
for (int i = 0; i < out_tensors.size(); ++i) {
|
||||
if (out_tensors[i].dtype() == DT_VARIANT) {
|
||||
// Currently `ExpectEqual()` does not support the variant tensor
|
||||
// yet, so we manually cast the variant to numeric/string tensor.
|
||||
const Tensor *output = out_tensors[i].scalar<Variant>()().get<Tensor>();
|
||||
const Tensor *expected_output =
|
||||
test_case.expected_outputs[i].scalar<Variant>()().get<Tensor>();
|
||||
TF_EXPECT_OK(ExpectEqual(*output, *expected_output));
|
||||
} else {
|
||||
TF_EXPECT_OK(ExpectEqual(out_tensors[i], test_case.expected_outputs[i]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TensorDatasetOpTest, DatasetTypeString) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||
|
||||
const TestCase &test_case = PlainTensorsTestCase();
|
||||
std::vector<Tensor> components = test_case.components;
|
||||
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||
for (auto &component : components) {
|
||||
inputs.push_back(&component);
|
||||
}
|
||||
std::unique_ptr<OpKernel> tensor_dataset_kernel;
|
||||
TF_ASSERT_OK(CreateTensorDatasetKernel(test_case.expected_output_dtypes,
|
||||
test_case.expected_output_shapes,
|
||||
&tensor_dataset_kernel));
|
||||
std::unique_ptr<OpKernelContext> tensor_dataset_context;
|
||||
TF_ASSERT_OK(CreateTensorDatasetContext(tensor_dataset_kernel.get(), &inputs,
|
||||
&tensor_dataset_context));
|
||||
DatasetBase *tensor_dataset;
|
||||
TF_ASSERT_OK(CreateDataset(tensor_dataset_kernel.get(),
|
||||
tensor_dataset_context.get(), &tensor_dataset));
|
||||
core::ScopedUnref scoped_unref(tensor_dataset);
|
||||
|
||||
EXPECT_EQ(tensor_dataset->type_string(), kOpName);
|
||||
}
|
||||
|
||||
TEST_F(TensorDatasetOpTest, DatasetNodeName) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||
|
||||
const TestCase &test_case = PlainTensorsTestCase();
|
||||
std::vector<Tensor> components = test_case.components;
|
||||
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||
for (auto &component : components) {
|
||||
inputs.push_back(&component);
|
||||
}
|
||||
std::unique_ptr<OpKernel> tensor_dataset_kernel;
|
||||
TF_ASSERT_OK(CreateTensorDatasetKernel(test_case.expected_output_dtypes,
|
||||
test_case.expected_output_shapes,
|
||||
&tensor_dataset_kernel));
|
||||
std::unique_ptr<OpKernelContext> tensor_dataset_context;
|
||||
TF_ASSERT_OK(CreateTensorDatasetContext(tensor_dataset_kernel.get(), &inputs,
|
||||
&tensor_dataset_context));
|
||||
DatasetBase *tensor_dataset;
|
||||
TF_ASSERT_OK(CreateDataset(tensor_dataset_kernel.get(),
|
||||
tensor_dataset_context.get(), &tensor_dataset));
|
||||
core::ScopedUnref scoped_unref(tensor_dataset);
|
||||
|
||||
EXPECT_EQ(tensor_dataset->node_name(), kNodeName);
|
||||
}
|
||||
|
||||
TEST_F(TensorDatasetOpTest, DatasetOutputDtypes) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||
|
||||
const TestCase &test_case = PlainTensorsTestCase();
|
||||
std::vector<Tensor> components = test_case.components;
|
||||
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||
for (auto &component : components) {
|
||||
inputs.push_back(&component);
|
||||
}
|
||||
std::unique_ptr<OpKernel> tensor_dataset_kernel;
|
||||
TF_ASSERT_OK(CreateTensorDatasetKernel(test_case.expected_output_dtypes,
|
||||
test_case.expected_output_shapes,
|
||||
&tensor_dataset_kernel));
|
||||
std::unique_ptr<OpKernelContext> tensor_dataset_context;
|
||||
TF_ASSERT_OK(CreateTensorDatasetContext(tensor_dataset_kernel.get(), &inputs,
|
||||
&tensor_dataset_context));
|
||||
DatasetBase *tensor_dataset;
|
||||
TF_ASSERT_OK(CreateDataset(tensor_dataset_kernel.get(),
|
||||
tensor_dataset_context.get(), &tensor_dataset));
|
||||
core::ScopedUnref scoped_unref(tensor_dataset);
|
||||
|
||||
EXPECT_EQ(tensor_dataset->output_dtypes(), test_case.expected_output_dtypes);
|
||||
}
|
||||
|
||||
TEST_F(TensorDatasetOpTest, DatasetOutputShapes) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||
|
||||
const TestCase &test_case = PlainTensorsTestCase();
|
||||
std::vector<Tensor> components = test_case.components;
|
||||
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||
for (auto &component : components) {
|
||||
inputs.push_back(&component);
|
||||
}
|
||||
std::unique_ptr<OpKernel> tensor_dataset_kernel;
|
||||
TF_ASSERT_OK(CreateTensorDatasetKernel(test_case.expected_output_dtypes,
|
||||
test_case.expected_output_shapes,
|
||||
&tensor_dataset_kernel));
|
||||
std::unique_ptr<OpKernelContext> tensor_dataset_context;
|
||||
TF_ASSERT_OK(CreateTensorDatasetContext(tensor_dataset_kernel.get(), &inputs,
|
||||
&tensor_dataset_context));
|
||||
DatasetBase *tensor_dataset;
|
||||
TF_ASSERT_OK(CreateDataset(tensor_dataset_kernel.get(),
|
||||
tensor_dataset_context.get(), &tensor_dataset));
|
||||
core::ScopedUnref scoped_unref(tensor_dataset);
|
||||
|
||||
EXPECT_EQ(tensor_dataset->output_shapes().size(),
|
||||
test_case.expected_output_shapes.size());
|
||||
for (int i = 0; i < test_case.expected_output_shapes.size(); i++) {
|
||||
EXPECT_TRUE(test_case.expected_output_shapes[i].IsIdenticalTo(
|
||||
tensor_dataset->output_shapes()[i]));
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TensorDatasetOpTest, Cardinality) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||
|
||||
const TestCase &test_case = PlainTensorsTestCase();
|
||||
std::vector<Tensor> components = test_case.components;
|
||||
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||
for (auto &component : components) {
|
||||
inputs.push_back(&component);
|
||||
}
|
||||
std::unique_ptr<OpKernel> tensor_dataset_kernel;
|
||||
TF_ASSERT_OK(CreateTensorDatasetKernel(test_case.expected_output_dtypes,
|
||||
test_case.expected_output_shapes,
|
||||
&tensor_dataset_kernel));
|
||||
std::unique_ptr<OpKernelContext> tensor_dataset_context;
|
||||
TF_ASSERT_OK(CreateTensorDatasetContext(tensor_dataset_kernel.get(), &inputs,
|
||||
&tensor_dataset_context));
|
||||
DatasetBase *tensor_dataset;
|
||||
TF_ASSERT_OK(CreateDataset(tensor_dataset_kernel.get(),
|
||||
tensor_dataset_context.get(), &tensor_dataset));
|
||||
core::ScopedUnref scoped_unref(tensor_dataset);
|
||||
|
||||
EXPECT_EQ(tensor_dataset->Cardinality(), test_case.expected_cardinality);
|
||||
}
|
||||
|
||||
TEST_P(ParametrizedTensorDatasetOpTest, DatasetSave) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||
|
||||
const TestCase &test_case = GetParam();
|
||||
std::vector<Tensor> components = test_case.components;
|
||||
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||
for (auto &component : components) {
|
||||
inputs.push_back(&component);
|
||||
}
|
||||
std::unique_ptr<OpKernel> tensor_dataset_kernel;
|
||||
TF_ASSERT_OK(CreateTensorDatasetKernel(test_case.expected_output_dtypes,
|
||||
test_case.expected_output_shapes,
|
||||
&tensor_dataset_kernel));
|
||||
std::unique_ptr<OpKernelContext> tensor_dataset_context;
|
||||
TF_ASSERT_OK(CreateTensorDatasetContext(tensor_dataset_kernel.get(), &inputs,
|
||||
&tensor_dataset_context));
|
||||
DatasetBase *tensor_dataset;
|
||||
TF_ASSERT_OK(CreateDataset(tensor_dataset_kernel.get(),
|
||||
tensor_dataset_context.get(), &tensor_dataset));
|
||||
core::ScopedUnref scoped_unref(tensor_dataset);
|
||||
|
||||
std::unique_ptr<SerializationContext> serialization_context;
|
||||
TF_ASSERT_OK(CreateSerializationContext(&serialization_context));
|
||||
VariantTensorData data;
|
||||
VariantTensorDataWriter writer(&data);
|
||||
TF_ASSERT_OK(tensor_dataset->Save(serialization_context.get(), &writer));
|
||||
TF_ASSERT_OK(writer.Flush());
|
||||
}
|
||||
|
||||
TEST_P(ParametrizedTensorDatasetOpTest, IteratorOutputDtypes) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||
|
||||
const TestCase &test_case = GetParam();
|
||||
std::vector<Tensor> components = test_case.components;
|
||||
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||
for (auto &component : components) {
|
||||
inputs.push_back(&component);
|
||||
}
|
||||
std::unique_ptr<OpKernel> tensor_dataset_kernel;
|
||||
TF_ASSERT_OK(CreateTensorDatasetKernel(test_case.expected_output_dtypes,
|
||||
test_case.expected_output_shapes,
|
||||
&tensor_dataset_kernel));
|
||||
std::unique_ptr<OpKernelContext> tensor_dataset_context;
|
||||
TF_ASSERT_OK(CreateTensorDatasetContext(tensor_dataset_kernel.get(), &inputs,
|
||||
&tensor_dataset_context));
|
||||
DatasetBase *tensor_dataset;
|
||||
TF_ASSERT_OK(CreateDataset(tensor_dataset_kernel.get(),
|
||||
tensor_dataset_context.get(), &tensor_dataset));
|
||||
core::ScopedUnref scoped_unref(tensor_dataset);
|
||||
|
||||
std::unique_ptr<IteratorContext> iterator_context;
|
||||
TF_ASSERT_OK(
|
||||
CreateIteratorContext(tensor_dataset_context.get(), &iterator_context));
|
||||
std::unique_ptr<IteratorBase> iterator;
|
||||
TF_ASSERT_OK(tensor_dataset->MakeIterator(iterator_context.get(), "Iterator",
|
||||
&iterator));
|
||||
EXPECT_EQ(iterator->output_dtypes(), test_case.expected_output_dtypes);
|
||||
}
|
||||
|
||||
TEST_P(ParametrizedTensorDatasetOpTest, IteratorOutputShapes) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||
|
||||
const TestCase &test_case = GetParam();
|
||||
std::vector<Tensor> components = test_case.components;
|
||||
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||
for (auto &component : components) {
|
||||
inputs.push_back(&component);
|
||||
}
|
||||
std::unique_ptr<OpKernel> tensor_dataset_kernel;
|
||||
TF_ASSERT_OK(CreateTensorDatasetKernel(test_case.expected_output_dtypes,
|
||||
test_case.expected_output_shapes,
|
||||
&tensor_dataset_kernel));
|
||||
std::unique_ptr<OpKernelContext> tensor_dataset_context;
|
||||
TF_ASSERT_OK(CreateTensorDatasetContext(tensor_dataset_kernel.get(), &inputs,
|
||||
&tensor_dataset_context));
|
||||
DatasetBase *tensor_dataset;
|
||||
TF_ASSERT_OK(CreateDataset(tensor_dataset_kernel.get(),
|
||||
tensor_dataset_context.get(), &tensor_dataset));
|
||||
core::ScopedUnref scoped_unref(tensor_dataset);
|
||||
|
||||
std::unique_ptr<IteratorContext> iterator_context;
|
||||
TF_ASSERT_OK(
|
||||
CreateIteratorContext(tensor_dataset_context.get(), &iterator_context));
|
||||
std::unique_ptr<IteratorBase> iterator;
|
||||
TF_ASSERT_OK(tensor_dataset->MakeIterator(iterator_context.get(), "Iterator",
|
||||
&iterator));
|
||||
|
||||
EXPECT_EQ(iterator->output_shapes().size(),
|
||||
test_case.expected_output_shapes.size());
|
||||
for (int i = 0; i < test_case.expected_output_shapes.size(); ++i) {
|
||||
EXPECT_TRUE(test_case.expected_output_shapes[i].IsIdenticalTo(
|
||||
iterator->output_shapes()[i]));
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TensorDatasetOpTest, IteratorOutputPrefix) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||
|
||||
const TestCase &test_case = PlainTensorsTestCase();
|
||||
std::vector<Tensor> components = test_case.components;
|
||||
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||
for (auto &component : components) {
|
||||
inputs.push_back(&component);
|
||||
}
|
||||
std::unique_ptr<OpKernel> tensor_dataset_kernel;
|
||||
TF_ASSERT_OK(CreateTensorDatasetKernel(test_case.expected_output_dtypes,
|
||||
test_case.expected_output_shapes,
|
||||
&tensor_dataset_kernel));
|
||||
std::unique_ptr<OpKernelContext> tensor_dataset_context;
|
||||
TF_ASSERT_OK(CreateTensorDatasetContext(tensor_dataset_kernel.get(), &inputs,
|
||||
&tensor_dataset_context));
|
||||
DatasetBase *tensor_dataset;
|
||||
TF_ASSERT_OK(CreateDataset(tensor_dataset_kernel.get(),
|
||||
tensor_dataset_context.get(), &tensor_dataset));
|
||||
core::ScopedUnref scoped_unref(tensor_dataset);
|
||||
|
||||
std::unique_ptr<IteratorContext> iterator_context;
|
||||
TF_ASSERT_OK(
|
||||
CreateIteratorContext(tensor_dataset_context.get(), &iterator_context));
|
||||
std::unique_ptr<IteratorBase> iterator;
|
||||
TF_ASSERT_OK(tensor_dataset->MakeIterator(iterator_context.get(), "Iterator",
|
||||
&iterator));
|
||||
|
||||
EXPECT_EQ(iterator->prefix(), "Iterator::FromTensor");
|
||||
}
|
||||
|
||||
TEST_P(ParametrizedTensorDatasetOpTest, Roundtrip) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||
|
||||
const TestCase &test_case = GetParam();
|
||||
std::vector<Tensor> components = test_case.components;
|
||||
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||
for (auto &component : components) {
|
||||
inputs.push_back(&component);
|
||||
}
|
||||
std::unique_ptr<OpKernel> tensor_dataset_kernel;
|
||||
TF_ASSERT_OK(CreateTensorDatasetKernel(test_case.expected_output_dtypes,
|
||||
test_case.expected_output_shapes,
|
||||
&tensor_dataset_kernel));
|
||||
std::unique_ptr<OpKernelContext> tensor_dataset_context;
|
||||
TF_ASSERT_OK(CreateTensorDatasetContext(tensor_dataset_kernel.get(), &inputs,
|
||||
&tensor_dataset_context));
|
||||
DatasetBase *tensor_dataset;
|
||||
TF_ASSERT_OK(CreateDataset(tensor_dataset_kernel.get(),
|
||||
tensor_dataset_context.get(), &tensor_dataset));
|
||||
core::ScopedUnref scoped_unref(tensor_dataset);
|
||||
|
||||
std::unique_ptr<IteratorContext> iterator_ctx;
|
||||
TF_ASSERT_OK(
|
||||
CreateIteratorContext(tensor_dataset_context.get(), &iterator_ctx));
|
||||
std::unique_ptr<IteratorBase> iterator;
|
||||
TF_ASSERT_OK(
|
||||
tensor_dataset->MakeIterator(iterator_ctx.get(), "Iterator", &iterator));
|
||||
|
||||
std::unique_ptr<SerializationContext> serialization_ctx;
|
||||
TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx));
|
||||
bool end_of_sequence = false;
|
||||
std::vector<Tensor> out_tensors;
|
||||
int cur_iteration = 0;
|
||||
const std::vector<int> &breakpoints = test_case.breakpoints;
|
||||
for (int breakpoint : breakpoints) {
|
||||
VariantTensorData data;
|
||||
VariantTensorDataWriter writer(&data);
|
||||
TF_EXPECT_OK(iterator->Save(serialization_ctx.get(), &writer));
|
||||
TF_EXPECT_OK(writer.Flush());
|
||||
VariantTensorDataReader reader(&data);
|
||||
TF_EXPECT_OK(iterator->Restore(iterator_ctx.get(), &reader));
|
||||
|
||||
while (cur_iteration <= breakpoint) {
|
||||
TF_EXPECT_OK(iterator->GetNext(iterator_ctx.get(), &out_tensors,
|
||||
&end_of_sequence));
|
||||
if (!end_of_sequence) {
|
||||
EXPECT_EQ(out_tensors.size(), test_case.expected_outputs.size());
|
||||
for (int i = 0; i < out_tensors.size(); ++i) {
|
||||
if (out_tensors[i].dtype() == DT_VARIANT) {
|
||||
// Currently `ExpectEqual()` does not support the variant tensor
|
||||
// yet, so we manually cast the variant to numeric/string tensor.
|
||||
const Tensor *output =
|
||||
out_tensors[i].scalar<Variant>()().get<Tensor>();
|
||||
const Tensor *expected_output =
|
||||
test_case.expected_outputs[i].scalar<Variant>()().get<Tensor>();
|
||||
TF_EXPECT_OK(ExpectEqual(*output, *expected_output));
|
||||
} else {
|
||||
TF_EXPECT_OK(
|
||||
ExpectEqual(out_tensors[i], test_case.expected_outputs[i]));
|
||||
}
|
||||
}
|
||||
}
|
||||
cur_iteration++;
|
||||
}
|
||||
|
||||
if (breakpoint >= test_case.expected_cardinality) {
|
||||
EXPECT_TRUE(end_of_sequence);
|
||||
} else {
|
||||
EXPECT_FALSE(end_of_sequence);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(
|
||||
TensorDatasetOpTest, ParametrizedTensorDatasetOpTest,
|
||||
::testing::ValuesIn(std::vector<TestCase>({PlainTensorsTestCase(),
|
||||
NestedTensorsTestCase()})));
|
||||
|
||||
} // namespace
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
@ -82,10 +82,8 @@ Status ComputeSpansCore(OpKernelContext* context, const Kernel& kernel,
|
||||
const float col_f = x + 0.5f;
|
||||
const float sample_f = col_f * inv_scale + inv_translate;
|
||||
|
||||
// Don't sample when the sampling *kernel* is completely outside the
|
||||
// source image.
|
||||
if (sample_f < 0 - kernel.Radius() * kernel_scale ||
|
||||
sample_f > input_size + kernel.Radius() * kernel_scale) {
|
||||
// Don't sample when the sampling location is outside the source image.
|
||||
if (sample_f < 0 || sample_f > input_size) {
|
||||
// Add an empty span.
|
||||
starts_vec(x) = 0;
|
||||
continue;
|
||||
@ -169,11 +167,15 @@ Status ComputeGradSpansCore(OpKernelContext* context, const Spans& spans,
|
||||
auto grad_weights_vec = grad_spans->weights.vec<float>();
|
||||
grad_weights_vec.setZero();
|
||||
for (int input_index = 0; input_index < forward_input_size; ++input_index) {
|
||||
const int start_span = grad_components[input_index].front().index;
|
||||
grad_starts_vec(input_index) = start_span;
|
||||
for (const GradComponent& gc : grad_components[input_index]) {
|
||||
grad_weights_vec(input_index * grad_spans->span_size + gc.index -
|
||||
start_span) += gc.weight;
|
||||
if (!grad_components[input_index].empty()) {
|
||||
const int start_span = grad_components[input_index].front().index;
|
||||
grad_starts_vec(input_index) = start_span;
|
||||
for (const GradComponent& gc : grad_components[input_index]) {
|
||||
grad_weights_vec(input_index * grad_spans->span_size + gc.index -
|
||||
start_span) += gc.weight;
|
||||
}
|
||||
} else {
|
||||
grad_starts_vec(input_index) = 0;
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
|
@ -120,7 +120,8 @@ void Sample(const DynamicKernel& kernel, const bool antialias,
|
||||
1;
|
||||
|
||||
std::fill(dest, dest + channels, 0.0f);
|
||||
if (y_span_end <= y_span_start || x_span_end <= x_span_start) {
|
||||
if (sample_f.x() < 0.0f || sample_f.y() < 0.0f || sample_f.x() > in_width ||
|
||||
sample_f.y() > in_height) {
|
||||
return;
|
||||
}
|
||||
const Vector2f one_over_kernel_scale(1.0f / kernel_scale.x(),
|
||||
@ -170,6 +171,8 @@ void ScaleAndTranslateBaseline(const DynamicKernel& kernel,
|
||||
|
||||
const int64 out_height = output.dimension(1);
|
||||
const int64 out_width = output.dimension(2);
|
||||
const int64 in_height = images.dimension(1);
|
||||
const int64 in_width = images.dimension(2);
|
||||
|
||||
for (int b = 0; b < batch; ++b) {
|
||||
for (int64 y = 0; y < out_height; ++y) {
|
||||
@ -178,8 +181,13 @@ void ScaleAndTranslateBaseline(const DynamicKernel& kernel,
|
||||
for (int64 x = 0; x < out_width; ++x) {
|
||||
const float out_x_f = static_cast<float>(x) + 0.5;
|
||||
const float in_x_f = out_x_f * scale.x() + translate.x();
|
||||
Sample(kernel, antialias, images, b, scale, Vector2f(in_x_f, in_y_f),
|
||||
&output(b, y, x, 0));
|
||||
if (in_x_f < 0.0f || in_y_f < 0.0f || in_x_f > in_width ||
|
||||
in_y_f > in_height) {
|
||||
std::fill(&output(b, y, x, 0), &output(b, y, x + 1, 0), 0.0f);
|
||||
} else {
|
||||
Sample(kernel, antialias, images, b, scale, Vector2f(in_x_f, in_y_f),
|
||||
&output(b, y, x, 0));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -216,7 +216,7 @@ class StageOp : public OpKernel {
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("Stage").Device(DEVICE_CPU), StageOp);
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
REGISTER_KERNEL_BUILDER(Name("Stage").Device(DEVICE_GPU), StageOp);
|
||||
#endif
|
||||
#ifdef TENSORFLOW_USE_SYCL
|
||||
@ -249,7 +249,7 @@ class UnstageOp : public OpKernel {
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("Unstage").Device(DEVICE_CPU), UnstageOp);
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
REGISTER_KERNEL_BUILDER(Name("Unstage").Device(DEVICE_GPU), UnstageOp);
|
||||
#endif
|
||||
#ifdef TENSORFLOW_USE_SYCL
|
||||
@ -284,7 +284,7 @@ class StagePeekOp : public OpKernel {
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("StagePeek").Device(DEVICE_CPU), StagePeekOp);
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("StagePeek").HostMemory("index").Device(DEVICE_GPU), StagePeekOp);
|
||||
#endif
|
||||
@ -314,7 +314,7 @@ class StageSizeOp : public OpKernel {
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("StageSize").Device(DEVICE_CPU), StageSizeOp);
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
REGISTER_KERNEL_BUILDER(Name("StageSize").HostMemory("size").Device(DEVICE_GPU),
|
||||
StageSizeOp);
|
||||
#endif
|
||||
@ -339,7 +339,7 @@ class StageClearOp : public OpKernel {
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("StageClear").Device(DEVICE_CPU), StageClearOp);
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
REGISTER_KERNEL_BUILDER(Name("StageClear").Device(DEVICE_GPU), StageClearOp);
|
||||
#endif
|
||||
#ifdef TENSORFLOW_USE_SYCL
|
||||
|
@ -134,7 +134,7 @@ TF_CALL_half(REGISTER_CPU);
|
||||
TF_CALL_float(REGISTER_CPU);
|
||||
TF_CALL_double(REGISTER_CPU);
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
REGISTER_KERNEL_BUILDER(Name("SoftmaxCrossEntropyWithLogits")
|
||||
.Device(DEVICE_GPU)
|
||||
.TypeConstraint<Eigen::half>("T"),
|
||||
@ -147,7 +147,7 @@ REGISTER_KERNEL_BUILDER(Name("SoftmaxCrossEntropyWithLogits")
|
||||
.Device(DEVICE_GPU)
|
||||
.TypeConstraint<double>("T"),
|
||||
SoftmaxXentWithLogitsOp<GPUDevice, double>);
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#ifdef TENSORFLOW_USE_SYCL
|
||||
REGISTER_KERNEL_BUILDER(Name("SoftmaxCrossEntropyWithLogits")
|
||||
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#define EIGEN_USE_GPU
|
||||
|
||||
@ -54,4 +54,4 @@ template struct functor::XentFunctor<GPUDevice, double>;
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
@ -21990,6 +21990,34 @@ op {
|
||||
}
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "DivNoNan"
|
||||
input_arg {
|
||||
name: "x"
|
||||
type_attr: "T"
|
||||
}
|
||||
input_arg {
|
||||
name: "y"
|
||||
type_attr: "T"
|
||||
}
|
||||
output_arg {
|
||||
name: "z"
|
||||
type_attr: "T"
|
||||
}
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_HALF
|
||||
type: DT_FLOAT
|
||||
type: DT_DOUBLE
|
||||
type: DT_COMPLEX64
|
||||
type: DT_COMPLEX128
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "DrawBoundingBoxes"
|
||||
input_arg {
|
||||
@ -40851,6 +40879,35 @@ op {
|
||||
}
|
||||
is_commutative: true
|
||||
}
|
||||
op {
|
||||
name: "MulNoNan"
|
||||
input_arg {
|
||||
name: "x"
|
||||
type_attr: "T"
|
||||
}
|
||||
input_arg {
|
||||
name: "y"
|
||||
type_attr: "T"
|
||||
}
|
||||
output_arg {
|
||||
name: "z"
|
||||
type_attr: "T"
|
||||
}
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_HALF
|
||||
type: DT_FLOAT
|
||||
type: DT_DOUBLE
|
||||
type: DT_COMPLEX64
|
||||
type: DT_COMPLEX128
|
||||
}
|
||||
}
|
||||
}
|
||||
is_commutative: true
|
||||
}
|
||||
op {
|
||||
name: "MultiDeviceIterator"
|
||||
output_arg {
|
||||
|
@ -435,7 +435,7 @@ REGISTER_OP("MulNoNan")
|
||||
.Input("x: T")
|
||||
.Input("y: T")
|
||||
.Output("z: T")
|
||||
.Attr("T: {float, double}")
|
||||
.Attr("T: {half, float, double, complex64, complex128}")
|
||||
.SetIsCommutative()
|
||||
.SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
|
||||
|
||||
@ -460,7 +460,7 @@ REGISTER_OP("DivNoNan")
|
||||
.Input("x: T")
|
||||
.Input("y: T")
|
||||
.Output("z: T")
|
||||
.Attr("T: {float, double}")
|
||||
.Attr("T: {half, float, double, complex64, complex128}")
|
||||
.SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
|
||||
|
||||
REGISTER_OP("FloorDiv")
|
||||
|
@ -10019,8 +10019,11 @@ op {
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_HALF
|
||||
type: DT_FLOAT
|
||||
type: DT_DOUBLE
|
||||
type: DT_COMPLEX64
|
||||
type: DT_COMPLEX128
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -20017,8 +20020,11 @@ op {
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_HALF
|
||||
type: DT_FLOAT
|
||||
type: DT_DOUBLE
|
||||
type: DT_COMPLEX64
|
||||
type: DT_COMPLEX128
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -284,7 +284,7 @@ Status OAuthClient::ParseOAuthResponse(StringPiece response,
|
||||
return errors::FailedPrecondition("Unexpected Oauth token type: " +
|
||||
token_type);
|
||||
}
|
||||
int64 expires_in;
|
||||
int64 expires_in = 0;
|
||||
TF_RETURN_IF_ERROR(ReadJsonInt(root, "expires_in", &expires_in));
|
||||
*expiration_timestamp_sec = request_timestamp_sec + expires_in;
|
||||
TF_RETURN_IF_ERROR(ReadJsonString(root, "access_token", token));
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#include <setjmp.h>
|
||||
#include <stdio.h>
|
||||
#include <string.h>
|
||||
#include <cmath>
|
||||
#include <fstream>
|
||||
#include <vector>
|
||||
|
||||
@ -228,7 +229,9 @@ void DecodeLocation(const float* encoded_location, const float* box_priors,
|
||||
}
|
||||
}
|
||||
|
||||
float DecodeScore(float encoded_score) { return 1 / (1 + exp(-encoded_score)); }
|
||||
float DecodeScore(float encoded_score) {
|
||||
return 1 / (1 + std::exp(-encoded_score));
|
||||
}
|
||||
|
||||
void DrawBox(const int image_width, const int image_height, int left, int top,
|
||||
int right, int bottom, tensorflow::TTypes<uint8>::Flat* image) {
|
||||
|
@ -72,7 +72,7 @@ class TextEmbeddingModel(tf.train.Checkpoint):
|
||||
normalized_sentences = tf.strings.regex_replace(
|
||||
input=sentences, pattern=r"\pP", rewrite="")
|
||||
normalized_sentences = tf.reshape(normalized_sentences, [-1])
|
||||
sparse_tokens = tf.string_split(normalized_sentences, " ")
|
||||
sparse_tokens = tf.strings.split(normalized_sentences, " ")
|
||||
|
||||
# Deal with a corner case: there is one empty sentence.
|
||||
sparse_tokens, _ = tf.sparse.fill_empty_rows(sparse_tokens, tf.constant(""))
|
||||
|
@ -50,7 +50,7 @@ class TextRnnModel(tf.train.Checkpoint):
|
||||
# splitting on spaces.
|
||||
normalized_sentences = tf.strings.regex_replace(
|
||||
input=sentences, pattern=r"\pP", rewrite="")
|
||||
sparse_tokens = tf.string_split(normalized_sentences, " ")
|
||||
sparse_tokens = tf.strings.split(normalized_sentences, " ")
|
||||
|
||||
# Deal with a corner case: there is one empty sentence.
|
||||
sparse_tokens, _ = tf.sparse.fill_empty_rows(sparse_tokens, tf.constant(""))
|
||||
|
@ -53,7 +53,7 @@ type Graph struct {
|
||||
c *C.TF_Graph
|
||||
}
|
||||
|
||||
// Graph execution options
|
||||
// The GraphImportOptions struct holds parameters for the ImportWithOptions function.
|
||||
type GraphImportOptions struct {
|
||||
// Node prefix
|
||||
Prefix string
|
||||
@ -170,7 +170,7 @@ func (g *Graph) Operation(name string) *Operation {
|
||||
|
||||
// Operations returns a list of all operations in the graph
|
||||
func (g *Graph) Operations() []Operation {
|
||||
var pos C.size_t = 0
|
||||
var pos C.size_t
|
||||
ops := []Operation{}
|
||||
for {
|
||||
cop := C.TF_GraphNextOperation(g.c, &pos)
|
||||
|
@ -5342,7 +5342,6 @@ py_library(
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":distribute",
|
||||
":framework",
|
||||
":framework_for_generated_wrappers",
|
||||
":platform",
|
||||
@ -5371,6 +5370,7 @@ py_library(
|
||||
":summary_ops_gen",
|
||||
":summary_ops_v2",
|
||||
":util",
|
||||
"//tensorflow/python/distribute:summary_op_util",
|
||||
"//tensorflow/python/eager:context",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
|
@ -88,20 +88,6 @@ from tensorflow.python.util.tf_export import tf_export
|
||||
# TODO(mdan): Add a test specific to this converter.
|
||||
|
||||
|
||||
# TODO(mdan): Remove when updating the API.
|
||||
@tf_export('autograph.experimental.Verbosity')
|
||||
class Verbosity(enum.IntEnum):
|
||||
"""Represents conversion verbosity levels.
|
||||
|
||||
Attributes:
|
||||
BRIEF: No logging, minimal error messages.
|
||||
VERBOSE: Detailed logging of generated code, detailed error messages.
|
||||
"""
|
||||
|
||||
BRIEF = 0
|
||||
VERBOSE = 1
|
||||
|
||||
|
||||
@tf_export('autograph.experimental.Feature')
|
||||
class Feature(enum.Enum):
|
||||
"""Represents conversion options that can be toggled on or off.
|
||||
|
@ -377,16 +377,12 @@ def _is_not_callable(obj):
|
||||
return False
|
||||
|
||||
|
||||
# TODO(mdan): Remove obsolete args.
|
||||
@tf_export('autograph.to_graph')
|
||||
def to_graph(entity,
|
||||
recursive=True,
|
||||
arg_values=None,
|
||||
arg_types=None,
|
||||
experimental_optional_features=converter.Feature.ALL,
|
||||
experimental_strip_decorators=None,
|
||||
experimental_verbose=converter.Verbosity.BRIEF,
|
||||
experimental_partial_types=None):
|
||||
experimental_optional_features=converter.Feature.ALL):
|
||||
"""Converts a Python entity into a TensorFlow graph.
|
||||
|
||||
Also see: `tf.autograph.to_code`, `tf.function`.
|
||||
@ -442,9 +438,6 @@ def to_graph(entity,
|
||||
experimental_optional_features: `None`, a tuple of, or a single
|
||||
`tf.autograph.experimental.Feature` value. Controls the use of
|
||||
optional features in the conversion process.
|
||||
experimental_strip_decorators: Deprecated, unused.
|
||||
experimental_verbose: Deprecated, unused.
|
||||
experimental_partial_types: Deprecated, unused.
|
||||
|
||||
Returns:
|
||||
Same as `entity`, the converted Python function or class.
|
||||
@ -452,10 +445,6 @@ def to_graph(entity,
|
||||
Raises:
|
||||
ValueError: If the entity could not be converted.
|
||||
"""
|
||||
del experimental_strip_decorators
|
||||
del experimental_verbose
|
||||
del experimental_partial_types
|
||||
|
||||
try:
|
||||
program_ctx = converter.ProgramContext(
|
||||
options=converter.ConversionOptions(
|
||||
@ -520,8 +509,7 @@ def to_code(entity,
|
||||
arg_values=None,
|
||||
arg_types=None,
|
||||
indentation=' ',
|
||||
experimental_optional_features=converter.Feature.ALL,
|
||||
experimental_partial_types=None):
|
||||
experimental_optional_features=converter.Feature.ALL):
|
||||
"""Similar to `to_graph`, but returns Python source code as a string.
|
||||
|
||||
Also see: `tf.autograph.to_graph`.
|
||||
@ -544,13 +532,10 @@ def to_code(entity,
|
||||
experimental_optional_features: `None`, a tuple of, or a single
|
||||
`tf.autograph.experimental.Feature` value. Controls the use of
|
||||
optional features in the conversion process.
|
||||
experimental_partial_types: Deprecated, unused.
|
||||
|
||||
Returns:
|
||||
The converted code as string.
|
||||
"""
|
||||
del experimental_partial_types
|
||||
|
||||
program_ctx = converter.ProgramContext(
|
||||
options=converter.ConversionOptions(
|
||||
recursive=recursive,
|
||||
|
@ -27,7 +27,7 @@ import datetime
|
||||
from tensorflow.python.util import tf_contextlib
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2019, 3, 14)
|
||||
_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2019, 3, 15)
|
||||
|
||||
|
||||
@tf_export("compat.forward_compatible")
|
||||
|
@ -19,70 +19,19 @@ from __future__ import print_function
|
||||
|
||||
from tensorflow.python.data.experimental.ops import random_ops
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.util import convert
|
||||
from tensorflow.python.data.ops import readers
|
||||
from tensorflow.python.data.util import nest
|
||||
from tensorflow.python.data.util import structure
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
|
||||
from tensorflow.python.ops import gen_experimental_dataset_ops
|
||||
from tensorflow.python.ops import gen_stateless_random_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.util import deprecation
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
class _ParallelInterleaveDataset(dataset_ops.UnaryDataset):
|
||||
"""A `Dataset` that maps a function over its input and flattens the result."""
|
||||
|
||||
def __init__(self, input_dataset, map_func, cycle_length, block_length,
|
||||
sloppy, buffer_output_elements, prefetch_input_elements):
|
||||
"""See `tf.data.experimental.parallel_interleave()` for details."""
|
||||
self._input_dataset = input_dataset
|
||||
self._map_func = dataset_ops.StructuredFunctionWrapper(
|
||||
map_func, self._transformation_name(), dataset=input_dataset)
|
||||
if not isinstance(self._map_func.output_structure,
|
||||
dataset_ops.DatasetStructure):
|
||||
raise TypeError("`map_func` must return a `Dataset` object.")
|
||||
self._structure = self._map_func.output_structure._element_structure # pylint: disable=protected-access
|
||||
self._cycle_length = ops.convert_to_tensor(
|
||||
cycle_length, dtype=dtypes.int64, name="cycle_length")
|
||||
self._block_length = ops.convert_to_tensor(
|
||||
block_length, dtype=dtypes.int64, name="block_length")
|
||||
self._sloppy = ops.convert_to_tensor(
|
||||
sloppy, dtype=dtypes.bool, name="sloppy")
|
||||
self._buffer_output_elements = convert.optional_param_to_tensor(
|
||||
"buffer_output_elements",
|
||||
buffer_output_elements,
|
||||
argument_default=2 * block_length)
|
||||
self._prefetch_input_elements = convert.optional_param_to_tensor(
|
||||
"prefetch_input_elements",
|
||||
prefetch_input_elements,
|
||||
argument_default=2 * cycle_length)
|
||||
variant_tensor = ged_ops.experimental_parallel_interleave_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._map_func.function.captured_inputs,
|
||||
self._cycle_length,
|
||||
self._block_length,
|
||||
self._sloppy,
|
||||
self._buffer_output_elements,
|
||||
self._prefetch_input_elements,
|
||||
f=self._map_func.function,
|
||||
**dataset_ops.flat_structure(self))
|
||||
super(_ParallelInterleaveDataset, self).__init__(input_dataset,
|
||||
variant_tensor)
|
||||
|
||||
def _functions(self):
|
||||
return [self._map_func]
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
return self._structure
|
||||
|
||||
def _transformation_name(self):
|
||||
return "tf.data.experimental.parallel_interleave()"
|
||||
|
||||
|
||||
@deprecation.deprecated(
|
||||
None,
|
||||
"Use `tf.data.Dataset.interleave(map_func, cycle_length, block_length, "
|
||||
@ -139,7 +88,7 @@ def parallel_interleave(map_func,
|
||||
`tf.data.Dataset.apply`.
|
||||
"""
|
||||
def _apply_fn(dataset):
|
||||
return _ParallelInterleaveDataset(
|
||||
return readers.ParallelInterleaveDataset(
|
||||
dataset, map_func, cycle_length, block_length, sloppy,
|
||||
buffer_output_elements, prefetch_input_elements)
|
||||
|
||||
@ -176,10 +125,11 @@ class _DirectedInterleaveDataset(dataset_ops.Dataset):
|
||||
|
||||
def _as_variant_tensor(self):
|
||||
# pylint: disable=protected-access
|
||||
return (ged_ops.experimental_directed_interleave_dataset(
|
||||
self._selector_input._variant_tensor,
|
||||
[data_input._variant_tensor for data_input in self._data_inputs],
|
||||
**dataset_ops.flat_structure(self)))
|
||||
return (
|
||||
gen_experimental_dataset_ops.experimental_directed_interleave_dataset(
|
||||
self._selector_input._variant_tensor,
|
||||
[data_input._variant_tensor for data_input in self._data_inputs],
|
||||
**dataset_ops.flat_structure(self)))
|
||||
# pylint: enable=protected-access
|
||||
|
||||
def _inputs(self):
|
||||
|
@ -25,6 +25,7 @@ import numpy as np
|
||||
|
||||
from tensorflow.python.data.experimental.ops import batching
|
||||
from tensorflow.python.data.experimental.ops import error_ops
|
||||
from tensorflow.python.data.experimental.ops import interleave_ops
|
||||
from tensorflow.python.data.experimental.ops import optimization
|
||||
from tensorflow.python.data.experimental.ops import parsing_ops
|
||||
from tensorflow.python.data.experimental.ops import shuffle_ops
|
||||
@ -493,15 +494,9 @@ def make_csv_dataset_v2(
|
||||
return features
|
||||
|
||||
# Read files sequentially (if num_parallel_reads=1) or in parallel
|
||||
dataset = dataset.interleave(
|
||||
filename_to_dataset,
|
||||
cycle_length=num_parallel_reads,
|
||||
num_parallel_calls=num_parallel_reads)
|
||||
|
||||
if sloppy:
|
||||
options = dataset_ops.Options()
|
||||
options.experimental_deterministic = False
|
||||
dataset = dataset.with_options(options)
|
||||
dataset = dataset.apply(
|
||||
interleave_ops.parallel_interleave(
|
||||
filename_to_dataset, cycle_length=num_parallel_reads, sloppy=sloppy))
|
||||
|
||||
dataset = _maybe_shuffle_and_repeat(
|
||||
dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed)
|
||||
@ -833,15 +828,11 @@ def make_batched_features_dataset_v2(file_pattern,
|
||||
reader_args = []
|
||||
|
||||
# Read files sequentially (if reader_num_threads=1) or in parallel
|
||||
dataset = dataset.interleave(
|
||||
lambda filename: reader(filename, *reader_args),
|
||||
cycle_length=reader_num_threads,
|
||||
num_parallel_calls=reader_num_threads)
|
||||
|
||||
if sloppy_ordering:
|
||||
options = dataset_ops.Options()
|
||||
options.experimental_deterministic = False
|
||||
dataset = dataset.with_options(options)
|
||||
dataset = dataset.apply(
|
||||
interleave_ops.parallel_interleave(
|
||||
lambda filename: reader(filename, *reader_args),
|
||||
cycle_length=reader_num_threads,
|
||||
sloppy=sloppy_ordering))
|
||||
|
||||
# Extract values if the `Example` tensors are stored as key-value tuples.
|
||||
if dataset_ops.get_legacy_output_types(dataset) == (
|
||||
|
@ -1188,7 +1188,9 @@ class DatasetV2(object):
|
||||
"""
|
||||
dataset = transformation_func(self)
|
||||
if not isinstance(dataset, DatasetV2):
|
||||
raise TypeError("`transformation_func` must return a Dataset.")
|
||||
raise TypeError(
|
||||
"`transformation_func` must return a Dataset. Got {}.".format(
|
||||
dataset))
|
||||
dataset._input_datasets = [self] # pylint: disable=protected-access
|
||||
return dataset
|
||||
|
||||
@ -2121,7 +2123,9 @@ class SparseTensorSliceDataset(DatasetSource):
|
||||
def __init__(self, sparse_tensor):
|
||||
"""See `Dataset.from_sparse_tensor_slices()` for details."""
|
||||
if not isinstance(sparse_tensor, sparse_tensor_lib.SparseTensor):
|
||||
raise TypeError("`sparse_tensor` must be a `tf.SparseTensor` object.")
|
||||
raise TypeError(
|
||||
"`sparse_tensor` must be a `tf.SparseTensor` object. Was {}.".format(
|
||||
sparse_tensor))
|
||||
self._sparse_tensor = sparse_tensor
|
||||
|
||||
indices_shape = self._sparse_tensor.indices.get_shape()
|
||||
@ -2885,7 +2889,11 @@ def _default_padding(input_dataset):
|
||||
if t.base_dtype == dtypes.string:
|
||||
return ""
|
||||
elif t.base_dtype == dtypes.variant:
|
||||
raise TypeError("Unable to create padding for field of type 'variant'")
|
||||
error_msg = ("Unable to create padding for field of type 'variant' "
|
||||
"because t.base_type == dtypes.variant == "
|
||||
"{}.".format(
|
||||
t.base_dtype))
|
||||
raise TypeError(error_msg)
|
||||
else:
|
||||
return np.zeros_like(t.as_numpy_dtype())
|
||||
|
||||
@ -3066,7 +3074,9 @@ class FlatMapDataset(UnaryDataset):
|
||||
self._map_func = StructuredFunctionWrapper(
|
||||
map_func, self._transformation_name(), dataset=input_dataset)
|
||||
if not isinstance(self._map_func.output_structure, DatasetStructure):
|
||||
raise TypeError("`map_func` must return a `Dataset` object.")
|
||||
raise TypeError(
|
||||
"`map_func` must return a `Dataset` object. Got {}".format(
|
||||
type(self._map_func.output_structure)))
|
||||
self._structure = self._map_func.output_structure._element_structure # pylint: disable=protected-access
|
||||
variant_tensor = gen_dataset_ops.flat_map_dataset(
|
||||
input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
@ -3096,7 +3106,9 @@ class InterleaveDataset(UnaryDataset):
|
||||
self._map_func = StructuredFunctionWrapper(
|
||||
map_func, self._transformation_name(), dataset=input_dataset)
|
||||
if not isinstance(self._map_func.output_structure, DatasetStructure):
|
||||
raise TypeError("`map_func` must return a `Dataset` object.")
|
||||
raise TypeError(
|
||||
"`map_func` must return a `Dataset` object. Got {}".format(
|
||||
type(self._map_func.output_structure)))
|
||||
self._structure = self._map_func.output_structure._element_structure # pylint: disable=protected-access
|
||||
self._cycle_length = ops.convert_to_tensor(
|
||||
cycle_length, dtype=dtypes.int64, name="cycle_length")
|
||||
@ -3124,8 +3136,7 @@ class InterleaveDataset(UnaryDataset):
|
||||
|
||||
|
||||
class ParallelInterleaveDataset(UnaryDataset):
|
||||
"""A `Dataset` that maps a function over its input and interleaves the result.
|
||||
"""
|
||||
"""A `Dataset` that maps a function over its input and interleaves the result."""
|
||||
|
||||
def __init__(self, input_dataset, map_func, cycle_length, block_length,
|
||||
num_parallel_calls):
|
||||
@ -3134,7 +3145,9 @@ class ParallelInterleaveDataset(UnaryDataset):
|
||||
self._map_func = StructuredFunctionWrapper(
|
||||
map_func, self._transformation_name(), dataset=input_dataset)
|
||||
if not isinstance(self._map_func.output_structure, DatasetStructure):
|
||||
raise TypeError("`map_func` must return a `Dataset` object.")
|
||||
raise TypeError(
|
||||
"`map_func` must return a `Dataset` object. Got {}".format(
|
||||
type(self._map_func.output_structure)))
|
||||
self._structure = self._map_func.output_structure._element_structure # pylint: disable=protected-access
|
||||
self._cycle_length = ops.convert_to_tensor(
|
||||
cycle_length, dtype=dtypes.int64, name="cycle_length")
|
||||
@ -3177,7 +3190,10 @@ class FilterDataset(UnaryUnchangedStructureDataset):
|
||||
use_legacy_function=use_legacy_function)
|
||||
if not wrapped_func.output_structure.is_compatible_with(
|
||||
structure_lib.TensorStructure(dtypes.bool, [])):
|
||||
raise ValueError("`predicate` must return a scalar boolean tensor.")
|
||||
error_msg = ("`predicate` return type must be convertible to a scalar "
|
||||
"boolean tensor. Was {}.").format(
|
||||
wrapped_func.output_structure)
|
||||
raise ValueError(error_msg)
|
||||
self._predicate = wrapped_func
|
||||
variant_tensor = gen_dataset_ops.filter_dataset(
|
||||
input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
|
@ -26,6 +26,7 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_dataset_ops
|
||||
from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
@ -118,6 +119,57 @@ class _TFRecordDataset(dataset_ops.DatasetSource):
|
||||
return structure.TensorStructure(dtypes.string, [])
|
||||
|
||||
|
||||
class ParallelInterleaveDataset(dataset_ops.UnaryDataset):
|
||||
"""A `Dataset` that maps a function over its input and flattens the result."""
|
||||
|
||||
def __init__(self, input_dataset, map_func, cycle_length, block_length,
|
||||
sloppy, buffer_output_elements, prefetch_input_elements):
|
||||
"""See `tf.data.experimental.parallel_interleave()` for details."""
|
||||
self._input_dataset = input_dataset
|
||||
self._map_func = dataset_ops.StructuredFunctionWrapper(
|
||||
map_func, self._transformation_name(), dataset=input_dataset)
|
||||
if not isinstance(self._map_func.output_structure,
|
||||
dataset_ops.DatasetStructure):
|
||||
raise TypeError("`map_func` must return a `Dataset` object.")
|
||||
self._structure = self._map_func.output_structure._element_structure # pylint: disable=protected-access
|
||||
self._cycle_length = ops.convert_to_tensor(
|
||||
cycle_length, dtype=dtypes.int64, name="cycle_length")
|
||||
self._block_length = ops.convert_to_tensor(
|
||||
block_length, dtype=dtypes.int64, name="block_length")
|
||||
self._sloppy = ops.convert_to_tensor(
|
||||
sloppy, dtype=dtypes.bool, name="sloppy")
|
||||
self._buffer_output_elements = convert.optional_param_to_tensor(
|
||||
"buffer_output_elements",
|
||||
buffer_output_elements,
|
||||
argument_default=2 * block_length)
|
||||
self._prefetch_input_elements = convert.optional_param_to_tensor(
|
||||
"prefetch_input_elements",
|
||||
prefetch_input_elements,
|
||||
argument_default=2 * cycle_length)
|
||||
variant_tensor = ged_ops.experimental_parallel_interleave_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._map_func.function.captured_inputs,
|
||||
self._cycle_length,
|
||||
self._block_length,
|
||||
self._sloppy,
|
||||
self._buffer_output_elements,
|
||||
self._prefetch_input_elements,
|
||||
f=self._map_func.function,
|
||||
**dataset_ops.flat_structure(self))
|
||||
super(ParallelInterleaveDataset, self).__init__(input_dataset,
|
||||
variant_tensor)
|
||||
|
||||
def _functions(self):
|
||||
return [self._map_func]
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
return self._structure
|
||||
|
||||
def _transformation_name(self):
|
||||
return "tf.data.experimental.parallel_interleave()"
|
||||
|
||||
|
||||
@tf_export("data.TFRecordDataset", v1=[])
|
||||
class TFRecordDatasetV2(dataset_ops.DatasetV2):
|
||||
"""A `Dataset` comprising records from one or more TFRecord files."""
|
||||
@ -169,10 +221,10 @@ class TFRecordDatasetV2(dataset_ops.DatasetV2):
|
||||
if num_parallel_reads is None:
|
||||
self._impl = filenames.flat_map(read_one_file)
|
||||
else:
|
||||
self._impl = filenames.interleave(
|
||||
read_one_file,
|
||||
cycle_length=num_parallel_reads,
|
||||
num_parallel_calls=num_parallel_reads)
|
||||
self._impl = ParallelInterleaveDataset(
|
||||
filenames, read_one_file, cycle_length=num_parallel_reads,
|
||||
block_length=1, sloppy=False, buffer_output_elements=None,
|
||||
prefetch_input_elements=None)
|
||||
variant_tensor = self._impl._variant_tensor # pylint: disable=protected-access
|
||||
super(TFRecordDatasetV2, self).__init__(variant_tensor)
|
||||
|
||||
|
@ -469,6 +469,16 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "summary_op_util",
|
||||
srcs = ["summary_op_util.py"],
|
||||
deps = [
|
||||
":distribute_lib",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:tensor_util",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "values",
|
||||
srcs = ["values.py"],
|
||||
|
@ -512,23 +512,50 @@ class DistributionStrategy(object):
|
||||
_require_cross_replica_or_default_context_extended(self._extended)
|
||||
return self._extended._reduce(reduce_op, value) # pylint: disable=protected-access
|
||||
|
||||
@doc_controls.do_not_generate_docs # DEPRECATED, -> `DistributedValues`
|
||||
@doc_controls.do_not_generate_docs # DEPRECATED
|
||||
def unwrap(self, value):
|
||||
"""Returns the list of all per-replica values contained in `value`.
|
||||
"""Returns the list of all local per-replica values contained in `value`.
|
||||
|
||||
DEPRECATED: Please use `experimental_local_results` instead.
|
||||
|
||||
Note: This only returns values on the workers initiated by this client.
|
||||
When using a `Strategy` like
|
||||
`tf.distribute.experimental.MultiWorkerMirroredStrategy`, each worker
|
||||
will be its own client, and this function will only return values
|
||||
computed on that worker.
|
||||
|
||||
Args:
|
||||
value: A value returned by `extended.call_for_each_replica()` or a
|
||||
variable created in `scope`.
|
||||
value: A value returned by `experimental_run()`,
|
||||
`extended.call_for_each_replica()`, or a variable created in `scope`.
|
||||
|
||||
Returns:
|
||||
A tuple of values contained in `value`. If `value` represents a single
|
||||
value, this returns `(value,).`
|
||||
"""
|
||||
return self._extended._unwrap(value) # pylint: disable=protected-access
|
||||
return self._extended._local_results(value) # pylint: disable=protected-access
|
||||
|
||||
@doc_controls.do_not_generate_docs # DEPRECATED, -> `DistributedValues`
|
||||
def experimental_local_results(self, value):
|
||||
"""Returns the list of all local per-replica values contained in `value`.
|
||||
|
||||
Note: This only returns values on the workers initiated by this client.
|
||||
When using a `Strategy` like
|
||||
`tf.distribute.experimental.MultiWorkerMirroredStrategy`, each worker
|
||||
will be its own client, and this function will only return values
|
||||
computed on that worker.
|
||||
|
||||
Args:
|
||||
value: A value returned by `experimental_run()`, `experimental_run_v2()`,
|
||||
`extended.call_for_each_replica()`, or a variable created in `scope`.
|
||||
|
||||
Returns:
|
||||
A tuple of values contained in `value`. If `value` represents a single
|
||||
value, this returns `(value,).`
|
||||
"""
|
||||
return self._extended._local_results(value) # pylint: disable=protected-access
|
||||
|
||||
@doc_controls.do_not_generate_docs # DEPRECATED: TF v1.x only
|
||||
def group(self, value, name=None):
|
||||
"""Shortcut for `tf.group(self.unwrap(value))`."""
|
||||
"""Shortcut for `tf.group(self.experimental_local_results(value))`."""
|
||||
return self._extended._group(value, name) # pylint: disable=protected-access
|
||||
|
||||
@property
|
||||
@ -1067,7 +1094,7 @@ class DistributionStrategyExtended(object):
|
||||
# Called once in "cross-replica" context.
|
||||
def merge_fn(distribution, three_plus_replica_id):
|
||||
# sum the values across replicas
|
||||
return sum(distribution.unwrap(three_plus_replica_id))
|
||||
return sum(distribution.experimental_local_results(three_plus_replica_id))
|
||||
|
||||
# Called once per replica in `distribution`, in a "replica" context.
|
||||
def fn(three):
|
||||
@ -1082,7 +1109,8 @@ class DistributionStrategyExtended(object):
|
||||
...
|
||||
merged_results = distribution.call_for_each_replica(fn, args=[3])
|
||||
# merged_results has the values from every replica execution of `fn`.
|
||||
print(distribution.unwrap(merged_results)) # Prints a list
|
||||
# This statement prints a list:
|
||||
print(distribution.experimental_local_results(merged_results))
|
||||
```
|
||||
|
||||
Args:
|
||||
@ -1104,8 +1132,9 @@ class DistributionStrategyExtended(object):
|
||||
|
||||
def _reduce(self, reduce_op, value):
|
||||
# Default implementation until we have an implementation for each strategy.
|
||||
return self._unwrap(self._reduce_to(
|
||||
reduce_op, value, device_util.current() or "/device:CPU:0"))[0]
|
||||
return self._local_results(
|
||||
self._reduce_to(reduce_op, value,
|
||||
device_util.current() or "/device:CPU:0"))[0]
|
||||
|
||||
def reduce_to(self, reduce_op, value, destinations):
|
||||
"""Combine (via e.g. sum or mean) values across replicas.
|
||||
@ -1224,7 +1253,7 @@ class DistributionStrategyExtended(object):
|
||||
def _update_non_slot(self, colocate_with, fn, args, kwargs, group):
|
||||
raise NotImplementedError("must be implemented in descendants")
|
||||
|
||||
def _unwrap(self, distributed_value):
|
||||
def _local_results(self, distributed_value):
|
||||
raise NotImplementedError("must be implemented in descendants")
|
||||
|
||||
def value_container(self, value):
|
||||
@ -1238,13 +1267,14 @@ class DistributionStrategyExtended(object):
|
||||
A container that `value` belongs to.
|
||||
If value does not belong to any container (including the case of
|
||||
container having been destroyed), returns the value itself.
|
||||
`value in unwrap(value_container(value))` will always be true.
|
||||
`value in experimental_local_results(value_container(value))` will
|
||||
always be true.
|
||||
"""
|
||||
raise NotImplementedError("must be implemented in descendants")
|
||||
|
||||
def _group(self, value, name=None):
|
||||
"""Shortcut for `tf.group(distribution.unwrap(value))`."""
|
||||
value = nest.flatten(self._unwrap(value))
|
||||
"""Implementation of `group`."""
|
||||
value = nest.flatten(self._local_results(value))
|
||||
|
||||
if len(value) != 1 or name is not None:
|
||||
return control_flow_ops.group(value, name=name)
|
||||
@ -1590,12 +1620,12 @@ class _DefaultDistributionExtended(DistributionStrategyExtended):
|
||||
if should_group:
|
||||
return result
|
||||
else:
|
||||
return nest.map_structure(self._unwrap, result)
|
||||
return nest.map_structure(self._local_results, result)
|
||||
|
||||
def read_var(self, replica_local_var):
|
||||
return array_ops.identity(replica_local_var)
|
||||
|
||||
def _unwrap(self, distributed_value):
|
||||
def _local_results(self, distributed_value):
|
||||
return (distributed_value,)
|
||||
|
||||
def value_container(self, value):
|
||||
|
@ -648,6 +648,7 @@ class MultiStepContext(object):
|
||||
def merge_fn(distribution, value):
|
||||
# NOTE(priyag): For non tensor outputs, we simply return all the values
|
||||
# in a list as reduction doesn't make sense on non tensors.
|
||||
self._non_tensor_outputs[name] = distribution.unwrap(value)
|
||||
self._non_tensor_outputs[name] = (
|
||||
distribution.experimental_local_results(value))
|
||||
distribution_strategy_context.get_replica_context().merge_call(
|
||||
merge_fn, args=(output,))
|
||||
|
@ -602,7 +602,7 @@ class MirroredExtended(distribute_lib.DistributionStrategyExtended):
|
||||
fn_result = fn(ctx, iterator.get_next())
|
||||
for (name, output) in ctx.last_step_outputs.items():
|
||||
# Convert all outputs to tensors, potentially from `DistributedValues`.
|
||||
ctx.last_step_outputs[name] = self._unwrap(output)
|
||||
ctx.last_step_outputs[name] = self._local_results(output)
|
||||
flat_last_step_outputs = nest.flatten(ctx.last_step_outputs)
|
||||
with ops.control_dependencies([fn_result]):
|
||||
return [i + 1] + flat_last_step_outputs
|
||||
@ -741,7 +741,7 @@ class MirroredExtended(distribute_lib.DistributionStrategyExtended):
|
||||
assert isinstance(replica_local_var, values.Mirrored)
|
||||
return array_ops.identity(replica_local_var.get())
|
||||
|
||||
def _unwrap(self, val):
|
||||
def _local_results(self, val):
|
||||
if isinstance(val, values.DistributedValues):
|
||||
return val.values
|
||||
return (val,)
|
||||
|
@ -159,13 +159,13 @@ class OneDeviceExtended(distribute_lib.DistributionStrategyExtended):
|
||||
if group:
|
||||
return result
|
||||
else:
|
||||
return nest.map_structure(self._unwrap, result)
|
||||
return nest.map_structure(self._local_results, result)
|
||||
|
||||
def read_var(self, replica_local_var):
|
||||
"""Read the aggregate value of a replica-local variable."""
|
||||
return array_ops.identity(replica_local_var)
|
||||
|
||||
def _unwrap(self, value):
|
||||
def _local_results(self, value):
|
||||
return (value,)
|
||||
|
||||
def value_container(self, value):
|
||||
|
@ -426,7 +426,7 @@ class ParameterServerStrategyExtended(
|
||||
if group:
|
||||
return result
|
||||
else:
|
||||
return nest.map_structure(self._unwrap, result)
|
||||
return nest.map_structure(self._local_results, result)
|
||||
|
||||
# TODO(yuefengz): does it need to call _select_single_value?
|
||||
def _update_non_slot(self, colocate_with, fn, args, kwargs, group):
|
||||
@ -436,9 +436,9 @@ class ParameterServerStrategyExtended(
|
||||
if group:
|
||||
return result
|
||||
else:
|
||||
return nest.map_structure(self._unwrap, result)
|
||||
return nest.map_structure(self._local_results, result)
|
||||
|
||||
def _unwrap(self, val):
|
||||
def _local_results(self, val):
|
||||
if isinstance(val, values.DistributedValues):
|
||||
return val.values
|
||||
return (val,)
|
||||
|
48
tensorflow/python/distribute/summary_op_util.py
Normal file
48
tensorflow/python/distribute/summary_op_util.py
Normal file
@ -0,0 +1,48 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#==============================================================================
|
||||
"""Contains utility functions used by summary ops in distribution strategy."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
|
||||
from tensorflow.python.distribute import distribution_strategy_context
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_util
|
||||
|
||||
|
||||
def skip_summary():
|
||||
"""Determines if summary should be skipped.
|
||||
|
||||
If using multiple replicas in distributed strategy, skip summaries on all
|
||||
replicas except the first one (replica_id=0).
|
||||
|
||||
Returns:
|
||||
True if the summary is skipped; False otherwise.
|
||||
"""
|
||||
|
||||
# TODO(priyag): Add a new optional argument that will provide multiple
|
||||
# alternatives to override default behavior. (e.g. run on last replica,
|
||||
# compute sum or mean across replicas).
|
||||
replica_context = distribution_strategy_context.get_replica_context()
|
||||
if not replica_context:
|
||||
return False
|
||||
# TODO(b/118385803): when replica_id of _TPUReplicaContext is properly
|
||||
# initialized, remember to change here as well.
|
||||
replica_id = replica_context.replica_id_in_sync_group
|
||||
if isinstance(replica_id, ops.Tensor):
|
||||
replica_id = tensor_util.constant_value(replica_id)
|
||||
return replica_id and replica_id > 0
|
@ -500,7 +500,7 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended):
|
||||
assert isinstance(var, values.TPUMirroredVariable)
|
||||
return var.read_value()
|
||||
|
||||
def _unwrap(self, val):
|
||||
def _local_results(self, val):
|
||||
if isinstance(val, values.DistributedValues):
|
||||
# Return in a deterministic order.
|
||||
return tuple(val.get(device=d) for d in sorted(val.devices))
|
||||
@ -589,7 +589,7 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended):
|
||||
if group:
|
||||
return result
|
||||
else:
|
||||
return nest.map_structure(self._unwrap, result)
|
||||
return nest.map_structure(self._local_results, result)
|
||||
|
||||
def _configure(self,
|
||||
session_config=None,
|
||||
|
@ -272,7 +272,7 @@ class DistributedValues(object):
|
||||
def device_map(self):
|
||||
return self._device_map
|
||||
|
||||
# TODO(josh11b): Replace unwrap with this?
|
||||
# TODO(josh11b): Replace experimental_local_results with this?
|
||||
@property
|
||||
def values(self):
|
||||
return self._values
|
||||
@ -622,8 +622,9 @@ def validate_colocate(v, extended):
|
||||
|
||||
def _apply_aggregation(strategy, value, aggregation, destinations):
|
||||
if aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA:
|
||||
return strategy.extended.broadcast_to(strategy.unwrap(value)[0],
|
||||
destinations=destinations)
|
||||
return strategy.extended.broadcast_to(
|
||||
strategy.experimental_local_results(value)[0],
|
||||
destinations=destinations)
|
||||
reduce_op = reduce_util.ReduceOp.from_variable_aggregation(aggregation)
|
||||
return strategy.extended.reduce_to(reduce_op, value, destinations)
|
||||
|
||||
@ -824,7 +825,7 @@ class TPUMirroredVariable(trackable.Trackable):
|
||||
def device_map(self):
|
||||
return self._device_map
|
||||
|
||||
# TODO(josh11b): Replace unwrap with this?
|
||||
# TODO(josh11b): Replace experimental_local_results with this?
|
||||
@property
|
||||
def values(self):
|
||||
return self._values
|
||||
@ -1410,7 +1411,7 @@ def update_regroup(extended, device_map, updates, group):
|
||||
# so we can avoid all these nest operations.
|
||||
regrouped = regroup(device_map, updates, Mirrored)
|
||||
if not group:
|
||||
return nest.map_structure(extended._unwrap, regrouped) # pylint: disable=protected-access
|
||||
return nest.map_structure(extended._local_results, regrouped) # pylint: disable=protected-access
|
||||
grouped_flat = []
|
||||
for u in nest.flatten(regrouped):
|
||||
if isinstance(u, DistributedValues):
|
||||
|
@ -735,7 +735,7 @@ class ConcreteFunction(object):
|
||||
# In case of eager execution, function definition gets added to context
|
||||
# during construction itself.
|
||||
|
||||
# TODO(allel/shivaniagrawal): rename this to register to reflect the
|
||||
# TODO(allenl/shivaniagrawal): rename this to register to reflect the
|
||||
# method's functionality better. Remove register_gradient_functions argument
|
||||
# and figure out if these needs to be registered.
|
||||
|
||||
|
@ -73,6 +73,10 @@ def function_def_to_graph(fdef, input_shapes=None):
|
||||
func_graph.outputs = [
|
||||
func_graph.get_tensor_by_name(name) for name in output_tensor_names
|
||||
]
|
||||
func_graph.control_outputs = [
|
||||
func_graph.get_operation_by_name(fdef.control_ret[ret_name])
|
||||
for ret_name in fdef.signature.control_output
|
||||
]
|
||||
|
||||
return func_graph
|
||||
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user