use std::cmp;
use crate::utilities::dot_product;
use crate::{gemm::gemm, utilities::daxpy_update};
#[cfg(feature = "profiling")]
use crate::profiling;
const NB: usize = 32;
#[allow(unsafe_op_in_unsafe_fn, clippy::missing_safety_doc, clippy::too_many_arguments)]
pub unsafe fn syrk(
uplo: char,
trans: char,
n: usize,
k: usize,
alpha: f64,
a: *const f64,
lda: usize,
beta: f64,
c: *mut f64,
ldc: usize,
) {
#[cfg(feature = "profiling")]
let _timer = profiling::ScopedTimer::new("SYRK");
if n == 0 || ((alpha == 0.0 || k == 0) && beta == 1.0) {
return;
}
if beta == 0.0 {
for j in 0..n {
for i in 0..=j {
if uplo == 'U' {
*c.add(i + j * ldc) = 0.0;
} else {
*c.add(j + i * ldc) = 0.0;
}
}
}
} else if beta != 1.0 {
for j in 0..n {
for i in 0..=j {
if uplo == 'U' {
*c.add(i + j * ldc) *= beta;
} else {
*c.add(j + i * ldc) *= beta;
}
}
}
}
if alpha == 0.0 {
return;
}
if uplo == 'U' {
if trans == 'T' {
for p in (0..k).step_by(NB) {
let pb = cmp::min(k - p, NB);
for j in (0..n).step_by(NB) {
let jb = cmp::min(n - j, NB);
syrk_kernel('U', 'T', jb, pb, alpha, a.add(p + j * lda), lda, 1.0, c.add(j + j * ldc), ldc);
for i in (0..j).step_by(NB) {
let ib = cmp::min(j - i, NB);
gemm(
'T',
'N',
ib,
jb,
pb,
alpha,
a.add(p + i * lda),
lda,
a.add(p + j * lda),
lda,
1.0,
c.add(i + j * ldc),
ldc,
);
}
}
}
} else {
for p in (0..k).step_by(NB) {
let pb = cmp::min(k - p, NB);
for j in (0..n).step_by(NB) {
let jb = cmp::min(n - j, NB);
syrk_kernel('U', 'N', jb, pb, alpha, a.add(j + p * lda), lda, 1.0, c.add(j + j * ldc), ldc);
for i in (0..j).step_by(NB) {
let ib = cmp::min(j - i, NB);
gemm(
'N',
'T',
ib,
jb,
pb,
alpha,
a.add(i + p * lda),
lda,
a.add(j + p * lda),
lda,
1.0,
c.add(i + j * ldc),
ldc,
);
}
}
}
}
} else {
if trans == 'N' {
for p in (0..k).step_by(NB) {
let pb = cmp::min(k - p, NB);
for j in (0..n).step_by(NB) {
let jb = cmp::min(n - j, NB);
syrk_kernel('L', 'N', jb, pb, alpha, a.add(j + p * lda), lda, 1.0, c.add(j + j * ldc), ldc);
for i in (j + jb..n).step_by(NB) {
let ib = cmp::min(n - i, NB);
gemm(
'N',
'T',
ib,
jb,
pb,
alpha,
a.add(i + p * lda),
lda,
a.add(j + p * lda),
lda,
1.0,
c.add(i + j * ldc),
ldc,
);
}
}
}
} else {
for p in (0..k).step_by(NB) {
let pb = cmp::min(k - p, NB);
for j in (0..n).step_by(NB) {
let jb = cmp::min(n - j, NB);
syrk_kernel('L', 'T', jb, pb, alpha, a.add(p + j * lda), lda, 1.0, c.add(j + j * ldc), ldc);
for i in (j + jb..n).step_by(NB) {
let ib = cmp::min(n - i, NB);
gemm(
'T',
'N',
ib,
jb,
pb,
alpha,
a.add(p + i * lda),
lda,
a.add(p + j * lda),
lda,
1.0,
c.add(i + j * ldc),
ldc,
);
}
}
}
}
}
}
#[allow(unsafe_op_in_unsafe_fn, clippy::missing_safety_doc, clippy::too_many_arguments)]
unsafe fn syrk_kernel(
uplo: char,
trans: char,
n: usize,
k: usize,
alpha: f64,
a: *const f64,
lda: usize,
_beta: f64, c: *mut f64,
ldc: usize,
) {
if trans == 'N' {
if uplo == 'U' {
for j in 0..n {
for l in 0..k {
let temp = alpha * *a.add(j + l * lda);
if temp != 0.0 {
daxpy_update(j + 1, temp, a.add(l * lda), 1, c.add(j * ldc), 1);
}
}
}
} else {
for j in 0..n {
for l in 0..k {
let temp = alpha * *a.add(j + l * lda);
if temp != 0.0 {
daxpy_update(n - j, temp, a.add(j + l * lda), 1, c.add(j + j * ldc), 1);
}
}
}
}
} else if uplo == 'U' {
for j in 0..n {
for i in 0..=j {
let temp = dot_product(k, a.add(i * lda), 1, a.add(j * lda), 1);
*c.add(i + j * ldc) += alpha * temp;
}
}
} else {
for j in 0..n {
for i in j..n {
let temp = dot_product(k, a.add(i * lda), 1, a.add(j * lda), 1);
*c.add(i + j * ldc) += alpha * temp;
}
}
}
}