커뮤니티
Python
제목:    bidirectional rnn
  1818   김윤중

import tensorflow as tf
import numpy as np

tf.reset_default_graph()

# Create input data shape:(batch, time seq, feature dim)
X = np.random.randn(2, 4, 3) #ok 4>=3 nTimeStep > ntimsstep must
# batch별 sequence 길이 리스트 shape:(batch)
X_lengths = [4,3]

cell = tf.nn.rnn_cell.LSTMCell(num_units=3, state_is_tuple=True)
outputs, states  = tf.nn.bidirectional_dynamic_rnn(
    cell_fw=cell,
    cell_bw=cell,
    dtype=tf.float64,
    inputs=X,
    sequence_length=X_lengths
    )

#output_fw, output_bw = outputs
#states_fw, states_bw = states

outputs1=tf.concat(outputs,2)
#outputs2=outputs1[:,-1]
outputs2=outputs1[0,X_lengths[0]-1]
for i in range(1,len(X_lengths)) :
    outputs2=tf.concat([outputs2,outputs1[i,X_lengths[i]-1]],0)
outputs2=tf.reshape(outputs2,(-1,tf.shape(outputs1)[2]))

result = tf.contrib.learn.run_n(
   {"outputs":outputs,"outputs1":outputs1,"outputs2":outputs2},
    n=1,
    feed_dict=None)

def p(msg,t) :
     print("{} : {} \n{}".format(msg,np.asarray(t).shape,t))
p("data X:",X)
p("seg_lengths :",X_lengths)
p("outputs",result[0]["outputs"])
p("outputs1",result[0]["outputs1"])
p("outputs2",result[0]["outputs2"])

"""
data X: : (2, 4, 3)
[[[-2.7369186  -0.44013785 -0.67428595]
  [ 1.97298412 -0.59839211  0.06377079]
  [-0.44002746  0.56551496 -0.15333033]
  [ 0.95483825  1.33353933 -1.07743915]]

 [[-0.53162633 -0.50605474  0.82120397]
  [-2.72968905 -1.38788464  1.31444994]
  [-0.86094564  1.64185999 -1.22644838]
  [-0.66589069  0.95510109 -0.3838982 ]]]
seg_lengths : : (2,)
[4, 3]
outputs : (2, 2, 4, 3)
(array([[[-0.03376572,  0.06207141, -0.37092176],
        [ 0.36850623, -0.00673047, -0.0665919 ],
        [ 0.1498923 , -0.069967  , -0.09792052],
        [ 0.39896102, -0.2444771 ,  0.03818484]],

       [[-0.05301782,  0.02612809, -0.18653059],
        [-0.02161815,  0.05489208, -0.65950723],
        [-0.18981102, -0.08420968, -0.14672833],
        [ 0.        ,  0.        ,  0.        ]]]),
array([[[ 0.04857439, -0.02707952, -0.29398762],
        [ 0.43813989, -0.14918428,  0.09946015],
        [ 0.02675867, -0.15365579,  0.06839967],
        [ 0.20468815, -0.15008634,  0.07335179]],

       [[-0.08960411,  0.01435264, -0.35860821],
        [-0.02136613,  0.01834225, -0.41219225],
        [-0.10440344, -0.07238064,  0.11305396],
        [ 0.        ,  0.        ,  0.        ]]]))
outputs1 : (2, 4, 6)
[[[-0.03376572  0.06207141 -0.37092176  0.04857439 -0.02707952 -0.29398762]
  [ 0.36850623 -0.00673047 -0.0665919   0.43813989 -0.14918428  0.09946015]
  [ 0.1498923  -0.069967   -0.09792052  0.02675867 -0.15365579  0.06839967]
  [ 0.39896102 -0.2444771   0.03818484  0.20468815 -0.15008634  0.07335179]]

 [[-0.05301782  0.02612809 -0.18653059 -0.08960411  0.01435264 -0.35860821]
  [-0.02161815  0.05489208 -0.65950723 -0.02136613  0.01834225 -0.41219225]
  [-0.18981102 -0.08420968 -0.14672833 -0.10440344 -0.07238064  0.11305396]
  [ 0.          0.          0.          0.          0.          0.        ]]]
outputs2 : (2, 6)
[[ 0.39896102 -0.2444771   0.03818484  0.20468815 -0.15008634  0.07335179]
 [-0.18981102 -0.08420968 -0.14672833 -0.10440344 -0.07238064  0.11305396]]
Press any key to continue . . .
"""