Support XlaCompiler arguments of type kConstant in XlaCompilationCache
These arguments types are generated in two ways: 1) XlaCompileOnDemandOp, based on the xla op kernel requirement, provides inputs with constant constraint during compilation. This allows the op to get compiled to HLO and then the actual HLO op can get executed on the required device. 2) _XlaCompileOp gets some of the inputs as "constants" variadic input. These are the constants coming from GuaranteeConst op. TPUCompileMlirOp doesn't support these constants yet. I also found a note in the old bridge to revisit this idea as it has limited utility. (Couldn't find the note now.) kConstant type arguments are supported by inlining the constant arguments after import and dropping those arguments from the signature. Mapping from the old set of arguments to the final list of arguments is set in the input_mapping of compilation result. Some of the compilation tests are disabled with this change as those were going through the old bridge and requires further changes to support those. PiperOrigin-RevId: 310040060 Change-Id: I121bb717e0376728bd727c7646116a04c77dcb53
This commit is contained in:
parent
9d15d75988
commit
ae422505db
|
@ -41,6 +41,7 @@ limitations under the License.
|
||||||
#include "tensorflow/core/graph/node_builder.h"
|
#include "tensorflow/core/graph/node_builder.h"
|
||||||
#include "tensorflow/core/lib/hash/hash.h"
|
#include "tensorflow/core/lib/hash/hash.h"
|
||||||
#include "tensorflow/core/platform/env.h"
|
#include "tensorflow/core/platform/env.h"
|
||||||
|
#include "tensorflow/core/platform/errors.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
#include "tensorflow/core/protobuf/graph_debug_info.pb.h"
|
#include "tensorflow/core/protobuf/graph_debug_info.pb.h"
|
||||||
#include "tensorflow/core/public/version.h"
|
#include "tensorflow/core/public/version.h"
|
||||||
|
@ -277,29 +278,25 @@ Status XlaCompilationCache::CompileSingleOp(
|
||||||
const NodeDef& node_def = ctx->op_kernel().def();
|
const NodeDef& node_def = ctx->op_kernel().def();
|
||||||
TF_ASSIGN_OR_RETURN(auto graph, CreateGraph(node_def, args, result_dtypes));
|
TF_ASSIGN_OR_RETURN(auto graph, CreateGraph(node_def, args, result_dtypes));
|
||||||
|
|
||||||
bool are_params = absl::c_all_of(args, [](const XlaCompiler::Argument arg) {
|
bool are_args_supported =
|
||||||
return arg.kind == XlaCompiler::Argument::kParameter;
|
absl::c_all_of(args, [](const XlaCompiler::Argument arg) {
|
||||||
|
return arg.kind == XlaCompiler::Argument::kConstant ||
|
||||||
|
arg.kind == XlaCompiler::Argument::kParameter;
|
||||||
});
|
});
|
||||||
const ConfigProto* config = ctx->function_library()->config_proto();
|
const ConfigProto* config = ctx->function_library()->config_proto();
|
||||||
bool use_mlir = config && config->experimental().enable_mlir_bridge();
|
bool use_mlir = config && config->experimental().enable_mlir_bridge();
|
||||||
// Use MLIR bridge if all the arguments are parameters.
|
// TODO(b/155596779): Understand the source of other argument types and
|
||||||
// TODO(hinsu): Support other argument types instead of silently falling
|
// depending on the source either support those or avoid these codepath.
|
||||||
// back to the XLA compiler.
|
if (!use_mlir || !are_args_supported) {
|
||||||
if (!are_params || !use_mlir) {
|
|
||||||
return compiler->CompileGraph(compile_options, node_def.name(),
|
return compiler->CompileGraph(compile_options, node_def.name(),
|
||||||
std::move(graph), args, result);
|
std::move(graph), args, result);
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::InlinedVector<TensorShape, 4> arg_shapes;
|
|
||||||
arg_shapes.reserve(args.size());
|
|
||||||
for (const XlaCompiler::Argument& arg : args) {
|
|
||||||
arg_shapes.push_back(absl::get<TensorShape>(arg.shape));
|
|
||||||
}
|
|
||||||
GraphDebugInfo debug_info;
|
GraphDebugInfo debug_info;
|
||||||
return CompileGraphToXlaHlo(
|
return CompileGraphToXlaHlo(
|
||||||
*graph, {arg_shapes.data(), arg_shapes.size()},
|
*graph, {args.data(), args.size()}, options.device_type.type_string(),
|
||||||
options.device_type.type_string(), compile_options.use_tuple_arg,
|
compile_options.use_tuple_arg, *options.flib_def, debug_info,
|
||||||
*options.flib_def, debug_info, options.shape_representation_fn, result);
|
options.shape_representation_fn, result);
|
||||||
};
|
};
|
||||||
return CompileImpl(options, name, args, compile_op,
|
return CompileImpl(options, name, args, compile_op,
|
||||||
/*compile_threshold=*/absl::nullopt,
|
/*compile_threshold=*/absl::nullopt,
|
||||||
|
|
|
@ -1114,6 +1114,7 @@ COMPILE_MLIR_UTIL_DEPS = [
|
||||||
"//tensorflow/stream_executor/lib",
|
"//tensorflow/stream_executor/lib",
|
||||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||||
"//tensorflow/compiler/xla/service:hlo",
|
"//tensorflow/compiler/xla/service:hlo",
|
||||||
|
":convert_tensor",
|
||||||
]
|
]
|
||||||
|
|
||||||
# Prefer to link 'compile_mlir_util' library that also links necessary
|
# Prefer to link 'compile_mlir_util' library that also links necessary
|
||||||
|
|
|
@ -17,10 +17,13 @@ limitations under the License.
|
||||||
|
|
||||||
#include "absl/types/optional.h"
|
#include "absl/types/optional.h"
|
||||||
#include "llvm/ADT/ArrayRef.h"
|
#include "llvm/ADT/ArrayRef.h"
|
||||||
|
#include "llvm/ADT/STLExtras.h"
|
||||||
#include "llvm/ADT/StringRef.h"
|
#include "llvm/ADT/StringRef.h"
|
||||||
|
#include "llvm/Support/raw_ostream.h"
|
||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||||
#include "mlir/IR/Dialect.h" // from @llvm-project
|
#include "mlir/IR/Dialect.h" // from @llvm-project
|
||||||
#include "mlir/IR/Function.h" // from @llvm-project
|
#include "mlir/IR/Function.h" // from @llvm-project
|
||||||
|
#include "mlir/IR/Location.h" // from @llvm-project
|
||||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||||
#include "mlir/IR/OpDefinition.h" // from @llvm-project
|
#include "mlir/IR/OpDefinition.h" // from @llvm-project
|
||||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||||
|
@ -35,6 +38,7 @@ limitations under the License.
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
|
#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
|
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h"
|
#include "tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h"
|
||||||
|
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
|
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
|
#include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
|
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
|
||||||
|
@ -393,14 +397,47 @@ Status CompileSerializedMlirToXlaHlo(
|
||||||
std::move(custom_legalization_passes));
|
std::move(custom_legalization_passes));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Rewrites the given module with specified args. For each of the constant args,
|
||||||
|
// it gets inlined in the "main' function and the corresponding argument is
|
||||||
|
// removed from the signature.
|
||||||
|
// Returns the original indices for the other arguments on success.
|
||||||
|
static StatusOr<std::vector<int>> RewriteWithArgs(
|
||||||
|
mlir::ModuleOp module, llvm::ArrayRef<const XlaCompiler::Argument> args) {
|
||||||
|
mlir::FuncOp main_fn = module.lookupSymbol<mlir::FuncOp>("main");
|
||||||
|
std::vector<int> params;
|
||||||
|
|
||||||
|
auto builder = mlir::OpBuilder(main_fn.getBody());
|
||||||
|
std::vector<int> args_to_erase;
|
||||||
|
for (int idx = 0; idx < args.size(); idx++) {
|
||||||
|
const XlaCompiler::Argument& xla_arg = args[idx];
|
||||||
|
mlir::BlockArgument mlir_arg = main_fn.getArgument(idx);
|
||||||
|
if (xla_arg.kind != XlaCompiler::Argument::kConstant) {
|
||||||
|
params.push_back(idx);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_ASSIGN_OR_RETURN(auto value_attr,
|
||||||
|
ConvertTensor(xla_arg.constant_value, &builder));
|
||||||
|
// TODO(hinsu): Use the actual location of the constant.
|
||||||
|
auto constant = builder.create<mlir::TF::ConstOp>(
|
||||||
|
mlir::UnknownLoc::get(module.getContext()), value_attr);
|
||||||
|
mlir_arg.replaceAllUsesWith(constant);
|
||||||
|
args_to_erase.push_back(idx);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int idx : llvm::reverse(args_to_erase)) main_fn.eraseArgument(idx);
|
||||||
|
return params;
|
||||||
|
}
|
||||||
|
|
||||||
Status CompileGraphToXlaHlo(
|
Status CompileGraphToXlaHlo(
|
||||||
const Graph& graph, llvm::ArrayRef<TensorShape> arg_shapes,
|
const Graph& graph, llvm::ArrayRef<const XlaCompiler::Argument> args,
|
||||||
llvm::StringRef device_type, bool use_tuple_args,
|
llvm::StringRef device_type, bool use_tuple_args,
|
||||||
const FunctionLibraryDefinition& flib_def, const GraphDebugInfo& debug_info,
|
const FunctionLibraryDefinition& flib_def, const GraphDebugInfo& debug_info,
|
||||||
const XlaCompiler::ShapeRepresentationFn shape_representation_fn,
|
const XlaCompiler::ShapeRepresentationFn shape_representation_fn,
|
||||||
XlaCompiler::CompilationResult* compilation_result,
|
XlaCompiler::CompilationResult* compilation_result,
|
||||||
std::vector<std::unique_ptr<mlir::Pass>> custom_legalization_passes) {
|
std::vector<std::unique_ptr<mlir::Pass>> custom_legalization_passes) {
|
||||||
RegisterDialects();
|
RegisterDialects();
|
||||||
|
|
||||||
mlir::MLIRContext context;
|
mlir::MLIRContext context;
|
||||||
GraphImportConfig config;
|
GraphImportConfig config;
|
||||||
config.graph_as_function = true;
|
config.graph_as_function = true;
|
||||||
|
@ -408,10 +445,19 @@ Status CompileGraphToXlaHlo(
|
||||||
ConvertGraphToMlir(graph, debug_info, flib_def, config, &context);
|
ConvertGraphToMlir(graph, debug_info, flib_def, config, &context);
|
||||||
if (!module_or.ok()) return module_or.status();
|
if (!module_or.ok()) return module_or.status();
|
||||||
|
|
||||||
return CompileMlirToXlaHlo(module_or.ValueOrDie().get(), arg_shapes,
|
mlir::ModuleOp module = module_or.ValueOrDie().get();
|
||||||
device_type, use_tuple_args,
|
TF_ASSIGN_OR_RETURN(std::vector<int> remaining_params,
|
||||||
shape_representation_fn, compilation_result,
|
RewriteWithArgs(module, {args.data(), args.size()}));
|
||||||
std::move(custom_legalization_passes));
|
llvm::SmallVector<TensorShape, 4> arg_shapes;
|
||||||
|
arg_shapes.reserve(args.size());
|
||||||
|
for (unsigned idx : remaining_params)
|
||||||
|
arg_shapes.push_back(absl::get<TensorShape>(args[idx].shape));
|
||||||
|
|
||||||
|
auto status = CompileMlirToXlaHlo(
|
||||||
|
module, arg_shapes, device_type, use_tuple_args, shape_representation_fn,
|
||||||
|
compilation_result, std::move(custom_legalization_passes));
|
||||||
|
compilation_result->input_mapping = remaining_params;
|
||||||
|
return status;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
|
@ -71,7 +71,7 @@ Status CompileSerializedMlirToXlaHlo(
|
||||||
|
|
||||||
// Same as the above but takes input as TensorFlow Graph.
|
// Same as the above but takes input as TensorFlow Graph.
|
||||||
Status CompileGraphToXlaHlo(
|
Status CompileGraphToXlaHlo(
|
||||||
const Graph& graph, llvm::ArrayRef<TensorShape> arg_shapes,
|
const Graph& graph, llvm::ArrayRef<const XlaCompiler::Argument> args,
|
||||||
llvm::StringRef device_type, bool use_tuple_args,
|
llvm::StringRef device_type, bool use_tuple_args,
|
||||||
const FunctionLibraryDefinition& flib_def, const GraphDebugInfo& debug_info,
|
const FunctionLibraryDefinition& flib_def, const GraphDebugInfo& debug_info,
|
||||||
const XlaCompiler::ShapeRepresentationFn shape_representation_fn,
|
const XlaCompiler::ShapeRepresentationFn shape_representation_fn,
|
||||||
|
|
|
@ -455,8 +455,12 @@ TEST(CompileGraphToXlaHlo, Basic) {
|
||||||
test::graph::Retval(&graph, 0, arg);
|
test::graph::Retval(&graph, 0, arg);
|
||||||
|
|
||||||
XlaCompiler::CompilationResult result;
|
XlaCompiler::CompilationResult result;
|
||||||
|
XlaCompiler::Argument compiler_arg;
|
||||||
|
compiler_arg.kind = XlaCompiler::Argument::kParameter;
|
||||||
|
compiler_arg.shape = TensorShape();
|
||||||
|
|
||||||
TF_ASSERT_OK(
|
TF_ASSERT_OK(
|
||||||
CompileGraphToXlaHlo(graph, /*arg_shapes=*/{TensorShape()}, "XLA_CPU_JIT",
|
CompileGraphToXlaHlo(graph, /*args=*/{compiler_arg}, "XLA_CPU_JIT",
|
||||||
/*use_tuple_args=*/false, flib_def, GraphDebugInfo(),
|
/*use_tuple_args=*/false, flib_def, GraphDebugInfo(),
|
||||||
/*shape_representation_fn=*/nullptr, &result));
|
/*shape_representation_fn=*/nullptr, &result));
|
||||||
|
|
||||||
|
|
|
@ -1102,6 +1102,8 @@ class BinaryOpsTest(xla_test.XLATestCase):
|
||||||
x,
|
x,
|
||||||
expected=np.matmul(x, x.transpose([0, 1, 3, 2])))
|
expected=np.matmul(x, x.transpose([0, 1, 3, 2])))
|
||||||
|
|
||||||
|
@test_util.disable_mlir_bridge(
|
||||||
|
"TODO(b/155097273): Handle complex dtype constants")
|
||||||
def testExpandDims(self):
|
def testExpandDims(self):
|
||||||
for dtype in self.numeric_types:
|
for dtype in self.numeric_types:
|
||||||
self._testBinary(
|
self._testBinary(
|
||||||
|
@ -1199,6 +1201,8 @@ class BinaryOpsTest(xla_test.XLATestCase):
|
||||||
np.full([1, 1, 3, 5], 3., dtype=np.float32),
|
np.full([1, 1, 3, 5], 3., dtype=np.float32),
|
||||||
expected=np.full([4, 5, 1, 2, 5], 18., dtype=np.float32))
|
expected=np.full([4, 5, 1, 2, 5], 18., dtype=np.float32))
|
||||||
|
|
||||||
|
@test_util.disable_mlir_bridge(
|
||||||
|
"TODO(b/155097273): Handle complex dtype constants")
|
||||||
def testPad(self):
|
def testPad(self):
|
||||||
for dtype, pad_type in itertools.product(
|
for dtype, pad_type in itertools.product(
|
||||||
self.numeric_types, [np.int32, np.int64]):
|
self.numeric_types, [np.int32, np.int64]):
|
||||||
|
@ -1230,6 +1234,8 @@ class BinaryOpsTest(xla_test.XLATestCase):
|
||||||
[7, 7, 7, 7, 7, 7]],
|
[7, 7, 7, 7, 7, 7]],
|
||||||
dtype=dtype))
|
dtype=dtype))
|
||||||
|
|
||||||
|
@test_util.disable_mlir_bridge(
|
||||||
|
"Requires concatenate op support in MlirHloBuilder")
|
||||||
def testSymmetricMirrorPad(self):
|
def testSymmetricMirrorPad(self):
|
||||||
mirror_pad = lambda t, paddings: array_ops.pad(t, paddings, "SYMMETRIC")
|
mirror_pad = lambda t, paddings: array_ops.pad(t, paddings, "SYMMETRIC")
|
||||||
for dtype in self.numeric_types:
|
for dtype in self.numeric_types:
|
||||||
|
@ -1261,6 +1267,8 @@ class BinaryOpsTest(xla_test.XLATestCase):
|
||||||
np.array([[0, 0], [0, 0]], dtype=np.int32),
|
np.array([[0, 0], [0, 0]], dtype=np.int32),
|
||||||
expected=np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype))
|
expected=np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype))
|
||||||
|
|
||||||
|
@test_util.disable_mlir_bridge(
|
||||||
|
"Requires concatenate op support in MlirHloBuilder")
|
||||||
def testReflectMirrorPad(self):
|
def testReflectMirrorPad(self):
|
||||||
mirror_pad = lambda t, paddings: array_ops.pad(t, paddings, "REFLECT")
|
mirror_pad = lambda t, paddings: array_ops.pad(t, paddings, "REFLECT")
|
||||||
for dtype in self.numeric_types:
|
for dtype in self.numeric_types:
|
||||||
|
@ -1335,6 +1343,8 @@ class BinaryOpsTest(xla_test.XLATestCase):
|
||||||
],
|
],
|
||||||
dtype=dtype))
|
dtype=dtype))
|
||||||
|
|
||||||
|
@test_util.disable_mlir_bridge(
|
||||||
|
"TODO(b/155097273): Handle complex dtype constants")
|
||||||
def testReshape(self):
|
def testReshape(self):
|
||||||
for dtype in self.numeric_types:
|
for dtype in self.numeric_types:
|
||||||
self._testBinary(
|
self._testBinary(
|
||||||
|
@ -1414,6 +1424,7 @@ class BinaryOpsTest(xla_test.XLATestCase):
|
||||||
],
|
],
|
||||||
equality_test=self.ListsAreClose)
|
equality_test=self.ListsAreClose)
|
||||||
|
|
||||||
|
@test_util.disable_mlir_bridge("TODO(b/155097657): Debug incorrect answer")
|
||||||
def testTile(self):
|
def testTile(self):
|
||||||
for dtype in self.numeric_types:
|
for dtype in self.numeric_types:
|
||||||
self._testBinary(
|
self._testBinary(
|
||||||
|
@ -1466,6 +1477,8 @@ class BinaryOpsTest(xla_test.XLATestCase):
|
||||||
[1, 2]],
|
[1, 2]],
|
||||||
dtype=dtype))
|
dtype=dtype))
|
||||||
|
|
||||||
|
@test_util.disable_mlir_bridge(
|
||||||
|
"TODO(b/155097273): Handle complex dtype constants")
|
||||||
def testTranspose(self):
|
def testTranspose(self):
|
||||||
for dtype in self.numeric_types:
|
for dtype in self.numeric_types:
|
||||||
self._testBinary(
|
self._testBinary(
|
||||||
|
@ -1484,6 +1497,8 @@ class BinaryOpsTest(xla_test.XLATestCase):
|
||||||
np.array([1, 0], dtype=np.int32),
|
np.array([1, 0], dtype=np.int32),
|
||||||
expected=np.array([[1, 3], [2, 4]], dtype=dtype))
|
expected=np.array([[1, 3], [2, 4]], dtype=dtype))
|
||||||
|
|
||||||
|
@test_util.disable_mlir_bridge(
|
||||||
|
"TODO(b/155097273): Handle complex dtype constants")
|
||||||
def testConjugateTranspose(self):
|
def testConjugateTranspose(self):
|
||||||
for dtype in self.complex_types:
|
for dtype in self.complex_types:
|
||||||
self._testBinary(
|
self._testBinary(
|
||||||
|
@ -1521,6 +1536,8 @@ class BinaryOpsTest(xla_test.XLATestCase):
|
||||||
np.array([[4, 5, 6], [40, 50, 60]], dtype=dtype),
|
np.array([[4, 5, 6], [40, 50, 60]], dtype=dtype),
|
||||||
expected=np.array([[-3, 6, -3], [60, -120, 60]], dtype=dtype))
|
expected=np.array([[-3, 6, -3], [60, -120, 60]], dtype=dtype))
|
||||||
|
|
||||||
|
@test_util.disable_mlir_bridge(
|
||||||
|
"Define BroadcastArgs op in TF and const fold it")
|
||||||
def testBroadcastArgs(self):
|
def testBroadcastArgs(self):
|
||||||
self._testBinary(array_ops.broadcast_dynamic_shape,
|
self._testBinary(array_ops.broadcast_dynamic_shape,
|
||||||
np.array([2, 3, 5], dtype=np.int32),
|
np.array([2, 3, 5], dtype=np.int32),
|
||||||
|
@ -1572,6 +1589,8 @@ class BinaryOpsTest(xla_test.XLATestCase):
|
||||||
np.array([2, 1, 5], dtype=np.int32),
|
np.array([2, 1, 5], dtype=np.int32),
|
||||||
expected=np.array([2, 3, 5], dtype=np.int32))
|
expected=np.array([2, 3, 5], dtype=np.int32))
|
||||||
|
|
||||||
|
@test_util.disable_mlir_bridge("Error handling")
|
||||||
|
def testBroadcastArgsError(self):
|
||||||
with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
|
with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
|
||||||
"Incompatible shapes"):
|
"Incompatible shapes"):
|
||||||
self._testBinary(array_ops.broadcast_dynamic_shape,
|
self._testBinary(array_ops.broadcast_dynamic_shape,
|
||||||
|
@ -1579,6 +1598,8 @@ class BinaryOpsTest(xla_test.XLATestCase):
|
||||||
np.array([4, 5, 6], dtype=np.int32),
|
np.array([4, 5, 6], dtype=np.int32),
|
||||||
expected=None)
|
expected=None)
|
||||||
|
|
||||||
|
@test_util.disable_mlir_bridge(
|
||||||
|
"Requires BroadcastInDim method in MlirHloBuilder")
|
||||||
def testBroadcastTo(self):
|
def testBroadcastTo(self):
|
||||||
for dtype in self.all_types:
|
for dtype in self.all_types:
|
||||||
x = np.random.randint(0, high=100, size=[2, 3])
|
x = np.random.randint(0, high=100, size=[2, 3])
|
||||||
|
|
|
@ -47,6 +47,7 @@ class GatherNdTest(xla_test.XLATestCase):
|
||||||
np.array([8, 1, 2, 3, 7, 5], dtype=dtype),
|
np.array([8, 1, 2, 3, 7, 5], dtype=dtype),
|
||||||
np.array([[4], [4], [0]], np.int32)))
|
np.array([[4], [4], [0]], np.int32)))
|
||||||
|
|
||||||
|
@test_util.disable_mlir_bridge("Error handling")
|
||||||
def testEmptyIndicesAndParamsOKButJustEmptyParamsFails(self):
|
def testEmptyIndicesAndParamsOKButJustEmptyParamsFails(self):
|
||||||
with self.session():
|
with self.session():
|
||||||
params = np.ones((3, 3), dtype=np.float32)
|
params = np.ones((3, 3), dtype=np.float32)
|
||||||
|
|
|
@ -195,7 +195,8 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||||
args=(np.array([1, 2, 3], dtype=dtype),),
|
args=(np.array([1, 2, 3], dtype=dtype),),
|
||||||
expected=np.array([-1, -2, -3], dtype=dtype))
|
expected=np.array([-1, -2, -3], dtype=dtype))
|
||||||
|
|
||||||
@test_util.disable_mlir_bridge('Not supported yet')
|
@test_util.disable_mlir_bridge(
|
||||||
|
'Requires XlaPad op shape inference to have static result types')
|
||||||
def testPad(self):
|
def testPad(self):
|
||||||
for dtype in self.numeric_types:
|
for dtype in self.numeric_types:
|
||||||
|
|
||||||
|
@ -308,6 +309,7 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||||
self._assertOpOutputMatchesExpected(
|
self._assertOpOutputMatchesExpected(
|
||||||
lambda x: xla.transpose(x, [1, 0]), args=(v,), expected=v.T)
|
lambda x: xla.transpose(x, [1, 0]), args=(v,), expected=v.T)
|
||||||
|
|
||||||
|
@test_util.disable_mlir_bridge('Not supported yet')
|
||||||
def testDynamicSlice(self):
|
def testDynamicSlice(self):
|
||||||
for dtype in self.numeric_types:
|
for dtype in self.numeric_types:
|
||||||
self._assertOpOutputMatchesExpected(
|
self._assertOpOutputMatchesExpected(
|
||||||
|
@ -320,6 +322,7 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||||
[[673, 674], [683, 684], [693, 694]]]),
|
[[673, 674], [683, 684], [693, 694]]]),
|
||||||
dtype=dtype))
|
dtype=dtype))
|
||||||
|
|
||||||
|
@test_util.disable_mlir_bridge('Not supported yet')
|
||||||
def testDynamicSliceWithIncorrectStartIndicesShape(self):
|
def testDynamicSliceWithIncorrectStartIndicesShape(self):
|
||||||
with self.session() as session:
|
with self.session() as session:
|
||||||
with self.test_scope():
|
with self.test_scope():
|
||||||
|
@ -333,6 +336,7 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||||
(r'start_indices must be a vector with length equal to input rank, '
|
(r'start_indices must be a vector with length equal to input rank, '
|
||||||
r'but input rank is 3 and start_indices has shape \[2\].*'))
|
r'but input rank is 3 and start_indices has shape \[2\].*'))
|
||||||
|
|
||||||
|
@test_util.disable_mlir_bridge('Not supported yet')
|
||||||
def testDynamicSliceWithIncorrectSizeIndicesShape(self):
|
def testDynamicSliceWithIncorrectSizeIndicesShape(self):
|
||||||
with self.session() as session:
|
with self.session() as session:
|
||||||
with self.test_scope():
|
with self.test_scope():
|
||||||
|
|
Loading…
Reference in New Issue