pktools 2.6.7
Processing Kernel for geospatial data
Classes | Public Member Functions | Protected Types | Protected Member Functions | Protected Attributes | List of all members
Solver Class Reference
Inheritance diagram for Solver:
Inheritance graph
[legend]
Collaboration diagram for Solver:
Collaboration graph
[legend]

Classes

struct  SolutionInfo
 

Public Member Functions

void Solve (int l, const QMatrix &Q, const double *p_, const schar *y_, double *alpha_, double Cp, double Cn, double eps, SolutionInfo *si, int shrinking, bool verbose=false)
 

Protected Types

enum  { LOWER_BOUND , UPPER_BOUND , FREE }
 

Protected Member Functions

double get_C (int i)
 
void update_alpha_status (int i)
 
bool is_upper_bound (int i)
 
bool is_lower_bound (int i)
 
bool is_free (int i)
 
void swap_index (int i, int j)
 
void reconstruct_gradient ()
 
virtual int select_working_set (int &i, int &j)
 
virtual double calculate_rho ()
 
virtual void do_shrinking ()
 

Protected Attributes

int active_size
 
schar * y
 
double * G
 
char * alpha_status
 
double * alpha
 
const QMatrixQ
 
const double * QD
 
double eps
 
double Cp
 
double Cn
 
double * p
 
int * active_set
 
double * G_bar
 
int l
 
bool unshrink
 

Detailed Description

Definition at line 395 of file svm.cpp.

Member Enumeration Documentation

◆ anonymous enum

anonymous enum
protected

Definition at line 415 of file svm.cpp.

415{ LOWER_BOUND, UPPER_BOUND, FREE };

Constructor & Destructor Documentation

◆ Solver()

Solver::Solver ( )
inline

Definition at line 397 of file svm.cpp.

397{};

◆ ~Solver()

virtual Solver::~Solver ( )
inlinevirtual

Definition at line 398 of file svm.cpp.

398{};

Member Function Documentation

◆ calculate_rho()

double Solver::calculate_rho ( )
protectedvirtual

Definition at line 973 of file svm.cpp.

974{
975 double r;
976 int nr_free = 0;
977 double ub = INF, lb = -INF, sum_free = 0;
978 for(int i=0;i<active_size;i++)
979 {
980 double yG = y[i]*G[i];
981
982 if(is_upper_bound(i))
983 {
984 if(y[i]==-1)
985 ub = min(ub,yG);
986 else
987 lb = max(lb,yG);
988 }
989 else if(is_lower_bound(i))
990 {
991 if(y[i]==+1)
992 ub = min(ub,yG);
993 else
994 lb = max(lb,yG);
995 }
996 else
997 {
998 ++nr_free;
999 sum_free += yG;
1000 }
1001 }
1002
1003 if(nr_free>0)
1004 r = sum_free/nr_free;
1005 else
1006 r = (ub+lb)/2;
1007
1008 return r;
1009}

◆ do_shrinking()

void Solver::do_shrinking ( )
protectedvirtual

Definition at line 912 of file svm.cpp.

913{
914 int i;
915 double Gmax1 = -INF; // max { -y_i * grad(f)_i | i in I_up(\alpha) }
916 double Gmax2 = -INF; // max { y_i * grad(f)_i | i in I_low(\alpha) }
917
918 // find maximal violating pair first
919 for(i=0;i<active_size;i++)
920 {
921 if(y[i]==+1)
922 {
923 if(!is_upper_bound(i))
924 {
925 if(-G[i] >= Gmax1)
926 Gmax1 = -G[i];
927 }
928 if(!is_lower_bound(i))
929 {
930 if(G[i] >= Gmax2)
931 Gmax2 = G[i];
932 }
933 }
934 else
935 {
936 if(!is_upper_bound(i))
937 {
938 if(-G[i] >= Gmax2)
939 Gmax2 = -G[i];
940 }
941 if(!is_lower_bound(i))
942 {
943 if(G[i] >= Gmax1)
944 Gmax1 = G[i];
945 }
946 }
947 }
948
949 if(unshrink == false && Gmax1 + Gmax2 <= eps*10)
950 {
951 unshrink = true;
952 reconstruct_gradient();
953 active_size = l;
954 info("*");
955 }
956
957 for(i=0;i<active_size;i++)
958 if (be_shrunk(i, Gmax1, Gmax2))
959 {
960 active_size--;
961 while (active_size > i)
962 {
963 if (!be_shrunk(active_size, Gmax1, Gmax2))
964 {
965 swap_index(i,active_size);
966 break;
967 }
968 active_size--;
969 }
970 }
971}

