/home/runner/work/kynema/kynema/kynema/src/solver/linear_solver/dss_solve_cudss.hpp Source File

Kynema API: /home/runner/work/kynema/kynema/kynema/src/solver/linear_solver/dss_solve_cudss.hpp Source File
Kynema API
A flexible multibody structural dynamics code for wind turbines
Loading...
Searching...
No Matches
dss_solve_cudss.hpp
Go to the documentation of this file.
1#pragma once
2
3#include <Kokkos_Core.hpp>
4#include <cudss.h>
5
6#include "dss_algorithm.hpp"
8
9namespace kynema::dss {
10template <typename CrsMatrixType, typename MultiVectorType>
11struct SolveFunction<Handle<Algorithm::CUDSS>, CrsMatrixType, MultiVectorType> {
12 static void solve(
13 Handle<Algorithm::CUDSS>& dss_handle, CrsMatrixType& A, MultiVectorType& b,
14 MultiVectorType& x
15 ) {
16 const auto num_rows = A.numRows();
17 const auto num_cols = A.numCols();
18 const auto num_non_zero = static_cast<int>(A.nnz());
19
20 auto* values = A.values.data();
21 auto* row_ptrs = A.graph.row_map.data();
22 auto* col_inds = A.graph.entries.data();
23
24 auto& handle = dss_handle.get_handle();
25 auto& config = dss_handle.get_config();
26 auto& data = dss_handle.get_data();
27
28 cudssMatrix_t A_cudss;
29 cudssMatrix_t x_cudss;
30 cudssMatrix_t b_cudss;
31
32 cudssMatrixCreateCsr(
33 &A_cudss, num_rows, num_cols, num_non_zero, const_cast<int*>(row_ptrs), nullptr,
34 col_inds, values, CUDA_R_32I, CUDA_R_64F, CUDSS_MTYPE_GENERAL, CUDSS_MVIEW_FULL,
35 CUDSS_BASE_ZERO
36 );
37 cudssMatrixCreateDn(
38 &b_cudss, num_cols, 1, num_cols, b.data(), CUDA_R_64F, CUDSS_LAYOUT_COL_MAJOR
39 );
40 cudssMatrixCreateDn(
41 &x_cudss, num_rows, 1, num_rows, x.data(), CUDA_R_64F, CUDSS_LAYOUT_COL_MAJOR
42 );
43
44 cudssExecute(handle, CUDSS_PHASE_SOLVE, config, data, A_cudss, x_cudss, b_cudss);
45
46 cudssMatrixDestroy(A_cudss);
47 cudssMatrixDestroy(b_cudss);
48 cudssMatrixDestroy(x_cudss);
49 }
50};
51
52} // namespace kynema::dss
Definition dss_handle.hpp:10
Definition dss_algorithm.hpp:4
Algorithm
Definition dss_algorithm.hpp:6
static void solve(Handle< Algorithm::CUDSS > &dss_handle, CrsMatrixType &A, MultiVectorType &b, MultiVectorType &x)
Definition dss_solve_cudss.hpp:12
Definition dss_solve.hpp:8