Spaces:
				
			
			
	
			
			
		Sleeping
		
	
	
	
			
			
	
	
	
	
		
		
		Sleeping
		
	
		Nirav Madhani
		
	commited on
		
		
					Commit 
							
							·
						
						0558e79
	
1
								Parent(s):
							
							ab273a0
								
Flash server
Browse files- Dockerfile +27 -3
- app.py +65 -0
- init_model.py +9 -0
- main.py +13 -6
- test_api.py +54 -0
    	
        Dockerfile
    CHANGED
    
    | @@ -1,11 +1,35 @@ | |
| 1 | 
            -
            FROM python:3.10
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 2 |  | 
| 3 | 
             
            WORKDIR /app
         | 
|  | |
|  | |
| 4 | 
             
            RUN git clone https://github.com/octo-models/octo.git
         | 
| 5 | 
             
            WORKDIR /app/octo
         | 
|  | |
|  | |
| 6 | 
             
            RUN pip3 install -e .
         | 
| 7 | 
             
            RUN pip3 install -r requirements.txt
         | 
| 8 | 
             
            RUN pip3 install --upgrade "jax[cuda11_pip]==0.4.20" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
         | 
| 9 | 
            -
            RUN  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 10 | 
             
            COPY main.py /app/octo
         | 
| 11 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            FROM python:3.10-slim
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            # Update package list and install git
         | 
| 4 | 
            +
            RUN apt-get update && \
         | 
| 5 | 
            +
                apt-get install -y git && \
         | 
