Refactor into an app factory [2 of 2] (#322)

* Replace all `from nyaa import app` imports with `app = flask.current_app` (or `from flask import current_app as app` where possible)
* Add a separate config object for top-level and class statements as `nyaa.extensions.config`
Required because those codes don't have app context at the time of evaluation/execution.
* Remove `routes.py` file and register all blueprints in `nyaa/__init__.py`
* Refactor `nyaa/__init__.py` into an app factory
* Update tools
* Update tests (temporary, will be replaced)
This commit is contained in:
Kfir Hadas 2017-08-01 21:02:08 +03:00 committed by GitHub
parent 0181d6cb33
commit 87dd95f1e0
21 changed files with 140 additions and 96 deletions

16
WSGI.py
View File

@ -3,13 +3,15 @@
import gevent.monkey import gevent.monkey
gevent.monkey.patch_all() gevent.monkey.patch_all()
from nyaa import app from nyaa import create_app
if app.config["DEBUG"]: app = create_app('config')
from werkzeug.debug import DebuggedApplication
app.wsgi_app = DebuggedApplication(app.wsgi_app, True) if app.config['DEBUG']:
from werkzeug.debug import DebuggedApplication
app.wsgi_app = DebuggedApplication(app.wsgi_app, True)
if __name__ == '__main__': if __name__ == '__main__':
import gevent.pywsgi import gevent.pywsgi
gevent_server = gevent.pywsgi.WSGIServer(("localhost", 5000), app.wsgi_app) gevent_server = gevent.pywsgi.WSGIServer(("localhost", 5000), app.wsgi_app)
gevent_server.serve_forever() gevent_server.serve_forever()

View File

@ -1,9 +1,11 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import sqlalchemy import sqlalchemy
from nyaa import app, models from nyaa import create_app, models
from nyaa.extensions import db from nyaa.extensions import db
app = create_app('config')
NYAA_CATEGORIES = [ NYAA_CATEGORIES = [
('Anime', ['Anime Music Video', 'English-translated', 'Non-English-translated', 'Raw']), ('Anime', ['Anime Music Video', 'English-translated', 'Non-English-translated', 'Raw']),
('Audio', ['Lossless', 'Lossy']), ('Audio', ['Lossless', 'Lossy']),

View File

@ -5,9 +5,10 @@ import sys
from flask_script import Manager from flask_script import Manager
from flask_migrate import Migrate, MigrateCommand from flask_migrate import Migrate, MigrateCommand
from nyaa import app from nyaa import create_app
from nyaa.extensions import db from nyaa.extensions import db
app = create_app('config')
migrate = Migrate(app, db) migrate = Migrate(app, db)
manager = Manager(app) manager = Manager(app)

View File

@ -14,9 +14,10 @@ from elasticsearch import Elasticsearch
from elasticsearch.client import IndicesClient from elasticsearch.client import IndicesClient
from elasticsearch import helpers from elasticsearch import helpers
from nyaa import app, models from nyaa import create_app, models
from nyaa.extensions import db from nyaa.extensions import db
app = create_app('config')
es = Elasticsearch(timeout=30) es = Elasticsearch(timeout=30)
ic = IndicesClient(es) ic = IndicesClient(es)

View File

@ -4,69 +4,81 @@ import os
import flask import flask
from flask_assets import Bundle # noqa F401 from flask_assets import Bundle # noqa F401
from nyaa.api_handler import api_blueprint
from nyaa.extensions import assets, db, fix_paginate, toolbar from nyaa.extensions import assets, db, fix_paginate, toolbar
from nyaa.template_utils import bp as template_utils_bp
from nyaa.views import register_views
app = flask.Flask(__name__)
app.config.from_object('config')
# Don't refresh cookie each request def create_app(config):
app.config['SESSION_REFRESH_EACH_REQUEST'] = False """ Nyaa app factory """
app = flask.Flask(__name__)
app.config.from_object(config)
# Debugging # Don't refresh cookie each request
if app.config['DEBUG']: app.config['SESSION_REFRESH_EACH_REQUEST'] = False
app.config['DEBUG_TB_INTERCEPT_REDIRECTS'] = False
toolbar.init_app(app)
app.logger.setLevel(logging.DEBUG)
# Forbid caching # Debugging
@app.after_request if app.config['DEBUG']:
def forbid_cache(request): app.config['DEBUG_TB_INTERCEPT_REDIRECTS'] = False
request.headers['Cache-Control'] = 'no-cache, no-store, must-revalidate, max-age=0' toolbar.init_app(app)
request.headers['Pragma'] = 'no-cache' app.logger.setLevel(logging.DEBUG)
request.headers['Expires'] = '0'
return request
else: # Forbid caching
app.logger.setLevel(logging.WARNING) @app.after_request
def forbid_cache(request):
request.headers['Cache-Control'] = 'no-cache, no-store, must-revalidate, max-age=0'
request.headers['Pragma'] = 'no-cache'
request.headers['Expires'] = '0'
return request
# Logging else:
if 'LOG_FILE' in app.config: app.logger.setLevel(logging.WARNING)
from logging.handlers import RotatingFileHandler
app.log_handler = RotatingFileHandler(
app.config['LOG_FILE'], maxBytes=10000, backupCount=1)
app.logger.addHandler(app.log_handler)
# Log errors and display a message to the user in production mdode # Logging
if not app.config['DEBUG']: if 'LOG_FILE' in app.config:
@app.errorhandler(500) from logging.handlers import RotatingFileHandler
def internal_error(exception): app.log_handler = RotatingFileHandler(
app.logger.error(exception) app.config['LOG_FILE'], maxBytes=10000, backupCount=1)
flask.flash(flask.Markup( app.logger.addHandler(app.log_handler)
'<strong>An error occurred!</strong> Debug information has been logged.'), 'danger')
return flask.redirect('/')
# Get git commit hash # Log errors and display a message to the user in production mdode
app.config['COMMIT_HASH'] = None if not app.config['DEBUG']:
master_head = os.path.abspath(os.path.join(os.path.dirname(__file__), '../.git/refs/heads/master')) @app.errorhandler(500)
if os.path.isfile(master_head): def internal_error(exception):
with open(master_head, 'r') as head: app.logger.error(exception)
app.config['COMMIT_HASH'] = head.readline().strip() flask.flash(flask.Markup(
'<strong>An error occurred!</strong> Debug information has been logged.'), 'danger')
return flask.redirect('/')
# Enable the jinja2 do extension. # Get git commit hash
app.jinja_env.add_extension('jinja2.ext.do') app.config['COMMIT_HASH'] = None
app.jinja_env.lstrip_blocks = True master_head = os.path.abspath(os.path.join(os.path.dirname(__file__),
app.jinja_env.trim_blocks = True '../.git/refs/heads/master'))
if os.path.isfile(master_head):
with open(master_head, 'r') as head:
app.config['COMMIT_HASH'] = head.readline().strip()
# Database # Enable the jinja2 do extension.
fix_paginate() # This has to be before the database is initialized app.jinja_env.add_extension('jinja2.ext.do')
app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False app.jinja_env.lstrip_blocks = True
app.config['MYSQL_DATABASE_CHARSET'] = 'utf8mb4' app.jinja_env.trim_blocks = True
db.init_app(app)
assets.init_app(app) # Database
fix_paginate() # This has to be before the database is initialized
app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False
app.config['MYSQL_DATABASE_CHARSET'] = 'utf8mb4'
db.init_app(app)
# css = Bundle('style.scss', filters='libsass', # Assets
# output='style.css', depends='**/*.scss') assets.init_app(app)
# assets.register('style_all', css) # css = Bundle('style.scss', filters='libsass',
# output='style.css', depends='**/*.scss')
# assets.register('style_all', css)
from nyaa import routes # noqa E402 isort:skip # Blueprints
app.register_blueprint(template_utils_bp)
app.register_blueprint(api_blueprint, url_prefix='/api')
register_views(app)
return app

View File

@ -7,9 +7,11 @@ from werkzeug import secure_filename
from orderedset import OrderedSet from orderedset import OrderedSet
from nyaa import app, models, utils from nyaa import models, utils
from nyaa.extensions import db from nyaa.extensions import db
app = flask.current_app
@utils.cached_function @utils.cached_function
def get_category_id_map(): def get_category_id_map():

View File

@ -1,4 +1,7 @@
import os.path
from flask import abort from flask import abort
from flask.config import Config
from flask_assets import Environment from flask_assets import Environment
from flask_debugtoolbar import DebugToolbarExtension from flask_debugtoolbar import DebugToolbarExtension
from flask_sqlalchemy import BaseQuery, Pagination, SQLAlchemy from flask_sqlalchemy import BaseQuery, Pagination, SQLAlchemy
@ -32,3 +35,16 @@ def fix_paginate():
return Pagination(self, page, per_page, total_query_count, items) return Pagination(self, page, per_page, total_query_count, items)
BaseQuery.paginate_faste = paginate_faste BaseQuery.paginate_faste = paginate_faste
def _get_config():
# Workaround to get an available config object before the app is initiallized
# Only needed/used in top-level and class statements
# https://stackoverflow.com/a/18138250/7597273
root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
config = Config(root_path)
config.from_object('config')
return config
config = _get_config()

View File

@ -14,9 +14,12 @@ from wtforms.validators import (DataRequired, Email, EqualTo, Length, Optional,
from wtforms.widgets import Select as SelectWidget # For DisabledSelectField from wtforms.widgets import Select as SelectWidget # For DisabledSelectField
from wtforms.widgets import HTMLString, html_params # For DisabledSelectField from wtforms.widgets import HTMLString, html_params # For DisabledSelectField
from nyaa import app, bencode, models, utils from nyaa import bencode, models, utils
from nyaa.extensions import config
from nyaa.models import User from nyaa.models import User
app = flask.current_app
class Unique(object): class Unique(object):
@ -81,7 +84,7 @@ class RegisterForm(FlaskForm):
password_confirm = PasswordField('Password (confirm)') password_confirm = PasswordField('Password (confirm)')
if app.config['USE_RECAPTCHA']: if config['USE_RECAPTCHA']:
recaptcha = RecaptchaField() recaptcha = RecaptchaField()
@ -195,7 +198,7 @@ class UploadForm(FlaskForm):
'%(max)d at most.') '%(max)d at most.')
]) ])
if app.config['USE_RECAPTCHA']: if config['USE_RECAPTCHA']:
# Captcha only for not logged in users # Captcha only for not logged in users
_recaptcha_validator = RecaptchaValidator() _recaptcha_validator = RecaptchaValidator()

