2014-05-19 10 views
6

`Ciao a tutti.Sottoclasse di numpy ndarray non funziona come previsto

Ho trovato che si verifica uno strano comportamento durante la sottoclasse di ndarray.

import numpy as np 

class fooarray(np.ndarray): 
    def __new__(cls, input_array, *args, **kwargs): 
     obj = np.asarray(input_array).view(cls) 
     return obj 

    def __init__(self, *args, **kwargs): 
     return 

    def __array_finalize__(self, obj): 
     return 

a=fooarray(np.random.randn(3,5)) 
b=np.random.randn(3,5) 

a_sum=np.sum(a,axis=0,keepdims=True) 
b_sum=np.sum(b,axis=0, keepdims=True) 

print a_sum.ndim #1 
print b_sum.ndim #2 

Come avete visto, l'argomento keepdims non funziona per il mio sottoclasse fooarray. Ha perso uno dei suoi assi. Come posso evitare questo problema? O più in generale, come posso creare una sottoclasse di numpy ndarray correttamente?

+1

Una delle possibili soluzioni è utilizzare 'a.sum'. – emeth

risposta

4

np.sum può accettare una varietà di oggetti come input: non solo ndarrays, ma anche elenchi, generatori, np.matrix s, ad esempio. Il parametro keepdims ovviamente non ha senso per elenchi o generatori. Inoltre, non è appropriato per le istanze np.matrix, dal momento che np.matrix s hanno sempre 2 dimensioni. Se si guarda alla firma richiesta di np.matrix.sum si vede che il suo metodo sum non ha nessun parametro keepdims:

Definition: np.matrix.sum(self, axis=None, dtype=None, out=None) 

Così alcune sottoclassi di ndarray possono avere sum metodi che non hanno un parametro keepdims. Questa è una sfortunata violazione dello Liskov substitution principle e dell'origine del trabocchetto che hai incontrato.

Ora se si guarda a the source code for np.sum, si vede che è una funzione di delega che tenta di determinare cosa fare in base al tipo del primo argomento.

Se il tipo del primo argomento non è ndarray, il parametro keepdims viene eliminato. Lo fa perché passare il parametro keepdims a np.matrix.sum solleverebbe un'eccezione.

Quindi, a causa np.sum sta cercando di fare la delegazione nel modo più generale, non fare alcuna ipotesi su ciò che gli argomenti di una sottoclasse di ndarray possono assumere, fa cadere il parametro keepdims quando viene passato un fooarray.

La soluzione alternativa è di non utilizzare np.sum, ma chiamare a.sum. Questo è più diretto in ogni caso, dal momento che np.sum è semplicemente una funzione di delega.

import numpy as np 


class fooarray(np.ndarray): 
    def __new__(cls, input_array, *args, **kwargs): 
     obj = np.asarray(input_array, *args, **kwargs).view(cls) 
     return obj 

a = fooarray(np.random.randn(3, 5)) 
b = np.random.randn(3, 5) 

a_sum = a.sum(axis=0, keepdims=True) 
b_sum = np.sum(b, axis=0, keepdims=True) 

print(a_sum.ndim) # 2 
print(b_sum.ndim) # 2 
+0

Completamente rispondo alla mia domanda. Grazie =] –

2

Elaborare un po 'sul commento di @ mskimm, se si dà un'occhiata al relativo parte del codice sorgente di NumPy, core/fromnumeric.py, è chiaro il motivo per cui a.sum(..., keepdims=True) opere, mentre np.sum(a, ..., keepdims=True) non lo fa:

def sum(a, axis=None, dtype=None, out=None, keepdims=False): 
    ... 
    if isinstance(a, _gentype): 
     res = _sum_(a) 
     if out is not None: 
      out[...] = res 
      return out 
     return res 
    elif type(a) is not mu.ndarray: 
     try: 
      sum = a.sum 
     except AttributeError: 
      return _methods._sum(a, axis=axis, dtype=dtype, 
           out=out, keepdims=keepdims) 
     # NOTE: Dropping the keepdims parameters here... 
     return sum(axis=axis, dtype=dtype, out=out) 
    else: 
     return _methods._sum(a, axis=axis, dtype=dtype, 
          out=out, keepdims=keepdims) 
    ... 

Dal momento che hai sottoclassato np.ndarray, type(a) è fooarray, non mu.ndarray, così si finisce in questa linea:

# NOTE: Dropping the keepdims parameters here... 
return sum(axis=axis, dtype=dtype, out=out) 

L'argomento keepdims parola chiave è una funzione di relativa ndarrays, e non è implementata per alcune altre classi di matrice simile come np.matrix o np.ma.masked_array che hanno anche un metodo .sum(), quindi, perché tale parametro attualmente viene scartato per non ndarray S.

Problemi correlati