Skip to content

Commit

Permalink
[software] Add flag to fold q16 MMSE
Browse files Browse the repository at this point in the history
  • Loading branch information
mbertuletti committed Jan 6, 2025
1 parent 30b3458 commit 0d730ae
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 36 deletions.
81 changes: 59 additions & 22 deletions software/apps/baremetal/mimo_mmse_q16/main.c
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,29 @@ Parameters and defines
PARALLEL: When defined benchmark parallel MIMO-MMSE.
SINGLE: When defined benchmark single-core MIMO-MMSE.
FOLD: When defined 1 fold matrices in memory.
*/

int16_t l1_H[2 * N_TX * N_RX * N_ITR]
__attribute__((aligned(BANKING_FACTOR * NUM_CORES * sizeof(int32_t)),
section(".l1_prio")));
#define FOLD (1)
#define PARALLEL

#if FOLD
#define NUM_ROW (1 + ((N_ITR * N_TX - 1) / NUM_BANKS))
#define NUM_COL (NUM_BANKS / N_TX)

int16_t l1_G[2 * N_TX * NUM_BANKS * NUM_ROW]
__attribute__((aligned(sizeof(int32_t)), section(".l1_prio")));
int16_t l1_L[2 * N_TX * NUM_BANKS * NUM_ROW]
__attribute__((aligned(sizeof(int32_t)), section(".l1_prio")));
#else
int16_t l1_G[2 * N_TX * N_TX * N_ITR]
__attribute__((aligned(BANKING_FACTOR * NUM_CORES * sizeof(int32_t)),
section(".l1_prio")));
__attribute__((aligned(sizeof(int32_t)), section(".l1_prio")));
int16_t l1_L[2 * N_TX * N_TX * N_ITR]
__attribute__((aligned(BANKING_FACTOR * NUM_CORES * sizeof(int32_t)),
section(".l1_prio")));
__attribute__((aligned(sizeof(int32_t)), section(".l1_prio")));
#endif

