2012-05-29 14 views
13

Sto cercando di utilizzare i binding Java per libsvm:libsvm implementazione Java

http://www.csie.ntu.edu.tw/~cjlin/libsvm/ 

ho implementato un esempio 'banale', che è facilmente separabili linearmente in y. I dati sono definiti come:

double[][] train = new double[1000][]; 
double[][] test = new double[10][]; 

for (int i = 0; i < train.length; i++){ 
    if (i+1 > (train.length/2)){  // 50% positive 
     double[] vals = {1,0,i+i}; 
     train[i] = vals; 
    } else { 
     double[] vals = {0,0,i-i-i-2}; // 50% negative 
     train[i] = vals; 
    }   
} 

Dove la prima 'caratteristica' è la classe e il set di allenamento è definito in modo simile.

per addestrare il modello:

private svm_model svmTrain() { 
    svm_problem prob = new svm_problem(); 
    int dataCount = train.length; 
    prob.y = new double[dataCount]; 
    prob.l = dataCount; 
    prob.x = new svm_node[dataCount][];  

    for (int i = 0; i < dataCount; i++){    
     double[] features = train[i]; 
     prob.x[i] = new svm_node[features.length-1]; 
     for (int j = 1; j < features.length; j++){ 
      svm_node node = new svm_node(); 
      node.index = j; 
      node.value = features[j]; 
      prob.x[i][j-1] = node; 
     }   
     prob.y[i] = features[0]; 
    }    

    svm_parameter param = new svm_parameter(); 
    param.probability = 1; 
    param.gamma = 0.5; 
    param.nu = 0.5; 
    param.C = 1; 
    param.svm_type = svm_parameter.C_SVC; 
    param.kernel_type = svm_parameter.LINEAR;  
    param.cache_size = 20000; 
    param.eps = 0.001;  

    svm_model model = svm.svm_train(prob, param); 

    return model; 
} 

Quindi per valutare il modello che uso:

public int evaluate(double[] features) { 
    svm_node node = new svm_node(); 
    for (int i = 1; i < features.length; i++){ 
     node.index = i; 
     node.value = features[i]; 
    } 
    svm_node[] nodes = new svm_node[1]; 
    nodes[0] = node; 

    int totalClasses = 2;  
    int[] labels = new int[totalClasses]; 
    svm.svm_get_labels(_model,labels); 

    double[] prob_estimates = new double[totalClasses]; 
    double v = svm.svm_predict_probability(_model, nodes, prob_estimates); 

    for (int i = 0; i < totalClasses; i++){ 
     System.out.print("(" + labels[i] + ":" + prob_estimates[i] + ")"); 
    } 
    System.out.println("(Actual:" + features[0] + " Prediction:" + v + ")");    

    return (int)v; 
} 

Dove la matrice passata è un punto dal set di test.

I risultati sono sempre tornando classe 0. Con il risultato esatto dell'essere:

(0:0.9882998314585194)(1:0.011700168541480586)(Actual:0.0 Prediction:0.0) 
(0:0.9883952943701599)(1:0.011604705629839989)(Actual:0.0 Prediction:0.0) 
(0:0.9884899803606306)(1:0.011510019639369528)(Actual:0.0 Prediction:0.0) 
(0:0.9885838957058696)(1:0.011416104294130458)(Actual:0.0 Prediction:0.0) 
(0:0.9886770466322342)(1:0.011322953367765776)(Actual:0.0 Prediction:0.0) 
(0:0.9870913229268679)(1:0.012908677073132284)(Actual:1.0 Prediction:0.0) 
(0:0.9868781382588805)(1:0.013121861741119505)(Actual:1.0 Prediction:0.0) 
(0:0.986661444476744)(1:0.013338555523255982)(Actual:1.0 Prediction:0.0) 
(0:0.9864411843906802)(1:0.013558815609319848)(Actual:1.0 Prediction:0.0) 
(0:0.9862172999068877)(1:0.013782700093112332)(Actual:1.0 Prediction:0.0) 

qualcuno può spiegare perché questo classificatore non sta funzionando? C'è un passaggio che ho incasinato, o un passaggio che mi manca?

Grazie

risposta

13

mi sembra che il metodo di valutazione sia errato. Dovrebbe essere qualcosa di simile:

public double evaluate(double[] features, svm_model model) 
{ 
    svm_node[] nodes = new svm_node[features.length-1]; 
    for (int i = 1; i < features.length; i++) 
    { 
     svm_node node = new svm_node(); 
     node.index = i; 
     node.value = features[i]; 

     nodes[i-1] = node; 
    } 

    int totalClasses = 2;  
    int[] labels = new int[totalClasses]; 
    svm.svm_get_labels(model,labels); 

    double[] prob_estimates = new double[totalClasses]; 
    double v = svm.svm_predict_probability(model, nodes, prob_estimates); 

    for (int i = 0; i < totalClasses; i++){ 
     System.out.print("(" + labels[i] + ":" + prob_estimates[i] + ")"); 
    } 
    System.out.println("(Actual:" + features[0] + " Prediction:" + v + ")");    

    return v; 
} 
+4

