Spaces:
Running
Running
| import asyncio | |
| import base64 | |
| import logging | |
| import re | |
| from pathlib import Path | |
| from typing import Dict, List, Optional, Tuple | |
| import aiohttp | |
| import requests | |
| from llama_index.core.schema import Document | |
| logger = logging.getLogger(__name__) | |
| class GithubFileLoader: | |
| """ | |
| GitHub file loader that fetches specific files asynchronously. | |
| Returns LlamaIndex Document objects for each successfully loaded file. | |
| """ | |
| def __init__( | |
| self, | |
| github_token: Optional[str] = None, | |
| concurrent_requests: int = 10, | |
| timeout: int = 30, | |
| retries: int = 3, | |
| ): | |
| """ | |
| Initialize GitHub file loader. | |
| Args: | |
| github_token: GitHub API token for higher rate limits | |
| concurrent_requests: Number of concurrent requests | |
| timeout: Request timeout in seconds | |
| retries: Number of retry attempts for failed requests | |
| """ | |
| self.github_token = github_token | |
| self.concurrent_requests = concurrent_requests | |
| self.timeout = timeout | |
| self.retries = retries | |
| # Setup headers | |
| self.headers = { | |
| "Accept": "application/vnd.github.v3+json", | |
| "User-Agent": "LlamaIndex-GitHub-Loader/1.0", | |
| } | |
| if self.github_token: | |
| self.headers["Authorization"] = f"token {self.github_token}" | |
| def fetch_repository_files( | |
| self, | |
| repo_url: str, | |
| file_extensions: List[str] = [".md", ".mdx"], | |
| branch: str = "main", | |
| ) -> Tuple[List[str], str]: | |
| """ | |
| Fetch files from GitHub repository using GitHub API | |
| Args: | |
| repo_url: GitHub repository URL or owner/repo format | |
| file_extensions: List of file extensions to filter (e.g., [".md", ".mdx", ".txt"]) | |
| branch: Branch name to fetch from | |
| Returns: | |
| Tuple of (list_of_file_paths, status_message) | |
| """ | |
| try: | |
| # Parse GitHub URL to extract owner and repo | |
| repo_name = self._parse_repo_name(repo_url) | |
| if not repo_name: | |
| return ( | |
| [], | |
| "Invalid GitHub URL format. Use: https://github.com/owner/repo or owner/repo", | |
| ) | |
| # GitHub API endpoint for repository tree | |
| api_url = f"https://api.github.com/repos/{repo_name}/git/trees/{branch}?recursive=1" | |
| # Make request with authentication if token is available | |
| response = requests.get(api_url, headers=self.headers, timeout=self.timeout) | |
| if response.status_code == 200: | |
| data = response.json() | |
| filtered_files = [] | |
| # Filter for specified file extensions | |
| for item in data.get("tree", []): | |
| if item["type"] == "blob": | |
| file_path = item["path"] | |
| # Check if file has any of the specified extensions | |
| if any( | |
| file_path.lower().endswith(ext.lower()) | |
| for ext in file_extensions | |
| ): | |
| filtered_files.append(file_path) | |
| if filtered_files: | |
| ext_str = ", ".join(file_extensions) | |
| return ( | |
| filtered_files, | |
| f"Found {len(filtered_files)} files with extensions ({ext_str}) in {repo_name}/{branch}", | |
| ) | |
| else: | |
| ext_str = ", ".join(file_extensions) | |
| return ( | |
| [], | |
| f"No files with extensions ({ext_str}) found in repository {repo_name}/{branch}", | |
| ) | |
| elif response.status_code == 404: | |
| return ( | |
| [], | |
| f"Repository '{repo_name}' not found or branch '{branch}' doesn't exist", | |
| ) | |
| elif response.status_code == 403: | |
| if "rate limit" in response.text.lower(): | |
| return ( | |
| [], | |
| "GitHub API rate limit exceeded. Consider using a GitHub token.", | |
| ) | |
| else: | |
| return ( | |
| [], | |
| "Access denied. Repository may be private or require authentication.", | |
| ) | |
| else: | |
| return ( | |
| [], | |
| f"GitHub API Error: {response.status_code} - {response.text[:200]}", | |
| ) | |
| except requests.exceptions.Timeout: | |
| return [], f"Request timeout after {self.timeout} seconds" | |
| except requests.exceptions.RequestException as e: | |
| return [], f"Network error: {str(e)}" | |
| except Exception as e: | |
| return [], f"Unexpected error: {str(e)}" | |
| def _parse_repo_name(self, repo_url: str) -> Optional[str]: | |
| """ | |
| Parse repository URL to extract owner/repo format | |
| Args: | |
| repo_url: GitHub repository URL or owner/repo format | |
| Returns: | |
| Repository name in "owner/repo" format or None if invalid | |
| """ | |
| if "github.com" in repo_url: | |
| # Extract from full URL | |
| parts = ( | |
| repo_url.replace("https://github.com/", "") | |
| .replace("http://github.com/", "") | |
| .strip("/") | |
| .split("/") | |
| ) | |
| if len(parts) >= 2: | |
| return f"{parts[0]}/{parts[1]}" | |
| else: | |
| # Assume format is owner/repo | |
| parts = repo_url.strip().split("/") | |
| if len(parts) == 2 and all(part.strip() for part in parts): | |
| return repo_url.strip() | |
| return None | |
| def fetch_markdown_files( | |
| self, repo_url: str, branch: str = "main" | |
| ) -> Tuple[List[str], str]: | |
| """ | |
| Fetch markdown files from GitHub repository (backward compatibility method) | |
| Args: | |
| repo_url: GitHub repository URL or owner/repo format | |
| branch: Branch name to fetch from | |
| Returns: | |
| Tuple of (list_of_markdown_files, status_message) | |
| """ | |
| return self.fetch_repository_files( | |
| repo_url=repo_url, file_extensions=[".md", ".mdx"], branch=branch | |
| ) | |
| async def load_files( | |
| self, repo_name: str, file_paths: List[str], branch: str = "main" | |
| ) -> Tuple[List[Document], List[str]]: | |
| """ | |
| Load files from GitHub repository asynchronously. | |
| Args: | |
| repo_name: Repository name in format "owner/repo" | |
| file_paths: List of file paths to load | |
| branch: Branch name to load from | |
| Returns: | |
| Tuple of (successfully_loaded_documents, failed_file_paths) | |
| """ | |
| if not file_paths: | |
| return [], [] | |
| # Validate repo name format | |
| if not re.match(r"^[^/]+/[^/]+$", repo_name): | |
| raise ValueError(f"Invalid repo format: {repo_name}. Expected 'owner/repo'") | |
| # Create semaphore to limit concurrent requests | |
| semaphore = asyncio.Semaphore(self.concurrent_requests) | |
| # Create session | |
| connector = aiohttp.TCPConnector(limit=self.concurrent_requests) | |
| timeout_config = aiohttp.ClientTimeout(total=self.timeout) | |
| async with aiohttp.ClientSession( | |
| headers=self.headers, connector=connector, timeout=timeout_config | |
| ) as session: | |
| # Create tasks for all files | |
| tasks = [] | |
| for file_path in file_paths: | |
| task = asyncio.create_task( | |
| self._fetch_file_with_retry( | |
| session, semaphore, repo_name, file_path, branch | |
| ) | |
| ) | |
| tasks.append(task) | |
| # Wait for all tasks to complete | |
| results = await asyncio.gather(*tasks, return_exceptions=True) | |
| # Process results | |
| documents = [] | |
| failed_files = [] | |
| for i, result in enumerate(results): | |
| file_path = file_paths[i] | |
| if isinstance(result, Exception): | |
| logger.error(f"Failed to load {file_path}: {result}") | |
| failed_files.append(file_path) | |
| elif result is None: | |
| logger.warning(f"No content returned for {file_path}") | |
| failed_files.append(file_path) | |
| else: | |
| documents.append(result) | |
| logger.info( | |
| f"Successfully loaded {len(documents)} files, failed: {len(failed_files)}" | |
| ) | |
| return documents, failed_files | |
| async def _fetch_file_with_retry( | |
| self, | |
| session: aiohttp.ClientSession, | |
| semaphore: asyncio.Semaphore, | |
| repo_name: str, | |
| file_path: str, | |
| branch: str, | |
| ) -> Optional[Document]: | |
| """Fetch a single file with retry logic.""" | |
| async with semaphore: | |
| for attempt in range(self.retries + 1): | |
| try: | |
| return await self._fetch_single_file( | |
| session, repo_name, file_path, branch | |
| ) | |
| except Exception as e: | |
| if attempt == self.retries: | |
| logger.error( | |
| f"Failed to fetch {file_path} after {self.retries + 1} attempts: {e}" | |
| ) | |
| raise | |
| else: | |
| logger.warning( | |
| f"Attempt {attempt + 1} failed for {file_path}: {e}" | |
| ) | |
| await asyncio.sleep(2**attempt) # Exponential backoff | |
| return None | |
| async def _fetch_single_file( | |
| self, | |
| session: aiohttp.ClientSession, | |
| repo_name: str, | |
| file_path: str, | |
| branch: str, | |
| ) -> Document: | |
| """Fetch a single file from GitHub API.""" | |
| # Clean file path | |
| clean_path = file_path.strip("/") | |
| # Build API URL | |
| api_url = f"https://api.github.com/repos/{repo_name}/contents/{clean_path}" | |
| params = {"ref": branch} | |
| logger.debug(f"Fetching: {api_url}") | |
| async with session.get(api_url, params=params) as response: | |
| if response.status == 404: | |
| raise FileNotFoundError(f"File not found: {file_path}") | |
| elif response.status == 403: | |
| raise PermissionError("API rate limit exceeded or access denied") | |
| elif response.status != 200: | |
| raise Exception(f"HTTP {response.status}: {await response.text()}") | |
| data = await response.json() | |
| # Handle directory case | |
| if isinstance(data, list): | |
| raise ValueError(f"Path {file_path} is a directory, not a file") | |
| # Decode file content | |
| if data.get("encoding") == "base64": | |
| try: | |
| content_bytes = base64.b64decode(data["content"]) | |
| content_text = content_bytes.decode("utf-8") | |
| except Exception as e: | |
| logger.warning(f"Failed to decode {file_path}: {e}") | |
| # Try to decode as latin-1 as fallback | |
| content_text = content_bytes.decode("latin-1", errors="ignore") | |
| else: | |
| raise ValueError(f"Unsupported encoding: {data.get('encoding')}") | |
| # Create Document | |
| document = self._create_document( | |
| content=content_text, | |
| file_path=clean_path, | |
| repo_name=repo_name, | |
| branch=branch, | |
| file_data=data, | |
| ) | |
| return document | |
| def _create_document( | |
| self, content: str, file_path: str, repo_name: str, branch: str, file_data: Dict | |
| ) -> Document: | |
| """Create a LlamaIndex Document from file content and metadata.""" | |
| # Extract file info | |
| filename = Path(file_path).name | |
| file_extension = Path(file_path).suffix.lower() | |
| directory = ( | |
| str(Path(file_path).parent) if Path(file_path).parent != Path(".") else "" | |
| ) | |
| # Build URLs | |
| html_url = f"https://github.com/{repo_name}/blob/{branch}/{file_path}" | |
| raw_url = file_data.get("download_url", "") | |
| # Create metadata | |
| metadata = { | |
| "file_path": file_path, | |
| "file_name": filename, | |
| "file_extension": file_extension, | |
| "directory": directory, | |
| "repo": repo_name, | |
| "branch": branch, | |
| "sha": file_data.get("sha", ""), | |
| "size": file_data.get("size", 0), | |
| "url": html_url, | |
| "raw_url": raw_url, | |
| "type": file_data.get("type", "file"), | |
| } | |
| # Create document with unique ID | |
| doc_id = f"{repo_name}:{branch}:{file_path}" | |
| document = Document( | |
| text=content, | |
| doc_id=doc_id, | |
| metadata=metadata, # For backward compatibility | |
| ) | |
| return document | |
| def load_files_sync( | |
| self, repo_name: str, file_paths: List[str], branch: str = "main" | |
| ) -> Tuple[List[Document], List[str]]: | |
| """ | |
| Synchronous wrapper for load_files. | |
| Args: | |
| repo_name: Repository name in format "owner/repo" | |
| file_paths: List of file paths to load | |
| branch: Branch name to load from | |
| Returns: | |
| Tuple of (successfully_loaded_documents, failed_file_paths) | |
| """ | |
| return asyncio.run(self.load_files(repo_name, file_paths, branch)) | |
| # Convenience functions | |
| async def load_github_files_async( | |
| repo_name: str, | |
| file_paths: List[str], | |
| branch: str = "main", | |
| github_token: Optional[str] = None, | |
| concurrent_requests: int = 10, | |
| ) -> Tuple[List[Document], List[str]]: | |
| """ | |
| Convenience function to load GitHub files asynchronously. | |
| Args: | |
| repo_name: Repository name in format "owner/repo" | |
| file_paths: List of file paths to load | |
| branch: Branch name to load from | |
| github_token: GitHub API token | |
| concurrent_requests: Number of concurrent requests | |
| Returns: | |
| Tuple of (documents, failed_files) | |
| """ | |
| loader = GithubFileLoader( | |
| github_token=github_token, concurrent_requests=concurrent_requests | |
| ) | |
| return await loader.load_files(repo_name, file_paths, branch) | |
| def load_github_files( | |
| repo_name: str, | |
| file_paths: List[str], | |
| branch: str = "main", | |
| github_token: Optional[str] = None, | |
| concurrent_requests: int = 10, | |
| ) -> Tuple[List[Document], List[str]]: | |
| """ | |
| Convenience function to load GitHub files synchronously. | |
| Args: | |
| repo_name: Repository name in format "owner/repo" | |
| file_paths: List of file paths to load | |
| branch: Branch name to load from | |
| github_token: GitHub API token | |
| concurrent_requests: Number of concurrent requests | |
| Returns: | |
| Tuple of (documents, failed_files) | |
| """ | |
| loader = GithubFileLoader( | |
| github_token=github_token, concurrent_requests=concurrent_requests | |
| ) | |
| return loader.load_files_sync(repo_name, file_paths, branch) | |
| def fetch_markdown_files( | |
| repo_url: str, github_token: Optional[str] = None, branch: str = "main" | |
| ) -> Tuple[List[str], str]: | |
| """ | |
| Convenience function to fetch markdown files from GitHub repository | |
| Args: | |
| repo_url: GitHub repository URL or owner/repo format | |
| github_token: GitHub API token for higher rate limits | |
| branch: Branch name to fetch from | |
| Returns: | |
| Tuple of (list_of_files, status_message) | |
| """ | |
| loader = GithubFileLoader(github_token=github_token) | |
| return loader.fetch_markdown_files(repo_url, branch) | |
| def fetch_repository_files( | |
| repo_url: str, | |
| file_extensions: List[str] = [".md", ".mdx"], | |
| github_token: Optional[str] = None, | |
| branch: str = "main", | |
| ) -> Tuple[List[str], str]: | |
| """ | |
| Convenience function to fetch files with specific extensions from GitHub repository | |
| Args: | |
| repo_url: GitHub repository URL or owner/repo format | |
| file_extensions: List of file extensions to filter | |
| github_token: GitHub API token for higher rate limits | |
| branch: Branch name to fetch from | |
| Returns: | |
| Tuple of (list_of_files, status_message) | |
| """ | |
| loader = GithubFileLoader(github_token=github_token) | |
| return loader.fetch_repository_files(repo_url, file_extensions, branch) | |
| # Example usage | |
| if __name__ == "__main__": | |
| # Example file paths | |
| file_paths = [ | |
| "docs/contribute/docs.mdx", | |
| "docs/contribute/ml-handlers.mdx", | |
| "docs/contribute/community.mdx", | |
| "docs/contribute/python-coding-standards.mdx", | |
| "docs/features/data-integrations.mdx", | |
| "docs/features/ai-integrations.mdx", | |
| "docs/integrations/ai-engines/langchain_embedding.mdx", | |
| "docs/integrations/ai-engines/langchain.mdx", | |
| "docs/integrations/ai-engines/google_gemini.mdx", | |
| "docs/integrations/ai-engines/anomaly.mdx", | |
| "docs/integrations/ai-engines/amazon-bedrock.mdx", | |
| ] | |
| # Load files synchronously | |
| documents, failed = load_github_files( | |
| repo_name="mindsdb/mindsdb", | |
| file_paths=file_paths, | |
| branch="main", # Optional | |
| ) | |
| print(f"Loaded {len(documents)} documents") | |
| print(f"Failed to load {len(failed)} files: {failed}") | |
| # Print first document info | |
| if documents: | |
| doc = documents[0] | |
| print("\nFirst document:") | |
| print(f"ID: {doc.doc_id}") | |
| print(f"File: {doc.metadata['file_path']}") | |
| print(f"Size: {len(doc.text)} characters") | |
| print(f"Content preview: {doc.text[:200]}...") | |