int16_t l1_H[2 * N_TX * N_RX * N_ITR]
__attribute__((aligned(sizeof(int32_t)), section(".l1_prio")));
int16_t l1_S[2 * N_TX * N_ITR]
__attribute__((aligned(sizeof(int32_t)), section(".l1_prio")));
int16_t l1_y[2 * N_RX * N_ITR]
Expand All @@ -51,12 +62,14 @@ int main() {
uint32_t core_id = mempool_get_core_id();
uint32_t num_cores = mempool_get_core_count();
mempool_barrier_init(core_id); // Initialize barrier and synchronize
uint32_t time_init, time_end;

/* Initialize matrices */
if (core_id == 0) {
dma_memcpy_blocking(l1_H, l2_H, N_TX * N_RX * N_ITR * sizeof(int32_t));
dma_memcpy_blocking(l1_y, l2_y, N_RX * N_ITR * sizeof(int32_t));
dma_memcpy_blocking(l1_S, l2_S, N_TX * N_ITR * sizeof(int32_t));
printf("Data transferred\n");
}
mempool_barrier(num_cores);

Expand All @@ -65,13 +78,18 @@ int main() {

if (core_id == 0) {
mempool_start_benchmark();
mempool_hermitian_q16vecs((v2s *)l1_H, (v2s *)l1_G, (v2s *)l1_Sigma, N_RX,
N_TX);
mempool_MVP_conjtransp_q16vecs((v2s *)l1_H, (v2s *)l1_y, (v2s *)y2, N_RX,
N_TX, 0);
mempool_cholesky_q16vecs(l1_G, l1_L, N_TX);
mempool_Ltrisol_q16vecs(l1_L, y2, y3, N_TX, 0);
mempool_Ltrisol_q16vecs(l1_L, y3, l1_x, N_TX, 1);
time_init = mempool_get_timer();
v2s *PtrH = (v2s *)l1_H;
v2s *PtrG = (v2s *)l1_G;
v2s *PtrS = (v2s *)l1_Sigma;
v2s *Ptry = (v2s *)l1_y;
v2s *Ptry2 = (v2s *)y2;
mempool_hermitian_q16vecs(PtrH, PtrG, PtrS, N_RX, N_TX);
mempool_MVP_conjtransp_q16vecs(PtrH, Ptry, Ptry2, N_RX, N_TX, FOLD);
mempool_cholesky_q16vecs(l1_G, l1_L, N_TX, FOLD);
mempool_Ltrisol_q16vecs(l1_L, y2, y3, N_TX, 0, FOLD);
mempool_Ltrisol_q16vecs(l1_L, y3, l1_x, N_TX, 1, FOLD);
time_end = mempool_get_timer();
mempool_stop_benchmark();
}
mempool_barrier(num_cores);
Expand All @@ -81,30 +99,49 @@ int main() {
#ifdef PARALLEL

mempool_start_benchmark();
time_init = mempool_get_timer();
for (uint32_t itr = core_id; itr < N_ITR; itr += num_cores) {

int16_t *PtrH = l1_H + itr * (2 * N_TX * N_RX);
int16_t *Ptry = l1_y + itr * (2 * N_RX);
int16_t *PtrSigma = l1_S + itr * (2 * N_TX);

int16_t *PtrS = l1_S + itr * (2 * N_TX);

#if FOLD
int16_t *PtrG = l1_G + (itr / NUM_COL) * (2 * N_TX * NUM_BANKS) +
(itr % NUM_COL) * (2 * N_TX);
int16_t *PtrL = l1_L + (itr / NUM_COL) * (2 * N_TX * NUM_BANKS) +
(itr % NUM_COL) * (2 * N_TX);
int16_t *Ptry2 =
y2 + (itr / NUM_COL) * (2 * NUM_BANKS) + (itr % NUM_COL) * (2 * N_TX);
int16_t *Ptry3 =
y3 + (itr / NUM_COL) * (2 * NUM_BANKS) + (itr % NUM_COL) * (2 * N_TX);
int16_t *Ptrx = l1_x + itr * (2 * N_TX);
#else
int16_t *PtrG = l1_G + itr * (2 * N_TX * N_TX);
int16_t *PtrL = l1_L + itr * (2 * N_TX * N_TX);
int16_t *Ptry2 = y2 + itr * (2 * N_TX);
int16_t *Ptry3 = y3 + itr * (2 * N_TX);
int16_t *Ptrx = l1_x + itr * (2 * N_TX);
#endif

mempool_hermitian_q16vecs((v2s *)PtrH, (v2s *)PtrG, (v2s *)PtrSigma, N_RX,
mempool_hermitian_q16vecs((v2s *)PtrH, (v2s *)PtrG, (v2s *)PtrS, N_RX,
N_TX);
mempool_MVP_conjtransp_q16vecs((v2s *)PtrH, (v2s *)Ptry, (v2s *)Ptry2, N_RX,
N_TX, 0);
mempool_cholesky_q16vecs(PtrG, PtrL, N_TX);
mempool_Ltrisol_q16vecs(PtrL, Ptry2, Ptry3, N_TX, 0);
mempool_Ltrisol_q16vecs(PtrL, Ptry3, Ptrx, N_TX, 1);
N_TX, FOLD);
mempool_cholesky_q16vecs(PtrG, PtrL, N_TX, FOLD);
mempool_Ltrisol_q16vecs(PtrL, Ptry2, Ptry3, N_TX, 0, FOLD);
mempool_Ltrisol_q16vecs(PtrL, Ptry3, Ptrx, N_TX, 1, FOLD);
}
mempool_log_barrier(2, core_id);
mempool_barrier(num_cores);
time_end = mempool_get_timer();
mempool_stop_benchmark();

#endif

if (core_id == 0) {
printf("Runtime: %d\n", time_end - time_init);
}
mempool_barrier(num_cores);

return 0;
}
24 changes: 13 additions & 11 deletions software/kernels/baremetal/mempool_cholesky_q16s.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@
@param[in] n dimension of the input data
@return none
*/
void mempool_cholesky_q16vecs(int16_t *pSrc, int16_t *pL, const uint32_t n) {
void mempool_cholesky_q16vecs(int16_t *pSrc, int16_t *pL, const uint32_t n,
const uint32_t folded) {

uint32_t i, j, k;
int32_t sum; // Sum for elements on diagonal (real)
int32_t diag; // Diagonal element (real)
int32_t as, bs; // Sum for elements on rows (complex)
int32_t ap, bp; // Pivot elements (complex)
uint32_t i, j, k;
const uint32_t offset = folded ? NUM_BANKS : n;

v2s ab = (v2s){0, 0};
v2s cd = (v2s){0, 0};
Expand All @@ -33,30 +35,30 @@ void mempool_cholesky_q16vecs(int16_t *pSrc, int16_t *pL, const uint32_t n) {

// Elements on diagonal (input matrix is positive-definite)
sum = 0;
diag = (int32_t)pSrc[2 * (j * n + j)];
diag = (int32_t)pSrc[2 * (j * offset + j)];
for (k = 0; k < j; k++) {
ab = *(v2s *)&pL[2 * (j * n + k)];
ab = *(v2s *)&pL[2 * (j * offset + k)];
asm volatile("pv.dotsp.h %[sum], %[ab], %[ab];"
"srai %[sum], %[sum], 0x8;"
"p.clip %[sum], %[sum], 0x16;"
: [sum] "+&r"(sum)
: [ab] "r"(ab)
:);
}
pL[2U * (j * n + j)] = (int16_t)mempool_sqrt_q32s(diag - sum, 16);
pL[2U * (j * offset + j)] = (int16_t)mempool_sqrt_q32s(diag - sum, 16);

// Elements on rows
for (i = j + 1; i < n; i++) {
ap = (int32_t)pSrc[2 * (i * n + j)]; // Pivot
bp = (int32_t)pSrc[2 * (i * n + j) + 1]; // Pivot
diag = (int32_t)pL[2 * (j * n + j)]; // Diag
ap = (int32_t)pSrc[2 * (i * offset + j)]; // Pivot
bp = (int32_t)pSrc[2 * (i * offset + j) + 1]; // Pivot
diag = (int32_t)pL[2 * (j * offset + j)]; // Diag

as = 0;
bs = 0;
// Sum -> s = s + (ac + bd) + j*(bc - ad)
for (k = 0; k < j; k++) {
ab = *(v2s *)&pL[2U * (i * n + k)];
cd = *(v2s *)&pL[2U * (j * n + k)];
ab = *(v2s *)&pL[2U * (i * offset + k)];
cd = *(v2s *)&pL[2U * (j * offset + k)];
const uint32_t shuffle_mask = 0x00020003;
asm volatile(
// s = s + (ac + bd) + j(bc - ad)
Expand All @@ -81,7 +83,7 @@ void mempool_cholesky_q16vecs(int16_t *pSrc, int16_t *pL, const uint32_t n) {
: [ap] "+&r"(ap), [bp] "+&r"(bp), [res] "+&r"(res)
: [as] "r"(as), [bs] "r"(bs), [diag] "r"(diag)
:);
(*(v2s *)&pL[2 * (i * n + j)]) = res;
(*(v2s *)&pL[2 * (i * offset + j)]) = res;
}
}
return;
Expand Down
9 changes: 6 additions & 3 deletions software/kernels/baremetal/mempool_linearsolver_q16s.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,26 +17,29 @@
*/

void mempool_Ltrisol_q16vecs(int16_t *pL, int16_t *y, int16_t *x,
const uint32_t n, const uint32_t transposed) {
const uint32_t n, const uint32_t transposed,
const uint32_t folded) {

uint32_t i, j;
int32_t as, bs, diag;
v2s ab, cd;
v2s res = (v2s){0, 0};
v2s ndc = (v2s){0, 0};
const uint32_t offset = folded ? NUM_BANKS : n;

// Solve for each variable x[i] in loop
for (i = 0; i < n; i++) {
uint32_t ridx = transposed ? (n - i - 1) : i;
diag = pL[2U * (ridx + ridx)];
diag = pL[2U * (ridx * offset + ridx)];
// Initialize the sums
as = 0;
bs = 0;
// Use the previously solved variables to compute the sum
for (j = 0; j < i; j++) {

uint32_t cidx = transposed ? (n - j - 1) : j;
if (!transposed) {
ab = *(v2s *)&pL[2U * (ridx * n + cidx)];
ab = *(v2s *)&pL[2U * (ridx * offset + cidx)];
} else {
ab = *(v2s *)&pL[2U * (cidx * n + ridx)];
}
Expand Down

0 comments on commit 0d730ae

Please sign in to comment.