◆ get_C()

double Solver::get_C ( int  i)
inlineprotected

Definition at line 428 of file svm.cpp.

429 {
430 return (y[i] > 0)? Cp : Cn;
431 }

◆ is_free()

bool Solver::is_free ( int  i)
inlineprotected

Definition at line 442 of file svm.cpp.

442{ return alpha_status[i] == FREE; }

◆ is_lower_bound()

bool Solver::is_lower_bound ( int  i)
inlineprotected

Definition at line 441 of file svm.cpp.

441{ return alpha_status[i] == LOWER_BOUND; }

◆ is_upper_bound()

bool Solver::is_upper_bound ( int  i)
inlineprotected

Definition at line 440 of file svm.cpp.

440{ return alpha_status[i] == UPPER_BOUND; }

◆ reconstruct_gradient()

void Solver::reconstruct_gradient ( )
protected

Definition at line 464 of file svm.cpp.

465{
466 // reconstruct inactive elements of G from G_bar and free variables
467
468 if(active_size == l) return;
469
470 int i,j;
471 int nr_free = 0;
472
473 for(j=active_size;j<l;j++)
474 G[j] = G_bar[j] + p[j];
475
476 for(j=0;j<active_size;j++)
477 if(is_free(j))
478 nr_free++;
479
480 if(2*nr_free < active_size)
481 info("\nWARNING: using -h 0 may be faster\n");
482
483 if (nr_free*l > 2*active_size*(l-active_size))
484 {
485 for(i=active_size;i<l;i++)
486 {
487 const Qfloat *Q_i = Q->get_Q(i,active_size);
488 for(j=0;j<active_size;j++)
489 if(is_free(j))
490 G[i] += alpha[j] * Q_i[j];
491 }
492 }
493 else
494 {
495 for(i=0;i<active_size;i++)
496 if(is_free(i))
497 {
498 const Qfloat *Q_i = Q->get_Q(i,l);
499 double alpha_i = alpha[i];
500 for(j=active_size;j<l;j++)
501 G[j] += alpha_i * Q_i[j];
502 }
503 }
504}

◆ select_working_set()

int Solver::select_working_set ( int &  i,
int &  j 
)
protectedvirtual

Definition at line 793 of file svm.cpp.

