Skip to content

Commit c34efab

Browse files
Merge branch 'main' into keshavvinayak01/torch-aten-flex_attention
2 parents b0e8585 + 327b6b7 commit c34efab

File tree

16 files changed

+323
-66
lines changed

16 files changed

+323
-66
lines changed

build_tools/ci/build_posix.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ cmake -S "$repo_root/externals/llvm-project/llvm" -B "$build_dir" \
4141
-GNinja \
4242
-DCMAKE_BUILD_TYPE=Release \
4343
-DPython3_EXECUTABLE="$(which python3)" \
44+
-DPython_EXECUTABLE="$(which python3)" \
4445
-DLLVM_ENABLE_ASSERTIONS=ON \
4546
-DTORCH_MLIR_ENABLE_WERROR_FLAG=ON \
4647
-DCMAKE_INSTALL_PREFIX="$install_dir" \

build_tools/python_deploy/build_linux_packages.sh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,7 @@ function build_in_tree() {
244244
-DTORCH_MLIR_SRC_PYTORCH_BRANCH=${TORCH_MLIR_SRC_PYTORCH_BRANCH} \
245245
-DTM_PYTORCH_INSTALL_WITHOUT_REBUILD=${TM_PYTORCH_INSTALL_WITHOUT_REBUILD} \
246246
-DPython3_EXECUTABLE="$(which python3)" \
247+
-DPython_EXECUTABLE="$(which python3)" \
247248
/main_checkout/torch-mlir/externals/llvm-project/llvm
248249
cmake --build /main_checkout/torch-mlir/build --target tools/torch-mlir/all
249250
ccache -s
@@ -387,6 +388,7 @@ function build_out_of_tree() {
387388
-DLLVM_TARGETS_TO_BUILD=host \
388389
-DMLIR_ENABLE_BINDINGS_PYTHON=ON \
389390
-DPython3_EXECUTABLE="$(which python3)" \
391+
-DPython_EXECUTABLE="$(which python3)" \
390392
/main_checkout/torch-mlir/externals/llvm-project/llvm
391393
cmake --build /main_checkout/torch-mlir/llvm-build
392394
fi
@@ -409,6 +411,7 @@ function build_out_of_tree() {
409411
-DTORCH_MLIR_SRC_PYTORCH_BRANCH=${TORCH_MLIR_SRC_PYTORCH_BRANCH} \
410412
-DTM_PYTORCH_INSTALL_WITHOUT_REBUILD=${TM_PYTORCH_INSTALL_WITHOUT_REBUILD} \
411413
-DPython3_EXECUTABLE="$(which python3)" \
414+
-DPython_EXECUTABLE="$(which python3)" \
412415
/main_checkout/torch-mlir
413416
cmake --build /main_checkout/torch-mlir/build_oot
414417
ccache -s

build_tools/write_env_file.sh

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,21 @@ portable_realpath() {
1313

1414
td="$(portable_realpath "$(dirname "$0")"/..)"
1515
build_dir="$(portable_realpath "${TORCH_MLIR_BUILD_DIR:-$td/build}")"
16-
python_packages_dir="$build_dir/python_packages"
16+
17+
in_tree_pkg_dir="${build_dir}/tools/torch-mlir/python_packages"
18+
out_of_tree_pkg_dir="${build_dir}/python_packages"
19+
20+
if [[ ! -d "${in_tree_pkg_dir}" && ! -d "${out_of_tree_pkg_dir}" ]]; then
21+
echo "Couldn't find in-tree or out-of-tree build, exiting."
22+
exit 1
23+
fi
24+
25+
# The `-nt` check works even if one of the two directories is missing.
26+
if [[ "${in_tree_pkg_dir}" -nt "${out_of_tree_pkg_dir}" ]]; then
27+
python_packages_dir="${in_tree_pkg_dir}"
28+
else
29+
python_packages_dir="${out_of_tree_pkg_dir}"
30+
fi
1731

1832
write_env_file() {
1933
echo "Updating $build_dir/.env file"

docs/development.md

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -194,30 +194,48 @@ TIP: add multiple target options to stack build phases
194194

195195
### Setup Python Environment to export the built Python packages
196196

197+
When CMake is configured with `-DMLIR_ENABLE_BINDINGS_PYTHON=ON`, the python packages will typically be located in either:
198+
199+
1. `./build/tools/torch-mlir/python_packages/` if doing an in-tree build.
200+
2. `./build/python_packages/` if doing an out-of-tree build.
201+
202+
For the following sections, let `python_pkg_dir` represent whichever of the above is relevant for your build setup. On Linux and macOS, you can run `./build_tools/write_env_file.sh` to generate a file `./.env` in your root source directory with the correct `PYTHONPATH`.
203+
197204
#### Linux and macOS
198205

206+
To get the base `PYTHONPATH`, run:
207+
199208
```shell
200-
export PYTHONPATH=`pwd`/build/python_packages/torch_mlir:`pwd`/test/python/fx_importer
209+
./build_tools/write_env_file.sh
210+
source ./.env && export PYTHONPATH
211+
```
212+
213+
To run fx_importer tests, you can append the following:
214+
215+
```
216+
export PYTHONPATH="${PYTHONPATH}":/test/python/fx_importer"
201217
```
202218
203219
#### Windows PowerShell
204220
221+
To get the base `PYTHONPATH`, identify your `python_pkg_dir` from above and set this variable in your environment:
222+
223+
```shell
224+
$env:PYTHONPATH = "<python_pkg_dir>/torch-mlir"
225+
```
226+
227+
To run fx_importer tests, you can append the following:
228+
205229
```shell
206-
$env:PYTHONPATH = "$PWD/build/tools/torch-mlir/python_packages/torch_mlir;$PWD/test/python/fx_importer"
230+
$env:PYTHONPATH += ";$PWD/test/python/fx_importer"
207231
```
208232
209233
### Testing MLIR output in various dialects
210234
211235
To test the MLIR output to torch dialect, you can use `test/python/fx_importer/basic_test.py`.
212236
213237
Make sure you have activated the virtualenv and set the `PYTHONPATH` above
214-
(if running on Windows, modify the environment variable as shown above):
215-
216-
```shell
217-
source mlir_venv/bin/activate
218-
export PYTHONPATH=`pwd`/build/tools/torch-mlir/python_packages/torch_mlir:`pwd`/test/python/fx_importer
219-
python test/python/fx_importer/basic_test.py
220-
```
238+
(if running on Windows, modify the environment variable as shown above).
221239
222240
This will display the basic example in TORCH dialect.
223241
@@ -226,10 +244,10 @@ using torchscript with the example `projects/pt1/examples/torchscript_resnet18_a
226244
This path doesn't give access to the current generation work that is being driven via the fx_importer
227245
and may lead to errors.
228246
229-
Same as above, but with different python path and example:
247+
The base `PYTHONPATH` should be set as above, then the example can be run with the following command (similar on Windows):
230248
231249
```shell
232-
export PYTHONPATH=`pwd`/build/tools/torch-mlir/python_packages/torch_mlir:`pwd`/projects/pt1/examples
250+
export PYTHONPATH="${PYTHONPATH}:$PWD/projects/pt1/examples"
233251
python projects/pt1/examples/torchscript_resnet18_all_output_types.py
234252
```
235253

lib/Conversion/TorchToLinalg/Linear.cpp

Lines changed: 101 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1391,9 +1391,13 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
13911391
return success();
13921392
}
13931393

1394-
if (numSpatialDims != 2)
1394+
if (numSpatialDims != 2 && numSpatialDims != 3)
13951395
return rewriter.notifyMatchFailure(
1396-
op, "unimplemented: only 1D and 2D grouped convolution supported");
1396+
op, "unimplemented: only 2D and 3D grouped convolution supported");
1397+
if (numSpatialDims == 3 && inputZp) {
1398+
return rewriter.notifyMatchFailure(
1399+
op, "unimplemented: quantized 3D grouped convolution not supported");
1400+
}
13971401

13981402
// Grouped case, use the grouped conv linalg op
13991403
auto expandGroups = [&](Value tensor, size_t dim) {
@@ -1435,21 +1439,101 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
14351439
weight = transposed ? weight : expandWeight(weight);
14361440
auto expandOutputTensor = expandGroups(outputTensor, 1);
14371441

1438-
// TODO: add 1D and 3D case
1439-
if (!inputZp) {
1440-
conv = rewriter
1441-
.create<linalg::Conv2DNgchwGfchwOp>(
1442-
loc, expandOutputTensor.getResultType(),
1443-
ValueRange{paddedInputExpanded, weight},
1444-
expandOutputTensor.getResult(), stridesAttr, dilationAttr)
1445-
.getResult(0);
1446-
} else {
1447-
conv = rewriter
1448-
.create<linalg::Conv2DNgchwGfchwQOp>(
1449-
loc, expandOutputTensor.getResultType(),
1450-
ValueRange{paddedInputExpanded, weight, inputZp, weightZp},
1451-
expandOutputTensor.getResult(), stridesAttr, dilationAttr)
1452-
.getResult(0);
1442+
if (numSpatialDims == 2) {
1443+
// 2D grouped convolution
1444+
if (!inputZp) {
1445+
conv =
1446+
rewriter
1447+
.create<linalg::Conv2DNgchwGfchwOp>(
1448+
loc, expandOutputTensor.getResultType(),
1449+
ValueRange{paddedInputExpanded, weight},
1450+
expandOutputTensor.getResult(), stridesAttr, dilationAttr)
1451+
.getResult(0);
1452+
} else {
1453+
conv =
1454+
rewriter
1455+
.create<linalg::Conv2DNgchwGfchwQOp>(
1456+
loc, expandOutputTensor.getResultType(),
1457+
ValueRange{paddedInputExpanded, weight, inputZp, weightZp},
1458+
expandOutputTensor.getResult(), stridesAttr, dilationAttr)
1459+
.getResult(0);
1460+
}
1461+
} else if (numSpatialDims == 3) {
1462+
// MLIR does not have a named 3D grouped convolution op, so we use
1463+
// linalg.generic instead.
1464+
AffineExpr d0, d1, d2, d3, d4, d5, d6, d7, d8, d9;
1465+
bindDims(context, d0, d1, d2, d3, d4, d5, d6, d7, d8, d9);
1466+
1467+
SmallVector<AffineExpr> inputExprs = {
1468+
d0, // N
1469+
d1, // G
1470+
d6, // C/G
1471+
d3 * strideInts[0] + d7 * dilationInts[0], // D
1472+
d4 * strideInts[1] + d8 * dilationInts[1], // H
1473+
d5 * strideInts[2] + d9 * dilationInts[2] // W
1474+
};
1475+
1476+
SmallVector<AffineExpr> weightExprs = {
1477+
d1, // G
1478+
d2, // F/G
1479+
d6, // C/G
1480+
d7, // KD
1481+
d8, // KH
1482+
d9 // KW
1483+
};
1484+
1485+
SmallVector<AffineExpr> outputExprs = {
1486+
d0, // N
1487+
d1, // G
1488+
d2, // F/G
1489+
d3, // OD
1490+
d4, // OH
1491+
d5, // OW
1492+
};
1493+
1494+
SmallVector<AffineMap> indexingMaps = {
1495+
AffineMap::get(10, 0, inputExprs, rewriter.getContext()),
1496+
AffineMap::get(10, 0, weightExprs, rewriter.getContext()),
1497+
AffineMap::get(10, 0, outputExprs, rewriter.getContext())};
1498+
1499+
SmallVector<utils::IteratorType> iteratorTypes = {
1500+
utils::IteratorType::parallel, // N
1501+
utils::IteratorType::parallel, // G
1502+
utils::IteratorType::parallel, // F/G
1503+
utils::IteratorType::parallel, // OD
1504+
utils::IteratorType::parallel, // OH
1505+
utils::IteratorType::parallel, // OW
1506+
utils::IteratorType::reduction, // C/G
1507+
utils::IteratorType::reduction, // KD
1508+
utils::IteratorType::reduction, // KH
1509+
utils::IteratorType::reduction // KW
1510+
};
1511+
1512+
conv =
1513+
rewriter
1514+
.create<linalg::GenericOp>(
1515+
loc, expandOutputTensor.getResultType(),
1516+
ValueRange{paddedInputExpanded, weight},
1517+
expandOutputTensor.getResult(), indexingMaps, iteratorTypes,
1518+
[&](OpBuilder &b, Location loc, ValueRange args) {
1519+
Value input = args[0];
1520+
Value weight = args[1];
1521+
Value output = args[2];
1522+
1523+
// Convert input and weight to accumulator type if needed
1524+
Type accType = output.getType();
1525+
if (input.getType() != accType) {
1526+
input = b.create<arith::ExtFOp>(loc, accType, input);
1527+
}
1528+
if (weight.getType() != accType) {
1529+
weight = b.create<arith::ExtFOp>(loc, accType, weight);
1530+
}
1531+
1532+
Value mul = b.create<arith::MulFOp>(loc, input, weight);
1533+
Value add = b.create<arith::AddFOp>(loc, mul, output);
1534+
b.create<linalg::YieldOp>(loc, add);
1535+
})
1536+
.getResult(0);
14531537
}
14541538
conv = rewriter.create<tensor::CollapseShapeOp>(
14551539
loc, outputTensor.getType(), conv,

lib/Conversion/TorchToTosa/TorchToTosa.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6435,6 +6435,11 @@ class ConvertAtenConstPatternOp : public OpConversionPattern<AtenOpT> {
64356435
for (auto s : shape)
64366436
size *= s;
64376437

6438+
if (size == 0) {
6439+
return rewriter.notifyMatchFailure(
6440+
op, "Shape must not have a dimension of size zero");
6441+
}
6442+
64386443
SmallVector<int32_t> values(size, fillVal);
64396444
auto constOp =
64406445
tosa::getConstTensor<int32_t>(rewriter, op, values, shape).value();

lib/Dialect/Torch/Transforms/Passes.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,12 +70,16 @@ void mlir::torch::Torch::createTorchScriptModuleToTorchBackendPipeline(
7070

7171
void mlir::torch::Torch::createTorchDynamoExportToTorchBackendPipeline(
7272
OpPassManager &pm, const TorchLoweringPipelineOptions &options) {
73+
// Inline func.call operations created by higher-order ops like while_loop
74+
// to conform to the linalg-on-tensors backend contract.
75+
pm.addPass(createInlinerPass());
7376
pm.addNestedPass<func::FuncOp>(
7477
createReduceOpVariantsPass(options.extraLibrary));
7578
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
7679
if (options.decompose) {
7780
pm.addNestedPass<func::FuncOp>(
7881
Torch::createDecomposeComplexOpsPass(options.backendLegalOps));
82+
pm.addNestedPass<func::FuncOp>(Torch::createRecomposeComplexOpsPass());
7983
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
8084
}
8185
}

lib/RefBackend/RefBackend.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ static std::string getConsumeReturnFunctionNameForReturnTypes(TypeRange types) {
109109
tokens.push_back(getTypeToken(type));
110110

111111
return std::accumulate(tokens.begin(), tokens.end(), std::string(),
112-
[](std::string &a, std::string &b) {
112+
[](std::string a, std::string b) {
113113
return a.empty() ? b : (a + "_" + b);
114114
});
115115
}

0 commit comments

Comments
 (0)