This contains my bachelors thesis and associated tex files, code snippets and maybe more. Topic: Data Movement in Heterogeneous Memories with Intel Data Streaming Accelerator
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

750 lines
27 KiB

  1. #pragma once
  2. #include <iostream>
  3. #include <unordered_map>
  4. #include <shared_mutex>
  5. #include <mutex>
  6. #include <memory>
  7. #include <sched.h>
  8. #include <numa.h>
  9. #include <numaif.h>
  10. #include <dml/dml.hpp>
  11. namespace dml {
  12. inline const std::string StatusCodeToString(const dml::status_code code) {
  13. switch (code) {
  14. case dml::status_code::ok: return "ok";
  15. case dml::status_code::false_predicate: return "false predicate";
  16. case dml::status_code::partial_completion: return "partial completion";
  17. case dml::status_code::nullptr_error: return "nullptr error";
  18. case dml::status_code::bad_size: return "bad size";
  19. case dml::status_code::bad_length: return "bad length";
  20. case dml::status_code::inconsistent_size: return "inconsistent size";
  21. case dml::status_code::dualcast_bad_padding: return "dualcast bad padding";
  22. case dml::status_code::bad_alignment: return "bad alignment";
  23. case dml::status_code::buffers_overlapping: return "buffers overlapping";
  24. case dml::status_code::delta_delta_empty: return "delta delta empty";
  25. case dml::status_code::batch_overflow: return "batch overflow";
  26. case dml::status_code::execution_failed: return "execution failed";
  27. case dml::status_code::unsupported_operation: return "unsupported operation";
  28. case dml::status_code::queue_busy: return "queue busy";
  29. case dml::status_code::error: return "unknown error";
  30. case dml::status_code::config_error: return "config error";
  31. default: return "unhandled error";
  32. }
  33. }
  34. }
  35. namespace dsacache {
  36. inline bool CheckFlag(const uint64_t value, const uint64_t flag) {
  37. return (value & flag) != 0;
  38. }
  39. inline uint64_t UnsetFlag(const uint64_t value, const uint64_t flag) {
  40. return value & (~flag);
  41. }
  42. inline uint64_t SetFlag(const uint64_t value, const uint64_t flag) {
  43. return value | flag;
  44. }
  45. constexpr uint64_t FLAG_WAIT_WEAK = 0b1ULL << 63;
  46. constexpr uint64_t FLAG_HANDLE_PF = 0b1ULL << 62;
  47. class Cache;
  48. /*
  49. * Class Description:
  50. * Holds all required information on one cache entry and is used
  51. * both internally by the Cache and externally by the user.
  52. *
  53. * Important Usage Notes:
  54. * The pointer is only updated in WaitOnCompletion() which
  55. * therefore must be called by the user at some point in order
  56. * to use the cached data. Using this class as T for
  57. * std::shared_ptr<T> is not recommended as references are
  58. * already counted internally.
  59. *
  60. * Cache Lifetime:
  61. * As long as the instance is referenced, the pointer it stores
  62. * is guaranteed to be either nullptr or pointing to a valid copy.
  63. *
  64. * Implementation Detail:
  65. * Performs self-reference counting with a shared atomic integer.
  66. * Therefore on creating a copy the reference count is increased
  67. * and with the destructor it is deacresed. If the last copy is
  68. * destroyed the actual underlying data is freed and all shared
  69. * variables deleted.
  70. *
  71. * Notes on Thread Safety:
  72. * Class is thread safe in any possible state and performs
  73. * reference counting and deallocation itself entirely atomically.
  74. */
  75. class CacheData {
  76. public:
  77. using dml_handler = dml::handler<dml::mem_copy_operation, std::allocator<uint8_t>>;
  78. private:
  79. static constexpr uint64_t maxptr = 0xffff'ffff'ffff'ffff;
  80. // set to false if we do not own the cache pointer
  81. bool delete_ = false;
  82. // data source and size of the block
  83. uint8_t* src_;
  84. size_t size_;
  85. // global reference counting object
  86. std::atomic<int32_t>* active_;
  87. // global cache-location pointer
  88. std::atomic<uint8_t*>* cache_;
  89. // object-local incomplete cache location pointer
  90. // contract: only access when being in sole posession of handlers
  91. uint8_t** incomplete_cache_;
  92. // flags inherited from parent cache
  93. uint64_t flags_ = 0;
  94. // dml handler vector pointer which is used
  95. // to wait on caching task completion
  96. std::atomic<dml_handler*>* handler_;
  97. // deallocates the global cache-location
  98. // and invalidates it
  99. void Deallocate();
  100. size_t GetSize() const { return size_; }
  101. uint8_t* GetSource() const { return src_; }
  102. int32_t GetRefCount() const { return active_->load(); }
  103. void SetTaskHandlerAndCache(uint8_t* cache, dml_handler* handler);
  104. // initializes the class after which it is thread safe
  105. // but may only be destroyed safely after setting handlers
  106. void Init();
  107. friend Cache;
  108. public:
  109. CacheData(uint8_t* data, const size_t size);
  110. CacheData(const CacheData& other);
  111. ~CacheData();
  112. // waits on completion of caching operations
  113. // for this task and is safe to be called in
  114. // any state of the object
  115. void WaitOnCompletion();
  116. // returns the cache data location for this
  117. // instance which is valid as long as the
  118. // instance is alive - !!! this may also
  119. // yield a nullptr !!!
  120. uint8_t* GetDataLocation() const { return cache_->load(); }
  121. void SetFlags(const uint64_t flags) { flags_ = flags; }
  122. uint64_t GetFlags() const { return flags_; }
  123. };
  124. /*
  125. * Class Description:
  126. * Class will handle access to data through internal copies.
  127. * These are obtained via work submission to the Intel DSA which takes
  128. * care of asynchronously duplicating the data. The user will define
  129. * where these copies lie and which system nodes will perform the copy.
  130. * This is done through policy functions set during initialization.
  131. *
  132. * Placement Policy:
  133. * The Placement Policy Function decides on which node a particular
  134. * entry is to be placed, given the current executing node and the
  135. * data source node and data size. This in turn means that for one
  136. * datum, multiple cached copies may exist at one time.
  137. *
  138. * Cache Lifetime:
  139. * When accessing the cache, a CacheData-object will be returned.
  140. * As long as this object lives, the pointer which it holds is
  141. * guaranteed to be either nullptr or a valid copy. When destroyed
  142. * the entry is marked for deletion which is only carried out
  143. * when system memory pressure drives an automated cache flush.
  144. *
  145. * Restrictions:
  146. * - Overlapping Pointers may lead to undefined behaviour during
  147. * manual cache invalidation which should not be used if you
  148. * intend to have these types of pointers
  149. * - Cache Invalidation may only be performed manually and gives
  150. * no ordering guarantees. Therefore, it is the users responsibility
  151. * to ensure that results after invalidation have been generated
  152. * using the latest state of data. The cache is best suited
  153. * to static data.
  154. *
  155. * Notes on Thread Safety:
  156. * - Cache is completely thread-safe after initialization
  157. * - CacheData-class will handle deallocation of data itself by
  158. * performing self-reference-counting atomically and only
  159. * deallocating if the last reference is destroyed
  160. * - The internal cache state has one lock which is either
  161. * acquired shared for reading the state (upon accessing an already
  162. * cached element) or unique (accessing a new element, flushing, invalidating)
  163. * - Waiting on copy completion is done over an atomic-wait in copies
  164. * of the original CacheData-instance
  165. * - Overall this class may experience performance issues due to the use
  166. * of locking (in any configuration), lock contention (worsens with higher
  167. * core count, node count and utilization) and atomics (worse in the same
  168. * situations as lock contention)
  169. *
  170. * Improving Performance:
  171. * When data is never shared between threads or memory size for the cache is
  172. * not an issue you may consider having one Cache-instance per thread and removing
  173. * the lock in Cache and modifying the reference counting and waiting mechanisms
  174. * of CacheData accordingly (although this is high effort and will yield little due
  175. * to the atomics not being shared among cores/nodes).
  176. * Otherwise, one Cache-instance per node could also be considered. This will allow
  177. * the placement policy function to be barebones and reduces the lock contention and
  178. * synchronization impact of the atomic variables.
  179. */
  180. class Cache {
  181. public:
  182. // cache policy is defined as a type here to allow flexible usage of the cacher
  183. // given a numa destination node (where the data will be needed), the numa source
  184. // node (current location of the data) and the data size, this function should
  185. // return optimal cache placement
  186. // dst node and returned value can differ if the system, for example, has HBM
  187. // attached accessible directly to node n under a different node id m
  188. typedef int (CachePolicy)(const int numa_dst_node, const int numa_src_node, const size_t data_size);
  189. // copy policy specifies the copy-executing nodes for a given task
  190. // which allows flexibility in assignment for optimizing raw throughput
  191. // or choosing a conservative usage policy
  192. typedef std::vector<int> (CopyPolicy)(const int numa_dst_node, const int numa_src_node, const size_t data_size);
  193. private:
  194. // flags to store options duh
  195. uint64_t flags_ = 0;
  196. // map from [dst-numa-node,map2]
  197. // map2 from [data-ptr,cache-structure]
  198. struct LockedNodeCacheState {
  199. std::shared_mutex cache_mutex_;
  200. std::unordered_map<uint8_t*, CacheData> node_cache_state_;
  201. };
  202. std::unordered_map<uint8_t, LockedNodeCacheState*> cache_state_;
  203. CachePolicy* cache_policy_function_ = nullptr;
  204. CopyPolicy* copy_policy_function_ = nullptr;
  205. // function used to submit a copy task on a specific node to the dml
  206. // engine on that node - will change the current threads node assignment
  207. // to achieve this so take care to restore this
  208. dml::handler<dml::mem_copy_operation, std::allocator<uint8_t>> ExecuteCopy(
  209. const uint8_t* src, uint8_t* dst, const size_t size, const int node
  210. ) const;
  211. // allocates the required memory on the destination node
  212. // and then submits task to the dml library for processing
  213. // and attaches the handlers to the cache data structure
  214. void SubmitTask(CacheData* task, const int dst_node, const int src_node);
  215. // querries the policy functions for the given data and size
  216. // to obtain destination cache node, also returns the datas
  217. // source node for further usage
  218. // output may depend on the calling threads node assignment
  219. // as this is set as the "optimal placement" node
  220. void GetCacheNode(uint8_t* src, const size_t size, int* OUT_DST_NODE, int* OUT_SRC_NODE) const;
  221. // allocates memory of size "size" on the numa node "node"
  222. // and returns nullptr if this is not possible, also may
  223. // try to flush the cache of the requested node to
  224. // alleviate encountered shortage
  225. uint8_t* AllocOnNode(const size_t size, const int node);
  226. // checks whether the cache contains an entry for
  227. // the given data in the given memory node and
  228. // returns it, otherwise returns nullptr
  229. std::unique_ptr<CacheData> GetFromCache(uint8_t* src, const size_t size, const int dst_node);
  230. public:
  231. ~Cache();
  232. Cache() = default;
  233. Cache(const Cache& other) = delete;
  234. // initializes the cache with the two policy functions
  235. // only after this is it safe to use in a threaded environment
  236. void Init(CachePolicy* cache_policy_function, CopyPolicy* copy_policy_function);
  237. // function to perform data access through the cache
  238. std::unique_ptr<CacheData> Access(uint8_t* data, const size_t size);
  239. // flushes the cache of inactive entries
  240. // if node is -1 then the whole cache is
  241. // checked and otherwise the specified
  242. // node - no checks on node validity
  243. void Flush(const int node = -1);
  244. // forces out all entries from the
  245. // cache and therefore will also "forget"
  246. // still-in-use entries, these will still
  247. // be properly deleted, but the cache
  248. // will be fresh - use for testing
  249. void Clear();
  250. void Invalidate(uint8_t* data);
  251. void SetFlags(const uint64_t flags) { flags_ = flags; }
  252. uint64_t GetFlags() { return flags_; }
  253. };
  254. }
  255. inline void dsacache::Cache::Clear() {
  256. for (auto& nc : cache_state_) {
  257. std::unique_lock<std::shared_mutex> lock(nc.second->cache_mutex_);
  258. nc.second->node_cache_state_.clear();
  259. }
  260. }
  261. inline void dsacache::Cache::Init(CachePolicy* cache_policy_function, CopyPolicy* copy_policy_function) {
  262. cache_policy_function_ = cache_policy_function;
  263. copy_policy_function_ = copy_policy_function;
  264. // initialize numa library
  265. numa_available();
  266. // obtain all available nodes
  267. // and those we may allocate
  268. // memory on
  269. const int nodes_max = numa_num_configured_nodes();
  270. const bitmask* valid_nodes = numa_get_mems_allowed();
  271. // prepare the cache state with entries
  272. // for all given nodes
  273. for (int node = 0; node < nodes_max; node++) {
  274. if (numa_bitmask_isbitset(valid_nodes, node)) {
  275. void* block = numa_alloc_onnode(sizeof(LockedNodeCacheState), node);
  276. auto* state = new(block)LockedNodeCacheState;
  277. cache_state_.insert({node,state});
  278. }
  279. }
  280. }
  281. inline std::unique_ptr<dsacache::CacheData> dsacache::Cache::Access(uint8_t* data, const size_t size) {
  282. // get destination numa node for the cache
  283. int dst_node = -1;
  284. int src_node = -1;
  285. GetCacheNode(data, size, &dst_node, &src_node);
  286. // TODO: at this point it could be beneficial to check whether
  287. // TODO: the given destination node is present as an entry
  288. // TODO: in the cache state to see if it is valid
  289. // check whether the data is already cached
  290. std::unique_ptr<CacheData> task = GetFromCache(data, size, dst_node);
  291. if (task != nullptr) {
  292. return std::move(task);
  293. }
  294. // at this point the requested data is not present in cache
  295. // and we create a caching task for it, copying our current flags
  296. task = std::make_unique<CacheData>(data, size);
  297. task->SetFlags(flags_);
  298. {
  299. LockedNodeCacheState* local_cache_state = cache_state_[dst_node];
  300. std::unique_lock<std::shared_mutex> lock(local_cache_state->cache_mutex_);
  301. const auto state = local_cache_state->node_cache_state_.emplace(task->GetSource(), *task);
  302. // if state.second is false then no insertion took place
  303. // which means that concurrently whith this thread
  304. // some other thread must have accessed the same
  305. // resource in which case we return the other
  306. // threads data cache structure
  307. if (!state.second) {
  308. return std::move(std::make_unique<CacheData>(state.first->second));
  309. }
  310. // initialize the task now for thread safety
  311. // as we are now sure that we will submit work
  312. // to it and will not delete it beforehand
  313. task->Init();
  314. }
  315. SubmitTask(task.get(), dst_node, src_node);
  316. return std::move(task);
  317. }
  318. inline uint8_t* dsacache::Cache::AllocOnNode(const size_t size, const int node) {
  319. // allocate data on this node and flush the unused parts of the
  320. // cache if the operation fails and retry once
  321. // TODO: smarter flush strategy could keep some stuff cached
  322. // check currently free memory to see if the data fits
  323. long long int free_space = 0;
  324. numa_node_size64(node, &free_space);
  325. if (free_space < size) {
  326. // dst node lacks memory space so we flush the cache for this
  327. // node hoping to free enough currently unused entries to make
  328. // the second allocation attempt successful
  329. Flush(node);
  330. // re-test by getting the free space and checking again
  331. numa_node_size64(node, &free_space);
  332. if (free_space < size) {
  333. return nullptr;
  334. }
  335. }
  336. uint8_t* dst = reinterpret_cast<uint8_t*>(numa_alloc_onnode(size, node));
  337. if (dst == nullptr) {
  338. return nullptr;
  339. }
  340. return dst;
  341. }
  342. inline void dsacache::Cache::SubmitTask(CacheData* task, const int dst_node, const int src_node) {
  343. static thread_local int last_node_index = -1;
  344. // stores the last node used for the local thread so we can achieve some
  345. // load balancing which locally might look like round robin, but considering
  346. // that one source thread may see different results for "executing_nodes" with
  347. // different sizes, and that multiple threads will submit, in reality we
  348. // achieve a "wild-west-style" load balance here
  349. uint8_t* dst = AllocOnNode(task->GetSize(), dst_node);
  350. if (dst == nullptr) {
  351. return;
  352. }
  353. // querry copy policy function for the nodes available to use for the copy
  354. const std::vector<int> executing_nodes = copy_policy_function_(dst_node, src_node, task->GetSize());
  355. // use our load balancing method and determine node for this task
  356. last_node_index = ++last_node_index % executing_nodes.size();
  357. const int node = executing_nodes[last_node_index];
  358. // submit the copy and attach it to the task entry
  359. auto* handler = new CacheData::dml_handler();
  360. *handler = ExecuteCopy(task->GetSource(), dst, task->GetSize(), node);
  361. task->SetTaskHandlerAndCache(dst, handler);
  362. }
  363. inline dml::handler<dml::mem_copy_operation, std::allocator<uint8_t>> dsacache::Cache::ExecuteCopy(
  364. const uint8_t* src, uint8_t* dst, const size_t size, const int node
  365. ) const {
  366. dml::const_data_view srcv = dml::make_view(src, size);
  367. dml::data_view dstv = dml::make_view(dst, size);
  368. if (CheckFlag(flags_, FLAG_HANDLE_PF)) {
  369. return dml::submit<dml::hardware>(
  370. dml::mem_copy.block_on_fault(), srcv, dstv,
  371. dml::execution_interface<dml::hardware,std::allocator<uint8_t>>(), node
  372. );
  373. }
  374. else {
  375. return dml::submit<dml::hardware>(
  376. dml::mem_copy, srcv, dstv,
  377. dml::execution_interface<dml::hardware,std::allocator<uint8_t>>(), node
  378. );
  379. }
  380. }
  381. inline void dsacache::Cache::GetCacheNode(uint8_t* src, const size_t size, int* OUT_DST_NODE, int* OUT_SRC_NODE) const {
  382. // obtain numa node of current thread to determine where the data is needed
  383. const int current_cpu = sched_getcpu();
  384. const int current_node = numa_node_of_cpu(current_cpu);
  385. // obtain node that the given data pointer is allocated on
  386. *OUT_SRC_NODE = -1;
  387. get_mempolicy(OUT_SRC_NODE, NULL, 0, (void*)src, MPOL_F_NODE | MPOL_F_ADDR);
  388. // querry cache policy function for the destination numa node
  389. *OUT_DST_NODE = cache_policy_function_(current_node, *OUT_SRC_NODE, size);
  390. }
  391. inline void dsacache::Cache::Flush(const int node) {
  392. // this lambda is used because below we have two code paths that
  393. // flush nodes, either one single or all successively
  394. const auto FlushNode = [](std::unordered_map<uint8_t*,CacheData>& map) {
  395. // begin at the front of the map
  396. auto it = map.begin();
  397. // loop until we reach the end of the map
  398. while (it != map.end()) {
  399. // if the iterator points to an inactive element
  400. // then we may erase it
  401. if (it->second.GetRefCount() <= 1) {
  402. // erase the iterator from the map
  403. map.erase(it);
  404. // as the erasure invalidated out iterator
  405. // we must start at the beginning again
  406. it = map.begin();
  407. }
  408. else {
  409. // if element is active just move over to the next one
  410. it++;
  411. }
  412. }
  413. };
  414. // we require exclusive lock as we modify the cache state
  415. // node == -1 means that cache on all nodes should be flushed
  416. if (node == -1) {
  417. for (auto& nc : cache_state_) {
  418. std::unique_lock<std::shared_mutex> lock(nc.second->cache_mutex_);
  419. FlushNode(nc.second->node_cache_state_);
  420. }
  421. }
  422. else {
  423. std::unique_lock<std::shared_mutex> lock(cache_state_[node]->cache_mutex_);
  424. FlushNode(cache_state_[node]->node_cache_state_);
  425. }
  426. }
  427. inline std::unique_ptr<dsacache::CacheData> dsacache::Cache::GetFromCache(uint8_t* src, const size_t size, const int dst_node) {
  428. // the best situation is if this data is already cached
  429. // which we check in an unnamed block in which the cache
  430. // is locked for reading to prevent another thread
  431. // from marking the element we may find as unused and
  432. // clearing it
  433. LockedNodeCacheState* local_cache_state = cache_state_[dst_node];
  434. // lock the cache state in shared-mode because we read
  435. std::shared_lock<std::shared_mutex> lock(local_cache_state->cache_mutex_);
  436. // search for the data in our cache state structure at the given node
  437. const auto search = local_cache_state->node_cache_state_.find(src);
  438. // if the data is in our structure we continue
  439. if (search != local_cache_state->node_cache_state_.end()) {
  440. // now check whether the sizes match
  441. if (search->second.GetSize() >= size) {
  442. // return a unique copy of the entry which uses the object
  443. // lifetime and destructor to safely handle deallocation
  444. return std::move(std::make_unique<CacheData>(search->second));
  445. }
  446. else {
  447. // if the sizes missmatch then we clear the current entry from cache
  448. // which will cause its deletion only after the last possible outside
  449. // reference is also destroyed
  450. local_cache_state->node_cache_state_.erase(search);
  451. }
  452. }
  453. return nullptr;
  454. }
  455. void dsacache::Cache::Invalidate(uint8_t* data) {
  456. // as the cache is modified we must obtain a unique writers lock
  457. // loop through all per-node-caches available
  458. for (auto node : cache_state_) {
  459. std::unique_lock<std::shared_mutex> lock(node.second->cache_mutex_);
  460. // search for an entry for the given data pointer
  461. auto search = node.second->node_cache_state_.find(data);
  462. if (search != node.second->node_cache_state_.end()) {
  463. // if the data is represented in-cache
  464. // then it will be erased to re-trigger
  465. // caching on next access
  466. node.second->node_cache_state_.erase(search);
  467. }
  468. }
  469. }
  470. inline dsacache::Cache::~Cache() {
  471. for (auto node : cache_state_) {
  472. node.second->~LockedNodeCacheState();
  473. numa_free(reinterpret_cast<void*>(node.second), sizeof(LockedNodeCacheState));
  474. }
  475. }
  476. inline dsacache::CacheData::CacheData(uint8_t* data, const size_t size) {
  477. src_ = data;
  478. size_ = size;
  479. delete_ = false;
  480. active_ = new std::atomic<int32_t>(1);
  481. cache_ = new std::atomic<uint8_t*>(data);
  482. handler_ = new std::atomic<dml_handler*>(nullptr);
  483. incomplete_cache_ = new uint8_t*(nullptr);
  484. }
  485. inline dsacache::CacheData::CacheData(const dsacache::CacheData& other) {
  486. // we copy the ptr to the global atomic reference counter
  487. // and increase the amount of active references
  488. active_ = other.active_;
  489. const int current_active = active_->fetch_add(1);
  490. src_ = other.src_;
  491. size_ = other.size_;
  492. cache_ = other.cache_;
  493. flags_ = other.flags_;
  494. incomplete_cache_ = other.incomplete_cache_;
  495. handler_ = other.handler_;
  496. }
  497. inline dsacache::CacheData::~CacheData() {
  498. // due to fetch_sub returning the preivously held value
  499. // we must subtract one locally to get the current value
  500. const int32_t v = active_->fetch_sub(1) - 1;
  501. // if the returned value is zero or lower
  502. // then we must execute proper deletion
  503. // as this was the last reference
  504. if (v == 0) {
  505. // on deletion we must ensure that all offloaded
  506. // operations have completed successfully
  507. WaitOnCompletion();
  508. // only then can we deallocate the memory
  509. Deallocate();
  510. delete active_;
  511. delete cache_;
  512. delete handler_;
  513. delete incomplete_cache_;
  514. }
  515. }
  516. inline void dsacache::CacheData::Deallocate() {
  517. // although deallocate should only be called from
  518. // a safe context to do so, it can not hurt to
  519. // defensively perform the operation atomically
  520. // and check for incomplete cache if no deallocation
  521. // takes place for the retrieved local cache
  522. uint8_t* cache_local = cache_->exchange(nullptr);
  523. if (cache_local != nullptr && delete_) numa_free(cache_local, size_);
  524. else if (*incomplete_cache_ != nullptr) numa_free(*incomplete_cache_, size_);
  525. else;
  526. }
  527. inline void dsacache::CacheData::WaitOnCompletion() {
  528. // first check if waiting is even neccessary as a valid
  529. // cache pointer signals that no waiting is to be performed
  530. if (cache_->load() != nullptr) {
  531. return;
  532. }
  533. // then check if the handlers are available
  534. handler_->wait(nullptr);
  535. // exchange the global handlers pointer with nullptr to have a local
  536. // copy - this signals that this thread is the sole owner and therefore
  537. // responsible for waiting for them. we can not set to nullptr here but
  538. // set to maximum of 64-bit in order to prevent deadlocks from the above
  539. // waiting construct
  540. dml_handler* local_handler = handler_->exchange(reinterpret_cast<dml_handler*>(maxptr));
  541. // ensure that no other thread snatched the handlers before us
  542. // and in case one did, wait again and then return
  543. if (local_handler == nullptr || local_handler == reinterpret_cast<dml_handler*>(maxptr)) {
  544. cache_->wait(nullptr);
  545. return;
  546. }
  547. // at this point we are responsible for waiting for the handlers
  548. // and handling any error that comes through them gracefully
  549. if (CheckFlag(flags_, FLAG_WAIT_WEAK) && !local_handler->is_finished()) {
  550. handler_->store(local_handler);
  551. return;
  552. }
  553. // perform the wait
  554. auto result = local_handler->get();
  555. // at this point handlers has been waited for
  556. // and therefore may be decomissioned
  557. delete local_handler;
  558. // if the copy tasks failed we abort the whole task
  559. // otherwise the cache will be set to valid now
  560. if (result.status != dml::status_code::ok) {
  561. cache_->store(src_);
  562. numa_free(*incomplete_cache_, size_);
  563. delete_ = false;
  564. *incomplete_cache_ = nullptr;
  565. }
  566. else {
  567. cache_->store(*incomplete_cache_);
  568. }
  569. // notify all waiting threads so they wake up quickly
  570. cache_->notify_all();
  571. handler_->notify_all();
  572. }
  573. void dsacache::CacheData::SetTaskHandlerAndCache(uint8_t* cache, dml_handler* handler) {
  574. *incomplete_cache_ = cache;
  575. handler_->store(handler);
  576. handler_->notify_one();
  577. }
  578. void dsacache::CacheData::Init() {
  579. cache_->store(nullptr);
  580. delete_ = true;
  581. }