794{
795 // return i,j such that
796 // i: maximizes -y_i * grad(f)_i, i in I_up(\alpha)
797 // j: minimizes the decrease of obj value
798 // (if quadratic coefficeint <= 0, replace it with tau)
799 // -y_j*grad(f)_j < -y_i*grad(f)_i, j in I_low(\alpha)
800
801 double Gmax = -INF;
802 double Gmax2 = -INF;
803 int Gmax_idx = -1;
804 int Gmin_idx = -1;
805 double obj_diff_min = INF;
806
807 for(int t=0;t<active_size;t++)
808 if(y[t]==+1)
809 {
810 if(!is_upper_bound(t))
811 if(-G[t] >= Gmax)
812 {
813 Gmax = -G[t];
814 Gmax_idx = t;
815 }
816 }
817 else
818 {
819 if(!is_lower_bound(t))
820 if(G[t] >= Gmax)
821 {
822 Gmax = G[t];
823 Gmax_idx = t;
824 }
825 }
826
827 int i = Gmax_idx;
828 const Qfloat *Q_i = NULL;
829 if(i != -1) // NULL Q_i not accessed: Gmax=-INF if i=-1
830 Q_i = Q->get_Q(i,active_size);
831
832 for(int j=0;j<active_size;j++)
833 {
834 if(y[j]==+1)
835 {
836 if (!is_lower_bound(j))
837 {
838 double grad_diff=Gmax+G[j];
839 if (G[j] >= Gmax2)
840 Gmax2 = G[j];
841 if (grad_diff > 0)
842 {
843 double obj_diff;
844 double quad_coef = QD[i]+QD[j]-2.0*y[i]*Q_i[j];
845 if (quad_coef > 0)
846 obj_diff = -(grad_diff*grad_diff)/quad_coef;
847 else
848 obj_diff = -(grad_diff*grad_diff)/TAU;
849
850 if (obj_diff <= obj_diff_min)
851 {
852 Gmin_idx=j;
853 obj_diff_min = obj_diff;
854 }
855 }
856 }
857 }
858 else
859 {
860 if (!is_upper_bound(j))
861 {
862 double grad_diff= Gmax-G[j];
863 if (-G[j] >= Gmax2)
864 Gmax2 = -G[j];
865 if (grad_diff > 0)
866 {
867 double obj_diff;
868 double quad_coef = QD[i]+QD[j]+2.0*y[i]*Q_i[j];
869 if (quad_coef > 0)
870 obj_diff = -(grad_diff*grad_diff)/quad_coef;
871 else
872 obj_diff = -(grad_diff*grad_diff)/TAU;
873
874 if (obj_diff <= obj_diff_min)
875 {
876 Gmin_idx=j;
877 obj_diff_min = obj_diff;
878 }
879 }
880 }
881 }
882 }
883
884 if(Gmax+Gmax2 < eps)
885 return 1;
886
887 out_i = Gmax_idx;
888 out_j = Gmin_idx;
889 return 0;
890}

◆ Solve()

void Solver::Solve ( int  l,
const QMatrix Q,
const double *  p_,
const schar *  y_,
double *  alpha_,
double  Cp,
double  Cn,
double  eps,
SolutionInfo si,
int  shrinking,
bool  verbose = false 
)

Definition at line 506 of file svm.cpp.

