Refactor into an app factory [2 of 2] (#322)

* Replace all `from nyaa import app` imports with `app = flask.current_app` (or `from flask import current_app as app` where possible)
* Add a separate config object for top-level and class statements as `nyaa.extensions.config`
Required because those codes don't have app context at the time of evaluation/execution.
* Remove `routes.py` file and register all blueprints in `nyaa/__init__.py`
* Refactor `nyaa/__init__.py` into an app factory
* Update tools
* Update tests (temporary, will be replaced)
This commit is contained in:
Kfir Hadas 2017-08-01 21:02:08 +03:00 committed by GitHub
parent 0181d6cb33
commit 87dd95f1e0
21 changed files with 140 additions and 96 deletions

16
WSGI.py
View File

@ -3,13 +3,15 @@
import gevent.monkey
gevent.monkey.patch_all()
from nyaa import app
from nyaa import create_app
if app.config["DEBUG"]:
from werkzeug.debug import DebuggedApplication
app.wsgi_app = DebuggedApplication(app.wsgi_app, True)
app = create_app('config')
if app.config['DEBUG']:
from werkzeug.debug import DebuggedApplication
app.wsgi_app = DebuggedApplication(app.wsgi_app, True)
if __name__ == '__main__':
import gevent.pywsgi
gevent_server = gevent.pywsgi.WSGIServer(("localhost", 5000), app.wsgi_app)
gevent_server.serve_forever()
import gevent.pywsgi
gevent_server = gevent.pywsgi.WSGIServer(("localhost", 5000), app.wsgi_app)
gevent_server.serve_forever()

View File

@ -1,9 +1,11 @@
#!/usr/bin/env python3
import sqlalchemy
from nyaa import app, models
from nyaa import create_app, models
from nyaa.extensions import db
app = create_app('config')
NYAA_CATEGORIES = [
('Anime', ['Anime Music Video', 'English-translated', 'Non-English-translated', 'Raw']),
('Audio', ['Lossless', 'Lossy']),

View File

@ -5,9 +5,10 @@ import sys
from flask_script import Manager
from flask_migrate import Migrate, MigrateCommand
from nyaa import app
from nyaa import create_app
from nyaa.extensions import db
app = create_app('config')
migrate = Migrate(app, db)
manager = Manager(app)

View File

@ -14,9 +14,10 @@ from elasticsearch import Elasticsearch
from elasticsearch.client import IndicesClient
from elasticsearch import helpers
from nyaa import app, models
from nyaa import create_app, models
from nyaa.extensions import db
app = create_app('config')
es = Elasticsearch(timeout=30)
ic = IndicesClient(es)

View File

@ -4,69 +4,81 @@ import os
import flask
from flask_assets import Bundle # noqa F401
from nyaa.api_handler import api_blueprint
from nyaa.extensions import assets, db, fix_paginate, toolbar
from nyaa.template_utils import bp as template_utils_bp
from nyaa.views import register_views
app = flask.Flask(__name__)
app.config.from_object('config')
# Don't refresh cookie each request
app.config['SESSION_REFRESH_EACH_REQUEST'] = False
def create_app(config):
""" Nyaa app factory """
app = flask.Flask(__name__)
app.config.from_object(config)
# Debugging
if app.config['DEBUG']:
app.config['DEBUG_TB_INTERCEPT_REDIRECTS'] = False
toolbar.init_app(app)
app.logger.setLevel(logging.DEBUG)
# Don't refresh cookie each request
app.config['SESSION_REFRESH_EACH_REQUEST'] = False
# Forbid caching
@app.after_request
def forbid_cache(request):
request.headers['Cache-Control'] = 'no-cache, no-store, must-revalidate, max-age=0'
request.headers['Pragma'] = 'no-cache'
request.headers['Expires'] = '0'
return request
# Debugging
if app.config['DEBUG']:
app.config['DEBUG_TB_INTERCEPT_REDIRECTS'] = False
toolbar.init_app(app)
app.logger.setLevel(logging.DEBUG)
else:
app.logger.setLevel(logging.WARNING)
# Forbid caching
@app.after_request
def forbid_cache(request):
request.headers['Cache-Control'] = 'no-cache, no-store, must-revalidate, max-age=0'
request.headers['Pragma'] = 'no-cache'
request.headers['Expires'] = '0'
return request
# Logging
if 'LOG_FILE' in app.config:
from logging.handlers import RotatingFileHandler
app.log_handler = RotatingFileHandler(
app.config['LOG_FILE'], maxBytes=10000, backupCount=1)
app.logger.addHandler(app.log_handler)
else:
app.logger.setLevel(logging.WARNING)
# Log errors and display a message to the user in production mdode
if not app.config['DEBUG']:
@app.errorhandler(500)
def internal_error(exception):
app.logger.error(exception)
flask.flash(flask.Markup(
'<strong>An error occurred!</strong> Debug information has been logged.'), 'danger')
return flask.redirect('/')
# Logging
if 'LOG_FILE' in app.config:
from logging.handlers import RotatingFileHandler
app.log_handler = RotatingFileHandler(
app.config['LOG_FILE'], maxBytes=10000, backupCount=1)
app.logger.addHandler(app.log_handler)
# Get git commit hash
app.config['COMMIT_HASH'] = None
master_head = os.path.abspath(os.path.join(os.path.dirname(__file__), '../.git/refs/heads/master'))
if os.path.isfile(master_head):
with open(master_head, 'r') as head:
app.config['COMMIT_HASH'] = head.readline().strip()
# Log errors and display a message to the user in production mdode
if not app.config['DEBUG']:
@app.errorhandler(500)
def internal_error(exception):
app.logger.error(exception)
flask.flash(flask.Markup(
'<strong>An error occurred!</strong> Debug information has been logged.'), 'danger')
return flask.redirect('/')
# Enable the jinja2 do extension.
app.jinja_env.add_extension('jinja2.ext.do')
app.jinja_env.lstrip_blocks = True
app.jinja_env.trim_blocks = True
# Get git commit hash
app.config['COMMIT_HASH'] = None
master_head = os.path.abspath(os.path.join(os.path.dirname(__file__),
'../.git/refs/heads/master'))
if os.path.isfile(master_head):
with open(master_head, 'r') as head:
app.config['COMMIT_HASH'] = head.readline().strip()
# Database
fix_paginate() # This has to be before the database is initialized
app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False
app.config['MYSQL_DATABASE_CHARSET'] = 'utf8mb4'
db.init_app(app)
# Enable the jinja2 do extension.
app.jinja_env.add_extension('jinja2.ext.do')
app.jinja_env.lstrip_blocks = True
app.jinja_env.trim_blocks = True
assets.init_app(app)
# Database
fix_paginate() # This has to be before the database is initialized
app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False
app.config['MYSQL_DATABASE_CHARSET'] = 'utf8mb4'
db.init_app(app)
# css = Bundle('style.scss', filters='libsass',
# output='style.css', depends='**/*.scss')
# assets.register('style_all', css)
# Assets
assets.init_app(app)
# css = Bundle('style.scss', filters='libsass',
# output='style.css', depends='**/*.scss')
# assets.register('style_all', css)
from nyaa import routes # noqa E402 isort:skip
# Blueprints
app.register_blueprint(template_utils_bp)
app.register_blueprint(api_blueprint, url_prefix='/api')
register_views(app)
return app

View File

@ -7,9 +7,11 @@ from werkzeug import secure_filename
from orderedset import OrderedSet
from nyaa import app, models, utils
from nyaa import models, utils
from nyaa.extensions import db
app = flask.current_app
@utils.cached_function
def get_category_id_map():

View File

@ -1,4 +1,7 @@
import os.path
from flask import abort
from flask.config import Config
from flask_assets import Environment
from flask_debugtoolbar import DebugToolbarExtension
from flask_sqlalchemy import BaseQuery, Pagination, SQLAlchemy
@ -32,3 +35,16 @@ def fix_paginate():
return Pagination(self, page, per_page, total_query_count, items)
BaseQuery.paginate_faste = paginate_faste
def _get_config():
# Workaround to get an available config object before the app is initiallized
# Only needed/used in top-level and class statements
# https://stackoverflow.com/a/18138250/7597273
root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
config = Config(root_path)
config.from_object('config')
return config
config = _get_config()

View File

@ -14,9 +14,12 @@ from wtforms.validators import (DataRequired, Email, EqualTo, Length, Optional,
from wtforms.widgets import Select as SelectWidget # For DisabledSelectField
from wtforms.widgets import HTMLString, html_params # For DisabledSelectField
from nyaa import app, bencode, models, utils
from nyaa import bencode, models, utils
from nyaa.extensions import config
from nyaa.models import User
app = flask.current_app
class Unique(object):
@ -81,7 +84,7 @@ class RegisterForm(FlaskForm):
password_confirm = PasswordField('Password (confirm)')
if app.config['USE_RECAPTCHA']:
if config['USE_RECAPTCHA']:
recaptcha = RecaptchaField()
@ -195,7 +198,7 @@ class UploadForm(FlaskForm):
'%(max)d at most.')
])
if app.config['USE_RECAPTCHA']:
if config['USE_RECAPTCHA']:
# Captcha only for not logged in users
_recaptcha_validator = RecaptchaValidator()

