@@ -117,9 +117,10 @@ static ::mlir::Type makeSignlessType(::mlir::Type type) {
117117
118118// convert ddpt's DTYpeId into MLIR type
119119static ::mlir::Type getTType (::mlir::OpBuilder &builder, DTypeId dtype,
120- ::mlir::SmallVector<int64_t > &lhShape,
121- ::mlir::SmallVector<int64_t > &ownShape,
122- ::mlir::SmallVector<int64_t > &rhShape,
120+ const ::mlir::SmallVector<int64_t > &gShape ,
121+ const ::mlir::SmallVector<int64_t > &lhShape,
122+ const ::mlir::SmallVector<int64_t > &ownShape,
123+ const ::mlir::SmallVector<int64_t > &rhShape,
123124 uint64_t team, bool balanced) {
124125 ::mlir::Type etyp;
125126
@@ -154,9 +155,7 @@ static ::mlir::Type getTType(::mlir::OpBuilder &builder, DTypeId dtype,
154155 };
155156
156157 if (team) {
157- if (ownShape.size ()) {
158- auto gShape = ownShape;
159- gShape [0 ] += lhShape[0 ] + rhShape[0 ];
158+ if (gShape .size ()) {
160159 return ::imex::dist::DistTensorType::get (gShape , etyp,
161160 {lhShape, ownShape, rhShape});
162161 } else {
@@ -183,8 +182,10 @@ ::mlir::Value DepManager::getDependent(::mlir::OpBuilder &builder,
183182 ownShape[i] = impl->local_shape ()[i];
184183 rhShape[i] = impl->rh_shape ()[i];
185184 }
186- auto typ = getTType (builder, fut.dtype (), lhShape, ownShape, rhShape,
187- fut.team (), fut.balanced ());
185+ auto typ = getTType (
186+ builder, fut.dtype (),
187+ ::mlir::SmallVector<int64_t >(impl->shape (), impl->shape () + rank),
188+ lhShape, ownShape, rhShape, fut.team (), fut.balanced ());
188189 _func.insertArgument (idx, typ, {}, loc);
189190 auto val = _func.getArgument (idx);
190191 _args.push_back ({guid, std::move (fut)});
@@ -516,7 +517,8 @@ JIT::JIT()
516517 crunner = crunner ? crunner : " libmlir_c_runner_utils.so" ;
517518 const char *idtr = getenv (" DDPT_IDTR_SO" );
518519 idtr = idtr ? idtr : " libidtr.so" ;
519- _sharedLibPaths = {idtr, crunner};
520+ _sharedLibPaths = {idtr, crunner,
521+ " /home/fschlimb/llvm/lib/libmlir_runner_utils.so" };
520522
521523 // detect target architecture
522524 auto tmBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost ();
@@ -562,14 +564,12 @@ void init() {
562564 ::mlir::registerConvertFuncToLLVMPass ();
563565 ::mlir::bufferization::registerBufferizationPasses ();
564566 ::mlir::arith::registerArithPasses ();
565- ::mlir::registerAffinePasses ();
566567 ::mlir::registerCanonicalizerPass ();
567568 ::mlir::registerConvertAffineToStandardPass ();
568569 ::mlir::registerFinalizeMemRefToLLVMConversionPass ();
569570 ::mlir::registerArithToLLVMConversionPass ();
570571 ::mlir::registerConvertMathToLLVMPass ();
571572 ::mlir::registerConvertControlFlowToLLVMPass ();
572- ::mlir::registerConvertLinalgToLLVMPass ();
573573 ::mlir::registerConvertOpenMPToLLVMPass ();
574574 ::mlir::memref::registerMemRefPasses ();
575575 ::mlir::registerReconcileUnrealizedCastsPass ();
0 commit comments