Skip to content

Commit 260cd0d

Browse files
committed
Add ACTS tracker
1 parent f65cbf7 commit 260cd0d

File tree

3 files changed

+531
-1
lines changed

3 files changed

+531
-1
lines changed

Detectors/Upgrades/ALICE3/TRK/reconstruction/CMakeLists.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ o2_add_library(TRKReconstruction
1818
SOURCES src/TimeFrame.cxx
1919
src/Clusterer.cxx
2020
$<$<BOOL:${Acts_FOUND}>:src/ClustererACTS.cxx>
21+
$<$<BOOL:${Acts_FOUND}>:src/TrackerACTS.cxx>
2122
PUBLIC_LINK_LIBRARIES
2223
O2::ITStracking
2324
O2::GPUCommon
@@ -45,7 +46,8 @@ set(dictHeaders include/TRKReconstruction/TimeFrame.h
4546
include/TRKReconstruction/Clusterer.h)
4647

4748
if(Acts_FOUND)
48-
list(APPEND dictHeaders include/TRKReconstruction/ClustererACTS.h)
49+
list(APPEND dictHeaders include/TRKReconstruction/ClustererACTS.h
50+
include/TRKReconstruction/TrackerACTS.h)
4951
endif()
5052

5153
o2_target_root_dictionary(TRKReconstruction
Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
// Copyright 2019-2020 CERN and copyright holders of ALICE O2.
2+
// See https://alice-o2.web.cern.ch/copyright for details of the copyright holders.
3+
// All rights not expressly granted are reserved.
4+
//
5+
// This software is distributed under the terms of the GNU General Public
6+
// License v3 (GPL Version 3), copied verbatim in the file "COPYING".
7+
//
8+
// In applying this license CERN does not waive the privileges and immunities
9+
// granted to it by virtue of its status as an Intergovernmental Organization
10+
// or submit itself to any jurisdiction.
11+
///
12+
/// \file TrackerACTS.h
13+
/// \brief
14+
///
15+
16+
#ifndef ALICE3_INCLUDE_TRACKERACTS_H_
17+
#define ALICE3_INCLUDE_TRACKERACTS_H_
18+
19+
#include <array>
20+
#include <chrono>
21+
#include <cmath>
22+
#include <fstream>
23+
#include <iomanip>
24+
#include <iosfwd>
25+
#include <memory>
26+
#include <string_view>
27+
#include <utility>
28+
#include <sstream>
29+
30+
#include <oneapi/tbb/task_arena.h>
31+
32+
#include "ITStracking/Configuration.h"
33+
#include "CommonConstants/MathConstants.h"
34+
#include "ITStracking/Definitions.h"
35+
#include "ITStracking/MathUtils.h"
36+
#include "ITStracking/TimeFrame.h"
37+
#include "ITStracking/TrackerTraits.h"
38+
#include "ITStracking/Road.h"
39+
#include "ITStracking/BoundedAllocator.h"
40+
41+
#include "DataFormatsITS/TrackITS.h"
42+
#include "SimulationDataFormat/MCCompLabel.h"
43+
44+
namespace o2
45+
{
46+
47+
namespace gpu
48+
{
49+
class GPUChainITS;
50+
}
51+
namespace its
52+
{
53+
54+
template <int nLayers>
55+
class Tracker
56+
{
57+
using LogFunc = std::function<void(const std::string& s)>;
58+
59+
public:
60+
Tracker(TrackerTraits<nLayers>* traits);
61+
62+
void adoptTimeFrame(TimeFrame<nLayers>& tf);
63+
64+
void clustersToTracks(
65+
const LogFunc& = [](const std::string& s) { std::cout << s << '\n'; },
66+
const LogFunc& = [](const std::string& s) { std::cerr << s << '\n'; });
67+
68+
void setParameters(const std::vector<TrackingParameters>& p) { mTrkParams = p; }
69+
void setMemoryPool(std::shared_ptr<BoundedMemoryResource> pool) { mMemoryPool = pool; }
70+
std::vector<TrackingParameters>& getParameters() { return mTrkParams; }
71+
void setBz(float bz) { mTraits->setBz(bz); }
72+
bool isMatLUT() const { return mTraits->isMatLUT(); }
73+
void setNThreads(int n, std::shared_ptr<tbb::task_arena>& arena) { mTraits->setNThreads(n, arena); }
74+
void printSummary() const;
75+
void computeTracksMClabels();
76+
77+
private:
78+
void initialiseTimeFrame(int iteration) { mTraits->initialiseTimeFrame(iteration); }
79+
void computeTracklets(int iteration, int iROFslice, int iVertex) { mTraits->computeLayerTracklets(iteration, iROFslice, iVertex); }
80+
void computeCells(int iteration) { mTraits->computeLayerCells(iteration); }
81+
void findCellsNeighbours(int iteration) { mTraits->findCellsNeighbours(iteration); }
82+
void findRoads(int iteration) { mTraits->findRoads(iteration); }
83+
void findShortPrimaries() { mTraits->findShortPrimaries(); }
84+
void extendTracks(int iteration) { mTraits->extendTracks(iteration); }
85+
86+
// MC interaction
87+
void computeRoadsMClabels();
88+
void rectifyClusterIndices();
89+
90+
template <typename... T, typename... F>
91+
float evaluateTask(void (Tracker::*task)(T...), std::string_view taskName, int iteration, LogFunc logger, F&&... args);
92+
93+
TrackerTraits<nLayers>* mTraits = nullptr; /// Observer pointer, not owned by this class
94+
TimeFrame<nLayers>* mTimeFrame = nullptr; /// Observer pointer, not owned by this class
95+
96+
std::vector<TrackingParameters> mTrkParams;
97+
o2::gpu::GPUChainITS* mRecoChain = nullptr;
98+
99+
unsigned int mNumberOfDroppedTFs{0};
100+
unsigned int mTimeFrameCounter{0};
101+
double mTotalTime{0};
102+
std::shared_ptr<BoundedMemoryResource> mMemoryPool;
103+
104+
enum State {
105+
TFInit = 0,
106+
Trackleting,
107+
Celling,
108+
Neighbouring,
109+
Roading,
110+
NStates,
111+
};
112+
State mCurState{TFInit};
113+
static constexpr std::array<const char*, NStates> StateNames{"TimeFrame initialisation", "Tracklet finding", "Cell finding", "Neighbour finding", "Road finding"};
114+
};
115+
116+
template <int nLayers>
117+
template <typename... T, typename... F>
118+
float Tracker<nLayers>::evaluateTask(void (Tracker<nLayers>::*task)(T...), std::string_view taskName, int iteration, LogFunc logger, F&&... args)
119+
{
120+
float diff{0.f};
121+
122+
if constexpr (constants::DoTimeBenchmarks) {
123+
auto start = std::chrono::high_resolution_clock::now();
124+
(this->*task)(std::forward<F>(args)...);
125+
auto end = std::chrono::high_resolution_clock::now();
126+
127+
std::chrono::duration<double, std::milli> diff_t{end - start};
128+
diff = diff_t.count();
129+
130+
std::stringstream sstream;
131+
if (taskName.empty()) {
132+
sstream << diff << "\t";
133+
} else {
134+
sstream << std::setw(2) << " - " << taskName << " completed in: " << diff << " ms";
135+
}
136+
logger(sstream.str());
137+
138+
if (mTrkParams[0].SaveTimeBenchmarks) {
139+
std::string taskNameStr(taskName);
140+
std::transform(taskNameStr.begin(), taskNameStr.end(), taskNameStr.begin(),
141+
[](unsigned char c) { return std::tolower(c); });
142+
std::replace(taskNameStr.begin(), taskNameStr.end(), ' ', '_');
143+
if (std::ofstream file{"its_time_benchmarks.txt", std::ios::app}) {
144+
file << "trk:" << iteration << '\t' << taskNameStr << '\t' << diff << '\n';
145+
}
146+
}
147+
148+
} else {
149+
(this->*task)(std::forward<F>(args)...);
150+
}
151+
152+
return diff;
153+
}
154+
155+
} // namespace its
156+
} // namespace o2
157+
158+
#endif /* ALICE3_INCLUDE_TRACKER_H_ */

0 commit comments

Comments
 (0)