|  | import dataclasses | 
					
						
						|  | import json | 
					
						
						|  | import warnings | 
					
						
						|  | from dataclasses import dataclass, MISSING | 
					
						
						|  | from functools import partial | 
					
						
						|  | from typing import Optional, Any | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | @partial(dataclass, frozen=True, kw_only=True) | 
					
						
						|  | class JsonComparable: | 
					
						
						|  | def to_json(self) -> str: | 
					
						
						|  | return json.dumps(dataclasses.asdict(self)) | 
					
						
						|  |  | 
					
						
						|  | def __eq__(self, other: "JsonComparable") -> bool: | 
					
						
						|  | return self.to_json() == other.to_json() | 
					
						
						|  |  | 
					
						
						|  | def __hash__(self) -> int: | 
					
						
						|  | return hash(self.to_json()) | 
					
						
						|  |  | 
					
						
						|  | def __lt__(self, other: "JsonComparable") -> bool: | 
					
						
						|  | return self.to_json() < other.to_json() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | @partial(dataclass, frozen=True, kw_only=True) | 
					
						
						|  | class SubblockConfig(JsonComparable): | 
					
						
						|  | no_op: bool = False | 
					
						
						|  | replace_with_linear: bool = False | 
					
						
						|  | sparsify: Optional[list[str]] = None | 
					
						
						|  |  | 
					
						
						|  | def __post_init__(self): | 
					
						
						|  | assert not (self.no_op and self.replace_with_linear) | 
					
						
						|  |  | 
					
						
						|  | def _force_setattr(self, name: str, value: Any) -> None: | 
					
						
						|  | """ | 
					
						
						|  | Set an attribute even in frozen dataclasses. | 
					
						
						|  | Use only inside __post_init__! | 
					
						
						|  | """ | 
					
						
						|  | object.__setattr__(self, name, value) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | @partial(dataclass, frozen=True, kw_only=True) | 
					
						
						|  | class AttentionConfig(SubblockConfig): | 
					
						
						|  | n_heads_in_group: Optional[int] = None | 
					
						
						|  | window_length: Optional[int] = None | 
					
						
						|  | num_sink_tokens: Optional[int] = None | 
					
						
						|  | use_prefill_window_in_sink_attention: bool = False | 
					
						
						|  | unshifted_sink: bool = False | 
					
						
						|  |  | 
					
						
						|  | def __post_init__(self): | 
					
						
						|  | super().__post_init__() | 
					
						
						|  | assert not (self.no_op and self.replace_with_linear) | 
					
						
						|  |  | 
					
						
						|  | if self.no_op or self.replace_with_linear: | 
					
						
						|  | for irrelevant_att in ["n_heads_in_group", "window_length", "num_sink_tokens"]: | 
					
						
						|  | self._force_setattr(irrelevant_att, None) | 
					
						
						|  | else: | 
					
						
						|  | assert self.n_heads_in_group is not None | 
					
						
						|  |  | 
					
						
						|  | if self.is_sink: | 
					
						
						|  | assert not (self.unshifted_sink and self.use_prefill_window_in_sink_attention), \ | 
					
						
						|  | ("Unshifted sink uses its own kind of explicit masking, not standard window. " | 
					
						
						|  | "Set use_prefill_window_in_sink_attention to False.") | 
					
						
						|  | assert not (self.num_sink_tokens == 0 and not self.unshifted_sink), \ | 
					
						
						|  | "Fake sink attention with 0 sink tokens is only supported with unshifted_sink=True" | 
					
						
						|  |  | 
					
						
						|  | @property | 
					
						
						|  | def prefill_sliding_window(self) -> Optional[int]: | 
					
						
						|  | if self.window_length is not None: | 
					
						
						|  | if not self.is_sink or self.use_prefill_window_in_sink_attention: | 
					
						
						|  | return self.window_length | 
					
						
						|  | return None | 
					
						
						|  |  | 
					
						
						|  | @property | 
					
						
						|  | def is_sliding(self) -> bool: | 
					
						
						|  | return self.prefill_sliding_window is not None | 
					
						
						|  |  | 
					
						
						|  | @property | 
					
						
						|  | def is_sink(self) -> bool: | 
					
						
						|  | return ( | 
					
						
						|  | (self.window_length is not None) | 
					
						
						|  | and | 
					
						
						|  | (self.num_sink_tokens is not None) | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | @partial(dataclass, frozen=True, kw_only=True) | 
					
						
						|  | class FFNConfig(SubblockConfig): | 
					
						
						|  | ffn_mult: Optional[float] = None | 
					
						
						|  |  | 
					
						
						|  | def __post_init__(self): | 
					
						
						|  | super().__post_init__() | 
					
						
						|  | if self.no_op or self.replace_with_linear: | 
					
						
						|  | self._force_setattr("ffn_mult", None) | 
					
						
						|  | else: | 
					
						
						|  | assert self.ffn_mult is not None | 
					
						
						|  | self._force_setattr("ffn_mult", round(self.ffn_mult, 6)) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | @partial(dataclass, frozen=True, kw_only=True) | 
					
						
						|  | class BlockConfig(JsonComparable): | 
					
						
						|  | attention: AttentionConfig = MISSING | 
					
						
						|  | ffn: FFNConfig = MISSING | 
					
						
						|  |  | 
					
						
						|  | def __post_init__(self): | 
					
						
						|  | """ | 
					
						
						|  | Init subblock dataclasses from dicts | 
					
						
						|  | """ | 
					
						
						|  | for subblock_name in dataclasses.fields(self): | 
					
						
						|  | subblock_config = getattr(self, subblock_name.name) | 
					
						
						|  | if isinstance(subblock_config, dict): | 
					
						
						|  | subblock_fields = [field.name for field in dataclasses.fields(subblock_name.type)] | 
					
						
						|  | unsupported_fields = [field_name for field_name in subblock_config.keys() | 
					
						
						|  | if field_name not in subblock_fields] | 
					
						
						|  | if len(unsupported_fields) > 0: | 
					
						
						|  | warnings.warn(f"Removed unsupported fields {unsupported_fields} from {subblock_name.type.__name__}") | 
					
						
						|  | subblock_config = {k: v for k, v in subblock_config.items() if k not in unsupported_fields} | 
					
						
						|  | object.__setattr__(self, subblock_name.name, | 
					
						
						|  | subblock_name.type(**subblock_config)) | 
					
						
						|  |  |