2015-12-18 31 views
9

I tre tensioni, A, B and C in tensorflow, A e B sono entrambi di forma (m, n, r), C è un tensore binario di forma (m, n, 1).Come trasmettere esplicitamente un tensore per abbinare la forma di un altro in tensorflow?

Desidero selezionare elementi da A o B in base al valore di C. Lo strumento è ovvio tf.select, tuttavia, che non ha a trasmettere la semantica, quindi ho bisogno di trasmettere prima esplicitamente C per la stessa forma di A e B.

Questo sarebbe il mio primo tentativo di come fare questo, ma doesn Mi piace mescolare un tensore (tf.shape(A)[2]) nell'elenco delle forme.

import tensorflow as tf 
A = tf.random_normal([20, 100, 10]) 
B = tf.random_normal([20, 100, 10]) 
C = tf.random_normal([20, 100, 1]) 
C = tf.greater_equal(C, tf.zeros_like(C)) 

C = tf.tile(C, [1,1,tf.shape(A)[2]]) 
D = tf.select(C, A, B) 

Qual è l'approccio corretto qui?

+2

Un hack che funziona: posso usare la semantica di trasmissione delle * * moltiplicare e moltiplicare per un tensore quelli in tal modo: 'Expander = tf.ones_like (B)', quindi 'C = Expander * C' – wxs

risposta

9

MODIFICA: In tutte le versioni di TensorFlow da 0.12rc0, il codice nella domanda funziona direttamente. TensorFlow imposterà automaticamente tensori e numeri Python in un argomento tensoriale. La soluzione seguente che utilizza tf.pack() è necessaria solo nelle versioni precedenti a 0.12rc0. Si noti che tf.pack() è stato rinominato in tf.stack() in TensorFlow 1.0.


La soluzione è molto vicina al lavoro. È necessario sostituire la riga:

C = tf.tile(C, [1,1,tf.shape(C)[2]]) 

... con il seguente:

C = tf.tile(C, tf.pack([1, 1, tf.shape(A)[2]])) 

(Il motivo per il problema è che tensorflow non convertire implicitamente un elenco di tensori e letterali Python in un tensore. tf.pack() prende una lista dei tensori, quindi converte ciascuno degli elementi a suo ingresso (1, 1, e tf.shape(C)[2]) per un tensore. Poiché ogni elemento è uno scalare, il risultato sarà un vettore.)

+1

Penso che tu abbia un extra' ['e manchi un') ', ma poi ottengo un errore un po 'criptico quando eseguo * * la sessione tf:' InvalidArgumentError: Inputs all'operazione Select_13 di il tipo Select deve avere le stesse dimensioni e forma. Input 0: dim {size: 20} dim {size: 100} dim {size: 1}! = Input 1: dim {size: 20} dim {size: 100} dim {size: 10} ' – wxs

+0

Buon punto, e Ho aggiornato la risposta - inoltre, l'argomento di 'tf.shape()' avrebbe dovuto essere 'A' (o' B'). Questo funziona per me - quale errore stai vedendo? – mrry

+0

Sì, ora è risolto :) non ho notato il parametro errato su 'tf.shape()'. Grazie! – wxs

2

Qui 'Sa trucco sporco:

import tensorflow as tf 

def broadcast(tensor, shape): 
    return tensor + tf.zeros(shape, dtype=tensor.dtype) 

A = tf.random_normal([20, 100, 10]) 
B = tf.random_normal([20, 100, 10]) 
C = tf.random_normal([20, 100, 1]) 

C = broadcast(C, A.shape) 
D = tf.select(C, A, B) 
0
import tensorflow as tf 

def broadcast(tensor, shape): 
    """Broadcasts ``x`` to have shape ``shape``. 
                    | 
    Uses ``tf.Assert`` statements to ensure that the broadcast is 
    valid. 

    First calculates the number of missing dimensions in 
    ``tf.shape(x)`` and left-pads the shape of ``x`` with that many 
    ones. Then identifies the dimensions of ``x`` that require 
    tiling and tiles those dimensions appropriately. 

    Args: 
     x (tf.Tensor): The tensor to broadcast. 
     shape (Union[tf.TensorShape, tf.Tensor, Sequence[int]]): 
      The shape to broadcast to. 

    Returns: 
     tf.Tensor: ``x``, reshaped and tiled to have shape ``shape``. 

    """ 
    with tf.name_scope('broadcast') as scope: 
     shape_x = tf.shape(x) 
     rank_x = tf.shape(shape0)[0] 
     shape_t = tf.convert_to_tensor(shape, preferred_dtype=tf.int32) 
     rank_t = tf.shape(shape1)[0] 

     with tf.control_dependencies([tf.Assert(
      rank_t >= rank_x, 
      ['len(shape) must be >= tf.rank(x)', shape_x, shape_t], 
      summarize=255 
     )]): 
      missing_dims = tf.ones(tf.stack([rank_t - rank_x], 0), tf.int32) 

     shape_x_ = tf.concat([missing_dims, shape_x], 0) 
     should_tile = tf.equal(shape_x_, 1) 

     with tf.control_dependencies([tf.Assert(
      tf.reduce_all(tf.logical_or(tf.equal(shape_x_, shape_t), should_tile), 
      ['cannot broadcast shapes', shape_x, shape_t], 
      summarize=255 
     )]): 
      multiples = tf.where(should_tile, shape_t, tf.ones_like(shape_t)) 
      out = tf.tile(tf.reshape(x, shape_x_), multiples, name=scope) 

     try: 
      out.set_shape(shape) 
     except: 
      pass 

     return out 

A = tf.random_normal([20, 100, 10]) 
B = tf.random_normal([20, 100, 10]) 
C = tf.random_normal([20, 100, 1]) 

C = broadcast(C, A.shape) 
D = tf.select(C, A, B) 
Problemi correlati