Skip to content

Commit 5c1f315

Browse files
committed
[PWGEM/Dilepton] update matchingMFT.cxx to support ML
1 parent c6c9dcd commit 5c1f315

4 files changed

Lines changed: 268 additions & 30 deletions

File tree

PWGEM/Dilepton/DataModel/lmeeMLTables.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,11 +160,11 @@ DECLARE_SOA_COLUMN(NClustersMFT, nClustersMFT, uint8_t); //!
160160
DECLARE_SOA_COLUMN(IsPrimary, isPrimary, bool); //!
161161
DECLARE_SOA_COLUMN(IsCorrectMatch, isCorrectMatch, bool); //!
162162

163-
DECLARE_SOA_COLUMN(NMFTs, nMFTs, uint16_t); //! number of MFTsa tracks per collision
163+
DECLARE_SOA_COLUMN(MultMFT, multMFT, uint16_t); //! number of MFTsa tracks per collision
164164
} // namespace emmlfwdtrack
165165

166166
DECLARE_SOA_TABLE(EMFwdTracksForML, "AOD", "EMFWDTRKML", //!
167-
o2::soa::Index<>, collision::PosZ, /*collision::NumContrib,*/ mult::MultFT0C, /*evsel::NumTracksInTimeRange,*/ evsel::SumAmpFT0CInTimeRange, emmltrack::HadronicRate, emmlfwdtrack::NMFTs,
167+
o2::soa::Index<>, collision::PosZ, /*collision::NumContrib,*/ mult::MultFT0C, /*evsel::NumTracksInTimeRange,*/ evsel::SumAmpFT0CInTimeRange, emmltrack::HadronicRate, emmlfwdtrack::MultMFT,
168168
// fwdtrack::TrackType,
169169

170170
emmlfwdtrack::Signed1PtMFTatMP, emmlfwdtrack::TglMFTatMP, emmlfwdtrack::PhiMFTatMP,

PWGEM/Dilepton/Tasks/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ o2physics_add_dpl_workflow(study-mc-truth
117117

118118
o2physics_add_dpl_workflow(matching-mft
119119
SOURCES matchingMFT.cxx
120-
PUBLIC_LINK_LIBRARIES O2::Framework O2Physics::AnalysisCore O2::GlobalTracking
120+
PUBLIC_LINK_LIBRARIES O2::Framework O2Physics::AnalysisCore O2::GlobalTracking O2Physics::MLCore
121121
COMPONENT_NAME Analysis)
122122

123123
o2physics_add_dpl_workflow(tagging-hfe

PWGEM/Dilepton/Tasks/matchingMFT.cxx

Lines changed: 107 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,17 @@
1313
/// \brief a task to study matching MFT-[MCH-MID] in MC
1414
/// \author daiki.sekihata@cern.ch
1515

16+
#include "PWGEM/Dilepton/Utils/MlResponseFwdTrack.h"
17+
1618
#include "Common/CCDB/EventSelectionParams.h"
1719
#include "Common/CCDB/RCTSelectionFlags.h"
20+
#include "Common/Core/RecoDecay.h"
1821
#include "Common/Core/fwdtrackUtilities.h"
1922
#include "Common/DataModel/Centrality.h"
2023
#include "Common/DataModel/CollisionAssociationTables.h"
2124
#include "Common/DataModel/EventSelection.h"
2225
#include "Common/DataModel/Multiplicity.h"
26+
#include "Tools/ML/MlResponse.h"
2327

2428
#include <CCDB/BasicCCDBManager.h>
2529
#include <DataFormatsParameters/GRPMagField.h>
@@ -111,6 +115,18 @@ struct matchingMFT {
111115
Configurable<float> matchingZ{"matchingZ", -77.5, "z position where matching is performed"};
112116
Configurable<bool> cfgApplyPreselectionInBestMatch{"cfgApplyPreselectionInBestMatch", false, "flag to apply preselection in find best match function"};
113117

118+
// configuration for matching with ML
119+
Configurable<bool> useMLmatching{"useMLmatching", false, "Flag to use ML for matching between MFT and MCH-MID"};
120+
Configurable<std::vector<std::string>> onnxFileNames{"onnxFileNames", std::vector<std::string>{"filename"}, "ONNX file names for each bin (if not from CCDB full path)"};
121+
Configurable<std::vector<std::string>> onnxPathsCCDB{"onnxPathsCCDB", std::vector<std::string>{"path"}, "Paths of models on CCDB"};
122+
Configurable<std::vector<double>> binsMl{"binsMl", std::vector<double>{0.1, 0.15, 0.2, 0.25, 0.4, 0.8, 1.6, 2.0, 20}, "Bin limits for ML application"};
123+
Configurable<std::vector<double>> cutsMl{"cutsMl", std::vector<double>{0.95, 0.95, 0.7, 0.7, 0.8, 0.8, 0.7, 0.7}, "ML cuts per bin"};
124+
Configurable<std::vector<std::string>> namesInputFeatures{"namesInputFeatures", std::vector<std::string>{"multFT0C", "ptMCHMID", "rSigned1Pt", "dEta", "dPhi", "dX", "dY", "chi2MatchMCHMFT"}, "Names of ML model input features"};
125+
Configurable<std::string> nameBinningFeature{"nameBinningFeature", "multFT0C", "Names of ML model binning feature"};
126+
Configurable<int64_t> timestampCCDB{"timestampCCDB", -1, "timestamp of the ONNX file for ML model used to query in CCDB. Exceptions: > 0 for the specific timestamp, 0 gets the run dependent timestamp"};
127+
Configurable<bool> loadModelsFromCCDB{"loadModelsFromCCDB", false, "Flag to enable or disable the loading of models from CCDB"};
128+
Configurable<bool> enableOptimizations{"enableOptimizations", false, "Enables the ONNX extended model-optimization: sessionOptions.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED)"};
129+
114130
struct : ConfigurableGroup {
115131
std::string prefix = "eventcut_group";
116132
Configurable<float> cfgZvtxMin{"cfgZvtxMin", -10.f, "min. Zvtx"};
@@ -151,10 +167,28 @@ struct matchingMFT {
151167
} eventcuts;
152168

153169
o2::aod::rctsel::RCTFlagsChecker rctChecker;
170+
o2::analysis::MlResponseFwdTrack<float> mlResponseFwdTrack;
154171

155172
HistogramRegistry fRegistry{"fRegistry"};
156173
static constexpr std::string_view muon_types[5] = {"MFTMCHMID/", "MFTMCHMIDOtherMatch/", "MFTMCH/", "MCHMID/", "MCH/"};
157174

175+
struct matchedCandidate {
176+
float multFT0C{0};
177+
float multMFT{0};
178+
float ptMCHMID{0};
179+
float rSigned1Pt{1e+10};
180+
float dEta{1e+10};
181+
float dPhi{1e+10};
182+
float dX{1e+10};
183+
float dY{1e+10};
184+
float chi2MatchMCHMFT{1e+10};
185+
186+
float sigmaPhiMFT{1e+10};
187+
float sigmaTglMFT{1e+10};
188+
float sigmaPhiMCHMID{1e+10};
189+
float sigmaTglMCHMID{1e+10};
190+
};
191+
158192
void init(o2::framework::InitContext&)
159193
{
160194
if (doprocessWithoutFTTCA && doprocessWithFTTCA) {
@@ -169,6 +203,31 @@ struct matchingMFT {
169203
rctChecker.init(eventcuts.cfgRCTLabel.value, eventcuts.cfgCheckZDC.value, eventcuts.cfgTreatLimitedAcceptanceAsBad.value);
170204

171205
addHistograms();
206+
207+
if (useMLmatching) {
208+
static constexpr int nClassesMl = 2;
209+
const std::vector<int> cutDirMl = {o2::cuts_ml::CutNot, o2::cuts_ml::CutSmaller};
210+
const std::vector<std::string> labelsClasses = {"Background", "Signal"};
211+
const uint32_t nBinsMl = binsMl.value.size() - 1;
212+
const std::vector<std::string> labelsBins(nBinsMl, "bin");
213+
double cutsMlArr[nBinsMl][nClassesMl];
214+
for (uint32_t i = 0; i < nBinsMl; i++) {
215+
cutsMlArr[i][0] = 0.0;
216+
cutsMlArr[i][1] = cutsMl.value[i];
217+
}
218+
o2::framework::LabeledArray<double> cutsMl = {cutsMlArr[0], nBinsMl, nClassesMl, labelsBins, labelsClasses};
219+
220+
mlResponseFwdTrack.configure(binsMl.value, cutsMl, cutDirMl, nClassesMl);
221+
if (loadModelsFromCCDB) {
222+
ccdbApi.init(ccdburl);
223+
mlResponseFwdTrack.setModelPathsCCDB(onnxFileNames.value, ccdbApi, onnxPathsCCDB.value, timestampCCDB.value);
224+
} else {
225+
mlResponseFwdTrack.setModelPathsLocal(onnxFileNames.value);
226+
}
227+
mlResponseFwdTrack.cacheInputFeaturesIndices(namesInputFeatures);
228+
mlResponseFwdTrack.cacheBinningIndex(nameBinningFeature);
229+
mlResponseFwdTrack.init(enableOptimizations.value);
230+
} // end of ML configuration
172231
}
173232

174233
o2::ccdb::CcdbApi ccdbApi;
@@ -415,28 +474,6 @@ struct matchingMFT {
415474
return (clmap > 0);
416475
}
417476

418-
// template <typename T>
419-
// float meanClusterSizeMFT(T const& track)
420-
// {
421-
// uint64_t mftClusterSizesAndTrackFlags = track.mftClusterSizesAndTrackFlags();
422-
// uint16_t clsSize = 0;
423-
// uint16_t n = 0;
424-
// for (unsigned int layer = 0; layer < 10; layer++) {
425-
// uint16_t size_per_layer = (mftClusterSizesAndTrackFlags >> (layer * 6)) & 0x3f;
426-
// clsSize += size_per_layer;
427-
// if (size_per_layer > 0) {
428-
// n++;
429-
// }
430-
// // LOGF(info, "track.globalIndex() = %d, layer = %d, size_per_layer = %d", track.globalIndex(), layer, size_per_layer);
431-
// }
432-
433-
// if (n > 0) {
434-
// return static_cast<float>(clsSize) / static_cast<float>(n) * std::fabs(std::sin(std::atan(track.tgl())));
435-
// } else {
436-
// return 0.f;
437-
// }
438-
// }
439-
440477
template <typename TFwdTracks, typename TMFTTracks, typename TCollision, typename TFwdTrack, typename TMFTrackCov>
441478
void getDxDyAtMatchingPlane(TCollision const& collision, TFwdTrack const& fwdtrack, TMFTrackCov const& mftCovs, float& dx, float& dy)
442479
{
@@ -817,8 +854,8 @@ struct matchingMFT {
817854
std::vector<std::tuple<int, int, int>> vec_min_chi2MatchMCHMFT; // std::pair<globalIndex of global muon, globalIndex of matched MCH-MID, globalIndex of MFT> -> chi2MatchMCHMFT;
818855
// std::map<std::tuple<int, int, int>, bool> mapCorrectMatch;
819856

820-
template <typename TCollision, typename TFwdTrack, typename TFwdTracks, typename TMFTTracks>
821-
void findBestMatchPerMCHMID(TCollision const& collision, TFwdTrack const& fwdtrack, TFwdTracks const& fwdtracks, TMFTTracks const&)
857+
template <bool withMFTCov = false, typename TCollision, typename TFwdTrack, typename TFwdTracks, typename TMFTTracks, typename TMFTTracksCov>
858+
void findBestMatchPerMCHMID(TCollision const& collision, TFwdTrack const& fwdtrack, TFwdTracks const& fwdtracks, TMFTTracks const& mfttracks, TMFTTracksCov const& mftCovs)
822859
{
823860
if (fwdtrack.trackType() != o2::aod::fwdtrack::ForwardTrackTypeEnum::MuonStandaloneTrack) {
824861
return;
@@ -827,6 +864,8 @@ struct matchingMFT {
827864
return;
828865
}
829866

867+
auto mfttracks_per_collision = mfttracks.sliceBy(perCollision_MFT, collision.globalIndex());
868+
830869
std::tuple<int, int, int> tupleIds_at_min_chi2mftmch;
831870
float min_chi2MatchMCHMFT = 1e+10;
832871
auto muons_per_MCHMID = fwdtracks.sliceBy(fwdtracksPerMCHTrack, fwdtrack.globalIndex());
@@ -843,6 +882,9 @@ struct matchingMFT {
843882
float dcaXY_Matched = std::sqrt(dcaX_Matched * dcaX_Matched + dcaY_Matched * dcaY_Matched);
844883
float pDCA = fwdtrack.p() * dcaXY_Matched;
845884

885+
o2::dataformats::GlobalFwdTrack muonAtMP = propagateMuon(fwdtrack, fwdtrack, collision, propagationPoint::kToMatchingPlane, matchingZ, mBz, mZShift); // propagated to matching plane
886+
float phiMCHMIDatMP = RecoDecay::constrainAngle(muonAtMP.getPhi(), 0, 1U);
887+
846888
for (const auto& muon_tmp : muons_per_MCHMID) {
847889
if (muon_tmp.trackType() == o2::aod::fwdtrack::ForwardTrackTypeEnum::GlobalMuonTrack) {
848890
auto tupleId = std::make_tuple(muon_tmp.globalIndex(), muon_tmp.matchMCHTrackId(), muon_tmp.matchMFTTrackId());
@@ -885,6 +927,44 @@ struct matchingMFT {
885927
float dcaY = propmuonAtPV.getY() - collision.posY();
886928
float dcaXY = std::sqrt(dcaX * dcaX + dcaY * dcaY);
887929

930+
if constexpr (withMFTCov) {
931+
if (useMLmatching) {
932+
matchedCandidate candidate;
933+
candidate.multFT0C = collision.multFT0C();
934+
candidate.multMFT = static_cast<float>(mfttracks_per_collision.size());
935+
candidate.chi2MatchMCHMFT = muon_tmp.chi2MatchMCHMFT();
936+
937+
auto mfttrackcov = mftCovs.rawIteratorAt(map_mfttrackcovs[mfttrack.globalIndex()]);
938+
o2::track::TrackParCovFwd mftsaAtMP = getTrackParCovFwdShift(mfttrack, mZShift, mfttrackcov); // values at innermost update
939+
mftsaAtMP.propagateToZhelix(matchingZ, mBz); // propagated to matching plane
940+
float phiMFTatMP = RecoDecay::constrainAngle(mftsaAtMP.getPhi(), 0, 1U);
941+
942+
candidate.rSigned1Pt = mftsaAtMP.getInvQPt() / muonAtMP.getInvQPt();
943+
candidate.dEta = mftsaAtMP.getEta() - muonAtMP.getEta();
944+
candidate.dPhi = RecoDecay::constrainAngle(phiMFTatMP - phiMCHMIDatMP, -o2::constants::math::PIHalf, 1U);
945+
candidate.dX = mftsaAtMP.getX() - muonAtMP.getX();
946+
candidate.dY = mftsaAtMP.getY() - muonAtMP.getY();
947+
948+
candidate.sigmaTglMCHMID = std::sqrt(muonAtMP.getSigma2Tanl());
949+
candidate.sigmaPhiMCHMID = std::sqrt(muonAtMP.getSigma2Phi());
950+
candidate.sigmaTglMFT = std::sqrt(mftsaAtMP.getSigma2Tanl());
951+
candidate.sigmaPhiMFT = std::sqrt(mftsaAtMP.getSigma2Phi());
952+
953+
std::vector<float> inputFeatures = mlResponseFwdTrack.getInputFeatures(candidate);
954+
float binningFeature = mlResponseFwdTrack.getBinningFeature(candidate);
955+
int pbin = lower_bound(binsMl.value.begin(), binsMl.value.end(), binningFeature) - binsMl.value.begin() - 1;
956+
if (pbin < 0) {
957+
pbin = 0;
958+
} else if (static_cast<int>(binsMl.value.size()) - 2 < pbin) {
959+
pbin = static_cast<int>(binsMl.value.size()) - 2;
960+
}
961+
float probaEl = mlResponseFwdTrack.getModelOutput(inputFeatures, pbin)[1]; // 0: wrong, 1:correct
962+
if (probaEl < cutsMl.value[pbin]) {
963+
continue;
964+
}
965+
}
966+
}
967+
888968
if (isPrimary) {
889969
if (isMatched) {
890970
fRegistry.fill(HIST("MFTMCHMID/primary/correct/hdR_Chi2MatchMCHMFT"), muon_tmp.chi2MatchMCHMFT(), dr);
@@ -1092,7 +1172,7 @@ struct matchingMFT {
10921172
initCCDB(bc);
10931173
auto fwdtracks_per_coll = fwdtracks.sliceBy(perCollision, collision.globalIndex());
10941174
for (const auto& fwdtrack : fwdtracks_per_coll) {
1095-
findBestMatchPerMCHMID(collision, fwdtrack, fwdtracks, mfttracks);
1175+
findBestMatchPerMCHMID<false>(collision, fwdtrack, fwdtracks, mfttracks, nullptr);
10961176
} // end of fwdtrack loop
10971177
} // end of collision loop
10981178

@@ -1150,7 +1230,7 @@ struct matchingMFT {
11501230
auto fwdtrackIdsThisCollision = fwdtrackIndices.sliceBy(fwdtrackIndicesPerCollision, collision.globalIndex());
11511231
for (const auto& fwdtrackId : fwdtrackIdsThisCollision) {
11521232
auto fwdtrack = fwdtrackId.template fwdtrack_as<MyFwdTracks>();
1153-
findBestMatchPerMCHMID(collision, fwdtrack, fwdtracks, mfttracks);
1233+
findBestMatchPerMCHMID<false>(collision, fwdtrack, fwdtracks, mfttracks, nullptr);
11541234
} // end of fwdtrack loop
11551235
} // end of collision loop
11561236

@@ -1213,7 +1293,7 @@ struct matchingMFT {
12131293
auto fwdtrackIdsThisCollision = fwdtrackIndices.sliceBy(fwdtrackIndicesPerCollision, collision.globalIndex());
12141294
for (const auto& fwdtrackId : fwdtrackIdsThisCollision) {
12151295
auto fwdtrack = fwdtrackId.template fwdtrack_as<MyFwdTracks>();
1216-
findBestMatchPerMCHMID(collision, fwdtrack, fwdtracks, mfttracks);
1296+
findBestMatchPerMCHMID<true>(collision, fwdtrack, fwdtracks, mfttracks, mftCovs);
12171297
} // end of fwdtrack loop
12181298
} // end of collision loop
12191299

0 commit comments

Comments
 (0)