View File

@ -15,11 +15,12 @@ from sqlalchemy.ext import declarative
from sqlalchemy_fulltext import FullText from sqlalchemy_fulltext import FullText
from sqlalchemy_utils import ChoiceType, EmailType, PasswordType from sqlalchemy_utils import ChoiceType, EmailType, PasswordType
from nyaa import app from nyaa.extensions import config, db
from nyaa.extensions import db
from nyaa.torrents import create_magnet from nyaa.torrents import create_magnet
if app.config['USE_MYSQL']: app = flask.current_app
if config['USE_MYSQL']:
from sqlalchemy.dialects import mysql from sqlalchemy.dialects import mysql
BinaryType = mysql.BINARY BinaryType = mysql.BINARY
DescriptionTextType = mysql.TEXT DescriptionTextType = mysql.TEXT
@ -686,7 +687,7 @@ class SukebeiTorrent(TorrentBase, db.Model):
# Fulltext models for MySQL # Fulltext models for MySQL
if app.config['USE_MYSQL']: if config['USE_MYSQL']:
class NyaaTorrentNameSearch(FullText, NyaaTorrent): class NyaaTorrentNameSearch(FullText, NyaaTorrent):
__fulltext_columns__ = ('display_name',) __fulltext_columns__ = ('display_name',)
__table_args__ = {'extend_existing': True} __table_args__ = {'extend_existing': True}
@ -785,7 +786,7 @@ class SukebeiReport(ReportBase, db.Model):
# Choose our defaults for models.Torrent etc # Choose our defaults for models.Torrent etc
if app.config['SITE_FLAVOR'] == 'nyaa': if config['SITE_FLAVOR'] == 'nyaa':
Torrent = NyaaTorrent Torrent = NyaaTorrent
TorrentFilelist = NyaaTorrentFilelist TorrentFilelist = NyaaTorrentFilelist
TorrentInfo = NyaaTorrentInfo TorrentInfo = NyaaTorrentInfo
@ -798,7 +799,7 @@ if app.config['SITE_FLAVOR'] == 'nyaa':
Report = NyaaReport Report = NyaaReport
TorrentNameSearch = NyaaTorrentNameSearch TorrentNameSearch = NyaaTorrentNameSearch
elif app.config['SITE_FLAVOR'] == 'sukebei': elif config['SITE_FLAVOR'] == 'sukebei':
Torrent = SukebeiTorrent Torrent = SukebeiTorrent
TorrentFilelist = SukebeiTorrentFilelist TorrentFilelist = SukebeiTorrentFilelist
TorrentInfo = SukebeiTorrentInfo TorrentInfo = SukebeiTorrentInfo

