2014-04-29 9 views
9

Ho notato un comportamento incoerente in numpy.dot quando sono coinvolti gli zeri nan e zeri.Numpy.dot bug? Comportamento NaN incoerente

Qualcuno può dare un senso a questo? è un insetto? È specifico per la funzione dot?

Sto usando numpy v1.6.1, 64 bit, in esecuzione su linux (anche testato su v1.6.2). Ho anche provato su v1.8.0 su Windows 32bit (quindi non posso dire se le differenze sono dovute alla versione o al sistema operativo o all'arco).

from numpy import * 
0*nan, nan*0 
=> (nan, nan) # makes sense 

#1 
a = array([[0]]) 
b = array([[nan]]) 
dot(a, b) 
=> array([[ nan]]) # OK 

#2 -- adding a value to b. the first value in the result is 
#  not expected to be affected. 
a = array([[0]]) 
b = array([[nan, 1]]) 
dot(a, b) 
=> array([[ 0., 0.]]) # EXPECTED : array([[ nan, 0.]]) 
# (also happens in 1.6.2 and 1.8.0) 
# Also, as @Bill noted, a*b works as expected, but not dot(a,b) 

#3 -- changing a from 0 to 1, the first value in the result is 
#  not expected to be affected. 
a = array([[1]]) 
b = array([[nan, 1]]) 
dot(a, b) 
=> array([[ nan, 1.]]) # OK 

#4 -- changing shape of a, changes nan in result 
a = array([[0],[0]]) 
b = array([[ nan, 1.]]) 
dot(a, b) 
=> array([[ 0., 0.], [ 0., 0.]]) # EXPECTED : array([[ nan, 0.], [ nan, 0.]]) 
# (works as expected in 1.6.2 and 1.8.0) 

Caso # 4 sembra funzionare correttamente in v1.6.2 e v1.8.0, ma non Caso # 2 ...


EDIT: @seberg ha sottolineato che questo è un problema Blas ecco le informazioni circa l'installazione Blas ho trovato eseguendo from numpy.distutils.system_info import get_info; get_info('blas_opt'):

1.6.1 linux 64bit 
/usr/lib/python2.7/dist-packages/numpy/distutils/system_info.py:1423: UserWarning: 
    Atlas (http://math-atlas.sourceforge.net/) libraries not found. 
    Directories to search for the libraries can be specified in the 
    numpy/distutils/site.cfg file (section [atlas]) or by setting 
    the ATLAS environment variable. 
    warnings.warn(AtlasNotFoundError.__doc__) 
{'libraries': ['blas'], 'library_dirs': ['/usr/lib'], 'language': 'f77', 'define_macros': [('NO_ATLAS_INFO', 1)]} 

1.8.0 windows 32bit (anaconda) 
c:\Anaconda\Lib\site-packages\numpy\distutils\system_info.py:1534: UserWarning: 
    Blas (http://www.netlib.org/blas/) sources not found. 
    Directories to search for the sources can be specified in the 
    numpy/distutils/site.cfg file (section [blas_src]) or by setting 
    the BLAS_SRC environment variable. 
warnings.warn(BlasSrcNotFoundError.__doc__) 
{} 

(io personalmente non so che cosa fare di esso)

+1

È interessante per il caso 2, 'a * b' fornisce il risultato desiderato ma non' np.dot (a, b) '. – wflynny

+3

Il risultato del punto dipende dalla libreria blas in uso. Per esempio, sto vedendo lo stesso con openblas (ma non con atlante), quindi o questo non è specificato, o un bug nella libreria blas. La moltiplicazione non è correlata in realtà ... – seberg

+2

Hmm, prova 'da numpy.distutils.system_info import get_info; get_info ('blas_opt') ' – seberg

risposta

3

Penso che, come suggerito da Seberg, si tratti di un problema con la libreria BLAS utilizzata. Se si osserva come viene implementato numpy.dot here e here, è possibile trovare una chiamata a cblas_dgemm() per il caso matrice matriciale-volte-matrice.

Questo programma C, che riproduce alcuni dei vostri esempi, fornisce lo stesso risultato quando si utilizza il BLAS "normale" e la risposta corretta quando si utilizza ATLAS.

#include <stdio.h> 
#include <math.h> 

#include "cblas.h" 

void onebyone(double a11, double b11, double expectc11) 
{ 
    enum CBLAS_ORDER order=CblasRowMajor; 
    enum CBLAS_TRANSPOSE transA=CblasNoTrans; 
    enum CBLAS_TRANSPOSE transB=CblasNoTrans; 
    int M=1; 
    int N=1; 
    int K=1; 
    double alpha=1.0; 
    double A[1]={a11}; 
    int lda=1; 
    double B[1]={b11}; 
    int ldb=1; 
    double beta=0.0; 
    double C[1]; 
    int ldc=1; 

    cblas_dgemm(order, transA, transB, 
       M, N, K, 
       alpha,A,lda, 
       B, ldb, 
       beta, C, ldc); 

    printf("dot([ %.18g],[%.18g]) -> [%.18g]; expected [%.18g]\n",a11,b11,C[0],expectc11); 
} 

void onebytwo(double a11, double b11, double b12, 
       double expectc11, double expectc12) 
{ 
    enum CBLAS_ORDER order=CblasRowMajor; 
    enum CBLAS_TRANSPOSE transA=CblasNoTrans; 
    enum CBLAS_TRANSPOSE transB=CblasNoTrans; 
    int M=1; 
    int N=2; 
    int K=1; 
    double alpha=1.0; 
    double A[]={a11}; 
    int lda=1; 
    double B[2]={b11,b12}; 
    int ldb=2; 
    double beta=0.0; 
    double C[2]; 
    int ldc=2; 

    cblas_dgemm(order, transA, transB, 
       M, N, K, 
       alpha,A,lda, 
       B, ldb, 
       beta, C, ldc); 

    printf("dot([ %.18g],[%.18g, %.18g]) -> [%.18g, %.18g]; expected [%.18g, %.18g]\n", 
     a11,b11,b12,C[0],C[1],expectc11,expectc12); 
} 

int 
main() 
{ 
    onebyone(0, 0, 0); 
    onebyone(2, 3, 6); 
    onebyone(NAN, 0, NAN); 
    onebyone(0, NAN, NAN); 
    onebytwo(0, 0,0, 0,0); 
    onebytwo(2, 3,5, 6,10); 
    onebytwo(0, NAN,0, NAN,0); 
    onebytwo(NAN, 0,0, NAN,NAN); 
    return 0; 
} 

Uscita con BLAS:

dot([ 0],[0]) -> [0]; expected [0] 
dot([ 2],[3]) -> [6]; expected [6] 
dot([ nan],[0]) -> [nan]; expected [nan] 
dot([ 0],[nan]) -> [0]; expected [nan] 
dot([ 0],[0, 0]) -> [0, 0]; expected [0, 0] 
dot([ 2],[3, 5]) -> [6, 10]; expected [6, 10] 
dot([ 0],[nan, 0]) -> [0, 0]; expected [nan, 0] 
dot([ nan],[0, 0]) -> [nan, nan]; expected [nan, nan] 

Uscita con ATLAS:

dot([ 0],[0]) -> [0]; expected [0] 
dot([ 2],[3]) -> [6]; expected [6] 
dot([ nan],[0]) -> [nan]; expected [nan] 
dot([ 0],[nan]) -> [nan]; expected [nan] 
dot([ 0],[0, 0]) -> [0, 0]; expected [0, 0] 
dot([ 2],[3, 5]) -> [6, 10]; expected [6, 10] 
dot([ 0],[nan, 0]) -> [nan, 0]; expected [nan, 0] 
dot([ nan],[0, 0]) -> [nan, nan]; expected [nan, nan] 

BLAS sembra aver previsto comportamento quando il primo operando ha un NaN, e il male quando il primo operando è zero e il secondo ha un NaN.

In ogni caso, non penso che questo bug sia nel livello di Numpy; è in BLAS. Sembra che sia possibile risolvere il problema utilizzando invece ATLAS.

Sopra generato su Ubuntu 14.04, utilizzando gcc, BLAS e ATLAS forniti da Ubuntu.

Problemi correlati