View File

@ -15,11 +15,12 @@ from sqlalchemy.ext import declarative
from sqlalchemy_fulltext import FullText
from sqlalchemy_utils import ChoiceType, EmailType, PasswordType
from nyaa import app
from nyaa.extensions import db
from nyaa.extensions import config, db
from nyaa.torrents import create_magnet
if app.config['USE_MYSQL']:
app = flask.current_app
if config['USE_MYSQL']:
from sqlalchemy.dialects import mysql
BinaryType = mysql.BINARY
DescriptionTextType = mysql.TEXT
@ -686,7 +687,7 @@ class SukebeiTorrent(TorrentBase, db.Model):
# Fulltext models for MySQL
if app.config['USE_MYSQL']:
if config['USE_MYSQL']:
class NyaaTorrentNameSearch(FullText, NyaaTorrent):
__fulltext_columns__ = ('display_name',)
__table_args__ = {'extend_existing': True}
@ -785,7 +786,7 @@ class SukebeiReport(ReportBase, db.Model):
# Choose our defaults for models.Torrent etc
if app.config['SITE_FLAVOR'] == 'nyaa':
if config['SITE_FLAVOR'] == 'nyaa':
Torrent = NyaaTorrent
TorrentFilelist = NyaaTorrentFilelist
TorrentInfo = NyaaTorrentInfo
@ -798,7 +799,7 @@ if app.config['SITE_FLAVOR'] == 'nyaa':
Report = NyaaReport
TorrentNameSearch = NyaaTorrentNameSearch
elif app.config['SITE_FLAVOR'] == 'sukebei':
elif config['SITE_FLAVOR'] == 'sukebei':
Torrent = SukebeiTorrent
TorrentFilelist = SukebeiTorrentFilelist
TorrentInfo = SukebeiTorrentInfo

