OpenLab-NLP commited on
Commit
f80b20b
·
verified ·
1 Parent(s): c60c507

Update V2.py

Browse files
Files changed (1) hide show
  1. V2.py +29 -22
V2.py CHANGED
@@ -157,57 +157,64 @@ class HyperConv1D(layers.Layer):
157
 
158
  def call(self, x, training=None):
159
  x_in = x
160
- x_dtype = x.dtype
161
-
162
- # 1) input projection
163
- x_proj = self.input_proj(x)
164
 
 
 
 
165
  B = tf.shape(x_proj)[0]
166
  L = tf.shape(x_proj)[1]
167
  D = self.d_model
168
  pad = (self.k - 1) // 2
169
 
170
- # ------------------------------
171
- # 2) DynamicConv local mixing
172
- # ------------------------------
173
- kernels = self.kernel_generator(self.dynamic_dense(x_proj)) # (B, L, k)
 
 
174
  kernels = tf.nn.softmax(kernels, axis=-1)
175
 
 
176
  x_pad = tf.pad(x_proj, [[0,0],[pad,pad],[0,0]])
177
  x_pad_4d = tf.expand_dims(x_pad, axis=1) # (B,1,L+k-1,D)
178
  patches = tf.image.extract_patches(
179
- images=x_pad_4d,
180
- sizes=[1,1,self.k,1],
181
- strides=[1,1,1,1],
182
- rates=[1,1,1,1],
183
- padding='VALID'
184
- )
185
  patches = tf.reshape(patches, [B, L, self.k, D])
 
 
186
  kernels_exp = tf.expand_dims(kernels, axis=-1)
187
  out_local = tf.reduce_sum(patches * kernels_exp, axis=2) # (B,L,D)
188
  out_local = self.dynamic_proj(out_local)
189
 
190
- # ------------------------------
191
- # 3) Hyper scaling
192
- # ------------------------------
193
  h = self.hyper(x_proj)
194
  global_z = self.attn_pool(h)
195
  global_z = tf.nn.softmax(global_z, axis=1)
196
  global_z = tf.reduce_sum(h * global_z, axis=1)
197
 
198
  scale = tf.expand_dims(tf.nn.sigmoid(self.scale_dense(global_z)), 1)
 
199
  out_local = out_local * scale
200
 
201
- # ------------------------------
202
- # 4) Residual + SiLU + LayerNorm
203
- # ------------------------------
204
  out = x_proj + out_local
205
  out = tf.nn.silu(out)
206
  out = self.norm(out)
207
  out = self.dropout(out, training=training)
208
 
209
  return tf.cast(out, x_dtype)
210
-
 
211
  class L2NormLayer(layers.Layer):
212
  def __init__(self, axis=1, epsilon=1e-10, **kwargs):
213
  super().__init__(**kwargs)
@@ -223,7 +230,7 @@ class SentenceEncoder(Model):
223
  self.embed = layers.Embedding(vocab_size, embed_dim)
224
  self.pos_embed = layers.Embedding(input_dim=max_len, output_dim=embed_dim)
225
  self.dropout = layers.Dropout(dropout_rate)
226
- self.blocks = [HyperConv1D(d_model=embed_dim, k=7, mem_size=128, hyper_dim=256) for _ in range(4)]
227
  self.attn_pool = layers.Dense(1)
228
  self.ln_f = layers.LayerNormalization(epsilon=1e-5, dtype=tf.float32)
229
  self.latent = layers.Dense(latent_dim, activation=None)
 
157
 
158
  def call(self, x, training=None):
159
  x_in = x
160
+ x_dtype = x.dtype # 입력 dtype 저장
 
 
 
161
 
162
+ # 1) input projection
163
+ x_proj = self.input_proj(x) # (B, L, D)
164
+
165
  B = tf.shape(x_proj)[0]
166
  L = tf.shape(x_proj)[1]
167
  D = self.d_model
168
  pad = (self.k - 1) // 2
169
 
170
+ # ------------------------------
171
+ # 2) DynamicConv local mixing
172
+ # ------------------------------
173
+ # kernels 생성 x_proj dtype으로 맞춤
174
+ kernels = self.kernel_generator(self.dynamic_dense(x_proj))
175
+ kernels = tf.cast(kernels, x_proj.dtype)
176
  kernels = tf.nn.softmax(kernels, axis=-1)
177
 
178
+ # padding & patch 추출
179
  x_pad = tf.pad(x_proj, [[0,0],[pad,pad],[0,0]])
180
  x_pad_4d = tf.expand_dims(x_pad, axis=1) # (B,1,L+k-1,D)
181
  patches = tf.image.extract_patches(
182
+ images=x_pad_4d,
183
+ sizes=[1,1,self.k,1],
184
+ strides=[1,1,1,1],
185
+ rates=[1,1,1,1],
186
+ padding='VALID'
187
+ )
188
  patches = tf.reshape(patches, [B, L, self.k, D])
189
+
190
+ # kernels shape 맞추기
191
  kernels_exp = tf.expand_dims(kernels, axis=-1)
192
  out_local = tf.reduce_sum(patches * kernels_exp, axis=2) # (B,L,D)
193
  out_local = self.dynamic_proj(out_local)
194
 
195
+ # ------------------------------
196
+ # 3) Hyper scaling
197
+ # ------------------------------
198
  h = self.hyper(x_proj)
199
  global_z = self.attn_pool(h)
200
  global_z = tf.nn.softmax(global_z, axis=1)
201
  global_z = tf.reduce_sum(h * global_z, axis=1)
202
 
203
  scale = tf.expand_dims(tf.nn.sigmoid(self.scale_dense(global_z)), 1)
204
+ scale = tf.cast(scale, x_proj.dtype) # dtype 맞춤
205
  out_local = out_local * scale
206
 
207
+ # ------------------------------
208
+ # 4) Residual + SiLU + LayerNorm
209
+ # ------------------------------
210
  out = x_proj + out_local
211
  out = tf.nn.silu(out)
212
  out = self.norm(out)
213
  out = self.dropout(out, training=training)
214
 
215
  return tf.cast(out, x_dtype)
216
+
217
+
218
  class L2NormLayer(layers.Layer):
219
  def __init__(self, axis=1, epsilon=1e-10, **kwargs):
220
  super().__init__(**kwargs)
 
230
  self.embed = layers.Embedding(vocab_size, embed_dim)
231
  self.pos_embed = layers.Embedding(input_dim=max_len, output_dim=embed_dim)
232
  self.dropout = layers.Dropout(dropout_rate)
233
+ self.blocks = [HyperConv1D(d_model=embed_dim, k=7, hyper_dim=256) for _ in range(4)]
234
  self.attn_pool = layers.Dense(1)
235
  self.ln_f = layers.LayerNormalization(epsilon=1e-5, dtype=tf.float32)
236
  self.latent = layers.Dense(latent_dim, activation=None)