2015-06-04 12 views
6

Ho bisogno di spostare un array 3D da un vettore 3D di spostamento per un algoritmo. A partire da ora sto usando questo metodo (Admitedly molto brutta):Numpy roll in più dimensioni

shiftedArray = np.roll(np.roll(np.roll(arrayToShift, shift[0], axis=0) 
            , shift[1], axis=1), 
          shift[2], axis=2) 

che funziona, ma significa che sto chiamando 3 rotoli! (58% del mio tempo algoritmo è trascorso in questi, secondo il mio profiling)

Dalla documentazione di Numpy.roll:

Parametri:
spostamento: int

asse: int, opzionale

Nessuna menzione del parametro di tipo array ... Quindi non posso avere un rolling multidimensionale?

ho pensato che avrei potuto solo chiamare un tale tipo di funzione (suona come un Numpy cosa da fare):

np.roll(arrayToShift,3DshiftVector,axis=(0,1,2)) 

Forse con una versione piatta del mio schieramento rimodellato? ma come faccio a calcolare il vettore di spostamento? e questo cambiamento è davvero lo stesso?

Sono sorpreso di trovare alcuna soluzione facile per questo, come ho pensato che questo sarebbe una cosa abbastanza comune da fare (ok, non che comune, ma ...)

Quindi, come possiamo --relativamente-- spostare in modo efficiente un array ND da un vettore N-Dimensional?

risposta

3

Penso scipy.ndimage.interpolation.shift farà quello si vuole, dalla docs

turno: float o la sequenza, opzionale

lo spostamento lungo gli assi. Se un galleggiante, lo spostamento è lo stesso per ciascun asse. Se una sequenza, lo spostamento dovrebbe contenere un valore per ciascun asse.

Il che significa che è possibile effettuare le seguenti operazioni,

from scipy.ndimage.interpolation import shift 
import numpy as np 

arrayToShift = np.reshape([i for i in range(27)],(3,3,3)) 

print('Before shift') 
print(arrayToShift) 

shiftVector = (1,2,3) 
shiftedarray = shift(arrayToShift,shift=shiftVector,mode='wrap') 

print('After shift') 
print(shiftedarray) 

che produce,

Before shift 
[[[ 0 1 2] 
    [ 3 4 5] 
    [ 6 7 8]] 

[[ 9 10 11] 
    [12 13 14] 
    [15 16 17]] 

[[18 19 20] 
    [21 22 23] 
    [24 25 26]]] 
After shift 
[[[16 17 16] 
    [13 14 13] 
    [10 11 10]] 

[[ 7 8 7] 
    [ 4 5 4] 
    [ 1 2 1]] 

[[16 17 16] 
    [13 14 13] 
    [10 11 10]]] 
+0

Bello! Immagino di essere ossessionato dal rotolo di Numpy e di cercare su Google quello. Solo quando ho scritto questa domanda SO l'ho indicato come uno "spostamento", che mi avrebbe portato direttamente a questa funzione! Un po 'di modifica, dato che usa spline per turni non arrotondati (ma io uso puramente interi), ho dovuto specificare 'order = 0' (altrimenti questo metodo ha richiesto 100 volte più del mio vecchio metodo!= p) – Jiby

+0

Grande, felice che sia anche efficiente. Recentemente ho trovato shift scipy solo quando ho risposto a un'altra domanda (http://stackoverflow.com/questions/30399534/shift-elements-in-a-numpy-array/30534478#30534478). Fortran ha una funzionalità di spostamento intrinseco quindi immagino che scipy usi questo, motivo per cui può essere molto più veloce di numpy. Indovina l'ordine = 0' evita le spline ... –

+0

Lo spostamento 'scipy' usa l'interpolazione spline. Quindi sta calcolando nuovi valori, non solo spostando quelli esistenti. – hpaulj

0

Credo che roll sia lento perché l'array arrotolato non può essere espresso come una vista dei dati originali come può essere un'operazione di ridimensionamento o di risagoma. Quindi i dati vengono copiati ogni volta. Per sfondo vedi: https://scipy-lectures.github.io/advanced/advanced_numpy/#life-of-ndarray

