Update V2.py
Browse files
V2.py
CHANGED
|
@@ -157,57 +157,64 @@ class HyperConv1D(layers.Layer):
|
|
| 157 |
|
| 158 |
def call(self, x, training=None):
|
| 159 |
x_in = x
|
| 160 |
-
x_dtype = x.dtype
|
| 161 |
-
|
| 162 |
-
# 1) input projection
|
| 163 |
-
x_proj = self.input_proj(x)
|
| 164 |
|
|
|
|
|
|
|
|
|
|
| 165 |
B = tf.shape(x_proj)[0]
|
| 166 |
L = tf.shape(x_proj)[1]
|
| 167 |
D = self.d_model
|
| 168 |
pad = (self.k - 1) // 2
|
| 169 |
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
|
|
|
|
|
|
| 174 |
kernels = tf.nn.softmax(kernels, axis=-1)
|
| 175 |
|
|
|
|
| 176 |
x_pad = tf.pad(x_proj, [[0,0],[pad,pad],[0,0]])
|
| 177 |
x_pad_4d = tf.expand_dims(x_pad, axis=1) # (B,1,L+k-1,D)
|
| 178 |
patches = tf.image.extract_patches(
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
patches = tf.reshape(patches, [B, L, self.k, D])
|
|
|
|
|
|
|
| 186 |
kernels_exp = tf.expand_dims(kernels, axis=-1)
|
| 187 |
out_local = tf.reduce_sum(patches * kernels_exp, axis=2) # (B,L,D)
|
| 188 |
out_local = self.dynamic_proj(out_local)
|
| 189 |
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
h = self.hyper(x_proj)
|
| 194 |
global_z = self.attn_pool(h)
|
| 195 |
global_z = tf.nn.softmax(global_z, axis=1)
|
| 196 |
global_z = tf.reduce_sum(h * global_z, axis=1)
|
| 197 |
|
| 198 |
scale = tf.expand_dims(tf.nn.sigmoid(self.scale_dense(global_z)), 1)
|
|
|
|
| 199 |
out_local = out_local * scale
|
| 200 |
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
out = x_proj + out_local
|
| 205 |
out = tf.nn.silu(out)
|
| 206 |
out = self.norm(out)
|
| 207 |
out = self.dropout(out, training=training)
|
| 208 |
|
| 209 |
return tf.cast(out, x_dtype)
|
| 210 |
-
|
|
|
|
| 211 |
class L2NormLayer(layers.Layer):
|
| 212 |
def __init__(self, axis=1, epsilon=1e-10, **kwargs):
|
| 213 |
super().__init__(**kwargs)
|
|
@@ -223,7 +230,7 @@ class SentenceEncoder(Model):
|
|
| 223 |
self.embed = layers.Embedding(vocab_size, embed_dim)
|
| 224 |
self.pos_embed = layers.Embedding(input_dim=max_len, output_dim=embed_dim)
|
| 225 |
self.dropout = layers.Dropout(dropout_rate)
|
| 226 |
-
self.blocks = [HyperConv1D(d_model=embed_dim, k=7,
|
| 227 |
self.attn_pool = layers.Dense(1)
|
| 228 |
self.ln_f = layers.LayerNormalization(epsilon=1e-5, dtype=tf.float32)
|
| 229 |
self.latent = layers.Dense(latent_dim, activation=None)
|
|
|
|
| 157 |
|
| 158 |
def call(self, x, training=None):
|
| 159 |
x_in = x
|
| 160 |
+
x_dtype = x.dtype # 입력 dtype 저장
|
|
|
|
|
|
|
|
|
|
| 161 |
|
| 162 |
+
# 1) input projection
|
| 163 |
+
x_proj = self.input_proj(x) # (B, L, D)
|
| 164 |
+
|
| 165 |
B = tf.shape(x_proj)[0]
|
| 166 |
L = tf.shape(x_proj)[1]
|
| 167 |
D = self.d_model
|
| 168 |
pad = (self.k - 1) // 2
|
| 169 |
|
| 170 |
+
# ------------------------------
|
| 171 |
+
# 2) DynamicConv local mixing
|
| 172 |
+
# ------------------------------
|
| 173 |
+
# kernels 생성 후 x_proj dtype으로 맞춤
|
| 174 |
+
kernels = self.kernel_generator(self.dynamic_dense(x_proj))
|
| 175 |
+
kernels = tf.cast(kernels, x_proj.dtype)
|
| 176 |
kernels = tf.nn.softmax(kernels, axis=-1)
|
| 177 |
|
| 178 |
+
# padding & patch 추출
|
| 179 |
x_pad = tf.pad(x_proj, [[0,0],[pad,pad],[0,0]])
|
| 180 |
x_pad_4d = tf.expand_dims(x_pad, axis=1) # (B,1,L+k-1,D)
|
| 181 |
patches = tf.image.extract_patches(
|
| 182 |
+
images=x_pad_4d,
|
| 183 |
+
sizes=[1,1,self.k,1],
|
| 184 |
+
strides=[1,1,1,1],
|
| 185 |
+
rates=[1,1,1,1],
|
| 186 |
+
padding='VALID'
|
| 187 |
+
)
|
| 188 |
patches = tf.reshape(patches, [B, L, self.k, D])
|
| 189 |
+
|
| 190 |
+
# kernels shape 맞추기
|
| 191 |
kernels_exp = tf.expand_dims(kernels, axis=-1)
|
| 192 |
out_local = tf.reduce_sum(patches * kernels_exp, axis=2) # (B,L,D)
|
| 193 |
out_local = self.dynamic_proj(out_local)
|
| 194 |
|
| 195 |
+
# ------------------------------
|
| 196 |
+
# 3) Hyper scaling
|
| 197 |
+
# ------------------------------
|
| 198 |
h = self.hyper(x_proj)
|
| 199 |
global_z = self.attn_pool(h)
|
| 200 |
global_z = tf.nn.softmax(global_z, axis=1)
|
| 201 |
global_z = tf.reduce_sum(h * global_z, axis=1)
|
| 202 |
|
| 203 |
scale = tf.expand_dims(tf.nn.sigmoid(self.scale_dense(global_z)), 1)
|
| 204 |
+
scale = tf.cast(scale, x_proj.dtype) # dtype 맞춤
|
| 205 |
out_local = out_local * scale
|
| 206 |
|
| 207 |
+
# ------------------------------
|
| 208 |
+
# 4) Residual + SiLU + LayerNorm
|
| 209 |
+
# ------------------------------
|
| 210 |
out = x_proj + out_local
|
| 211 |
out = tf.nn.silu(out)
|
| 212 |
out = self.norm(out)
|
| 213 |
out = self.dropout(out, training=training)
|
| 214 |
|
| 215 |
return tf.cast(out, x_dtype)
|
| 216 |
+
|
| 217 |
+
|
| 218 |
class L2NormLayer(layers.Layer):
|
| 219 |
def __init__(self, axis=1, epsilon=1e-10, **kwargs):
|
| 220 |
super().__init__(**kwargs)
|
|
|
|
| 230 |
self.embed = layers.Embedding(vocab_size, embed_dim)
|
| 231 |
self.pos_embed = layers.Embedding(input_dim=max_len, output_dim=embed_dim)
|
| 232 |
self.dropout = layers.Dropout(dropout_rate)
|
| 233 |
+
self.blocks = [HyperConv1D(d_model=embed_dim, k=7, hyper_dim=256) for _ in range(4)]
|
| 234 |
self.attn_pool = layers.Dense(1)
|
| 235 |
self.ln_f = layers.LayerNormalization(epsilon=1e-5, dtype=tf.float32)
|
| 236 |
self.latent = layers.Dense(latent_dim, activation=None)
|