Puoi spiegare qual è l'errore nel codice domanda? Sto avendo problemi nel localizzare l'errore! :( – Daniel

2

Ecco una rielaborazione della esempio di cui sopra, che ho testato utilizzando i dati del seguente codice R: http://cbio.ensmp.fr/~jvert/svn/tutorials/practical/svmbasic/svmbasic_notes.pdf

import libsvm.*; 

public class libsvmTest { 

    public static void main(String [] args) { 

     double[][] xtrain = ... 
     double[][] xtest = ... 
     double[][] ytrain = ... 
     double[][] ytest = ... 

     svm_model m = svmTrain(xtrain,ytrain); 

     double[] ypred = svmPredict(xtest, m); 

     for (int i = 0; i < xtest.length; i++){ 
      System.out.println("(Actual:" + ytest[i][0] + " Prediction:" + ypred[i] + ")"); 
     } 

    } 

    static svm_model svmTrain(double[][] xtrain, double[][] ytrain) { 
     svm_problem prob = new svm_problem(); 
     int recordCount = xtrain.length; 
     int featureCount = xtrain[0].length; 
     prob.y = new double[recordCount]; 
     prob.l = recordCount; 
     prob.x = new svm_node[recordCount][featureCount];  

     for (int i = 0; i < recordCount; i++){    
      double[] features = xtrain[i]; 
      prob.x[i] = new svm_node[features.length]; 
      for (int j = 0; j < features.length; j++){ 
       svm_node node = new svm_node(); 
       node.index = j; 
       node.value = features[j]; 
       prob.x[i][j] = node; 
      }   
      prob.y[i] = ytrain[i][0]; 
     }    

     svm_parameter param = new svm_parameter(); 
     param.probability = 1; 
     param.gamma = 0.5; 
     param.nu = 0.5; 
     param.C = 100; 
     param.svm_type = svm_parameter.C_SVC; 
     param.kernel_type = svm_parameter.LINEAR;  
     param.cache_size = 20000; 
     param.eps = 0.001;  

     svm_model model = svm.svm_train(prob, param); 

     return model; 
    } 

    static double[] svmPredict(double[][] xtest, svm_model model) 
    { 

     double[] yPred = new double[xtest.length]; 

     for(int k = 0; k < xtest.length; k++){ 

     double[] fVector = xtest[k]; 

     svm_node[] nodes = new svm_node[fVector.length]; 
     for (int i = 0; i < fVector.length; i++) 
     { 
      svm_node node = new svm_node(); 
      node.index = i; 
      node.value = fVector[i]; 
      nodes[i] = node; 
     } 

     int totalClasses = 2;  
     int[] labels = new int[totalClasses]; 
     svm.svm_get_labels(model,labels); 

     double[] prob_estimates = new double[totalClasses]; 
     yPred[k] = svm.svm_predict_probability(model, nodes, prob_estimates); 

     } 

     return yPred; 
    } 


} 

Ecco l'output:

(Actual:1.0 Prediction:1.0) 
(Actual:1.0 Prediction:1.0) 
(Actual:1.0 Prediction:1.0) 
(Actual:1.0 Prediction:1.0) 
(Actual:1.0 Prediction:1.0) 
(Actual:1.0 Prediction:1.0) 
(Actual:1.0 Prediction:1.0) 
(Actual:1.0 Prediction:1.0) 
(Actual:1.0 Prediction:1.0) 
(Actual:1.0 Prediction:1.0) 
(Actual:1.0 Prediction:1.0) 
(Actual:1.0 Prediction:1.0) 
(Actual:1.0 Prediction:1.0) 
(Actual:1.0 Prediction:1.0) 
(Actual:1.0 Prediction:1.0) 
(Actual:-1.0 Prediction:-1.0) 
(Actual:-1.0 Prediction:-1.0) 
(Actual:-1.0 Prediction:-1.0) 
(Actual:-1.0 Prediction:-1.0) 
(Actual:-1.0 Prediction:-1.0) 
(Actual:-1.0 Prediction:-1.0) 
(Actual:-1.0 Prediction:-1.0) 
(Actual:-1.0 Prediction:-1.0) 
(Actual:-1.0 Prediction:1.0) 
(Actual:-1.0 Prediction:-1.0) 
(Actual:-1.0 Prediction:-1.0) 
(Actual:-1.0 Prediction:-1.0) 
(Actual:-1.0 Prediction:-1.0) 
(Actual:-1.0 Prediction:-1.0) 
(Actual:-1.0 Prediction:-1.0) 
+0

Grazie mille per il codice utile, perché hai usato param.probability = 1 ;? E in secondo luogo, sai come si può impostare il peso se si hanno classi sbilanciate? Voglio dire il peso con cui il parametro C ponderato – machinery

+0

Il valore di prob_estimates perda l'ambito quando si chiama svm.svm_predict_probability()? – user1040535

+0

Questo è semplicemente un post per facilitare l'avvio con LIBSVM, da cui l'utente deve determinare cosa funziona in base al problema. Per domande su questo argomento, ti suggerisco di visitare il sito dei maintaners di questo pacchetto: https://www.csie.ntu.edu.tw/~cjlin/libsvm/faq.html#/Q06:_Probability_outputs –