Spaces:
Runtime error
Runtime error
| import requests | |
| import torch | |
| from PIL import Image | |
| import hashlib | |
| import tempfile | |
| import unittest | |
| from io import BytesIO | |
| from pathlib import Path | |
| from unittest.mock import patch | |
| from urllib3 import HTTPResponse | |
| from urllib3._collections import HTTPHeaderDict | |
| import open_clip | |
| from open_clip.pretrained import download_pretrained_from_url | |
| class DownloadPretrainedTests(unittest.TestCase): | |
| def create_response(self, data, status_code=200, content_type='application/octet-stream'): | |
| fp = BytesIO(data) | |
| headers = HTTPHeaderDict({ | |
| 'Content-Type': content_type, | |
| 'Content-Length': str(len(data)) | |
| }) | |
| raw = HTTPResponse(fp, preload_content=False, headers=headers, status=status_code) | |
| return raw | |
| def test_download_pretrained_from_url_from_openaipublic(self, urllib): | |
| file_contents = b'pretrained model weights' | |
| expected_hash = hashlib.sha256(file_contents).hexdigest() | |
| urllib.request.urlopen.return_value = self.create_response(file_contents) | |
| with tempfile.TemporaryDirectory() as root: | |
| url = f'https://openaipublic.azureedge.net/clip/models/{expected_hash}/RN50.pt' | |
| download_pretrained_from_url(url, root) | |
| urllib.request.urlopen.assert_called_once() | |
| def test_download_pretrained_from_url_from_openaipublic_corrupted(self, urllib): | |
| file_contents = b'pretrained model weights' | |
| expected_hash = hashlib.sha256(file_contents).hexdigest() | |
| urllib.request.urlopen.return_value = self.create_response(b'corrupted pretrained model') | |
| with tempfile.TemporaryDirectory() as root: | |
| url = f'https://openaipublic.azureedge.net/clip/models/{expected_hash}/RN50.pt' | |
| with self.assertRaisesRegex(RuntimeError, r'checksum does not not match'): | |
| download_pretrained_from_url(url, root) | |
| urllib.request.urlopen.assert_called_once() | |
| def test_download_pretrained_from_url_from_openaipublic_valid_cache(self, urllib): | |
| file_contents = b'pretrained model weights' | |
| expected_hash = hashlib.sha256(file_contents).hexdigest() | |
| urllib.request.urlopen.return_value = self.create_response(file_contents) | |
| with tempfile.TemporaryDirectory() as root: | |
| local_file = Path(root) / 'RN50.pt' | |
| local_file.write_bytes(file_contents) | |
| url = f'https://openaipublic.azureedge.net/clip/models/{expected_hash}/RN50.pt' | |
| download_pretrained_from_url(url, root) | |
| urllib.request.urlopen.assert_not_called() | |
| def test_download_pretrained_from_url_from_openaipublic_corrupted_cache(self, urllib): | |
| file_contents = b'pretrained model weights' | |
| expected_hash = hashlib.sha256(file_contents).hexdigest() | |
| urllib.request.urlopen.return_value = self.create_response(file_contents) | |
| with tempfile.TemporaryDirectory() as root: | |
| local_file = Path(root) / 'RN50.pt' | |
| local_file.write_bytes(b'corrupted pretrained model') | |
| url = f'https://openaipublic.azureedge.net/clip/models/{expected_hash}/RN50.pt' | |
| download_pretrained_from_url(url, root) | |
| urllib.request.urlopen.assert_called_once() | |
| def test_download_pretrained_from_url_from_mlfoundations(self, urllib): | |
| file_contents = b'pretrained model weights' | |
| expected_hash = hashlib.sha256(file_contents).hexdigest()[:8] | |
| urllib.request.urlopen.return_value = self.create_response(file_contents) | |
| with tempfile.TemporaryDirectory() as root: | |
| url = f'https://github.com/mlfoundations/download/v0.2-weights/rn50-quickgelu-{expected_hash}.pt' | |
| download_pretrained_from_url(url, root) | |
| urllib.request.urlopen.assert_called_once() | |
| def test_download_pretrained_from_url_from_mlfoundations_corrupted(self, urllib): | |
| file_contents = b'pretrained model weights' | |
| expected_hash = hashlib.sha256(file_contents).hexdigest()[:8] | |
| urllib.request.urlopen.return_value = self.create_response(b'corrupted pretrained model') | |
| with tempfile.TemporaryDirectory() as root: | |
| url = f'https://github.com/mlfoundations/download/v0.2-weights/rn50-quickgelu-{expected_hash}.pt' | |
| with self.assertRaisesRegex(RuntimeError, r'checksum does not not match'): | |
| download_pretrained_from_url(url, root) | |
| urllib.request.urlopen.assert_called_once() | |
| def test_download_pretrained_from_hfh(self, urllib): | |
| model, _, preprocess = open_clip.create_model_and_transforms('hf-hub:hf-internal-testing/tiny-open-clip-model') | |
| tokenizer = open_clip.get_tokenizer('hf-hub:hf-internal-testing/tiny-open-clip-model') | |
| img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/coco_sample.png" | |
| image = preprocess(Image.open(requests.get(img_url, stream=True).raw)).unsqueeze(0) | |
| text = tokenizer(["a diagram", "a dog", "a cat"]) | |
| with torch.no_grad(): | |
| image_features = model.encode_image(image) | |
| text_features = model.encode_text(text) | |
| image_features /= image_features.norm(dim=-1, keepdim=True) | |
| text_features /= text_features.norm(dim=-1, keepdim=True) | |
| text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1) | |
| self.assertTrue(torch.allclose(text_probs, torch.tensor([[0.0597, 0.6349, 0.3053]]), 1e-3)) | |