2016-03-07 10 views
9

Ho quattro tensori multidimensionali v[i,j,k], a[i,s,l], w[j,s,t,m], x[k,t,n] in Numpy, e sto cercando di calcolare il tensore z[l,m,n] data da:riduzione efficiente dei molteplici tensori in Python

z[l,m,n] = sum_{i,j,k,s,t} v[i,j,k] * a[i,s,l] * w[j,s,t,m] * x[k,t,n]

Tutti i tensori sono relativamente piccolo (diciamo meno di 32k elementi in totale), tuttavia ho bisogno di eseguire questo calcolo molte volte, quindi vorrei che la funzione avesse il minor overhead possibile.

ho cercato di attuarlo utilizzando numpy.einsum in questo modo:

z = np.einsum('ijk,isl,jstm,ktn', v, a, w, x) 

ma era molto lento. Ho provato anche la seguente sequenza di chiamate numpy.tensordot:

z = np.zeros((a.shape[-1],w.shape[-1],x.shape[-1])) 
for s in range(a.shape[1]): 
    for t in range(x.shape[1]): 
    res = np.tensordot(v, a[:,s,:], (0,0)) 
    res = np.tensordot(res, w[:,s,t,:], (0,0)) 
    z += np.tensordot(res, x[:,s,:], (0,0)) 

all'interno di un doppio ciclo for per sommare s e t (sia s e t sono molto piccole, così che non è troppo di un problema). Questo ha funzionato molto meglio, ma non è ancora veloce come non vorrei. Penso che questo potrebbe essere dovuto a tutte le operazioni che tensordot deve eseguire internamente prima di prendere il prodotto reale (ad esempio, permutare gli assi).

Mi chiedevo se esiste un modo più efficiente per implementare questo tipo di operazioni in Numpy. Inoltre non mi dispiacerebbe implementare questa parte in Cython, ma non sono sicuro di quale sarebbe l'algoritmo giusto da usare.

+0

Potresti condividere le implementazioni di 'numpy.einsum' e' numpy.tensordot'? – Divakar

+0

@Divakar: sicuro. L'implementazione di 'einsum' è semplicemente' z = np.einsum ('ijk, isl, jstm, ktn', v, a, w, x) '. L'implementazione 'tensordot' è ' res = np.tensordot (v, a, (0,0)) res = np.tensordot (res, w, (0,0)) res = np.tensordot (res , x, (0,0)) ' – Alessandro

+0

Si prega di aggiungere tali implementazioni alla domanda utilizzando il pulsante" modifica "sotto la domanda. – Divakar

risposta

4

Utilizzando np.tensordot in alcune parti, è possibile vettorizzare le cose in questo modo -

# Perform "np.einsum('ijk,isl->jksl', v, a)" 
p1 = np.tensordot(v,a,axes=([0],[0]))   # shape = jksl 

# Perform "np.einsum('jksl,jstm->kltm', p1, w)" 
p2 = np.tensordot(p1,w,axes=([0,2],[0,1])) # shape = kltm 

# Perform "np.einsum('kltm,ktn->lmn', p2, w)" 
z = np.tensordot(p2,x,axes=([0,2],[0,1]))  # shape = lmn 

prova Runtime e verificare uscita -

In [15]: def einsum_based(v, a, w, x): 
    ...:  return np.einsum('ijk,isl,jstm,ktn', v, a, w, x) # (l,m,n) 
    ...: 
    ...: def vectorized_tdot(v, a, w, x): 
    ...:  p1 = np.tensordot(v,a,axes=([0],[0]))  # shape = jksl 
    ...:  p2 = np.tensordot(p1,w,axes=([0,2],[0,1])) # shape = kltm 
    ...:  return np.tensordot(p2,x,axes=([0,2],[0,1])) # shape = lmn 
    ...: 

Caso # 1:

In [16]: # Input params 
    ...: i,j,k,l,m,n = 10,10,10,10,10,10 
    ...: s,t = 3,3 # As problem states : "both s and t are very small". 
    ...: 
    ...: # Input arrays 
    ...: v = np.random.rand(i,j,k) 
    ...: a = np.random.rand(i,s,l) 
    ...: w = np.random.rand(j,s,t,m) 
    ...: x = np.random.rand(k,t,n) 
    ...: 

In [17]: np.allclose(einsum_based(v, a, w, x),vectorized_tdot(v, a, w, x)) 
Out[17]: True 

In [18]: %timeit einsum_based(v,a,w,x) 
10 loops, best of 3: 129 ms per loop 

In [19]: %timeit vectorized_tdot(v,a,w,x) 
1000 loops, best of 3: 397 µs per loop 

Caso # 2 (Dati più grandi):

In [20]: # Input params 
    ...: i,j,k,l,m,n = 15,15,15,15,15,15 
    ...: s,t = 3,3 # As problem states : "both s and t are very small". 
    ...: 
    ...: # Input arrays 
    ...: v = np.random.rand(i,j,k) 
    ...: a = np.random.rand(i,s,l) 
    ...: w = np.random.rand(j,s,t,m) 
    ...: x = np.random.rand(k,t,n) 
    ...: 

In [21]: np.allclose(einsum_based(v, a, w, x),vectorized_tdot(v, a, w, x)) 
Out[21]: True 

In [22]: %timeit einsum_based(v,a,w,x) 
1 loops, best of 3: 1.35 s per loop 

In [23]: %timeit vectorized_tdot(v,a,w,x) 
1000 loops, best of 3: 1.52 ms per loop 
+0

Grazie! Ho provato ad implementare la tua versione vettoriale nel mio programma, ed è ~ 20% più veloce della mia versione che usa il ciclo double for. – Alessandro

+0

@Alessandro Grazie per aver segnalato i miglioramenti delle prestazioni! – Divakar

Problemi correlati