AfshinMA commited on
Commit
696bf9f
·
verified ·
1 Parent(s): 684e559

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +120 -120
app.py CHANGED
@@ -1,121 +1,121 @@
1
- # Import required libraries
2
- import os
3
- import keras
4
- import gradio as gr
5
- import numpy as np
6
- import pandas as pd
7
- from PIL import Image
8
-
9
- # Function to safely load the models
10
- def load_model_safely(path: str):
11
- if not os.path.isfile(path) or not path.endswith('.keras'):
12
- raise FileNotFoundError(f"The file '{path}' does not exist or is not a .keras file.")
13
- return keras.saving.load_model(path)
14
-
15
- # Retrieve the current directory and specify model paths
16
- current_dir = os.getcwd() # Ensure correct initial directory
17
- model_paths = {
18
- 'CNN': os.path.join(current_dir, 'Project_7_Traffic_Sign_Detection', 'models', 'cnn_model.keras'),
19
- 'VGG19': os.path.join(current_dir, 'Project_7_Traffic_Sign_Detection', 'models', 'vgg19_model.keras'),
20
- 'ResNet50': os.path.join(current_dir, 'Project_7_Traffic_Sign_Detection', 'models', 'resnet50_model.keras'),
21
- }
22
-
23
- # Load models and handle potential exceptions
24
- models = {}
25
- for name, path in model_paths.items():
26
- try:
27
- models[name] = load_model_safely(path)
28
- except Exception as e:
29
- print(f"Error loading model {name} from {path}: {str(e)}")
30
-
31
- # Define the class labels
32
- classes = { 0:'Speed limit (20km/h)',
33
- 1:'Speed limit (30km/h)',
34
- 2:'Speed limit (50km/h)',
35
- 3:'Speed limit (60km/h)',
36
- 4:'Speed limit (70km/h)',
37
- 5:'Speed limit (80km/h)',
38
- 6:'End of speed limit (80km/h)',
39
- 7:'Speed limit (100km/h)',
40
- 8:'Speed limit (120km/h)',
41
- 9:'No passing',
42
- 10:'No passing veh over 3.5 tons',
43
- 11:'Right-of-way at intersection',
44
- 12:'Priority road',
45
- 13:'Yield',
46
- 14:'Stop',
47
- 15:'No vehicles',
48
- 16:'Veh > 3.5 tons prohibited',
49
- 17:'No entry',
50
- 18:'General caution',
51
- 19:'Dangerous curve left',
52
- 20:'Dangerous curve right',
53
- 21:'Double curve',
54
- 22:'Bumpy road',
55
- 23:'Slippery road',
56
- 24:'Road narrows on the right',
57
- 25:'Road work',
58
- 26:'Traffic signals',
59
- 27:'Pedestrians',
60
- 28:'Children crossing',
61
- 29:'Bicycles crossing',
62
- 30:'Beware of ice/snow',
63
- 31:'Wild animals crossing',
64
- 32:'End speed + passing limits',
65
- 33:'Turn right ahead',
66
- 34:'Turn left ahead',
67
- 35:'Ahead only',
68
- 36:'Go straight or right',
69
- 37:'Go straight or left',
70
- 38:'Keep right',
71
- 39:'Keep left',
72
- 40:'Roundabout mandatory',
73
- 41:'End of no passing',
74
- 42:'End no passing veh > 3.5 tons' }
75
-
76
- # Function to import and resize example images
77
- def get_example_images(images_dir:str, size=(50, 50)) -> list:
78
- images = []
79
- image_list = os.listdir(images_dir)
80
- for image in image_list:
81
- if image.lower().endswith('.png'):
82
- image_path = os.path.join(images_dir, image)
83
- img = Image.open(image_path)
84
- img = img.resize(size)
85
- images.append(img)
86
- return images
87
-
88
- # Function to preprocess the image and predict the class
89
- def preprocess_and_predict(image: Image.Image, size=(50, 50)) -> pd.DataFrame:
90
- img_resized = image.resize(size)
91
- img_array = np.array(img_resized).astype(np.float32) / 255.0
92
- img_array = np.expand_dims(img_array, axis=0) # Shape (1, 50, 50, 3)
93
-
94
- predictions = []
95
- for name, model in models.items():
96
- predicted_class_index = np.argmax(model.predict(img_array), axis=-1)[0]
97
- predictions.append({'Model': name, 'Predicted Label': classes[predicted_class_index]})
98
-
99
- return pd.DataFrame(predictions)
100
-
101
- # Directory for example images
102
- images_dir = os.path.join(current_dir, 'Project_7_Traffic_Sign_Detection', 'images')
103
-
104
- # Check if the images directory exists
105
- if not os.path.exists(images_dir):
106
- print(f"The images directory does not exist: {images_dir}")
107
- else:
108
- example_images = get_example_images(images_dir, (50, 50))
109
-
110
- # Create Gradio interface
111
- iface = gr.Interface(
112
- fn=preprocess_and_predict,
113
- inputs=gr.Image(type='pil'), # Changed to 'pil' for direct use with PIL
114
- outputs="dataframe", # Correct the output type
115
- examples=example_images,
116
- title="Traffic Sign Recognition",
117
- description="Upload a traffic sign image or choose an example to get the recognition result."
118
- )
119
-
120
- # Launch the Gradio app
121
  iface.launch()
 
