bhardwaj08sarthak commited on
Commit
e02f85c
·
verified ·
1 Parent(s): 2ec4936

Update level_classifier_tool_2.py

Browse files
Files changed (1) hide show
  1. level_classifier_tool_2.py +35 -14
level_classifier_tool_2.py CHANGED
@@ -21,20 +21,41 @@ class HFEmbeddingBackend:
21
  MODEL: Any = field(init=False, repr=False)
22
 
23
  def __post_init__(self):
24
- os.environ.setdefault("SPACES_ZERO_DISABLED", "1")
25
- try:
26
- torch.backends.cuda.sdp_kernel(enable_math=True, enable_flash=False, enable_mem_efficient=False)
27
- except Exception:
28
- pass
29
-
30
- self.TOK = AutoTokenizer.from_pretrained(self.model_name)
31
- self.MODEL = AutoModel.from_pretrained(self.model_name, attn_implementation="eager")
32
- try:
33
- self.MODEL.config.attn_implementation = "eager"
34
- except Exception:
35
- pass
36
-
37
- self.MODEL.to(self.device).eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  def encode(self, texts: Iterable[str], batch_size: int = 32) -> "Tuple[torch.Tensor, List[str]]":
40
  """
 
21
  MODEL: Any = field(init=False, repr=False)
22
 
23
  def __post_init__(self):
24
+ # 1) Try to disable Spaces ZeroGPU monkey-patch proactively
25
+ os.environ.setdefault("SPACES_ZERO_DISABLED", "1")
26
+ try:
27
+ # If Spaces was already imported somewhere, explicitly disable its patch.
28
+ from spaces import zero as _spaces_zero # safe import; no-op if not installed
29
+ if hasattr(_spaces_zero, "disable"):
30
+ _spaces_zero.disable()
31
+ except Exception:
32
+ pass
33
+
34
+ # 2) Keep attention off Flash/MemEfficient (avoid vectorized mask paths)
35
+ try:
36
+ torch.backends.cuda.sdp_kernel(enable_math=True, enable_flash=False, enable_mem_efficient=False)
37
+ except Exception:
38
+ pass
39
+
40
+ # 3) Load tokenizer/model and force eager attention (non-vmap route)
41
+ self.TOK = AutoTokenizer.from_pretrained(self.model_name)
42
+ self.MODEL = AutoModel.from_pretrained(self.model_name, attn_implementation="eager")
43
+
44
+ # (extra safety) disable any sliding/windowed attention that can trigger the vmap mask path
45
+ try:
46
+ if hasattr(self.MODEL.config, "sliding_window"):
47
+ self.MODEL.config.sliding_window = None
48
+ if hasattr(self.MODEL, "generation_config") and hasattr(self.MODEL.generation_config, "sliding_window"):
49
+ self.MODEL.generation_config.sliding_window = None
50
+ except Exception:
51
+ pass
52
+
53
+ try:
54
+ self.MODEL.config.attn_implementation = "eager"
55
+ except Exception:
56
+ pass
57
+
58
+ self.MODEL.to(self.device).eval()
59
 
60
  def encode(self, texts: Iterable[str], batch_size: int = 32) -> "Tuple[torch.Tensor, List[str]]":
61
  """