from transformers import Pipeline, AutoModelForTokenClassification import numpy as np from eval import retrieve_predictions, align_tokens_labels_from_wordids from reading import read_dataset from utils import read_config def write_sentences_to_format(sentences: list[str], filename: str): """ Écrit une phrase dans un fichier, un mot par ligne, avec le format : indexmot_______Seg=... """ if not sentences: return "" if isinstance(sentences, str): sentences=[sentences] import sys sys.stderr.write("Warning: only one sentence provided as a string instead of a list of sentences.\n") full="# newdoc_id = GUM_academic_discrimination\n" for sentence in sentences: words = sentence.strip().split() for i, word in enumerate(words, start=1): # Le premier mot → B-seg, sinon O seg_label = "B-seg" if i == 1 or word[0].isupper() else "O" line = f"{i}\t{word}\t_\t_\t_\t_\t_\t_\t_\tSeg={seg_label}\n" full+=line if filename: with open(filename, "w", encoding="utf-8") as f: f.write(full) return full class DiscoursePipeline(Pipeline): def __init__(self, model_id, tokenizer, output_folder="./pipe_out",sat_model:str="sat-3l", **kwargs): auto_model = AutoModelForTokenClassification.from_pretrained(model_id) super().__init__(model=auto_model, tokenizer=tokenizer, **kwargs) self.config = {"model_checkpoint": model_id, "sent_spliter":"sat","task":"seg","type":"tok","trace":False,"report_to":"none","sat_model":sat_model,"tok_config":{ "padding":"max_length", "truncation":True, "max_length": 512 }} self.model = model_id self.output_folder = output_folder def _sanitize_parameters(self, **kwargs): # Permet de passer des paramètres optionnels comme add_lang_token etc. preprocess_params = {} forward_params = {} postprocess_params = {} return preprocess_params, forward_params, postprocess_params def preprocess(self, text:str): self.original_text=text formatted_text=write_sentences_to_format(text.split("\n"), filename=None) dataset, _ = read_dataset( formatted_text, output_path=self.output_folder, config=self.config, add_lang_token=True, add_frame_token=True, ) return {"dataset": dataset} def _forward(self, inputs): dataset = inputs["dataset"] preds_from_model, label_ids, _ = retrieve_predictions( self.model, dataset, self.output_folder, self.tokenizer, self.config ) return {"preds": preds_from_model, "labels": label_ids, "dataset": dataset} def postprocess(self, outputs): preds = np.argmax(outputs["preds"], axis=-1) predictions = align_tokens_labels_from_wordids(preds, outputs["dataset"], self.tokenizer) edus=text_to_edus(self.original_text, predictions) return edus def get_plain_text_from_format(formatted_text:str) -> str: """ Lit un fichier conllu ou tok et retourne son contenu sous forme de chaîne de caractères. """ formatted_text=formatted_text.split("\n") s="" for line in formatted_text: if not line.startswith("#"): if len(line.split("\t"))>1: s+=line.split("\t")[1]+" " return s.strip() def get_preds_from_format(formatted_text:str) -> str: """ Lit un fichier conllu ou tok et retourne son contenu sous forme de chaîne de caractères. """ formatted_text=formatted_text.split("\n") s="" for line in formatted_text: if not line.startswith("#"): if len(line.split("\t"))>1: s+=line.split("\t")[-1]+" " return s.strip() def text_to_edus(text: str, labels: list[str]) -> list[str]: """ Découpe un texte brut en EDUs à partir d'une séquence de labels BIO. Args: text (str): Le texte brut (séquence de mots séparés par des espaces). labels (list[str]): La séquence de labels BIO (B, I, O), de même longueur que le nombre de tokens du texte. Returns: list[str]: La liste des EDUs (chaque EDU est une sous-chaîne du texte). """ words = text.strip().split() if len(words) != len(labels): raise ValueError(f"Longueur mismatch: {len(words)} mots vs {len(labels)} labels") edus = [] current_edu = [] for word, label in zip(words, labels): if label == "Conn=O" or label == "Seg=O": current_edu.append(word) elif label == "Conn=B-conn" or label == "Seg=B-seg": # Finir l'EDU courant si ouvert if current_edu: edus.append(" ".join(current_edu)) current_edu = [] current_edu.append(word) # Si un EDU est resté ouvert, on le ferme if current_edu: edus.append(" ".join(current_edu)) return edus