|
19 | 19 | #include <imex/Dialect/Dist/IR/DistOps.h> |
20 | 20 | #include <imex/Dialect/PTensor/IR/PTensorOps.h> |
21 | 21 | #include <imex/Utils/PassUtils.h> |
| 22 | +#include <mlir/Dialect/Linalg/Utils/Utils.h> |
22 | 23 | #include <mlir/IR/Builders.h> |
23 | 24 |
|
24 | 25 | #include <pybind11/numpy.h> |
@@ -168,7 +169,12 @@ struct DeferredSetItem : public Deferred { |
168 | 169 | std::vector<::mlir::Value> stridesV(nd); |
169 | 170 | for (auto i = 0; i < nd; ++i) { |
170 | 171 | offsV[i] = ::imex::createIndex(loc, builder, offs[i]); |
171 | | - sizesV[i] = ::imex::createIndex(loc, builder, sizes[i]); |
| 172 | + if (sizes[i] == ALL_SIZE) { |
| 173 | + sizesV[i] = |
| 174 | + builder.create<::imex::ptensor::DimOp>(loc, av, i).getResult(); |
| 175 | + } else { |
| 176 | + sizesV[i] = ::imex::createIndex(loc, builder, sizes[i]); |
| 177 | + } |
172 | 178 | stridesV[i] = ::imex::createIndex(loc, builder, strides[i]); |
173 | 179 | } |
174 | 180 | // insertsliceop has no return value, so we just create the op... |
@@ -280,7 +286,12 @@ struct DeferredGetItem : public Deferred { |
280 | 286 | sizesV[i] = builder.getIndexAttr(sizes[i]); |
281 | 287 | shape[i] = sizes[i]; |
282 | 288 | } else { |
283 | | - sizesV[i] = ::imex::createIndex(loc, builder, sizes[i]); |
| 289 | + if (sizes[i] == ALL_SIZE) { |
| 290 | + sizesV[i] = |
| 291 | + builder.create<::imex::ptensor::DimOp>(loc, av, i).getResult(); |
| 292 | + } else { |
| 293 | + sizesV[i] = ::imex::createIndex(loc, builder, sizes[i]); |
| 294 | + } |
284 | 295 | } |
285 | 296 | } |
286 | 297 |
|
|
0 commit comments