diff --git a/offloading-cacher/main.cpp b/offloading-cacher/main.cpp index 443b00b..8193f5a 100644 --- a/offloading-cacher/main.cpp +++ b/offloading-cacher/main.cpp @@ -7,6 +7,8 @@ #include "cache.hpp" +static constexpr size_t SIZE_64_MIB = 64 * 1024 * 1024; + dsacache::Cache CACHE; void InitCache(const std::string& device) { @@ -15,25 +17,47 @@ void InitCache(const std::string& device) { return numa_dst_node; }; - auto copy_policy = [](const int numa_dst_node, const int numa_src_node) { - return std::vector{ numa_src_node, numa_dst_node }; + auto copy_policy = [](const int numa_dst_node, const int numa_src_node, const size_t data_size) { + return std::vector{ numa_dst_node }; }; CACHE.Init(cache_policy,copy_policy); } else if (device == "xeonmax") { auto cache_policy = [](const int numa_dst_node, const int numa_src_node, const size_t data_size) { + // xeon max is configured to have hbm on node ids that are +8 + return numa_dst_node < 8 ? numa_dst_node + 8 : numa_dst_node; }; - auto copy_policy = [](const int numa_dst_node, const int numa_src_node) { - const bool same_socket = ((numa_dst_node ^ numa_src_node) & 4) == 0; - if (same_socket) { - const bool socket_number = numa_dst_node >> 2; - if (socket_number == 0) return std::vector{ 0, 1, 2, 3 }; - else return std::vector{ 4, 5, 6, 7 }; + auto copy_policy = [](const int numa_dst_node, const int numa_src_node, const size_t data_size) { + if (data_size < SIZE_64_MIB) { + // if the data size is small then the copy will just be carried + // out by the destination node which does not require setting numa + // thread affinity as the selected dsa engine is already the one + // present on the calling thread + + return std::vector{ (numa_dst_node >= 8 ? numa_dst_node - 8 : numa_dst_node) }; + } + else { + // for sufficiently large data, smart copy is used which will utilize + // all four engines for intra-socket copy operations and cross copy on + // the source and destination nodes for inter-socket copy + + const bool same_socket = ((numa_dst_node ^ numa_src_node) & 4) == 0; + + if (same_socket) { + const bool socket_number = numa_dst_node >> 2; + if (socket_number == 0) return std::vector{ 0, 1, 2, 3 }; + else return std::vector{ 4, 5, 6, 7 }; + } + else { + return std::vector{ + (numa_src_node >= 8 ? numa_src_node - 8 : numa_src_node), + (numa_dst_node >= 8 ? numa_dst_node - 8 : numa_dst_node) + }; + } } - else return std::vector{ numa_src_node, numa_dst_node }; }; CACHE.Init(cache_policy,copy_policy);