커뮤니티
Python/TensorFlow
제목:    tf.expand_dims
  750   김윤중

>>> b=np.expand_dims(a,-1); a.shape;a;b.shape;b
(2, 3)
array([[1, 2, 3],
       [4, 5, 6]])


(2, 3, 1)
array([ [ [1],
           [2],
           [3]],
         [ [4],
           [5],
           [6] ] ])

>>> b=np.expand_dims(a,0); a.shape;a;b.shape;b
(2, 3)
array([[1, 2, 3],
       [4, 5, 6]])


(1, 2, 3)
array([ [ [1, 2, 3],
           [4, 5, 6] ] ])

>>> b=np.expand_dims(a,1); a.shape;a;b.shape;b
(2, 3)
array([[1, 2, 3],
       [4, 5, 6]])

 


(2, 1, 3)
array([ [ [1, 2, 3] ],
         [ [4, 5, 6] ] ])

>>> b=np.expand_dims(a,2); a.shape;a;b.shape;b
(2, 3)
array([[1, 2, 3],
       [4, 5, 6]])


(2, 3, 1)
array([ [ [1],
           [2],
           [3] ],
         [ [4],
           [5],
           [6] ] ])

 

>>> print(inputs.eval())
(2,3,4)=(B,T,D)
[[[1 2 3 4]
  [1 2 3 4]
  [1 2 3 4]]

 [[1 2 3 4]
  [1 2 3 4]
  [1 2 3 4]]]

 

     
al  
shape(2,3)
[[11 12 13]
 [21 22 23]]
al1=tf.expand_dims(al,-1) 
shape(2,3,1)
[[[11]
  [12]
  [13]]
 [[21]
  [22]
  [23]]]

>>> al2=inputs*tf.expand_dims(al,-1); a21; print(al2.eval())
<tf.Tensor 'mul_5:0' shape=(2, 3, 4) dtype=int32>
[[[11 22 33 44]
  [12 24 36 48]
  [13 26 39 52]]

 [[21 42 63 84]
  [22 44 66 88]
  [23 46 69 92]]]

   
>>>  sum=tf.reduce_sum(al2,1); sum; print(sum.eval())
<tf.Tensor 'Sum_2:0' shape=(2, 4) =(B,D)dtype=int32>
[[ 36  72 108 144]
 [ 66 132 198 264]]
sum=tf.reduce_sum(al1,2); sum; print(sum.eval())
<tf.Tensor 'Sum_4:0' shape=(2, 3) =(B,T) dtype=int32>
[[110 120 130]
 [210 220 230]]

>>> al1=tf.expand_dims(al,1); al1; print(al1.eval())
<tf.Tensor 'ExpandDims_28:0' shape=(2, 1, 3) dtype=int32>
[[[11 12 13]]

 [[21 22 23]]]

   
al  shape(2,3)
[[11 12 13]
 [21 22 23]]
l1=tf.expand_dims(al,-1) 
shape(2,3,1)
[[[11]
  [12]
  [13]]
 [[21]
  [22]
  [23]]]
>>> al1=tf.expand_dims(al,0); al1; print(al1.eval())
<tf.Tensor 'ExpandDims_29:0' shape=(1, 2, 3) dtype=int32>
[[[11 12 13]
  [21 22 23]]]

>>> al1=tf.expand_dims(al,1); al1; print(al1.eval())
<tf.Tensor 'ExpandDims_30:0' shape=(2, 1, 3) dtype=int32>
[[[11 12 13]]

 [[21 22 23]]]

>>> al1=tf.expand_dims(al,2); shape=(2, 3, 1)
[[[11]
  [12]
  [13]]

 [[21]
  [22]
  [23]]]