View File

@ -1,9 +0,0 @@
from nyaa import app, template_utils, views
from nyaa.api_handler import api_blueprint
# Register all template filters and template globals
app.register_blueprint(template_utils.bp)
# Register the API routes
app.register_blueprint(api_blueprint, url_prefix='/api')
# Register the site's routes
views.register(app)

View File

@ -10,9 +10,11 @@ from elasticsearch import Elasticsearch
from elasticsearch_dsl import Q, Search from elasticsearch_dsl import Q, Search
from sqlalchemy_fulltext import FullTextSearch from sqlalchemy_fulltext import FullTextSearch
from nyaa import app, models from nyaa import models
from nyaa.extensions import db from nyaa.extensions import db
app = flask.current_app
DEFAULT_MAX_SEARCH_RESULT = 1000 DEFAULT_MAX_SEARCH_RESULT = 1000
DEFAULT_PER_PAGE = 75 DEFAULT_PER_PAGE = 75
SERACH_PAGINATE_DISPLAY_MSG = ('Displaying results {start}-{end} out of {total} results.<br>\n' SERACH_PAGINATE_DISPLAY_MSG = ('Displaying results {start}-{end} out of {total} results.<br>\n'

View File

@ -7,10 +7,10 @@ from urllib.parse import urlencode
import flask import flask
from werkzeug.urls import url_encode from werkzeug.urls import url_encode
from nyaa import app
from nyaa.backend import get_category_id_map from nyaa.backend import get_category_id_map
from nyaa.torrents import get_default_trackers from nyaa.torrents import get_default_trackers
app = flask.current_app
bp = flask.Blueprint('template-utils', __name__) bp = flask.Blueprint('template-utils', __name__)
_static_cache = {} # For static_cachebuster _static_cache = {} # For static_cachebuster

View File

@ -3,9 +3,11 @@ import os
import time import time
from urllib.parse import urlencode from urllib.parse import urlencode
from flask import current_app as app
from orderedset import OrderedSet from orderedset import OrderedSet
from nyaa import app, bencode from nyaa import bencode
USED_TRACKERS = OrderedSet() USED_TRACKERS = OrderedSet()

View File

@ -8,7 +8,7 @@ from nyaa.views import ( # isort:skip
) )
def register(flask_app): def register_views(flask_app):
""" Register the blueprints using the flask_app object """ """ Register the blueprints using the flask_app object """
flask_app.register_blueprint(account.bp) flask_app.register_blueprint(account.bp)
flask_app.register_blueprint(admin.bp) flask_app.register_blueprint(admin.bp)

View File

@ -6,10 +6,11 @@ from ipaddress import ip_address
import flask import flask
from nyaa import app, forms, models from nyaa import forms, models
from nyaa.extensions import db from nyaa.extensions import db
from nyaa.views.users import get_activation_link from nyaa.views.users import get_activation_link
app = flask.current_app
bp = flask.Blueprint('account', __name__) bp = flask.Blueprint('account', __name__)

View File

@ -5,12 +5,13 @@ from datetime import datetime, timedelta
import flask import flask
from flask_paginate import Pagination from flask_paginate import Pagination
from nyaa import app, models from nyaa import models
from nyaa.search import (DEFAULT_MAX_SEARCH_RESULT, DEFAULT_PER_PAGE, SERACH_PAGINATE_DISPLAY_MSG, from nyaa.search import (DEFAULT_MAX_SEARCH_RESULT, DEFAULT_PER_PAGE, SERACH_PAGINATE_DISPLAY_MSG,
_generate_query_string, search_db, search_elastic) _generate_query_string, search_db, search_elastic)
from nyaa.utils import chain_get from nyaa.utils import chain_get
from nyaa.views.account import logout from nyaa.views.account import logout
app = flask.current_app
bp = flask.Blueprint('main', __name__) bp = flask.Blueprint('main', __name__)

View File

@ -7,10 +7,11 @@ from werkzeug.datastructures import CombinedMultiDict
from sqlalchemy.orm import joinedload from sqlalchemy.orm import joinedload
from nyaa import app, backend, forms, models, torrents from nyaa import backend, forms, models, torrents
from nyaa.extensions import db from nyaa.extensions import db
from nyaa.utils import cached_function from nyaa.utils import cached_function
app = flask.current_app
bp = flask.Blueprint('torrents', __name__) bp = flask.Blueprint('torrents', __name__)

View File

@ -5,12 +5,13 @@ from flask_paginate import Pagination
from itsdangerous import BadSignature, URLSafeSerializer from itsdangerous import BadSignature, URLSafeSerializer
from nyaa import app, forms, models from nyaa import forms, models
from nyaa.extensions import db from nyaa.extensions import db
from nyaa.search import (DEFAULT_MAX_SEARCH_RESULT, DEFAULT_PER_PAGE, SERACH_PAGINATE_DISPLAY_MSG, from nyaa.search import (DEFAULT_MAX_SEARCH_RESULT, DEFAULT_PER_PAGE, SERACH_PAGINATE_DISPLAY_MSG,
_generate_query_string, search_db, search_elastic) _generate_query_string, search_db, search_elastic)
from nyaa.utils import chain_get from nyaa.utils import chain_get
app = flask.current_app
bp = flask.Blueprint('users', __name__) bp = flask.Blueprint('users', __name__)

4
run.py
View File

@ -1,3 +1,5 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
from nyaa import app from nyaa import create_app
app = create_app('config')
app.run(host='0.0.0.0', port=5500, debug=True) app.run(host='0.0.0.0', port=5500, debug=True)

View File

@ -3,7 +3,7 @@
import os import os
import unittest import unittest
from nyaa import app from nyaa import create_app
USE_MYSQL = True USE_MYSQL = True
@ -12,6 +12,7 @@ class NyaaTestCase(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
app = create_app('config')
app.config['TESTING'] = True app.config['TESTING'] = True
cls.app_context = app.app_context() cls.app_context = app.app_context()

View File

@ -6,16 +6,18 @@ import nyaa
class NyaaTestCase(unittest.TestCase): class NyaaTestCase(unittest.TestCase):
nyaa_app = nyaa.create_app('config')
def setUp(self): def setUp(self):
self.db, nyaa.app.config['DATABASE'] = tempfile.mkstemp() self.db, self.nyaa_app.config['DATABASE'] = tempfile.mkstemp()
nyaa.app.config['TESTING'] = True self.nyaa_app.config['TESTING'] = True
self.app = nyaa.app.test_client() self.app = self.nyaa_app.test_client()
with nyaa.app.app_context(): with self.nyaa_app.app_context():
nyaa.db.create_all() nyaa.db.create_all()
def tearDown(self): def tearDown(self):
os.close(self.db) os.close(self.db)
os.unlink(nyaa.app.config['DATABASE']) os.unlink(self.nyaa_app.config['DATABASE'])
def test_index_url(self): def test_index_url(self):
rv = self.app.get('/') rv = self.app.get('/')