// desc simple implementation of sparse matrix // maintainer hugoyu #ifndef __sparse_matrix_h__ #define __sparse_matrix_h__ #include #include #include "common.h" template class sparse_matrix { public: constexpr sparse_matrix(uint32 row, uint32 col) { assert(is_valid_size(row, col)); m_row = row; m_col = col; } constexpr sparse_matrix(const sparse_matrix& other) { m_row = other.m_row; m_col = other.m_col; m_element_buffer = other.m_element_buffer; } constexpr sparse_matrix(sparse_matrix&& other) { m_row = other.m_row; m_col = other.m_col; m_element_buffer.swap(other.m_element_buffer); } constexpr sparse_matrix& operator =(const sparse_matrix& other) { if (this != &other) { assert(is_valid_size(row, col)); m_row = row; m_col = col; m_element_buffer = other.m_element_buffer; } } constexpr sparse_matrix& operator =(sparse_matrix&& other) { if (this != &other) { m_row = other.m_row; m_col = other.m_col; m_element_buffer.swap(other.m_element_buffer); } } constexpr uint32 row() const { return m_row; } constexpr uint32 col() const { return m_col; } constexpr const T& operator ()(uint32 row, uint32 col) const { assert(is_valid_index(row, col)); auto iter = m_element_buffer.find(gen_element_key(row, col)); if (iter != m_element_buffer.end()) { return iter->second; } return T(); } constexpr T& operator ()(uint32 row, uint32 col) { assert(is_valid_index(row, col)); return m_element_buffer[gen_element_key(row, col)]; } sparse_matrix operator *(const T& right) const { sparse_matrix temp(m_row, m_col); for (auto& element : m_element_buffer) { /* uint32 row = 0; uint32 col = 0; extract_element_key(iter.first, row, col); temp(row, col) = iter.second * right; */ temp.m_element_buffer[element.first] = element.second * right; } return temp; } sparse_matrix& operator *=(const T& right) { for (auto& element : m_element_buffer) { element.second *= right; } return *this; } sparse_matrix operator +(const sparse_matrix& right) const { assert(row() == right.row() && col() == right.col()); sparse_matrix temp(m_row, m_col); /* for (auto& element : m_element_buffer) { //uint32 row = 0; //uint32 col = 0; //extract_element_key(element.first, row, col); //temp(row, col) = element.second; temp.m_element_buffer[element.first] = element.second; } */ temp.m_element_buffer = m_element_buffer; for (auto& element : right.m_element_buffer) { /* uint32 row = 0; uint32 col = 0; extract_element_key(iter.first, row, col); temp(row, col) = (*this)(row, col) + iter.second; */ /* auto val = T(); auto left_iter = m_element_buffer.find(element.first); if (left_iter != m_element_buffer.end()) { val = left_iter->second; } temp.m_element_buffer[element.first] = val + element.second; */ temp.m_element_buffer[element.first] += element.second; } return temp; } sparse_matrix& operator +=(const sparse_matrix& right) { assert(row() == right.row() && col() == right.col()); for (auto& element : right.m_element_buffer) { /* uint32 row = 0; uint32 col = 0; extract_element_key(iter.first, row, col); (*this)(row, col) += iter.second; */ /* auto val = T(); auto left_iter = m_element_buffer.find(element.first); if (left_iter != m_element_buffer.end()) { val = left_iter->second; } m_element_buffer[element.first] = val + element.second; */ m_element_buffer[element.first] += element.second; } return *this; } sparse_matrix operator -(const sparse_matrix& right) const { assert(row() == right.row() && col() == right.col()); sparse_matrix temp(m_row, m_col); /* for (auto& element : m_element_buffer) { uint32 row = 0; uint32 col = 0; extract_element_key(element.first, row, col); temp(row, col) = element.second; } */ temp.m_element_buffer = m_element_buffer; for (auto& element : right.m_element_buffer) { /* uint32 row = 0; uint32 col = 0; extract_element_key(iter.first, row, col); temp(row, col) = (*this)(row, col) - iter.second; */ /* auto val = T(); auto left_iter = m_element_buffer.find(element.first); if (left_iter != m_element_buffer.end()) { val = left_iter->second; } temp.m_element_buffer[element.first] = val - element.second; */ temp.m_element_buffer[element.first] -= element.second; } return temp; } sparse_matrix operator -=(const sparse_matrix& right) { assert(row() == right.m_row && col() == right.m_col); for (auto& element : right.m_element_buffer) { /* uint32 row = 0; uint32 col = 0; extract_element_key(iter.first, row, col); (*this)(row, col) -= iter.second; */ /* auto val = T(); auto left_iter = m_element_buffer.find(element.first); if (left_iter != m_element_buffer.end()) { val = left_iter->second; } m_element_buffer[element.first] = val - element.second; */ m_element_buffer[element.first] -= element.second; } return *this; } sparse_matrix operator *(const sparse_matrix& right) const { assert(col() == right.m_row); sparse_matrix temp(m_row, m_col); auto row = m_row; auto col = right.m_col; auto inn = m_col; for (uint32 i = 0; i < row; ++i) { for (uint32 j = 0; j < col; ++j) { auto val = T(); for (uint32 k = 0; k < inn; ++k) { val += (*this)(i, k) * right(k, j); } temp(i, j) = val; } } return temp; } constexpr sparse_matrix& operator *=(const sparse_matrix& right) { assert(col() == right.row()); // NOTE if right is *this, we can optimize space ? auto temp = (*this) * right; m_row = temp.m_row; m_col = temp.m_col; // simple swap, or we can just use move constructor m_element_buffer.swap(temp.m_element_buffer); return *this; } std::string to_string() { std::string string_buffer; string_buffer.reserve(m_row * m_col * 8); for (uint32 i = 0; i < m_row; ++i) { for (uint32 j = 0; j < m_col; ++j) { //string_buffer.append(this->operator()(i, j)); string_buffer.append(std::to_string((*this)(i, j))); string_buffer.append(", "); } string_buffer.append("\n"); } return string_buffer; } private: constexpr bool is_valid_size(uint32 row, uint32 col) const noexcept { return row > 0 && col > 0; } constexpr bool is_valid_index(uint32 row, uint32 col) const noexcept { return row < m_row && col < m_col; } constexpr static uint64 gen_element_key(uint32 row, uint32 col) noexcept { return ((uint64)row << 32) | ((uint64)col); } constexpr static void extract_element_key(uint64 key, uint32& row, uint32& col) noexcept { row = (uint32)((key >> 32) & 0xFFFFFFFF); col = (uint32)(key & 0xFFFFFFFF); } private: uint32 m_row{ 0 }; uint32 m_col{ 0 }; Container m_element_buffer; }; template constexpr sparse_matrix operator *(const T& left, const sparse_matrix& right) { return right * left; } #endif