View File

@ -1,9 +0,0 @@
from nyaa import app, template_utils, views
from nyaa.api_handler import api_blueprint
# Register all template filters and template globals
app.register_blueprint(template_utils.bp)
# Register the API routes
app.register_blueprint(api_blueprint, url_prefix='/api')
# Register the site's routes
views.register(app)

View File

@ -10,9 +10,11 @@ from elasticsearch import Elasticsearch
from elasticsearch_dsl import Q, Search
from sqlalchemy_fulltext import FullTextSearch
from nyaa import app, models
from nyaa import models
from nyaa.extensions import db
app = flask.current_app
DEFAULT_MAX_SEARCH_RESULT = 1000
DEFAULT_PER_PAGE = 75
SERACH_PAGINATE_DISPLAY_MSG = ('Displaying results {start}-{end} out of {total} results.<br>\n'

View File

@ -7,10 +7,10 @@ from urllib.parse import urlencode
import flask
from werkzeug.urls import url_encode
from nyaa import app
from nyaa.backend import get_category_id_map
from nyaa.torrents import get_default_trackers
app = flask.current_app
bp = flask.Blueprint('template-utils', __name__)
_static_cache = {} # For static_cachebuster

View File

@ -3,9 +3,11 @@ import os
import time
from urllib.parse import urlencode
from flask import current_app as app
from orderedset import OrderedSet
from nyaa import app, bencode
from nyaa import bencode
USED_TRACKERS = OrderedSet()

View File

@ -8,7 +8,7 @@ from nyaa.views import ( # isort:skip
)
def register(flask_app):
def register_views(flask_app):
""" Register the blueprints using the flask_app object """
flask_app.register_blueprint(account.bp)
flask_app.register_blueprint(admin.bp)

View File

@ -6,10 +6,11 @@ from ipaddress import ip_address
import flask
from nyaa import app, forms, models
from nyaa import forms, models
from nyaa.extensions import db
from nyaa.views.users import get_activation_link
app = flask.current_app
bp = flask.Blueprint('account', __name__)

View File

@ -5,12 +5,13 @@ from datetime import datetime, timedelta
import flask
from flask_paginate import Pagination
from nyaa import app, models
from nyaa import models
from nyaa.search import (DEFAULT_MAX_SEARCH_RESULT, DEFAULT_PER_PAGE, SERACH_PAGINATE_DISPLAY_MSG,
_generate_query_string, search_db, search_elastic)
from nyaa.utils import chain_get
from nyaa.views.account import logout
app = flask.current_app
bp = flask.Blueprint('main', __name__)

View File

@ -7,10 +7,11 @@ from werkzeug.datastructures import CombinedMultiDict
from sqlalchemy.orm import joinedload
from nyaa import app, backend, forms, models, torrents
from nyaa import backend, forms, models, torrents
from nyaa.extensions import db
from nyaa.utils import cached_function
app = flask.current_app
bp = flask.Blueprint('torrents', __name__)

View File

@ -5,12 +5,13 @@ from flask_paginate import Pagination
from itsdangerous import BadSignature, URLSafeSerializer
from nyaa import app, forms, models
from nyaa import forms, models
from nyaa.extensions import db
from nyaa.search import (DEFAULT_MAX_SEARCH_RESULT, DEFAULT_PER_PAGE, SERACH_PAGINATE_DISPLAY_MSG,
_generate_query_string, search_db, search_elastic)
from nyaa.utils import chain_get
app = flask.current_app
bp = flask.Blueprint('users', __name__)

4
run.py
View File

@ -1,3 +1,5 @@
#!/usr/bin/env python3
from nyaa import app
from nyaa import create_app
app = create_app('config')
app.run(host='0.0.0.0', port=5500, debug=True)

View File

@ -3,7 +3,7 @@
import os
import unittest
from nyaa import app
from nyaa import create_app
USE_MYSQL = True
@ -12,6 +12,7 @@ class NyaaTestCase(unittest.TestCase):
@classmethod
def setUpClass(cls):
app = create_app('config')
app.config['TESTING'] = True
cls.app_context = app.app_context()

View File

@ -6,16 +6,18 @@ import nyaa
class NyaaTestCase(unittest.TestCase):
nyaa_app = nyaa.create_app('config')
def setUp(self):
self.db, nyaa.app.config['DATABASE'] = tempfile.mkstemp()
nyaa.app.config['TESTING'] = True
self.app = nyaa.app.test_client()
with nyaa.app.app_context():
self.db, self.nyaa_app.config['DATABASE'] = tempfile.mkstemp()
self.nyaa_app.config['TESTING'] = True
self.app = self.nyaa_app.test_client()
with self.nyaa_app.app_context():
nyaa.db.create_all()
def tearDown(self):
os.close(self.db)
os.unlink(nyaa.app.config['DATABASE'])
os.unlink(self.nyaa_app.config['DATABASE'])
def test_index_url(self):
rv = self.app.get('/')