Spaces:
Runtime error
Runtime error
| import os | |
| from pathlib import Path | |
| import json | |
| from flask import Flask, request, jsonify, g | |
| from flask_expects_json import expects_json | |
| from flask_cors import CORS | |
| from PIL import Image | |
| from huggingface_hub import Repository | |
| from flask_apscheduler import APScheduler | |
| import shutil | |
| import sqlite3 | |
| import subprocess | |
| from jsonschema import ValidationError | |
| MODE = os.environ.get("FLASK_ENV", "production") | |
| IS_DEV = MODE == "development" | |
| app = Flask(__name__, static_url_path="/static") | |
| app.config["JSONIFY_PRETTYPRINT_REGULAR"] = False | |
| schema = { | |
| "type": "object", | |
| "properties": { | |
| "prompt": {"type": "string"}, | |
| "images": { | |
| "type": "array", | |
| "items": { | |
| "type": "object", | |
| "minProperties": 2, | |
| "maxProperties": 2, | |
| "properties": { | |
| "colors": { | |
| "type": "array", | |
| "items": {"type": "string"}, | |
| "maxItems": 5, | |
| "minItems": 5, | |
| }, | |
| "imgURL": {"type": "string"}, | |
| }, | |
| }, | |
| }, | |
| }, | |
| "minProperties": 2, | |
| "maxProperties": 2, | |
| } | |
| CORS(app) | |
| DB_FILE = Path("./data.db") | |
| TOKEN = os.environ.get("HUGGING_FACE_HUB_TOKEN") | |
| repo = Repository( | |
| local_dir="data", | |
| repo_type="dataset", | |
| clone_from="huggingface-projects/color-palettes-sd", | |
| use_auth_token=TOKEN, | |
| ) | |
| repo.git_pull() | |
| # copy db on db to local path | |
| shutil.copyfile("./data/data.db", DB_FILE) | |
| db = sqlite3.connect(DB_FILE) | |
| try: | |
| data = db.execute("SELECT * FROM palettes").fetchall() | |
| if IS_DEV: | |
| print(f"Loaded {len(data)} palettes from local db") | |
| db.close() | |
| except sqlite3.OperationalError: | |
| db.execute( | |
| "CREATE TABLE palettes (id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, data json, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL)" | |
| ) | |
| db.commit() | |
| def get_db(): | |
| db = getattr(g, "_database", None) | |
| if db is None: | |
| db = g._database = sqlite3.connect(DB_FILE) | |
| db.row_factory = sqlite3.Row | |
| return db | |
| def close_connection(exception): | |
| db = getattr(g, "_database", None) | |
| if db is not None: | |
| db.close() | |
| def update_repository(): | |
| repo.git_pull() | |
| # copy db on db to local path | |
| shutil.copyfile(DB_FILE, "./data/data.db") | |
| with sqlite3.connect("./data/data.db") as db: | |
| db.row_factory = sqlite3.Row | |
| palettes = db.execute("SELECT * FROM palettes").fetchall() | |
| data = [ | |
| { | |
| "id": row["id"], | |
| "data": json.loads(row["data"]), | |
| "created_at": row["created_at"], | |
| } | |
| for row in palettes | |
| ] | |
| with open("./data/data.json", "w") as f: | |
| json.dump(data, f, separators=(",", ":")) | |
| print("Updating repository") | |
| subprocess.Popen( | |
| "git add . && git commit --amend -m 'update' && git push --force", | |
| cwd="./data", | |
| shell=True, | |
| ) | |
| repo.push_to_hub(blocking=False) | |
| def index(): | |
| return app.send_static_file("index.html") | |
| def push(): | |
| if request.headers["token"] == TOKEN: | |
| update_repository() | |
| return jsonify({"success": True}) | |
| else: | |
| return "Error", 401 | |
| def getAllData(): | |
| palettes = get_db().execute("SELECT * FROM palettes").fetchall() | |
| data = [ | |
| { | |
| "id": row["id"], | |
| "data": json.loads(row["data"]), | |
| "created_at": row["created_at"], | |
| } | |
| for row in palettes | |
| ] | |
| return data | |
| def getdata(): | |
| return jsonify(getAllData()) | |
| def create(): | |
| data = g.data | |
| db = get_db() | |
| cursor = db.cursor() | |
| cursor.execute("INSERT INTO palettes(data) VALUES (?)", [json.dumps(data)]) | |
| db.commit() | |
| return jsonify(getAllData()) | |
| def bad_request(error): | |
| if isinstance(error.description, ValidationError): | |
| original_error = error.description | |
| return jsonify({"error": original_error.message}), 400 | |
| return error | |
| if __name__ == "__main__": | |
| if not IS_DEV: | |
| print("Starting scheduler -- Running Production") | |
| scheduler = APScheduler() | |
| scheduler.add_job( | |
| id="Update Dataset Repository", | |
| func=update_repository, | |
| trigger="interval", | |
| hours=1, | |
| ) | |
| scheduler.start() | |
| else: | |
| print("Not Starting scheduler -- Running Development") | |
| app.run( | |
| host="0.0.0.0", | |
| port=int(os.environ.get("PORT", 7860)), | |
| debug=True, | |
| use_reloader=IS_DEV, | |
| ) | |