|
101 | 101 | #include "llvm/Support/TargetSelect.h" |
102 | 102 | // #include "llvm/Support/raw_ostream.h" |
103 | 103 |
|
| 104 | +#include "ddptensor/itac.hpp" |
| 105 | + |
104 | 106 | namespace jit { |
105 | 107 |
|
106 | 108 | static ::mlir::Type makeSignlessType(::mlir::Type type) { |
@@ -261,6 +263,22 @@ uint64_t DepManager::handleResult(::mlir::OpBuilder &builder) { |
261 | 263 | ++idx; |
262 | 264 | } |
263 | 265 |
|
| 266 | + if (HAS_ITAC()) { |
| 267 | + int vtExeSym, vtDDPTClass; |
| 268 | + VT(VT_classdef, "ddpt", &vtDDPTClass); |
| 269 | + VT(VT_funcdef, "execute", vtDDPTClass, &vtExeSym); |
| 270 | + auto s = builder.create<::mlir::arith::ConstantOp>( |
| 271 | + loc, builder.getI32IntegerAttr(vtExeSym)); |
| 272 | + auto end = builder.create<::mlir::func::CallOp>( |
| 273 | + builder.getUnknownLoc(), "VT_end", |
| 274 | + ::mlir::TypeRange(builder.getIntegerType(32)), ::mlir::ValueRange(s)); |
| 275 | + mlir::OpBuilder::InsertionGuard guard(builder); |
| 276 | + builder.setInsertionPointToStart(end->getBlock()); |
| 277 | + (void)builder.create<::mlir::func::CallOp>( |
| 278 | + builder.getUnknownLoc(), "VT_begin", |
| 279 | + ::mlir::TypeRange(builder.getIntegerType(32)), ::mlir::ValueRange(s)); |
| 280 | + } |
| 281 | + |
264 | 282 | // add return statement |
265 | 283 | auto ret_value = builder.create<::mlir::func::ReturnOp>( |
266 | 284 | builder.getUnknownLoc(), ret_values); |
@@ -348,28 +366,55 @@ std::vector<intptr_t> JIT::run(::mlir::ModuleOp &module, |
348 | 366 | const std::string &fname, |
349 | 367 | std::vector<void *> &inp, size_t osz) { |
350 | 368 |
|
| 369 | + int vtDDPTClass, vtHashSym, vtEEngineSym, vtRunSym, vtHashGenSym; |
| 370 | + if (HAS_ITAC()) { |
| 371 | + VT(VT_classdef, "ddpt", &vtDDPTClass); |
| 372 | + VT(VT_funcdef, "lookup_cache", vtDDPTClass, &vtHashSym); |
| 373 | + VT(VT_funcdef, "gen_sha", vtDDPTClass, &vtHashGenSym); |
| 374 | + VT(VT_funcdef, "eengine", vtDDPTClass, &vtEEngineSym); |
| 375 | + VT(VT_funcdef, "run", vtDDPTClass, &vtRunSym); |
| 376 | + VT(VT_begin, vtEEngineSym); |
| 377 | + |
| 378 | + ::mlir::OpBuilder builder(module->getContext()); |
| 379 | + ::mlir::OpBuilder::InsertionGuard guard(builder); |
| 380 | + builder.setInsertionPoint(module.getBody(), |
| 381 | + std::prev(module.getBody()->end())); |
| 382 | + auto intTyp = builder.getIntegerType(32); |
| 383 | + auto funcType = builder.getFunctionType({intTyp}, {intTyp}); |
| 384 | + builder.create<::mlir::func::FuncOp>(module.getLoc(), "VT_begin", funcType) |
| 385 | + .setPrivate(); |
| 386 | + builder.create<::mlir::func::FuncOp>(module.getLoc(), "VT_end", funcType) |
| 387 | + .setPrivate(); |
| 388 | + } |
| 389 | + |
351 | 390 | ::mlir::ExecutionEngine *enginePtr; |
352 | 391 | std::unique_ptr<::mlir::ExecutionEngine> tmpEngine; |
353 | 392 |
|
354 | 393 | if (_useCache) { |
| 394 | + VT(VT_begin, vtHashGenSym); |
355 | 395 | static std::map<std::array<unsigned char, 20>, |
356 | 396 | std::unique_ptr<::mlir::ExecutionEngine>> |
357 | 397 | engineCache; |
358 | 398 |
|
359 | 399 | llvm::raw_sha1_ostream shaOS; |
360 | 400 | module->print(shaOS); |
361 | 401 | auto cksm = shaOS.sha1(); |
| 402 | + VT(VT_end, vtHashGenSym); |
362 | 403 |
|
| 404 | + VT(VT_begin, vtHashSym); |
363 | 405 | if (auto search = engineCache.find(cksm); search == engineCache.end()) { |
364 | 406 | engineCache[cksm] = createExecutionEngine(module); |
365 | 407 | } else { |
366 | 408 | if (_verbose) |
367 | 409 | std::cerr << "cached..." << std::endl; |
368 | 410 | } |
369 | 411 | enginePtr = engineCache[cksm].get(); |
| 412 | + VT(VT_end, vtHashSym); |
370 | 413 | } else { |
| 414 | + VT(VT_begin, vtHashSym); |
371 | 415 | tmpEngine = createExecutionEngine(module); |
372 | 416 | enginePtr = tmpEngine.get(); |
| 417 | + VT(VT_end, vtHashSym); |
373 | 418 | } |
374 | 419 |
|
375 | 420 | auto expectedFPtr = |
@@ -398,6 +443,7 @@ std::vector<intptr_t> JIT::run(::mlir::ModuleOp &module, |
398 | 443 | // call function |
399 | 444 | (*jittedFuncPtr)(args.data()); |
400 | 445 |
|
| 446 | + VT(VT_end, vtEEngineSym); |
401 | 447 | return out; |
402 | 448 | } |
403 | 449 |
|
|
0 commit comments