Che cosa può essere la pena di provare ad arrivare in prima pad vostro array (con la modalità 'wrap') e poi utilizzare le fette sulla matrice imbottita per ottenere shiftedArray: http://docs.scipy.org/doc/numpy/reference/generated/numpy.pad.html

+0

vero, e io avrei dovuto dire che ho iniziato con l'ipotesi che una (singola) copia dell'array fosse ok, ma rotolare tre volte di seguito significherebbe effettivamente 2 copie ridondanti, che voglio evitare. – Jiby

+0

Sì, questo metodo dovrebbe copiare una volta e quindi l'affettatura sarebbe solo una vista di quei dati. – YXD

+0

'np.roll' usa' arange' per costruire la tupla 'indexes', e poi un' take'. Quindi sta facendo l'indicizzazione avanzata più lenta. In teoria è possibile costruire la propria tupla di indicizzazione 3D e applicarla una sola volta. Ma potrebbe essere molto lavoro. – hpaulj

0

take in modalità wrap possono essere usate e penso che non cambia la matrice in memoria .

Ecco un'implementazione utilizzando @ ingressi di EdSmith:

arrayToShift = np.reshape([i for i in range(27)],(3,3,3)) 
shiftVector = np.array((1,2,3)) 
ind = 3-shiftVector  
np.take(np.take(np.take(arrayToShift,range(ind[0],ind[0]+3),axis=0,mode='wrap'),range(ind[1],ind[1]+3),axis=1,mode='wrap'),range(ind[2],ind[2]+3),axis=2,mode='wrap') 

che dà la stessa della OP:

np.roll(np.roll(np.roll(arrayToShift, shift[0], axis=0) , shift[1], axis=1),shift[2], axis=2) 

dà:

array([[[21, 22, 23], 
     [24, 25, 26], 
     [18, 19, 20]], 

     [[ 3, 4, 5], 
     [ 6, 7, 8], 
     [ 0, 1, 2]], 

     [[12, 13, 14], 
     [15, 16, 17], 
     [ 9, 10, 11]]]) 
3

