So-Bogus
A c++ sparse block matrix library aimed at Second Order cone problems
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Pages
MklBindings.hpp
1 /*
2  * This file is part of bogus, a C++ sparse block matrix library.
3  *
4  * Copyright 2013 Gilles Daviet <gdaviet@gmail.com>
5  *
6  * This Source Code Form is subject to the terms of the Mozilla Public
7  * License, v. 2.0. If a copy of the MPL was not distributed with this
8  * file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 */
10 
11 #ifndef BOGUS_BLOCK_MKL_BINDINGS_HPP
12 #define BOGUS_BLOCK_MKL_BINDINGS_HPP
13 
14 #include "SparseBlockMatrix.hpp"
15 #include "CompressedSparseBlockIndex.hpp"
16 
17 #include <mkl.h>
18 #include <mkl_spblas.h>
19 
20 // Creates a compile error with boost
21 #ifdef P4
22 #undef P4
23 #endif
24 
25 namespace bogus
26 {
27 
29 namespace mkl
30 {
31 
33 template< typename Scalar >
34 struct bindings {} ;
35 
36 template< >
37 struct bindings< double >
38 {
39  typedef double Scalar ;
40  static void bsrmv (char *transa, MKL_INT *m, MKL_INT *k, MKL_INT *lb, Scalar *alpha, char *matdescra, Scalar *val, MKL_INT *indx, MKL_INT *pntrb, MKL_INT *pntre, Scalar *x, Scalar *beta, Scalar *y)
41  {
42  mkl_dbsrmv( transa, m, k, lb, alpha,
43  matdescra, val, indx, pntrb, pntre,
44  x, beta, y ) ;
45  }
46 } ;
47 
48 /* FIXME sbsrmv no supported on older MKL versions ; do version detection
49 template< >
50 struct bindings< float >
51 {
52  typedef float Scalar ;
53  static void bsrmv (char *transa, MKL_INT *m, MKL_INT *k, MKL_INT *lb, Scalar *alpha, char *matdescra, Scalar *val, MKL_INT *indx, MKL_INT *pntrb, MKL_INT *pntre, Scalar *x, Scalar *beta, Scalar *y)
54  {
55  mkl_sbsrmv( transa, m, k, lb, alpha,
56  matdescra, val, indx, pntrb, pntre,
57  x, beta, y ) ;
58  }
59 } ; */
60 
61 
62 template< typename Scalar >
63 void bsrmv(
64  bool Symmetric, bool Transpose, MKL_INT Dimension,
65  MKL_INT m, MKL_INT k, const std::size_t offset,
66  const MKL_INT *rowIndex, const MKL_INT *columns, const Scalar *data,
67  const Scalar *rhs, int rhsCols, Scalar *res, Scalar alpha, Scalar beta )
68 {
69 
70  char matdescra[4] = { Symmetric ? 'S' : 'G', 'L', 'N', 'C'} ;
71 
72  char transa = Transpose ? 'T' : 'N' ;
73  MKL_INT lb = Dimension ;
74  Scalar *x = const_cast< Scalar* > ( rhs ) ;
75  Scalar *y = res ;
76 
77  MKL_INT* pntrb = const_cast< MKL_INT* >( rowIndex+offset ) ;
78  MKL_INT* pntre = const_cast< MKL_INT* >( rowIndex+offset+1 ) ;
79  MKL_INT* indx = const_cast< MKL_INT* >( columns ) ;
80 
81  Scalar *a = const_cast< Scalar* > ( data ) + pntrb[0] * lb * lb ;
82 
83  for( int i = 0 ; i < rhsCols ; ++i )
84  {
85  bindings< Scalar >::bsrmv( &transa, &m, &k, &lb, &alpha,
86  matdescra, a, indx, pntrb, pntre,
87  x + i*Dimension*k, &beta, y + i*Dimension*k ) ;
88  }
89 
90 }
91 
92 template < typename BlockPtr, typename BlockType >
93 static void rowmv( const SparseBlockIndex< true, MKL_INT, BlockPtr >& index,
94  const BlockType* data, MKL_INT row,
95  const typename BlockTraits< BlockType >::Scalar *rhs, int rhsCols,
96  typename BlockTraits< BlockType >::Scalar *res )
97 {
98  const MKL_INT *rowIndexOrig = index.rowIndex() ;
99  const MKL_INT offset = rowIndexOrig[ row ] ;
100  const MKL_INT rowIndex[2] = {0, rowIndexOrig[ row+1 ] - offset} ;
101  const MKL_INT *columns = index.columns() + offset ;
102 
103  mkl::bsrmv< typename BlockTraits< BlockType >::Scalar >
104  ( false, false, BlockTraits< BlockType >::RowsAtCompileTime,
105  1, index.innerSize(), 0,
106  rowIndex, columns,
107  data_pointer( data[offset + index.base] ),
108  rhs, rhsCols, res, 1, 1 ) ;
109 }
110 
111 } //namespace mkl
112 
113 template <>
114 struct SparseBlockMatrixOpProxy< true, true, double, MKL_INT >
115 {
116  typedef double Scalar ;
117 
118  template < bool Transpose, typename Derived, typename RhsT, typename ResT >
119  static void multiply( const SparseBlockMatrixBase< Derived >& matrix, const RhsT& rhs, ResT& res,
120  Scalar alpha, Scalar beta )
121  {
122  typedef BlockMatrixTraits< Derived > Traits ;
123 
124  mkl::bsrmv< Scalar >
125  ( Traits::is_symmetric, Transpose, Derived::RowsPerBlock,
126  matrix.rowsOfBlocks(), matrix.colsOfBlocks(), 0,
127  matrix.majorIndex().rowIndex(), matrix.majorIndex().columns(),
128  data_pointer( matrix.data()[0] ),
129  rhs.data(), rhs.cols(), res.data(), alpha, beta ) ;
130  }
131 
132  template < typename Derived, typename RhsT, typename ResT >
133  static void splitRowMultiply( const SparseBlockMatrixBase< Derived >& matrix, typename Derived::Index row, const RhsT& rhs, ResT& res )
134  {
135  typedef BlockMatrixTraits< Derived > Traits ;
136 
137  if( Traits::is_symmetric && !matrix.transposeIndex().valid )
138  {
139  SparseBlockSplitRowMultiplier< Traits::is_symmetric, !Traits::is_col_major >
140  ::splitRowMultiply( matrix, row, rhs, res ) ;
141 
142  return ;
143  }
144 
145  mkl::rowmv( matrix.majorIndex(), matrix.data(), row,
146  rhs.data(), rhs.cols(), res.data() ) ;
147 
148  if( Traits::is_symmetric )
149  {
150  mkl::rowmv( matrix.transposeIndex(), matrix.data(), row,
151  rhs.data(), rhs.cols(), res.data() ) ;
152  }
153 
154  // Remove diagonal block if it exist
155  const typename Traits::BlockPtr diagPtr = matrix.diagonalBlockPtr( row ) ;
156  if( diagPtr != matrix.InvalidBlockPtr )
157  {
158 
159  const Segmenter< BlockDims< typename Derived::BlockType, false >::Cols, const RhsT, typename Derived::Index >
160  segmenter( rhs, matrix.colOffsets() ) ;
161  res -= matrix.block( diagPtr ) * segmenter[row] ;
162  }
163 
164  }
165 } ;
166 
167 
168 } //namespace bogus
169 
170 #endif
Definition: Traits.hpp:19
Wrapper over scalar-specific mkl calls.
Definition: MklBindings.hpp:34
Base class for Transpose views of a BlockObjectBase.
Definition: Expressions.hpp:22
static const BlockPtr InvalidBlockPtr
Return value of blockPtr( Index, Index ) for non-existing block.
Definition: BlockMatrixBase.hpp:38
Access to segment of a vector corresponding to a given block-row.
Definition: Access.hpp:133
BlockPtr diagonalBlockPtr(Index row) const
Return a BlockPtr to the block a (row, row) or InvalidBlockPtr if it does not exist.
const Index * colOffsets() const
Returns an array containing the first index of each column.
Definition: SparseBlockMatrixBase.hpp:484
const BlockType * data() const
Access to blocks data as a raw pointer.
Definition: BlockMatrixBase.hpp:91
Base class for SparseBlockMatrix.
Definition: SparseBlockMatrixBase.hpp:36