Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| import re | |
| """ | |
| Extracts code from the file "./Libraries.ts". | |
| (Note that "Libraries.ts", must be in the same directory as | |
| this script). | |
| """ | |
| file = None | |
| def read_file(library: str, model_name: str) -> str: | |
| text = file | |
| match = re.search('const ' + library + '.*', text, re.DOTALL).group() | |
| if match: | |
| text = match[match.index('`') + 1:match.index('`;')].replace('${model.id}', model_name) | |
| return text | |
| file = """ | |
| import type { ModelData } from "./Types"; | |
| /** | |
| * Add your new library here. | |
| */ | |
| export enum ModelLibrary { | |
| "adapter-transformers" = "Adapter Transformers", | |
| "allennlp" = "allenNLP", | |
| "asteroid" = "Asteroid", | |
| "diffusers" = "Diffusers", | |
| "espnet" = "ESPnet", | |
| "fairseq" = "Fairseq", | |
| "flair" = "Flair", | |
| "keras" = "Keras", | |
| "nemo" = "NeMo", | |
| "pyannote-audio" = "pyannote.audio", | |
| "sentence-transformers" = "Sentence Transformers", | |
| "sklearn" = "Scikit-learn", | |
| "spacy" = "spaCy", | |
| "speechbrain" = "speechbrain", | |
| "tensorflowtts" = "TensorFlowTTS", | |
| "timm" = "Timm", | |
| "fastai" = "fastai", | |
| "transformers" = "Transformers", | |
| "stanza" = "Stanza", | |
| "fasttext" = "fastText", | |
| "stable-baselines3" = "Stable-Baselines3", | |
| "ml-agents" = "ML-Agents", | |
| } | |
| export const ALL_MODEL_LIBRARY_KEYS = Object.keys(ModelLibrary) as (keyof typeof ModelLibrary)[]; | |
| /** | |
| * Elements configurable by a model library. | |
| */ | |
| export interface LibraryUiElement { | |
| /** | |
| * Name displayed on the main | |
| * call-to-action button on the model page. | |
| */ | |
| btnLabel: string; | |
| /** | |
| * Repo name | |
| */ | |
| repoName: string; | |
| /** | |
| * URL to library's repo | |
| */ | |
| repoUrl: string; | |
| /** | |
| * Code snippet displayed on model page | |
| */ | |
| snippet: (model: ModelData) => string; | |
| } | |
| function nameWithoutNamespace(modelId: string): string { | |
| const splitted = modelId.split("/"); | |
| return splitted.length === 1 ? splitted[0] : splitted[1]; | |
| } | |
| //#region snippets | |
| const adapter_transformers = (model: ModelData) => | |
| `from transformers import ${model.config?.adapter_transformers?.model_class} | |
| model = ${model.config?.adapter_transformers?.model_class}.from_pretrained("${model.config?.adapter_transformers?.{model.id}}") | |
| model.load_adapter("${model.id}", source="hf")`; | |
| const allennlpUnknown = (model: ModelData) => | |
| `import allennlp_models | |
| from allennlp.predictors.predictor import Predictor | |
| predictor = Predictor.from_path("hf://${model.id}")`; | |
| const allennlpQuestionAnswering = (model: ModelData) => | |
| `import allennlp_models | |
| from allennlp.predictors.predictor import Predictor | |
| predictor = Predictor.from_path("hf://${model.id}") | |
| predictor_input = {"passage": "My name is Wolfgang and I live in Berlin", "question": "Where do I live?"} | |
| predictions = predictor.predict_json(predictor_input)`; | |
| const allennlp = (model: ModelData) => { | |
| if (model.tags?.includes("question-answering")) { | |
| return allennlpQuestionAnswering(model); | |
| } | |
| return allennlpUnknown(model); | |
| }; | |
| const asteroid = (model: ModelData) => | |
| `from asteroid.models import BaseModel | |
| model = BaseModel.from_pretrained("${model.id}")`; | |
| const diffusers = (model: ModelData) => | |
| `from diffusers import DiffusionPipeline | |
| pipeline = DiffusionPipeline.from_pretrained("${model.id}"${model.private ? ", use_auth_token=True" : ""})`; | |
| const espnetTTS = (model: ModelData) => | |
| `from espnet2.bin.tts_inference import Text2Speech | |
| model = Text2Speech.from_pretrained("${model.id}") | |
| speech, *_ = model("text to generate speech from")`; | |
| const espnetASR = (model: ModelData) => | |
| `from espnet2.bin.asr_inference import Speech2Text | |
| model = Speech2Text.from_pretrained( | |
| "${model.id}" | |
| ) | |
| speech, rate = soundfile.read("speech.wav") | |
| text, *_ = model(speech)`; | |
| const espnetUnknown = () => | |
| `unknown model type (must be text-to-speech or automatic-speech-recognition)`; | |
| const espnet = (model: ModelData) => { | |
| if (model.tags?.includes("text-to-speech")) { | |
| return espnetTTS(model); | |
| } else if (model.tags?.includes("automatic-speech-recognition")) { | |
| return espnetASR(model); | |
| } | |
| return espnetUnknown(); | |
| }; | |
| const fairseq = (model: ModelData) => | |
| `from fairseq.checkpoint_utils import load_model_ensemble_and_task_from_hf_hub | |
| models, cfg, task = load_model_ensemble_and_task_from_hf_hub( | |
| "${model.id}" | |
| )`; | |
| const flair = (model: ModelData) => | |
| `from flair.models import SequenceTagger | |
| tagger = SequenceTagger.load("${model.id}")`; | |
| const keras = (model: ModelData) => | |
| `from huggingface_hub import from_pretrained_keras | |
| model = from_pretrained_keras("${model.id}") | |
| `; | |
| const pyannote_audio_pipeline = (model: ModelData) => | |
| `from pyannote.audio import Pipeline | |
| pipeline = Pipeline.from_pretrained("${model.id}") | |
| # inference on the whole file | |
| pipeline("file.wav") | |
| # inference on an excerpt | |
| from pyannote.core import Segment | |
| excerpt = Segment(start=2.0, end=5.0) | |
| from pyannote.audio import Audio | |
| waveform, sample_rate = Audio().crop("file.wav", excerpt) | |
| pipeline({"waveform": waveform, "sample_rate": sample_rate})`; | |
| const pyannote_audio_model = (model: ModelData) => | |
| `from pyannote.audio import Model, Inference | |
| model = Model.from_pretrained("${model.id}") | |
| inference = Inference(model) | |
| # inference on the whole file | |
| inference("file.wav") | |
| # inference on an excerpt | |
| from pyannote.core import Segment | |
| excerpt = Segment(start=2.0, end=5.0) | |
| inference.crop("file.wav", excerpt)`; | |
| const pyannote_audio = (model: ModelData) => { | |
| if (model.tags?.includes("pyannote-audio-pipeline")) { | |
| return pyannote_audio_pipeline(model); | |
| } | |
| return pyannote_audio_model(model); | |
| }; | |
| const tensorflowttsTextToMel = (model: ModelData) => | |
| `from tensorflow_tts.inference import AutoProcessor, TFAutoModel | |
| processor = AutoProcessor.from_pretrained("${model.id}") | |
| model = TFAutoModel.from_pretrained("${model.id}") | |
| `; | |
| const tensorflowttsMelToWav = (model: ModelData) => | |
| `from tensorflow_tts.inference import TFAutoModel | |
| model = TFAutoModel.from_pretrained("${model.id}") | |
| audios = model.inference(mels) | |
| `; | |
| const tensorflowttsUnknown = (model: ModelData) => | |
| `from tensorflow_tts.inference import TFAutoModel | |
| model = TFAutoModel.from_pretrained("${model.id}") | |
| `; | |
| const tensorflowtts = (model: ModelData) => { | |
| if (model.tags?.includes("text-to-mel")) { | |
| return tensorflowttsTextToMel(model); | |
| } else if (model.tags?.includes("mel-to-wav")) { | |
| return tensorflowttsMelToWav(model); | |
| } | |
| return tensorflowttsUnknown(model); | |
| }; | |
| const timm = (model: ModelData) => | |
| `import timm | |
| model = timm.create_model("hf_hub:${model.id}", pretrained=True)`; | |
| const sklearn = (model: ModelData) => | |
| `from huggingface_hub import hf_hub_download | |
| import joblib | |
| model = joblib.load( | |
| hf_hub_download("${model.id}", "sklearn_model.joblib") | |
| )`; | |
| const fastai = (model: ModelData) => | |
| `from huggingface_hub import from_pretrained_fastai | |
| learn = from_pretrained_fastai("${model.id}")`; | |
| const sentenceTransformers = (model: ModelData) => | |
| `from sentence_transformers import SentenceTransformer | |
| model = SentenceTransformer("${model.id}")`; | |
| const spacy = (model: ModelData) => | |
| `!pip install https://huggingface.co/${model.id}/resolve/main/${nameWithoutNamespace(model.id)}-any-py3-none-any.whl | |
| # Using spacy.load(). | |
| import spacy | |
| nlp = spacy.load("${nameWithoutNamespace(model.id)}") | |
| # Importing as module. | |
| import ${nameWithoutNamespace(model.id)} | |
| nlp = ${nameWithoutNamespace(model.id)}.load()`; | |
| const stanza = (model: ModelData) => | |
| `import stanza | |
| stanza.download("${nameWithoutNamespace(model.id).replace("stanza-", "")}") | |
| nlp = stanza.Pipeline("${nameWithoutNamespace(model.id).replace("stanza-", "")}")`; | |
| const speechBrainMethod = (speechbrainInterface: string) => { | |
| switch (speechbrainInterface) { | |
| case "EncoderClassifier": | |
| return "classify_file"; | |
| case "EncoderDecoderASR": | |
| case "EncoderASR": | |
| return "transcribe_file"; | |
| case "SpectralMaskEnhancement": | |
| return "enhance_file"; | |
| case "SepformerSeparation": | |
| return "separate_file"; | |
| default: | |
| return undefined; | |
| } | |
| }; | |
| const speechbrain = (model: ModelData) => { | |
| const speechbrainInterface = model.config?.speechbrain?.interface; | |
| if (speechbrainInterface === undefined) { | |
| return `# interface not specified in config.json`; | |
| } | |
| const speechbrainMethod = speechBrainMethod(speechbrainInterface); | |
| if (speechbrainMethod === undefined) { | |
| return `# interface in config.json invalid`; | |
| } | |
| return `from speechbrain.pretrained import ${speechbrainInterface} | |
| model = ${speechbrainInterface}.from_hparams( | |
| "${model.id}" | |
| ) | |
| model.${speechbrainMethod}("file.wav")`; | |
| }; | |
| const transformers = (model: ModelData) => { | |
| const info = model.transformersInfo; | |
| if (!info) { | |
| return `# ⚠️ Type of model unknown`; | |
| } | |
| if (info.processor) { | |
| const varName = info.processor === "AutoTokenizer" ? "tokenizer" | |
| : info.processor === "AutoFeatureExtractor" ? "extractor" | |
| : "processor" | |
| ; | |
| return [ | |
| `from transformers import ${info.processor}, ${info.auto_model}`, | |
| "", | |
| `${varName} = ${info.processor}.from_pretrained("${model.id}"${model.private ? ", use_auth_token=True" : ""})`, | |
| "", | |
| `model = ${info.auto_model}.from_pretrained("${model.id}"${model.private ? ", use_auth_token=True" : ""})`, | |
| ].join("\n"); | |
| } else { | |
| return [ | |
| `from transformers import ${info.auto_model}`, | |
| "", | |
| `model = ${info.auto_model}.from_pretrained("${model.id}"${model.private ? ", use_auth_token=True" : ""})`, | |
| ].join("\n"); | |
| } | |
| }; | |
| const fasttext = (model: ModelData) => | |
| `from huggingface_hub import hf_hub_download | |
| import fasttext | |
| model = fasttext.load_model(hf_hub_download("${model.id}", "model.bin"))`; | |
| const stableBaselines3 = (model: ModelData) => | |
| `from huggingface_sb3 import load_from_hub | |
| checkpoint = load_from_hub( | |
| repo_id="${model.id}", | |
| filename="{MODEL FILENAME}.zip", | |
| )`; | |
| const nemoDomainResolver = (domain: string, model: ModelData): string | undefined => { | |
| const modelName = `${nameWithoutNamespace(model.id)}.nemo`; | |
| switch (domain) { | |
| case "ASR": | |
| return `import nemo.collections.asr as nemo_asr | |
| asr_model = nemo_asr.models.ASRModel.from_pretrained("${model.id}") | |
| transcriptions = asr_model.transcribe(["file.wav"])`; | |
| default: | |
| return undefined; | |
| } | |
| }; | |
| const mlAgents = (model: ModelData) => | |
| `mlagents-load-from-hf --repo-id="${model.id}" --local-dir="./downloads"`; | |
| const nemo = (model: ModelData) => { | |
| let command: string | undefined = undefined; | |
| // Resolve the tag to a nemo domain/sub-domain | |
| if (model.tags?.includes("automatic-speech-recognition")) { | |
| command = nemoDomainResolver("ASR", model); | |
| } | |
| return command ?? `# tag did not correspond to a valid NeMo domain.`; | |
| }; | |
| //#endregion | |
| export const MODEL_LIBRARIES_UI_ELEMENTS: { [key in keyof typeof ModelLibrary]?: LibraryUiElement } = { | |
| // ^^ TODO(remove the optional ? marker when Stanza snippet is available) | |
| "adapter-transformers": { | |
| btnLabel: "Adapter Transformers", | |
| repoName: "adapter-transformers", | |
| repoUrl: "https://github.com/Adapter-Hub/adapter-transformers", | |
| snippet: adapter_transformers, | |
| }, | |
| "allennlp": { | |
| btnLabel: "AllenNLP", | |
| repoName: "AllenNLP", | |
| repoUrl: "https://github.com/allenai/allennlp", | |
| snippet: allennlp, | |
| }, | |
| "asteroid": { | |
| btnLabel: "Asteroid", | |
| repoName: "Asteroid", | |
| repoUrl: "https://github.com/asteroid-team/asteroid", | |
| snippet: asteroid, | |
| }, | |
| "diffusers": { | |
| btnLabel: "Diffusers", | |
| repoName: "🤗/diffusers", | |
| repoUrl: "https://github.com/huggingface/diffusers", | |
| snippet: diffusers, | |
| }, | |
| "espnet": { | |
| btnLabel: "ESPnet", | |
| repoName: "ESPnet", | |
| repoUrl: "https://github.com/espnet/espnet", | |
| snippet: espnet, | |
| }, | |
| "fairseq": { | |
| btnLabel: "Fairseq", | |
| repoName: "fairseq", | |
| repoUrl: "https://github.com/pytorch/fairseq", | |
| snippet: fairseq, | |
| }, | |
| "flair": { | |
| btnLabel: "Flair", | |
| repoName: "Flair", | |
| repoUrl: "https://github.com/flairNLP/flair", | |
| snippet: flair, | |
| }, | |
| "keras": { | |
| btnLabel: "Keras", | |
| repoName: "Keras", | |
| repoUrl: "https://github.com/keras-team/keras", | |
| snippet: keras, | |
| }, | |
| "nemo": { | |
| btnLabel: "NeMo", | |
| repoName: "NeMo", | |
| repoUrl: "https://github.com/NVIDIA/NeMo", | |
| snippet: nemo, | |
| }, | |
| "pyannote-audio": { | |
| btnLabel: "pyannote.audio", | |
| repoName: "pyannote-audio", | |
| repoUrl: "https://github.com/pyannote/pyannote-audio", | |
| snippet: pyannote_audio, | |
| }, | |
| "sentence-transformers": { | |
| btnLabel: "sentence-transformers", | |
| repoName: "sentence-transformers", | |
| repoUrl: "https://github.com/UKPLab/sentence-transformers", | |
| snippet: sentenceTransformers, | |
| }, | |
| "sklearn": { | |
| btnLabel: "Scikit-learn", | |
| repoName: "Scikit-learn", | |
| repoUrl: "https://github.com/scikit-learn/scikit-learn", | |
| snippet: sklearn, | |
| }, | |
| "fastai": { | |
| btnLabel: "fastai", | |
| repoName: "fastai", | |
| repoUrl: "https://github.com/fastai/fastai", | |
| snippet: fastai, | |
| }, | |
| "spacy": { | |
| btnLabel: "spaCy", | |
| repoName: "spaCy", | |
| repoUrl: "https://github.com/explosion/spaCy", | |
| snippet: spacy, | |
| }, | |
| "speechbrain": { | |
| btnLabel: "speechbrain", | |
| repoName: "speechbrain", | |
| repoUrl: "https://github.com/speechbrain/speechbrain", | |
| snippet: speechbrain, | |
| }, | |
| "stanza": { | |
| btnLabel: "Stanza", | |
| repoName: "stanza", | |
| repoUrl: "https://github.com/stanfordnlp/stanza", | |
| snippet: stanza, | |
| }, | |
| "tensorflowtts": { | |
| btnLabel: "TensorFlowTTS", | |
| repoName: "TensorFlowTTS", | |
| repoUrl: "https://github.com/TensorSpeech/TensorFlowTTS", | |
| snippet: tensorflowtts, | |
| }, | |
| "timm": { | |
| btnLabel: "timm", | |
| repoName: "pytorch-image-models", | |
| repoUrl: "https://github.com/rwightman/pytorch-image-models", | |
| snippet: timm, | |
| }, | |
| "transformers": { | |
| btnLabel: "Transformers", | |
| repoName: "🤗/transformers", | |
| repoUrl: "https://github.com/huggingface/transformers", | |
| snippet: transformers, | |
| }, | |
| "fasttext": { | |
| btnLabel: "fastText", | |
| repoName: "fastText", | |
| repoUrl: "https://fasttext.cc/", | |
| snippet: fasttext, | |
| }, | |
| "stable-baselines3": { | |
| btnLabel: "stable-baselines3", | |
| repoName: "stable-baselines3", | |
| repoUrl: "https://github.com/huggingface/huggingface_sb3", | |
| snippet: stableBaselines3, | |
| }, | |
| "ml-agents": { | |
| btnLabel: "ml-agents", | |
| repoName: "ml-agents", | |
| repoUrl: "https://github.com/huggingface/ml-agents", | |
| snippet: mlAgents, | |
| }, | |
| } as const; | |
| """ | |
| if __name__ == '__main__': | |
| import sys | |
| library_name = "keras" | |
| model_name = "Distillgpt2" | |
| print(read_file(library_name, model_name)) | |
| """" | |
| try: | |
| args = sys.argv[1:] | |
| if args: | |
| print(read_file(args[0], args[1])) | |
| except IndexError: | |
| pass | |
| """ |