In teoria, utilizzando scipy.ndimage.interpolation.shift come descritto da @ Smith dovrebbe funzionare, ma a causa di un bug aperto (https://github.com/scipy/scipy/issues/1323), non funziona t dare un risultato equivalente a più chiamate di np.roll.


UPDATE: "Multi-roll" capacità è stato aggiunto al numpy.roll in NumPy versione 1.12.0. Ecco un esempio bidimensionale, in cui il primo asse viene arrotolata una posizione e il secondo asse è arrotolato tre posizioni:

In [7]: x = np.arange(20).reshape(4,5) 

In [8]: x 
Out[8]: 
array([[ 0, 1, 2, 3, 4], 
     [ 5, 6, 7, 8, 9], 
     [10, 11, 12, 13, 14], 
     [15, 16, 17, 18, 19]]) 

In [9]: numpy.roll(x, [1, 3], axis=(0, 1)) 
Out[9]: 
array([[17, 18, 19, 15, 16], 
     [ 2, 3, 4, 0, 1], 
     [ 7, 8, 9, 5, 6], 
     [12, 13, 14, 10, 11]]) 

Questo rende il codice sotto obsoleto. Lascerò lì per i posteri.


Il codice seguente definisce una funzione che chiamo multiroll che fa quello che si vuole. Ecco un esempio in cui viene applicata ad una matrice di forma (500, 500, 500):

In [64]: x = np.random.randn(500, 500, 500) 

In [65]: shift = [10, 15, 20] 

utilizzare le chiamate multiple a np.roll per generare il risultato atteso:

In [66]: yroll3 = np.roll(np.roll(np.roll(x, shift[0], axis=0), shift[1], axis=1), shift[2], axis=2) 

Generare la matrice spostato utilizzando multiroll:

In [67]: ymulti = multiroll(x, shift) 

Verificare che abbiamo ottenuto il risultato atteso:

In [68]: np.all(yroll3 == ymulti) 
Out[68]: True 

Per una matrice di queste dimensioni, facendo tre chiamate al np.roll è quasi tre volte più lento di una chiamata a multiroll:

In [69]: %timeit yroll3 = np.roll(np.roll(np.roll(x, shift[0], axis=0), shift[1], axis=1), shift[2], axis=2) 
1 loops, best of 3: 1.34 s per loop 

In [70]: %timeit ymulti = multiroll(x, shift) 
1 loops, best of 3: 474 ms per loop 

Ecco la definizione di multiroll:

from itertools import product 
import numpy as np 


def multiroll(x, shift, axis=None): 
    """Roll an array along each axis. 

    Parameters 
    ---------- 
    x : array_like 
     Array to be rolled. 
    shift : sequence of int 
     Number of indices by which to shift each axis. 
    axis : sequence of int, optional 
     The axes to be rolled. If not given, all axes is assumed, and 
     len(shift) must equal the number of dimensions of x. 

    Returns 
    ------- 
    y : numpy array, with the same type and size as x 
     The rolled array. 

    Notes 
    ----- 
    The length of x along each axis must be positive. The function 
    does not handle arrays that have axes with length 0. 

    See Also 
    -------- 
    numpy.roll 

    Example 
    ------- 
    Here's a two-dimensional array: 

    >>> x = np.arange(20).reshape(4,5) 
    >>> x 
    array([[ 0, 1, 2, 3, 4], 
      [ 5, 6, 7, 8, 9], 
      [10, 11, 12, 13, 14], 
      [15, 16, 17, 18, 19]]) 

    Roll the first axis one step and the second axis three steps: 

    >>> multiroll(x, [1, 3]) 
    array([[17, 18, 19, 15, 16], 
      [ 2, 3, 4, 0, 1], 
      [ 7, 8, 9, 5, 6], 
      [12, 13, 14, 10, 11]]) 

    That's equivalent to: 

    >>> np.roll(np.roll(x, 1, axis=0), 3, axis=1) 
    array([[17, 18, 19, 15, 16], 
      [ 2, 3, 4, 0, 1], 
      [ 7, 8, 9, 5, 6], 
      [12, 13, 14, 10, 11]]) 

    Not all the axes must be rolled. The following uses 
    the `axis` argument to roll just the second axis: 

    >>> multiroll(x, [2], axis=[1]) 
    array([[ 3, 4, 0, 1, 2], 
      [ 8, 9, 5, 6, 7], 
      [13, 14, 10, 11, 12], 
      [18, 19, 15, 16, 17]]) 

    which is equivalent to: 

    >>> np.roll(x, 2, axis=1) 
    array([[ 3, 4, 0, 1, 2], 
      [ 8, 9, 5, 6, 7], 
      [13, 14, 10, 11, 12], 
      [18, 19, 15, 16, 17]]) 

    """ 
    x = np.asarray(x) 
    if axis is None: 
     if len(shift) != x.ndim: 
      raise ValueError("The array has %d axes, but len(shift) is only " 
          "%d. When 'axis' is not given, a shift must be " 
          "provided for all axes." % (x.ndim, len(shift))) 
     axis = range(x.ndim) 
    else: 
     # axis does not have to contain all the axes. Here we append the 
     # missing axes to axis, and for each missing axis, append 0 to shift. 
     missing_axes = set(range(x.ndim)) - set(axis) 
     num_missing = len(missing_axes) 
     axis = tuple(axis) + tuple(missing_axes) 
     shift = tuple(shift) + (0,)*num_missing 

    # Use mod to convert all shifts to be values between 0 and the length 
    # of the corresponding axis. 
    shift = [s % x.shape[ax] for s, ax in zip(shift, axis)] 

    # Reorder the values in shift to correspond to axes 0, 1, ..., x.ndim-1. 
    shift = np.take(shift, np.argsort(axis)) 

    # Create the output array, and copy the shifted blocks from x to y. 
    y = np.empty_like(x) 
    src_slices = [(slice(n-shft, n), slice(0, n-shft)) 
        for shft, n in zip(shift, x.shape)] 
    dst_slices = [(slice(0, shft), slice(shft, n)) 
        for shft, n in zip(shift, x.shape)] 
    src_blks = product(*src_slices) 
    dst_blks = product(*dst_slices) 
    for src_blk, dst_blk in zip(src_blks, dst_blks): 
     y[dst_blk] = x[src_blk] 

    return y 
Problemi correlati