diff --git a/WSGI.py b/WSGI.py index 0208c97..6e080f8 100644 --- a/WSGI.py +++ b/WSGI.py @@ -3,13 +3,15 @@ import gevent.monkey gevent.monkey.patch_all() -from nyaa import app +from nyaa import create_app -if app.config["DEBUG"]: - from werkzeug.debug import DebuggedApplication - app.wsgi_app = DebuggedApplication(app.wsgi_app, True) +app = create_app('config') + +if app.config['DEBUG']: + from werkzeug.debug import DebuggedApplication + app.wsgi_app = DebuggedApplication(app.wsgi_app, True) if __name__ == '__main__': - import gevent.pywsgi - gevent_server = gevent.pywsgi.WSGIServer(("localhost", 5000), app.wsgi_app) - gevent_server.serve_forever() + import gevent.pywsgi + gevent_server = gevent.pywsgi.WSGIServer(("localhost", 5000), app.wsgi_app) + gevent_server.serve_forever() diff --git a/db_create.py b/db_create.py index 2b7f7ee..30fe4fe 100755 --- a/db_create.py +++ b/db_create.py @@ -1,9 +1,11 @@ #!/usr/bin/env python3 import sqlalchemy -from nyaa import app, models +from nyaa import create_app, models from nyaa.extensions import db +app = create_app('config') + NYAA_CATEGORIES = [ ('Anime', ['Anime Music Video', 'English-translated', 'Non-English-translated', 'Raw']), ('Audio', ['Lossless', 'Lossy']), diff --git a/db_migrate.py b/db_migrate.py index 127f377..92789f6 100755 --- a/db_migrate.py +++ b/db_migrate.py @@ -5,9 +5,10 @@ import sys from flask_script import Manager from flask_migrate import Migrate, MigrateCommand -from nyaa import app +from nyaa import create_app from nyaa.extensions import db +app = create_app('config') migrate = Migrate(app, db) manager = Manager(app) diff --git a/import_to_es.py b/import_to_es.py index f74420a..18c6b31 100755 --- a/import_to_es.py +++ b/import_to_es.py @@ -14,9 +14,10 @@ from elasticsearch import Elasticsearch from elasticsearch.client import IndicesClient from elasticsearch import helpers -from nyaa import app, models +from nyaa import create_app, models from nyaa.extensions import db +app = create_app('config') es = Elasticsearch(timeout=30) ic = IndicesClient(es) diff --git a/nyaa/__init__.py b/nyaa/__init__.py index f9f0a03..492c29b 100644 --- a/nyaa/__init__.py +++ b/nyaa/__init__.py @@ -4,69 +4,81 @@ import os import flask from flask_assets import Bundle # noqa F401 +from nyaa.api_handler import api_blueprint from nyaa.extensions import assets, db, fix_paginate, toolbar +from nyaa.template_utils import bp as template_utils_bp +from nyaa.views import register_views -app = flask.Flask(__name__) -app.config.from_object('config') -# Don't refresh cookie each request -app.config['SESSION_REFRESH_EACH_REQUEST'] = False +def create_app(config): + """ Nyaa app factory """ + app = flask.Flask(__name__) + app.config.from_object(config) -# Debugging -if app.config['DEBUG']: - app.config['DEBUG_TB_INTERCEPT_REDIRECTS'] = False - toolbar.init_app(app) - app.logger.setLevel(logging.DEBUG) + # Don't refresh cookie each request + app.config['SESSION_REFRESH_EACH_REQUEST'] = False - # Forbid caching - @app.after_request - def forbid_cache(request): - request.headers['Cache-Control'] = 'no-cache, no-store, must-revalidate, max-age=0' - request.headers['Pragma'] = 'no-cache' - request.headers['Expires'] = '0' - return request + # Debugging + if app.config['DEBUG']: + app.config['DEBUG_TB_INTERCEPT_REDIRECTS'] = False + toolbar.init_app(app) + app.logger.setLevel(logging.DEBUG) -else: - app.logger.setLevel(logging.WARNING) + # Forbid caching + @app.after_request + def forbid_cache(request): + request.headers['Cache-Control'] = 'no-cache, no-store, must-revalidate, max-age=0' + request.headers['Pragma'] = 'no-cache' + request.headers['Expires'] = '0' + return request -# Logging -if 'LOG_FILE' in app.config: - from logging.handlers import RotatingFileHandler - app.log_handler = RotatingFileHandler( - app.config['LOG_FILE'], maxBytes=10000, backupCount=1) - app.logger.addHandler(app.log_handler) + else: + app.logger.setLevel(logging.WARNING) -# Log errors and display a message to the user in production mdode -if not app.config['DEBUG']: - @app.errorhandler(500) - def internal_error(exception): - app.logger.error(exception) - flask.flash(flask.Markup( - 'An error occurred! Debug information has been logged.'), 'danger') - return flask.redirect('/') + # Logging + if 'LOG_FILE' in app.config: + from logging.handlers import RotatingFileHandler + app.log_handler = RotatingFileHandler( + app.config['LOG_FILE'], maxBytes=10000, backupCount=1) + app.logger.addHandler(app.log_handler) -# Get git commit hash -app.config['COMMIT_HASH'] = None -master_head = os.path.abspath(os.path.join(os.path.dirname(__file__), '../.git/refs/heads/master')) -if os.path.isfile(master_head): - with open(master_head, 'r') as head: - app.config['COMMIT_HASH'] = head.readline().strip() + # Log errors and display a message to the user in production mdode + if not app.config['DEBUG']: + @app.errorhandler(500) + def internal_error(exception): + app.logger.error(exception) + flask.flash(flask.Markup( + 'An error occurred! Debug information has been logged.'), 'danger') + return flask.redirect('/') -# Enable the jinja2 do extension. -app.jinja_env.add_extension('jinja2.ext.do') -app.jinja_env.lstrip_blocks = True -app.jinja_env.trim_blocks = True + # Get git commit hash + app.config['COMMIT_HASH'] = None + master_head = os.path.abspath(os.path.join(os.path.dirname(__file__), + '../.git/refs/heads/master')) + if os.path.isfile(master_head): + with open(master_head, 'r') as head: + app.config['COMMIT_HASH'] = head.readline().strip() -# Database -fix_paginate() # This has to be before the database is initialized -app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False -app.config['MYSQL_DATABASE_CHARSET'] = 'utf8mb4' -db.init_app(app) + # Enable the jinja2 do extension. + app.jinja_env.add_extension('jinja2.ext.do') + app.jinja_env.lstrip_blocks = True + app.jinja_env.trim_blocks = True -assets.init_app(app) + # Database + fix_paginate() # This has to be before the database is initialized + app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False + app.config['MYSQL_DATABASE_CHARSET'] = 'utf8mb4' + db.init_app(app) -# css = Bundle('style.scss', filters='libsass', -# output='style.css', depends='**/*.scss') -# assets.register('style_all', css) + # Assets + assets.init_app(app) + # css = Bundle('style.scss', filters='libsass', + # output='style.css', depends='**/*.scss') + # assets.register('style_all', css) -from nyaa import routes # noqa E402 isort:skip + # Blueprints + app.register_blueprint(template_utils_bp) + app.register_blueprint(api_blueprint, url_prefix='/api') + register_views(app) + + return app diff --git a/nyaa/backend.py b/nyaa/backend.py index fc8f993..0c20ecf 100644 --- a/nyaa/backend.py +++ b/nyaa/backend.py @@ -7,9 +7,11 @@ from werkzeug import secure_filename from orderedset import OrderedSet -from nyaa import app, models, utils +from nyaa import models, utils from nyaa.extensions import db +app = flask.current_app + @utils.cached_function def get_category_id_map(): diff --git a/nyaa/extensions.py b/nyaa/extensions.py index f1c509d..7abc26e 100644 --- a/nyaa/extensions.py +++ b/nyaa/extensions.py @@ -1,4 +1,7 @@ +import os.path + from flask import abort +from flask.config import Config from flask_assets import Environment from flask_debugtoolbar import DebugToolbarExtension from flask_sqlalchemy import BaseQuery, Pagination, SQLAlchemy @@ -32,3 +35,16 @@ def fix_paginate(): return Pagination(self, page, per_page, total_query_count, items) BaseQuery.paginate_faste = paginate_faste + + +def _get_config(): + # Workaround to get an available config object before the app is initiallized + # Only needed/used in top-level and class statements + # https://stackoverflow.com/a/18138250/7597273 + root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) + config = Config(root_path) + config.from_object('config') + return config + + +config = _get_config() diff --git a/nyaa/forms.py b/nyaa/forms.py index 37e4909..528f303 100644 --- a/nyaa/forms.py +++ b/nyaa/forms.py @@ -14,9 +14,12 @@ from wtforms.validators import (DataRequired, Email, EqualTo, Length, Optional, from wtforms.widgets import Select as SelectWidget # For DisabledSelectField from wtforms.widgets import HTMLString, html_params # For DisabledSelectField -from nyaa import app, bencode, models, utils +from nyaa import bencode, models, utils +from nyaa.extensions import config from nyaa.models import User +app = flask.current_app + class Unique(object): @@ -81,7 +84,7 @@ class RegisterForm(FlaskForm): password_confirm = PasswordField('Password (confirm)') - if app.config['USE_RECAPTCHA']: + if config['USE_RECAPTCHA']: recaptcha = RecaptchaField() @@ -195,7 +198,7 @@ class UploadForm(FlaskForm): '%(max)d at most.') ]) - if app.config['USE_RECAPTCHA']: + if config['USE_RECAPTCHA']: # Captcha only for not logged in users _recaptcha_validator = RecaptchaValidator() diff --git a/nyaa/models.py b/nyaa/models.py index 4b236f6..1f26cbc 100644 --- a/nyaa/models.py +++ b/nyaa/models.py @@ -15,11 +15,12 @@ from sqlalchemy.ext import declarative from sqlalchemy_fulltext import FullText from sqlalchemy_utils import ChoiceType, EmailType, PasswordType -from nyaa import app -from nyaa.extensions import db +from nyaa.extensions import config, db from nyaa.torrents import create_magnet -if app.config['USE_MYSQL']: +app = flask.current_app + +if config['USE_MYSQL']: from sqlalchemy.dialects import mysql BinaryType = mysql.BINARY DescriptionTextType = mysql.TEXT @@ -686,7 +687,7 @@ class SukebeiTorrent(TorrentBase, db.Model): # Fulltext models for MySQL -if app.config['USE_MYSQL']: +if config['USE_MYSQL']: class NyaaTorrentNameSearch(FullText, NyaaTorrent): __fulltext_columns__ = ('display_name',) __table_args__ = {'extend_existing': True} @@ -785,7 +786,7 @@ class SukebeiReport(ReportBase, db.Model): # Choose our defaults for models.Torrent etc -if app.config['SITE_FLAVOR'] == 'nyaa': +if config['SITE_FLAVOR'] == 'nyaa': Torrent = NyaaTorrent TorrentFilelist = NyaaTorrentFilelist TorrentInfo = NyaaTorrentInfo @@ -798,7 +799,7 @@ if app.config['SITE_FLAVOR'] == 'nyaa': Report = NyaaReport TorrentNameSearch = NyaaTorrentNameSearch -elif app.config['SITE_FLAVOR'] == 'sukebei': +elif config['SITE_FLAVOR'] == 'sukebei': Torrent = SukebeiTorrent TorrentFilelist = SukebeiTorrentFilelist TorrentInfo = SukebeiTorrentInfo diff --git a/nyaa/routes.py b/nyaa/routes.py deleted file mode 100644 index a67ac28..0000000 --- a/nyaa/routes.py +++ /dev/null @@ -1,9 +0,0 @@ -from nyaa import app, template_utils, views -from nyaa.api_handler import api_blueprint - -# Register all template filters and template globals -app.register_blueprint(template_utils.bp) -# Register the API routes -app.register_blueprint(api_blueprint, url_prefix='/api') -# Register the site's routes -views.register(app) diff --git a/nyaa/search.py b/nyaa/search.py index 1f20a36..5661973 100644 --- a/nyaa/search.py +++ b/nyaa/search.py @@ -10,9 +10,11 @@ from elasticsearch import Elasticsearch from elasticsearch_dsl import Q, Search from sqlalchemy_fulltext import FullTextSearch -from nyaa import app, models +from nyaa import models from nyaa.extensions import db +app = flask.current_app + DEFAULT_MAX_SEARCH_RESULT = 1000 DEFAULT_PER_PAGE = 75 SERACH_PAGINATE_DISPLAY_MSG = ('Displaying results {start}-{end} out of {total} results.
\n' diff --git a/nyaa/template_utils.py b/nyaa/template_utils.py index 91b7b17..54c4de8 100644 --- a/nyaa/template_utils.py +++ b/nyaa/template_utils.py @@ -7,10 +7,10 @@ from urllib.parse import urlencode import flask from werkzeug.urls import url_encode -from nyaa import app from nyaa.backend import get_category_id_map from nyaa.torrents import get_default_trackers +app = flask.current_app bp = flask.Blueprint('template-utils', __name__) _static_cache = {} # For static_cachebuster diff --git a/nyaa/torrents.py b/nyaa/torrents.py index 17862bc..a088c6c 100644 --- a/nyaa/torrents.py +++ b/nyaa/torrents.py @@ -3,9 +3,11 @@ import os import time from urllib.parse import urlencode +from flask import current_app as app + from orderedset import OrderedSet -from nyaa import app, bencode +from nyaa import bencode USED_TRACKERS = OrderedSet() diff --git a/nyaa/views/__init__.py b/nyaa/views/__init__.py index 03f4945..ae58c99 100644 --- a/nyaa/views/__init__.py +++ b/nyaa/views/__init__.py @@ -8,7 +8,7 @@ from nyaa.views import ( # isort:skip ) -def register(flask_app): +def register_views(flask_app): """ Register the blueprints using the flask_app object """ flask_app.register_blueprint(account.bp) flask_app.register_blueprint(admin.bp) diff --git a/nyaa/views/account.py b/nyaa/views/account.py index a1f1835..ce7922c 100644 --- a/nyaa/views/account.py +++ b/nyaa/views/account.py @@ -6,10 +6,11 @@ from ipaddress import ip_address import flask -from nyaa import app, forms, models +from nyaa import forms, models from nyaa.extensions import db from nyaa.views.users import get_activation_link +app = flask.current_app bp = flask.Blueprint('account', __name__) diff --git a/nyaa/views/main.py b/nyaa/views/main.py index f687ef5..f9c54e0 100644 --- a/nyaa/views/main.py +++ b/nyaa/views/main.py @@ -5,12 +5,13 @@ from datetime import datetime, timedelta import flask from flask_paginate import Pagination -from nyaa import app, models +from nyaa import models from nyaa.search import (DEFAULT_MAX_SEARCH_RESULT, DEFAULT_PER_PAGE, SERACH_PAGINATE_DISPLAY_MSG, _generate_query_string, search_db, search_elastic) from nyaa.utils import chain_get from nyaa.views.account import logout +app = flask.current_app bp = flask.Blueprint('main', __name__) diff --git a/nyaa/views/torrents.py b/nyaa/views/torrents.py index 3e93909..227838a 100644 --- a/nyaa/views/torrents.py +++ b/nyaa/views/torrents.py @@ -7,10 +7,11 @@ from werkzeug.datastructures import CombinedMultiDict from sqlalchemy.orm import joinedload -from nyaa import app, backend, forms, models, torrents +from nyaa import backend, forms, models, torrents from nyaa.extensions import db from nyaa.utils import cached_function +app = flask.current_app bp = flask.Blueprint('torrents', __name__) diff --git a/nyaa/views/users.py b/nyaa/views/users.py index abeb30c..22c1d9e 100644 --- a/nyaa/views/users.py +++ b/nyaa/views/users.py @@ -5,12 +5,13 @@ from flask_paginate import Pagination from itsdangerous import BadSignature, URLSafeSerializer -from nyaa import app, forms, models +from nyaa import forms, models from nyaa.extensions import db from nyaa.search import (DEFAULT_MAX_SEARCH_RESULT, DEFAULT_PER_PAGE, SERACH_PAGINATE_DISPLAY_MSG, _generate_query_string, search_db, search_elastic) from nyaa.utils import chain_get +app = flask.current_app bp = flask.Blueprint('users', __name__) diff --git a/run.py b/run.py index e0d9fa5..1b6472d 100755 --- a/run.py +++ b/run.py @@ -1,3 +1,5 @@ #!/usr/bin/env python3 -from nyaa import app +from nyaa import create_app + +app = create_app('config') app.run(host='0.0.0.0', port=5500, debug=True) diff --git a/tests/__init__.py b/tests/__init__.py index 27f5185..4663455 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -3,7 +3,7 @@ import os import unittest -from nyaa import app +from nyaa import create_app USE_MYSQL = True @@ -12,6 +12,7 @@ class NyaaTestCase(unittest.TestCase): @classmethod def setUpClass(cls): + app = create_app('config') app.config['TESTING'] = True cls.app_context = app.app_context() diff --git a/tests/test_nyaa.py b/tests/test_nyaa.py index 9a81330..e0f87ad 100644 --- a/tests/test_nyaa.py +++ b/tests/test_nyaa.py @@ -6,16 +6,18 @@ import nyaa class NyaaTestCase(unittest.TestCase): + nyaa_app = nyaa.create_app('config') + def setUp(self): - self.db, nyaa.app.config['DATABASE'] = tempfile.mkstemp() - nyaa.app.config['TESTING'] = True - self.app = nyaa.app.test_client() - with nyaa.app.app_context(): + self.db, self.nyaa_app.config['DATABASE'] = tempfile.mkstemp() + self.nyaa_app.config['TESTING'] = True + self.app = self.nyaa_app.test_client() + with self.nyaa_app.app_context(): nyaa.db.create_all() def tearDown(self): os.close(self.db) - os.unlink(nyaa.app.config['DATABASE']) + os.unlink(self.nyaa_app.config['DATABASE']) def test_index_url(self): rv = self.app.get('/')