1
+ # Import required libraries
2
+ import os
3
+ import keras
4
+ import gradio as gr
5
+ import numpy as np
6
+ import pandas as pd
7
+ from PIL import Image
8
+
9
+ # Function to safely load the models
10
+ def load_model_safely(path: str):
11
+ if not os.path.isfile(path) or not path.endswith('.keras'):
12
+ raise FileNotFoundError(f"The file '{path}' does not exist or is not a .keras file.")
13
+ return keras.saving.load_model(path)
14
+
15
+ # Retrieve the current directory and specify model paths
16
+ current_dir = os.getcwd() # Ensure correct initial directory
17
+ model_paths = {
18
+ 'CNN': os.path.join(current_dir, 'Project_7_Traffic_Sign_Detection', 'models', 'cnn_model.keras'),
19
+ 'VGG19': os.path.join(current_dir, 'Project_7_Traffic_Sign_Detection', 'models', 'vgg19_model.keras'),
20
+ 'ResNet50': os.path.join(current_dir, 'Project_7_Traffic_Sign_Detection', 'models', 'resnet50_model.keras'),
21
+ }
22
+
23
+ # Load models and handle potential exceptions
24
+ models = {}
25
+ for name, path in model_paths.items():
26
+ try:
27
+ models[name] = load_model_safely(path)
28
+ except Exception as e:
29
+ print(f"Error loading model {name} from {path}: {str(e)}")
30
+
31
+ # Define the class labels
32
+ classes = { 0:'Speed limit (20km/h)',
33
+ 1:'Speed limit (30km/h)',
34
+ 2:'Speed limit (50km/h)',
35
+ 3:'Speed limit (60km/h)',
36
+ 4:'Speed limit (70km/h)',
37
+ 5:'Speed limit (80km/h)',
38
+ 6:'End of speed limit (80km/h)',
39
+ 7:'Speed limit (100km/h)',
40
+ 8:'Speed limit (120km/h)',
41
+ 9:'No passing',
42
+ 10:'No passing veh over 3.5 tons',
43
+ 11:'Right-of-way at intersection',
44
+ 12:'Priority road',
45
+ 13:'Yield',
46
+ 14:'Stop',
47
+ 15:'No vehicles',
48
+ 16:'Veh > 3.5 tons prohibited',
49
+ 17:'No entry',
50
+ 18:'General caution',
51
+ 19:'Dangerous curve left',
52
+ 20:'Dangerous curve right',
53
+ 21:'Double curve',
54
+ 22:'Bumpy road',
55
+ 23:'Slippery road',
56
+ 24:'Road narrows on the right',
57
+ 25:'Road work',
58
+ 26:'Traffic signals',
59
+ 27:'Pedestrians',
60
+ 28:'Children crossing',
61
+ 29:'Bicycles crossing',
62
+ 30:'Beware of ice/snow',
63
+ 31:'Wild animals crossing',
64
+ 32:'End speed + passing limits',
65
+ 33:'Turn right ahead',
66
+ 34:'Turn left ahead',
67
+ 35:'Ahead only',
68
+ 36:'Go straight or right',
69
+ 37:'Go straight or left',
70
+ 38:'Keep right',
71
+ 39:'Keep left',
72
+ 40:'Roundabout mandatory',
73
+ 41:'End of no passing',
74
+ 42:'End no passing veh > 3.5 tons' }
75
+
76
+ # Function to import and resize example images
77
+ def get_example_images(images_dir:str, size=(50, 50)) -> list:
78
+ images = []
79
+ image_list = os.listdir(images_dir)
80
+ for image in image_list:
81
+ if image.lower().endswith('.png'):
82
+ image_path = os.path.join(images_dir, image)
83
+ img = Image.open(image_path)
84
+ img = img.resize(size)
85
+ images.append(img)
86
+ return images
87
+
88
+ # Function to preprocess the image and predict the class
89
+ def preprocess_and_predict(image: Image.Image, size=(50, 50)) -> pd.DataFrame:
90
+ img_resized = image.resize(size)
91
+ img_array = np.array(img_resized).astype(np.float32) / 255.0
92
+ img_array = np.expand_dims(img_array, axis=0) # Shape (1, 50, 50, 3)
93
+
94
+ predictions = []
95
+ for name, model in models.items():
96
+ predicted_class_index = np.argmax(model.predict(img_array), axis=-1)[0]
97
+ predictions.append({'Model': name, 'Predicted Label': classes[predicted_class_index]})
98
+
99
+ return pd.DataFrame(predictions)
100
+
101
+ # Directory for example images
102
+ images_dir = os.path.join(current_dir, 'Project_7_Traffic_Sign_Detection', 'images')
103
+
104
+ # Check if the images directory exists
105
+ if not os.path.exists(images_dir):
106
+ print(f"The images directory does not exist: {images_dir}")
107
+ else:
108
+ examples = get_example_images(images_dir, (50, 50))
109
+
110
+ # Create Gradio interface
111
+ iface = gr.Interface(
112
+ fn=preprocess_and_predict,
113
+ inputs=gr.Image(type='pil'), # Changed to 'pil' for direct use with PIL
114
+ outputs="dataframe", # Correct the output type
115
+ examples=examples,
116
+ title="Traffic Sign Recognition",
117
+ description="Upload a traffic sign image or choose an example to get the recognition result."
118
+ )
119
+
120
+ # Launch the Gradio app
121
  iface.launch()