| 6 | 
            +
                rm -rf /var/lib/apt/lists/*  # Clean up to reduce image size
         | 
| 7 |  | 
| 8 | 
             
            WORKDIR /app
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            # Clone the octo repository
         | 
| 11 | 
             
            RUN git clone https://github.com/octo-models/octo.git
         | 
| 12 | 
             
            WORKDIR /app/octo
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            # Install dependencies
         | 
| 15 | 
             
            RUN pip3 install -e .
         | 
| 16 | 
             
            RUN pip3 install -r requirements.txt
         | 
| 17 | 
             
            RUN pip3 install --upgrade "jax[cuda11_pip]==0.4.20" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
         | 
| 18 | 
            +
            RUN pip3 install scipy==1.10.1
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            # Install FastAPI and Uvicorn for the API
         | 
| 21 | 
            +
            RUN pip3 install fastapi uvicorn
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            # Copy and run the model initialization script to cache the model
         | 
| 24 | 
            +
            COPY init_model.py /app/octo
         | 
| 25 | 
            +
            RUN python init_model.py
         | 
| 26 | 
            +
             | 
| 27 | 
            +
            # Copy the original main.py and the API app.py
         | 
| 28 | 
             
            COPY main.py /app/octo
         | 
| 29 | 
            +
            COPY app.py /app/octo
         | 
| 30 | 
            +
             | 
| 31 | 
            +
            # Expose port 8000 for the API
         | 
| 32 | 
            +
            EXPOSE 8000
         | 
| 33 | 
            +
             | 
| 34 | 
            +
            # Run the API with Uvicorn
         | 
| 35 | 
            +
            CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "8000"]
         | 
    	
        app.py
    ADDED
    
    | @@ -0,0 +1,65 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from octo.model.octo_model import OctoModel
         | 
| 2 | 
            +
            from PIL import Image
         | 
| 3 | 
            +
            import numpy as np
         | 
| 4 | 
            +
            import jax
         | 
| 5 | 
            +
            from fastapi import FastAPI, HTTPException
         | 
| 6 | 
            +
            from pydantic import BaseModel
         | 
| 7 | 
            +
            import os
         | 
| 8 | 
            +
            import io
         | 
| 9 | 
            +
            import base64
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            # Set JAX to use CPU platform (adjust if GPU is needed)
         | 
| 12 | 
            +
            os.environ['JAX_PLATFORMS'] = 'cpu'
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            # Load the model once globally (assumes it's cached locally)
         | 
| 15 | 
            +
            model = OctoModel.load_pretrained("hf://rail-berkeley/octo-small-1.5")
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            # Initialize FastAPI app
         | 
| 18 | 
            +
            app = FastAPI(title="Octo Model Inference API")
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            # Define request body model
         | 
| 21 | 
            +
            class InferenceRequest(BaseModel):
         | 
| 22 | 
            +
                image_base64: str  # Base64-encoded image string
         | 
| 23 | 
            +
                task: str = "pick up the fork"  # Default task
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            # Health check endpoint
         | 
| 26 | 
            +
            @app.get("/health")
         | 
| 27 | 
            +
            async def health_check():
         | 
| 28 | 
            +
                return {"status": "healthy"}
         | 
| 29 | 
            +
             | 
| 30 | 
            +
            # Inference endpoint
         | 
| 31 | 
            +
            @app.post("/predict")
         | 
| 32 | 
            +
            async def predict(request: InferenceRequest):
         | 
| 33 | 
            +
                try:
         | 
| 34 | 
            +
                    # Decode base64 image
         | 
| 35 | 
            +
                    img_base64 = request.image_base64
         | 
| 36 | 
            +
                    if img_base64.startswith("data:image"):
         | 
| 37 | 
            +
                        img_base64 = img_base64.split(",")[1]
         | 
| 38 | 
            +
                    
         | 
| 39 | 
            +
                    img_data = base64.b64decode(img_base64)
         | 
| 40 | 
            +
                    img = Image.open(io.BytesIO(img_data)).resize((256, 256))
         | 
| 41 | 
            +
                    img = np.array(img)
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                    # Add batch and time horizon dimensions
         | 
| 44 | 
            +
                    img = img[np.newaxis, np.newaxis, ...]  # Shape: (1, 1, 256, 256, 3)
         | 
| 45 | 
            +
                    observation = {
         | 
| 46 | 
            +
                        "image_primary": img,
         | 
| 47 | 
            +
                        "timestep_pad_mask": np.array([[True]])
         | 
| 48 | 
            +
                    }
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                    # Create task and predict actions
         | 
| 51 | 
            +
                    task_obj = model.create_tasks(texts=[request.task])
         | 
| 52 | 
            +
                    actions = model.sample_actions(
         | 
| 53 | 
            +
                        observation, 
         | 
| 54 | 
            +
                        task_obj, 
         | 
| 55 | 
            +
                        unnormalization_statistics=model.dataset_statistics["bridge_dataset"]["action"], 
         | 
| 56 | 
            +
                        rng=jax.random.PRNGKey(0)
         | 
| 57 | 
            +
                    )
         | 
| 58 | 
            +
                    actions = actions[0]
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                    # Convert NumPy array to list for JSON response
         | 
| 61 | 
            +
                    actions_list = actions.tolist()
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                    return {"actions": actions_list}
         | 
| 64 | 
            +
                except Exception as e:
         | 
| 65 | 
            +
                    raise HTTPException(status_code=500, detail=f"Error processing request: {str(e)}")
         | 
    	
        init_model.py
    ADDED
    
    | @@ -0,0 +1,9 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from octo.model.octo_model import OctoModel
         | 
| 2 | 
            +
            import os
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            # Set JAX to CPU (consistent with your setup)
         | 
| 5 | 
            +
            os.environ['JAX_PLATFORMS'] = 'cpu'
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            # Load the model to cache it
         | 
| 8 | 
            +
            model = OctoModel.load_pretrained("hf://rail-berkeley/octo-small-1.5")
         | 
| 9 | 
            +
            print("Model downloaded and cached successfully.")
         | 
    	
        main.py
    CHANGED
    
    | @@ -17,10 +17,17 @@ img = np.array(Image.open(requests.get(IMAGE_URL, stream=True).raw).resize((256, | |
| 17 | 
             
            img = img[np.newaxis,np.newaxis,...]
         | 
| 18 | 
             
            observation = {"image_primary": img, "timestep_pad_mask": np.array([[True]])}
         | 
| 19 | 
             
            task = model.create_tasks(texts=["pick up the fork"])
         | 
| 20 | 
            -
             | 
| 21 | 
            -
             | 
| 22 | 
            -
             | 
| 23 | 
            -
                 | 
| 24 | 
            -
                 | 
| 25 | 
             
            )
         | 
| 26 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 17 | 
             
            img = img[np.newaxis,np.newaxis,...]
         | 
| 18 | 
             
            observation = {"image_primary": img, "timestep_pad_mask": np.array([[True]])}
         | 
| 19 | 
             
            task = model.create_tasks(texts=["pick up the fork"])
         | 
| 20 | 
            +
            norm_actions = model.sample_actions(observation, task, rng=jax.random.PRNGKey(0))
         | 
| 21 | 
            +
            norm_actions = norm_actions[0]   # remove batch
         | 
| 22 | 
            +
            actions = (
         | 
| 23 | 
            +
                norm_actions * model.dataset_statistics["bridge_dataset"]['action']['std']
         | 
| 24 | 
            +
                + model.dataset_statistics["bridge_dataset"]['action']['mean']
         | 
| 25 | 
             
            )
         | 
| 26 | 
            +
            actions = np.concatenate(
         | 
| 27 | 
            +
                    (
         | 
| 28 | 
            +
                        steps[step+1]['action']['world_vector'],
         | 
| 29 | 
            +
                        steps[step+1]['action']['rotation_delta'],
         | 
| 30 | 
            +
                        np.array(steps[step+1]['action']['open_gripper']).astype(np.float32)[None]
         | 
| 31 | 
            +
                    ), axis=-1
         | 
| 32 | 
            +
                )
         | 
| 33 | 
            +
            print(actions) 
         | 
    	
        test_api.py
    ADDED
    
    | @@ -0,0 +1,54 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import requests
         | 
| 2 | 
            +
            import base64
         | 
| 3 | 
            +
            from PIL import Image
         | 
| 4 | 
            +
            import io
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            # API endpoint URL (adjust if running on a different host/port)
         | 
| 7 | 
            +
            API_URL = "http://localhost:8000/predict"
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            # Example image URL from main.py
         | 
| 10 | 
            +
            IMAGE_URL = "https://rail.eecs.berkeley.edu/datasets/bridge_release/raw/bridge_data_v2/datacol2_toykitchen7/drawer_pnp/01/2023-04-19_09-18-15/raw/traj_group0/traj0/images0/im_12.jpg"
         | 
| 11 | 
            +
            TASK_TEXT = "pick up the fork"
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            def test_api(image_url=IMAGE_URL, task=TASK_TEXT):
         | 
| 14 | 
            +
                try:
         | 
| 15 | 
            +
                    # Download image from URL
         | 
| 16 | 
            +
                    response = requests.get(image_url, stream=True)
         | 
| 17 | 
            +
                    response.raise_for_status()  # Check for HTTP errors
         | 
| 18 | 
            +
                    img = Image.open(response.raw).resize((256, 256))
         | 
| 19 | 
            +
             | 
| 20 | 
            +
                    # Convert image to base64
         | 
| 21 | 
            +
                    img_byte_arr = io.BytesIO()
         | 
| 22 | 
            +
                    img.save(img_byte_arr, format="JPEG")  # Save as JPEG (adjust if needed)
         | 
| 23 | 
            +
                    img_byte_arr = img_byte_arr.getvalue()
         | 
| 24 | 
            +
                    base64_string = base64.b64encode(img_byte_arr).decode("utf-8")
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                    # Prepare payload for API
         | 
| 27 | 
            +
                    payload = {
         | 
| 28 | 
            +
                        "image_base64": base64_string,
         | 
| 29 | 
            +
                        "task": task
         | 
| 30 | 
            +
                    }
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                    # Send POST request to API
         | 
| 33 | 
            +
                    api_response = requests.post(API_URL, json=payload)
         | 
| 34 | 
            +
                    api_response.raise_for_status()  # Check for API errors
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                    # Print the result
         | 
| 37 | 
            +
                    result = api_response.json()
         | 
| 38 | 
            +
                    print(f"Task: {task}")
         | 
| 39 | 
            +
                    print(f"Image URL: {image_url}")
         | 
| 40 | 
            +
                    print(f"Predicted Actions: {result['actions']}")
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                except requests.exceptions.RequestException as e:
         | 
| 43 | 
            +
                    print(f"Error fetching image or calling API: {e}")
         | 
| 44 | 
            +
                except Exception as e:
         | 
| 45 | 
            +
                    print(f"Unexpected error: {e}")
         | 
| 46 | 
            +
             | 
| 47 | 
            +
            if __name__ == "__main__":
         | 
| 48 | 
            +
                # Test with default values (same as main.py)
         | 
| 49 | 
            +
                test_api()
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                # Test with a different URL and task (optional)
         | 
| 52 | 
            +
                # Replace with another valid URL if desired    
         | 
| 53 | 
            +
                print("\nTesting with another URL and task:")
         | 
| 54 | 
            +
                test_api(IMAGE_URL, TASK_TEXT)
         | 