510{
511 this->l = l;
512 this->Q = &Q;
513 QD=Q.get_QD();
514 clone(p, p_,l);
515 clone(y, y_,l);
516 clone(alpha,alpha_,l);
517 this->Cp = Cp;
518 this->Cn = Cn;
519 this->eps = eps;
520 unshrink = false;
521
522 // initialize alpha_status
523 {
524 alpha_status = new char[l];
525 for(int i=0;i<l;i++)
526 update_alpha_status(i);
527 }
528
529 // initialize active set (for shrinking)
530 {
531 active_set = new int[l];
532 for(int i=0;i<l;i++)
533 active_set[i] = i;
534 active_size = l;
535 }
536
537 // initialize gradient
538 {
539 G = new double[l];
540 G_bar = new double[l];
541 int i;
542 for(i=0;i<l;i++)
543 {
544 G[i] = p[i];
545 G_bar[i] = 0;
546 }
547 for(i=0;i<l;i++)
548 if(!is_lower_bound(i))
549 {
550 const Qfloat *Q_i = Q.get_Q(i,l);
551 double alpha_i = alpha[i];
552 int j;
553 for(j=0;j<l;j++)
554 G[j] += alpha_i*Q_i[j];
555 if(is_upper_bound(i))
556 for(j=0;j<l;j++)
557 G_bar[j] += get_C(i) * Q_i[j];
558 }
559 }
560
561 // optimization step
562
563 int iter = 0;
564 int max_iter = max(10000000, l>INT_MAX/100 ? INT_MAX : 100*l);
565 int counter = min(l,1000)+1;
566
567 while(iter < max_iter)
568 {
569 // show progress and do shrinking
570
571 if(--counter == 0)
572 {
573 counter = min(l,1000);
574 if(shrinking) do_shrinking();
575 if(verbose)//pk
576 info(".");
577 }
578
579 int i,j;
580 if(select_working_set(i,j)!=0)
581 {
582 // reconstruct the whole gradient
583 reconstruct_gradient();
584 // reset active set size and check
585 active_size = l;
586 if(verbose)//pk
587 info("*");
588 if(select_working_set(i,j)!=0)
589 break;
590 else
591 counter = 1; // do shrinking next iteration
592 }
593
594 ++iter;
595
596 // update alpha[i] and alpha[j], handle bounds carefully
597
598 const Qfloat *Q_i = Q.get_Q(i,active_size);
599 const Qfloat *Q_j = Q.get_Q(j,active_size);
600
601 double C_i = get_C(i);
602 double C_j = get_C(j);
603
604 double old_alpha_i = alpha[i];
605 double old_alpha_j = alpha[j];
606
607 if(y[i]!=y[j])
608 {
609 double quad_coef = QD[i]+QD[j]+2*Q_i[j];
610 if (quad_coef <= 0)
611 quad_coef = TAU;
612 double delta = (-G[i]-G[j])/quad_coef;
613 double diff = alpha[i] - alpha[j];
614 alpha[i] += delta;
615 alpha[j] += delta;
616
617 if(diff > 0)
618 {
619 if(alpha[j] < 0)
620 {
621 alpha[j] = 0;
622 alpha[i] = diff;
623 }
624 }
625 else
626 {
627 if(alpha[i] < 0)
628 {
629 alpha[i] = 0;
630 alpha[j] = -diff;
631 }
632 }
633 if(diff > C_i - C_j)
634 {
635 if(alpha[i] > C_i)
636 {
637 alpha[i] = C_i;
638 alpha[j] = C_i - diff;
639 }
640 }
641 else
642 {
643 if(alpha[j] > C_j)
644 {
645 alpha[j] = C_j;
646 alpha[i] = C_j + diff;
647 }
648 }
649 }
650 else
651 {
652 double quad_coef = QD[i]+QD[j]-2*Q_i[j];
653 if (quad_coef <= 0)
654 quad_coef = TAU;
655 double delta = (G[i]-G[j])/quad_coef;
656 double sum = alpha[i] + alpha[j];
657 alpha[i] -= delta;
658 alpha[j] += delta;
659
660 if(sum > C_i)
661 {
662 if(alpha[i] > C_i)
663 {
664 alpha[i] = C_i;
665 alpha[j] = sum - C_i;
666 }
667 }
668 else
669 {
670 if(alpha[j] < 0)
671 {
672 alpha[j] = 0;
673 alpha[i] = sum;
674 }
675 }
676 if(sum > C_j)
677 {
678 if(alpha[j] > C_j)
679 {
680 alpha[j] = C_j;
681 alpha[i] = sum - C_j;
682 }
683 }
684 else
685 {
686 if(alpha[i] < 0)
687 {
688 alpha[i] = 0;
689 alpha[j] = sum;
690 }
691 }
692 }
693
694 // update G
695
696 double delta_alpha_i = alpha[i] - old_alpha_i;
697 double delta_alpha_j = alpha[j] - old_alpha_j;
698
699 for(int k=0;k<active_size;k++)
700 {
701 G[k] += Q_i[k]*delta_alpha_i + Q_j[k]*delta_alpha_j;
702 }
703
704 // update alpha_status and G_bar
705
706 {
707 bool ui = is_upper_bound(i);
708 bool uj = is_upper_bound(j);
709 update_alpha_status(i);
710 update_alpha_status(j);
711 int k;
712 if(ui != is_upper_bound(i))
713 {
714 Q_i = Q.get_Q(i,l);
715 if(ui)
716 for(k=0;k<l;k++)
717 G_bar[k] -= C_i * Q_i[k];
718 else
719 for(k=0;k<l;k++)
720 G_bar[k] += C_i * Q_i[k];
721 }
722
723 if(uj != is_upper_bound(j))
724 {
725 Q_j = Q.get_Q(j,l);
726 if(uj)
727 for(k=0;k<l;k++)
728 G_bar[k] -= C_j * Q_j[k];
729 else
730 for(k=0;k<l;k++)
731 G_bar[k] += C_j * Q_j[k];
732 }
733 }
734 }
735
736 if(iter >= max_iter)
737 {
738 if(active_size < l)
739 {
740 // reconstruct the whole gradient to calculate objective value
741 reconstruct_gradient();
742 active_size = l;
743 if(verbose)//pk
744 info("*");
745 }
746 info("\nWARNING: reaching max number of iterations");
747 }
748
749 // calculate rho
750
751 si->rho = calculate_rho();
752
753 // calculate objective value
754 {
755 double v = 0;
756 int i;
757 for(i=0;i<l;i++)
758 v += alpha[i] * (G[i] + p[i]);
759
760 si->obj = v/2;
761 }
762
763 // put back the solution
764 {
765 for(int i=0;i<l;i++)
766 alpha_[active_set[i]] = alpha[i];
767 }
768
769 // juggle everything back
770 /*{
771 for(int i=0;i<l;i++)
772 while(active_set[i] != i)
773 swap_index(i,active_set[i]);
774 // or Q.swap_index(i,active_set[i]);
775 }*/
776
777 si->upper_bound_p = Cp;
778 si->upper_bound_n = Cn;
779
780 if(verbose)//pk
781 info("\noptimization finished, #iter = %d\n",iter);
782
783 delete[] p;
784 delete[] y;
785 delete[] alpha;
786 delete[] alpha_status;
787 delete[] active_set;
788 delete[] G;
789 delete[] G_bar;
790}

