00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012 #ifndef MTL_CHOLESKY_INCLUDE
00013 #define MTL_CHOLESKY_INCLUDE
00014
00015 #include <boost/numeric/mtl/concept/collection.hpp>
00016 #include <boost/numeric/mtl/recursion/matrix_recursator.hpp>
00017 #include <boost/numeric/mtl/utility/glas_tag.hpp>
00018 #include <boost/numeric/mtl/operation/dmat_dmat_mult.hpp>
00019 #include <boost/numeric/mtl/operation/assign_mode.hpp>
00020 #include <boost/numeric/mtl/matrix/transposed_view.hpp>
00021 #include <boost/numeric/mtl/recursion/base_case_cast.hpp>
00022
00023 namespace mtl { namespace matrix {
00024
00025 namespace with_bracket {
00026
00027
00028
00029
00030
00031 template < typename Matrix >
00032 void cholesky_base (Matrix & matrix)
00033 {
00034 typedef typename Collection<Matrix>::size_type size_type;
00035 for (size_type k = 0; k < matrix.num_cols(); k++) {
00036 matrix[k][k] = sqrt (matrix[k][k]);
00037
00038 for (size_type i = k + 1; i < matrix.num_rows(); i++) {
00039 matrix[i][k] /= matrix[k][k];
00040 typename Collection<Matrix>::value_type d = matrix[i][k];
00041
00042 for (size_type j = k + 1; j <= i; j++)
00043 matrix[i][j] -= d * matrix[j][k];
00044 }
00045 }
00046 }
00047
00048
00049 template < typename MatrixSW, typename MatrixNW >
00050 void tri_solve_base(MatrixSW & SW, const MatrixNW & NW)
00051 {
00052 typedef typename Collection<MatrixSW>::size_type size_type;
00053 for (size_type k = 0; k < NW.num_rows (); k++) {
00054
00055 for (size_type i = 0; i < SW.num_rows (); i++) {
00056 SW[i][k] /= NW[k][k];
00057 typename MatrixSW::value_type d = SW[i][k];
00058
00059 for (size_type j = k + 1; j < SW.num_cols (); j++)
00060 SW[i][j] -= d * NW[j][k];
00061 }
00062 }
00063 }
00064
00065
00066
00067 template < typename MatrixSE, typename MatrixSW >
00068 void tri_schur_base(MatrixSE & SE, const MatrixSW & SW)
00069 {
00070 typedef typename Collection<MatrixSE>::size_type size_type;
00071 for (size_type k = 0; k < SW.num_cols (); k++)
00072
00073 for (size_type i = 0; i < SE.num_rows (); i++) {
00074 typename MatrixSW::value_type d = SW[i][k];
00075 for (size_type j = 0; j <= i; j++)
00076 SE[i][j] -= d * SW[j][k];
00077 }
00078 }
00079
00080
00081 template < typename MatrixNE, typename MatrixNW, typename MatrixSW >
00082 void schur_update_base(MatrixNE & NE, const MatrixNW & NW, const MatrixSW & SW)
00083 {
00084 typedef typename Collection<MatrixNE>::size_type size_type;
00085 for (size_type k = 0; k < NW.num_cols (); k++)
00086 for (size_type i = 0; i < NE.num_rows (); i++) {
00087 typename MatrixNW::value_type d = NW[i][k];
00088 for (size_type j = 0; j < NE.num_cols (); j++)
00089 NE[i][j] -= d * SW[j][k];
00090 }
00091 }
00092
00093
00094
00095
00096
00097
00098 struct cholesky_base_t
00099 {
00100 template < typename Matrix >
00101 void operator() (Matrix & matrix)
00102 {
00103 cholesky_base(matrix);
00104 }
00105 };
00106
00107 struct tri_solve_base_t
00108 {
00109 template < typename MatrixSW, typename MatrixNW >
00110 void operator() (MatrixSW & SW, const MatrixNW & NW)
00111 {
00112 tri_solve_base(SW, NW);
00113 }
00114 };
00115
00116 struct tri_schur_base_t
00117 {
00118 template < typename MatrixSE, typename MatrixSW >
00119 void operator() (MatrixSE & SE, const MatrixSW & SW)
00120 {
00121 tri_schur_base(SE, SW);
00122 }
00123 };
00124
00125 struct schur_update_base_t
00126 {
00127 template < typename MatrixNE, typename MatrixNW, typename MatrixSW >
00128 void operator() (MatrixNE & NE, const MatrixNW & NW, const MatrixSW & SW)
00129 {
00130 schur_update_base(NE, NW, SW);
00131 }
00132 };
00133
00134 }
00135
00136
00137 namespace with_iterator {
00138
00139
00140
00141
00142
00143
00144 template < typename Matrix >
00145 void cholesky_base (Matrix& matrix)
00146 {
00147 typedef typename Collection<Matrix>::size_type size_type;
00148
00149 using namespace glas::tag; using traits::range_generator;
00150 typedef tag::iter::all all_it;
00151
00152 typedef typename Collection<Matrix>::value_type value_type;
00153 typedef typename range_generator<col, Matrix>::type cur_type;
00154 typedef typename range_generator<all_it, cur_type>::type iter_type;
00155
00156 typedef typename range_generator<row, Matrix>::type rcur_type;
00157 typedef typename range_generator<all_it, rcur_type>::type riter_type;
00158
00159 size_type k= 0;
00160 for (cur_type kb= begin<col>(matrix), kend= end<col>(matrix); kb != kend; ++kb, ++k) {
00161
00162 iter_type ib= begin<all_it>(kb), iend= end<all_it>(kb);
00163 ib+= k;
00164
00165 value_type root= sqrt (*ib);
00166 *ib= root;
00167
00168 ++ib;
00169 rcur_type rb= begin<row>(matrix); rb+= k+1;
00170 for (size_type i= k + 1; ib != iend; ++ib, ++rb, ++i) {
00171 *ib = *ib / root;
00172 typename Collection<Matrix>::value_type d = *ib;
00173 riter_type it1= begin<all_it>(rb); it1+= k+1;
00174 riter_type it1end= begin<all_it>(rb); it1end+= i+1;
00175 iter_type it2= begin<all_it>(kb); it2+= k+1;
00176 for (; it1 != it1end; ++it1, ++it2)
00177 *it1 = *it1 - d * *it2;
00178 }
00179 }
00180 }
00181
00182
00183 template < typename MatrixSW, typename MatrixNW >
00184 void tri_solve_base(MatrixSW & SW, const MatrixNW & NW)
00185 {
00186 typedef typename Collection<MatrixSW>::size_type size_type;
00187
00188 using namespace glas::tag; using traits::range_generator;
00189 typedef tag::iter::all all_it;
00190 typedef tag::const_iter::all all_cit;
00191
00192 typedef typename range_generator<col, MatrixNW>::type ccur_type;
00193 typedef typename range_generator<all_cit, ccur_type>::type citer_type;
00194
00195 typedef typename range_generator<row, MatrixSW>::type rcur_type;
00196 typedef typename range_generator<all_it, rcur_type>::type riter_type;
00197
00198 for (size_type k = 0; k < NW.num_rows (); k++)
00199 for (size_type i = 0; i < SW.num_rows (); i++) {
00200
00201 typename MatrixSW::value_type d = SW[i][k] /= NW[k][k];
00202
00203 rcur_type sw_i= begin<row>(SW); sw_i+= i;
00204 riter_type it1= begin<all_it>(sw_i); it1+= k+1;
00205 riter_type it1end= end<all_it>(sw_i);
00206
00207 ccur_type nw_k= begin<col>(NW); nw_k+= k;
00208 citer_type it2= begin<all_cit>(nw_k); it2+= k+1;
00209
00210 for(; it1 != it1end; ++it1, ++it2)
00211 *it1 = *it1 - d * *it2;
00212 }
00213 }
00214
00215
00216
00217 template < typename MatrixSE, typename MatrixSW >
00218 void tri_schur_base(MatrixSE & SE, const MatrixSW & SW)
00219 {
00220 typedef typename Collection<MatrixSE>::size_type size_type;
00221
00222 using namespace glas::tag; using traits::range_generator;
00223 typedef tag::iter::all all_it;
00224 typedef tag::const_iter::all all_cit;
00225
00226 typedef typename range_generator<col, MatrixSW>::type ccur_type;
00227 typedef typename range_generator<all_cit, ccur_type>::type citer_type;
00228
00229 typedef typename range_generator<row, MatrixSE>::type rcur_type;
00230 typedef typename range_generator<all_it, rcur_type>::type riter_type;
00231
00232 for (size_type k = 0; k < SW.num_cols (); k++)
00233 for (size_type i = 0; i < SE.num_rows (); i++) {
00234 typename MatrixSW::value_type d = SW[i][k];
00235
00236 rcur_type se_i= begin<row>(SE); se_i+= i;
00237 riter_type it1= begin<all_it>(se_i);
00238 riter_type it1end= begin<all_it>(se_i); it1end+= i+1;
00239
00240 ccur_type sw_k= begin<col>(SW); sw_k+= k;
00241 citer_type it2= begin<all_cit>(sw_k);
00242
00243 for(; it1 != it1end; ++it1, ++it2)
00244 *it1 = *it1 - d * *it2;
00245 }
00246 }
00247
00248
00249 template < typename MatrixNE, typename MatrixNW, typename MatrixSW >
00250 void schur_update_base(MatrixNE & NE, const MatrixNW & NW, const MatrixSW & SW)
00251 {
00252 typedef typename Collection<MatrixNE>::size_type size_type;
00253
00254 using namespace glas::tag; using traits::range_generator;
00255 typedef tag::iter::all all_it;
00256 typedef tag::const_iter::all all_cit;
00257
00258 typedef typename range_generator<col, MatrixSW>::type ccur_type;
00259 typedef typename range_generator<all_cit, ccur_type>::type citer_type;
00260
00261 typedef typename range_generator<row, MatrixNE>::type rcur_type;
00262 typedef typename range_generator<all_it, rcur_type>::type riter_type;
00263
00264 for (size_type k = 0; k < NW.num_cols (); k++)
00265 for (size_type i = 0; i < NE.num_rows (); i++) {
00266 typename MatrixNW::value_type d = NW[i][k];
00267
00268 rcur_type ne_i= begin<row>(NE); ne_i+= i;
00269 riter_type it1= begin<all_it>(ne_i);
00270 riter_type it1end= end<all_it>(ne_i);
00271
00272 ccur_type sw_k= begin<col>(SW); sw_k+= k;
00273 citer_type it2= begin<all_cit>(sw_k);
00274
00275 for (size_type j = 0; j < NE.num_cols (); j++)
00276 NE[i][j] -= d * SW[j][k];
00277 }
00278 }
00279
00280
00281
00282
00283
00284
00285 struct cholesky_base_t
00286 {
00287 template < typename Matrix >
00288 void operator() (Matrix & matrix)
00289 {
00290 cholesky_base(matrix);
00291 }
00292 };
00293
00294 struct tri_solve_base_t
00295 {
00296 template < typename MatrixSW, typename MatrixNW >
00297 void operator() (MatrixSW & SW, const MatrixNW & NW)
00298 {
00299 tri_solve_base(SW, NW);
00300 }
00301 };
00302
00303 struct tri_schur_base_t
00304 {
00305 template < typename MatrixSE, typename MatrixSW >
00306 void operator() (MatrixSE & SE, const MatrixSW & SW)
00307 {
00308 tri_schur_base(SE, SW);
00309 }
00310 };
00311
00312 struct schur_update_base_t
00313 {
00314 template < typename MatrixNE, typename MatrixNW, typename MatrixSW >
00315 void operator() (MatrixNE & NE, const MatrixNW & NW, const MatrixSW & SW)
00316 {
00317 schur_update_base(NE, NW, SW);
00318 }
00319 };
00320
00321 }
00322
00323
00324
00325
00326
00327
00328
00329 template <typename BaseTest, typename CholeskyBase, typename TriSolveBase, typename TriSchur, typename SchurUpdate>
00330 struct recursive_cholesky_visitor_t
00331 {
00332 typedef BaseTest base_test;
00333
00334 template < typename Recursator >
00335 bool is_base(const Recursator& recursator) const
00336 {
00337 return base_test()(recursator);
00338 }
00339
00340 template < typename Matrix >
00341 void cholesky_base(Matrix & matrix) const
00342 {
00343 CholeskyBase()(matrix);
00344 }
00345
00346 template < typename MatrixSW, typename MatrixNW >
00347 void tri_solve_base(MatrixSW & SW, const MatrixNW & NW) const
00348 {
00349 TriSolveBase()(SW, NW);
00350 }
00351
00352 template < typename MatrixSE, typename MatrixSW >
00353 void tri_schur_base(MatrixSE & SE, const MatrixSW & SW) const
00354 {
00355 TriSchur()(SE, SW);
00356 }
00357
00358 template < typename MatrixNE, typename MatrixNW, typename MatrixSW >
00359 void schur_update_base(MatrixNE & NE, const MatrixNW & NW, const MatrixSW & SW) const
00360 {
00361 SchurUpdate()(NE, NW, SW);
00362 }
00363 };
00364
00365
00366 namespace detail {
00367
00368
00369 template <typename MatrixMult>
00370 struct mult_schur_update_t
00371 {
00372 template < typename MatrixNE, typename MatrixNW, typename MatrixSW >
00373 void operator()(MatrixNE & NE, const MatrixNW & NW, const MatrixSW & SW)
00374 {
00375 transposed_view<MatrixSW> trans_sw(const_cast<MatrixSW&>(SW));
00376 MatrixMult()(NW, trans_sw, NE);
00377 }
00378 };
00379
00380 }
00381
00382
00383 namespace with_bracket {
00384 typedef recursive_cholesky_visitor_t<recursion::bound_test_static<64>, cholesky_base_t, tri_solve_base_t,
00385 tri_schur_base_t, schur_update_base_t >
00386 recursive_cholesky_base_visitor_t;
00387 }
00388
00389 namespace with_iterator {
00390 typedef recursive_cholesky_visitor_t<recursion::bound_test_static<64>,
00391 cholesky_base_t, tri_solve_base_t, tri_schur_base_t, schur_update_base_t>
00392 recursive_cholesky_base_visitor_t;
00393 }
00394
00395 typedef with_bracket::recursive_cholesky_base_visitor_t recursive_cholesky_default_visitor_t;
00396
00397
00398
00399
00400
00401
00402 namespace with_recursator {
00403
00404 template <typename Recursator, typename Visitor>
00405 void schur_update(Recursator E, Recursator W, Recursator N, Visitor vis)
00406 {
00407 using namespace recursion;
00408
00409 if (E.is_empty() || W.is_empty() || N.is_empty())
00410 return;
00411
00412 if (vis.is_base(E)) {
00413 typedef typename Visitor::base_test base_test;
00414 typedef typename base_case_matrix<typename Recursator::matrix_type, base_test>::type matrix_type;
00415
00416 matrix_type base_E(base_case_cast<base_test>(E.get_value())),
00417 base_W(base_case_cast<base_test>(W.get_value())),
00418 base_N(base_case_cast<base_test>(N.get_value()));
00419 vis.schur_update_base(base_E, base_W, base_N);
00420 } else{
00421 schur_update( E.north_east(),W.north_west() ,N.south_west() , vis);
00422 schur_update( E.north_east(), W.north_east(), N.south_east(), vis);
00423
00424 schur_update(E.north_west() , W.north_east(), N.north_east(), vis);
00425 schur_update(E.north_west() ,W.north_west() ,N.north_west() , vis);
00426
00427 schur_update(E.south_west() ,W.south_west() ,N.north_west() , vis);
00428 schur_update(E.south_west() , W.south_east(), N.north_east(), vis);
00429
00430 schur_update( E.south_east(), W.south_east(), N.south_east(), vis);
00431 schur_update( E.south_east(),W.south_west() ,N.south_west() , vis);
00432 }
00433 }
00434
00435
00436 template <typename Recursator, typename Visitor>
00437 void tri_solve(Recursator S, Recursator N, Visitor vis)
00438 {
00439 using namespace recursion;
00440
00441 if (S.is_empty())
00442 return;
00443
00444 if (vis.is_base(S)) {
00445 typedef typename Visitor::base_test base_test;
00446 typedef typename base_case_matrix<typename Recursator::matrix_type, base_test>::type matrix_type;
00447
00448 matrix_type base_S(base_case_cast<base_test>(S.get_value())),
00449 base_N(base_case_cast<base_test>(N.get_value()));
00450
00451 vis.tri_solve_base(base_S, base_N);
00452 } else{
00453
00454 tri_solve(S.north_west() ,N.north_west(), vis);
00455 schur_update( S.north_east(),S.north_west() ,N.south_west(), vis);
00456 tri_solve( S.north_east(), N.south_east(), vis);
00457 tri_solve(S.south_west() ,N.north_west() , vis);
00458 schur_update( S.south_east(),S.south_west() ,N.south_west(), vis);
00459 tri_solve( S.south_east(), N.south_east(), vis);
00460 }
00461 }
00462
00463
00464 template <typename Recursator, typename Visitor>
00465 void tri_schur(Recursator E, Recursator W, Visitor vis)
00466 {
00467 using namespace recursion;
00468
00469 if (E.is_empty() || W.is_empty())
00470 return;
00471
00472 if (vis.is_base(W)) {
00473 typedef typename Visitor::base_test base_test;
00474 typedef typename base_case_matrix<typename Recursator::matrix_type, base_test>::type matrix_type;
00475
00476 matrix_type base_E(base_case_cast<base_test>(E.get_value())),
00477 base_W(base_case_cast<base_test>(W.get_value()));
00478 vis.tri_schur_base(base_E, base_W);
00479 } else{
00480
00481 schur_update(E.south_west(), W.south_west(), W.north_west(), vis);
00482 schur_update(E.south_west(), W.south_east(), W.north_east(), vis);
00483 tri_schur( E.south_east() , W.south_east(), vis);
00484 tri_schur( E.south_east() ,W.south_west() , vis);
00485 tri_schur( E.north_west(), W.north_east(), vis);
00486 tri_schur( E.north_west(),W.north_west() , vis);
00487 }
00488 }
00489
00490
00491 template <typename Recursator, typename Visitor>
00492 void cholesky(Recursator recursator, Visitor vis)
00493 {
00494 using namespace recursion;
00495
00496 if (recursator.is_empty())
00497 return;
00498
00499 if (vis.is_base (recursator)){
00500 typedef typename Visitor::base_test base_test;
00501 typedef typename base_case_matrix<typename Recursator::matrix_type, base_test>::type matrix_type;
00502
00503 matrix_type base_matrix(base_case_cast<base_test>(recursator.get_value()));
00504 vis.cholesky_base (base_matrix);
00505 } else {
00506 cholesky(recursator.north_west(), vis);
00507 tri_solve( recursator.south_west(), recursator.north_west(), vis);
00508 tri_schur( recursator.south_east(), recursator.south_west(), vis);
00509 cholesky( recursator.south_east(), vis);
00510 }
00511 }
00512
00513 }
00514
00515
00516
00517 template <typename Backup= with_bracket::cholesky_base_t>
00518 struct recursive_cholesky_t
00519 {
00520 template <typename Matrix>
00521 void operator()(Matrix& matrix)
00522 {
00523 (*this)(matrix, recursive_cholesky_default_visitor_t());
00524 }
00525
00526 template <typename Matrix, typename Visitor>
00527 void operator()(Matrix& matrix, Visitor vis)
00528 {
00529 apply(matrix, vis, typename traits::category<Matrix>::type());
00530 }
00531
00532 private:
00533
00534 template <typename Matrix, typename Visitor>
00535 void apply(Matrix& matrix, Visitor, tag::universe)
00536 {
00537 Backup()(matrix);
00538 }
00539
00540
00541 template <typename Matrix, typename Visitor>
00542 void apply(Matrix& matrix, Visitor vis, tag::qsub_dividable)
00543 {
00544 matrix::recursator<Matrix> recursator(matrix);
00545 with_recursator::cholesky(recursator, vis);
00546 }
00547 };
00548
00549
00550 template <typename Matrix, typename Visitor>
00551 inline void recursive_cholesky(Matrix& matrix, Visitor vis)
00552 {
00553 recursive_cholesky_t<>()(matrix, vis);
00554 }
00555
00556 template <typename Matrix>
00557 inline void recursive_cholesky(Matrix& matrix)
00558 {
00559 recursive_cholesky(matrix, recursive_cholesky_default_visitor_t());
00560 }
00561
00562
00563
00564
00565
00566 template <typename Matrix>
00567 void fill_matrix_for_cholesky(Matrix& matrix)
00568 {
00569 typedef typename Collection<Matrix>::size_type size_type;
00570 typedef typename Collection<Matrix>::value_type value_type;
00571
00572 value_type x= 1.0;
00573 for (size_type i= 0; i < num_rows(matrix); i++)
00574 for (size_type j= 0; j <= i; j++)
00575 if (i != j) {
00576 matrix[i][j]= x; matrix[j][i]= x;
00577 x+= 1.0;
00578 }
00579
00580 for (size_type i= 0; i < num_rows(matrix); i++) {
00581 value_type rowsum= 0.0;
00582 for (size_type j=0; j<matrix.num_cols(); j++)
00583 if (i != j)
00584 rowsum += matrix[i][j];
00585 matrix[i][i]= rowsum * 2;
00586 }
00587 }
00588
00589
00590 }}
00591
00592
00593
00594
00595 #endif // MTL_CHOLESKY_INCLUDE