Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from transformers import pipeline | |
| from transformers import AutoFeatureExtractor, AutoModelForImageClassification | |
| import openpyxl | |
| #Function to predict the food from the image using the pre-trained model "nateraw/food" | |
| def predict(image): | |
| extractor = AutoFeatureExtractor.from_pretrained("nateraw/food") | |
| model = AutoModelForImageClassification.from_pretrained("nateraw/food") | |
| input = extractor(images=image, return_tensors='pt') | |
| output = model(**input) | |
| logits = output.logits | |
| pred_class = logits.argmax(-1).item() | |
| return(model.config.id2label[pred_class]) | |
| #Function to retrieve the Nutritional Value from database.xlsx which is downloaded from USDA | |
| def check_food(food, counter): | |
| path = './database.xlsx' | |
| wb_obj = openpyxl.load_workbook(path) | |
| sheet_obj = wb_obj.active | |
| foodPred, cal, carb, prot, fat = None, None, None, None, None | |
| #Filter to prioritize the most probable match between the prediction and the entries in the database | |
| for i in range(3, sheet_obj.max_row+1): | |
| cell_obj = sheet_obj.cell(row = i, column = 2) | |
| if counter == 0: | |
| if len(food) >= 3: | |
| foodName = food[0].capitalize() + " " + food[1] + " " + food[2] + "," | |
| elif len(food) == 2: | |
| foodName = food[0].capitalize() + " " + food[1] + "," | |
| elif len(food) == 1: | |
| foodName = food[0].capitalize() + "," | |
| condition = foodName == cell_obj.value[0:len(foodName):] | |
| elif counter == 1: | |
| if len(food) >= 3: | |
| foodName = food[0].capitalize() + " " + food[1] + " " + food[2] | |
| elif len(food) == 2: | |
| foodName = food[0].capitalize() + " " + food[1] | |
| elif len(food) == 1: | |
| foodName = food[0].capitalize() | |
| condition = foodName == cell_obj.value[0:len(foodName):] | |
| elif counter == 2: | |
| if len(food) >= 3: | |
| foodName = food[0] + " " + food[1] + " " + food[2] | |
| elif len(food) == 2: | |
| foodName = food[0] + " " + food[1] | |
| elif len(food) == 1: | |
| foodName = food[0] | |
| condition = foodName in cell_obj.value | |
| elif (counter == 3) & (len(food) > 1): | |
| condition = food[0].capitalize() == cell_obj.value[0:len(food[0]):] | |
| elif (counter == 4) & (len(food) > 1): | |
| condition = food[0] in cell_obj.value | |
| else: | |
| break | |
| #Update values if conditions are met | |
| if condition: | |
| foodPred = cell_obj.value | |
| cal = sheet_obj.cell(row = i, column = 5).value | |
| carb = sheet_obj.cell(row = i, column = 7).value | |
| prot = sheet_obj.cell(row = i, column = 6).value | |
| fat = sheet_obj.cell(row = i, column = 10).value | |
| break | |
| return foodPred, cal, carb, prot, fat | |
| #Function to prepare the output | |
| def get_cc(food, weight): | |
| #Configure the food string to match the entries in the database | |
| food = food.split("_") | |
| if food[-1][-1] == "s": | |
| food[-1] = food[-1][:-1] | |
| foodPred, cal, carb, prot, fat = None, None, None, None, None | |
| counter = 0 | |
| #Try for the most probable match between the prediction and the entries in the database | |
| while (not foodPred) & (counter <= 4): | |
| foodPred, cal, carb, prot, fat = check_food(food,counter) | |
| counter += 1 | |
| #Check if there is a match | |
| if food: | |
| output = foodPred + "\nCalories: " + str(round(cal * weight)/100) + " kJ\nCarbohydrate: " + str(round(carb * weight)/100) + " g\nProtein: " + str(round(prot * weight)/100) + " g\nTotal Fat: " + str(round(fat * weight)/100) + " g" | |
| elif not food: | |
| output = "No data for food" | |
| return(output) | |
| #Main function | |
| def CC(image, weight): | |
| pred = predict(image) | |
| cc = get_cc(pred, weight) | |
| return(pred, cc) | |
| interface = gr.Interface( | |
| fn = CC, | |
| inputs = [gr.inputs.Image(shape=(224,224)), gr.inputs.Number(default = 100, label = "Weight in grams (g):")], | |
| outputs = [gr.outputs.Textbox(label='Food Prediction:'), gr.outputs.Textbox(label='Nutritional Value:')], | |
| examples = [["pizza.jpg", 107], ["spaghetti.jpg",205]]) | |
| interface.launch() |