◆ swap_index()

void Solver::swap_index ( int  i,
int  j 
)
protected

Definition at line 452 of file svm.cpp.

453{
454 Q->swap_index(i,j);
455 swap(y[i],y[j]);
456 swap(G[i],G[j]);
457 swap(alpha_status[i],alpha_status[j]);
458 swap(alpha[i],alpha[j]);
459 swap(p[i],p[j]);
460 swap(active_set[i],active_set[j]);
461 swap(G_bar[i],G_bar[j]);
462}

◆ update_alpha_status()

void Solver::update_alpha_status ( int  i)
inlineprotected

Definition at line 432 of file svm.cpp.

433 {
434 if(alpha[i] >= get_C(i))
435 alpha_status[i] = UPPER_BOUND;
436 else if(alpha[i] <= 0)
437 alpha_status[i] = LOWER_BOUND;
438 else alpha_status[i] = FREE;
439 }

Member Data Documentation

◆ active_set

int* Solver::active_set
protected

Definition at line 423 of file svm.cpp.

◆ active_size

int Solver::active_size
protected

Definition at line 412 of file svm.cpp.

◆ alpha

double* Solver::alpha
protected

Definition at line 417 of file svm.cpp.

◆ alpha_status

char* Solver::alpha_status
protected

Definition at line 416 of file svm.cpp.

◆ Cn

double Solver::Cn
protected

Definition at line 421 of file svm.cpp.

◆ Cp

double Solver::Cp
protected

Definition at line 421 of file svm.cpp.

◆ eps

double Solver::eps
protected

Definition at line 420 of file svm.cpp.

◆ G

double* Solver::G
protected

Definition at line 414 of file svm.cpp.

◆ G_bar

double* Solver::G_bar
protected

Definition at line 424 of file svm.cpp.

◆ l

int Solver::l
protected

Definition at line 425 of file svm.cpp.

◆ p

double* Solver::p
protected

Definition at line 422 of file svm.cpp.

◆ Q

const QMatrix* Solver::Q
protected

Definition at line 418 of file svm.cpp.

◆ QD

const double* Solver::QD
protected

Definition at line 419 of file svm.cpp.

◆ unshrink

bool Solver::unshrink
protected

Definition at line 426 of file svm.cpp.

◆ y

schar* Solver::y
protected

Definition at line 413 of file svm.cpp.


The documentation for this class was generated from the following file: