mirror of
https://gitlab.com/SIGBUS/nyaa.git
synced 2025-01-24 19:10:16 +00:00
Refactor into an app factory [2 of 2] (#322)
* Replace all `from nyaa import app` imports with `app = flask.current_app` (or `from flask import current_app as app` where possible) * Add a separate config object for top-level and class statements as `nyaa.extensions.config` Required because those codes don't have app context at the time of evaluation/execution. * Remove `routes.py` file and register all blueprints in `nyaa/__init__.py` * Refactor `nyaa/__init__.py` into an app factory * Update tools * Update tests (temporary, will be replaced)
This commit is contained in:
parent
0181d6cb33
commit
87dd95f1e0
16
WSGI.py
16
WSGI.py
|
@ -3,13 +3,15 @@
|
|||
import gevent.monkey
|
||||
gevent.monkey.patch_all()
|
||||
|
||||
from nyaa import app
|
||||
from nyaa import create_app
|
||||
|
||||
if app.config["DEBUG"]:
|
||||
from werkzeug.debug import DebuggedApplication
|
||||
app.wsgi_app = DebuggedApplication(app.wsgi_app, True)
|
||||
app = create_app('config')
|
||||
|
||||
if app.config['DEBUG']:
|
||||
from werkzeug.debug import DebuggedApplication
|
||||
app.wsgi_app = DebuggedApplication(app.wsgi_app, True)
|
||||
|
||||
if __name__ == '__main__':
|
||||
import gevent.pywsgi
|
||||
gevent_server = gevent.pywsgi.WSGIServer(("localhost", 5000), app.wsgi_app)
|
||||
gevent_server.serve_forever()
|
||||
import gevent.pywsgi
|
||||
gevent_server = gevent.pywsgi.WSGIServer(("localhost", 5000), app.wsgi_app)
|
||||
gevent_server.serve_forever()
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
#!/usr/bin/env python3
|
||||
import sqlalchemy
|
||||
|
||||
from nyaa import app, models
|
||||
from nyaa import create_app, models
|
||||
from nyaa.extensions import db
|
||||
|
||||
app = create_app('config')
|
||||
|
||||
NYAA_CATEGORIES = [
|
||||
('Anime', ['Anime Music Video', 'English-translated', 'Non-English-translated', 'Raw']),
|
||||
('Audio', ['Lossless', 'Lossy']),
|
||||
|
|
|
@ -5,9 +5,10 @@ import sys
|
|||
from flask_script import Manager
|
||||
from flask_migrate import Migrate, MigrateCommand
|
||||
|
||||
from nyaa import app
|
||||
from nyaa import create_app
|
||||
from nyaa.extensions import db
|
||||
|
||||
app = create_app('config')
|
||||
migrate = Migrate(app, db)
|
||||
|
||||
manager = Manager(app)
|
||||
|
|
|
@ -14,9 +14,10 @@ from elasticsearch import Elasticsearch
|
|||
from elasticsearch.client import IndicesClient
|
||||
from elasticsearch import helpers
|
||||
|
||||
from nyaa import app, models
|
||||
from nyaa import create_app, models
|
||||
from nyaa.extensions import db
|
||||
|
||||
app = create_app('config')
|
||||
es = Elasticsearch(timeout=30)
|
||||
ic = IndicesClient(es)
|
||||
|
||||
|
|
116
nyaa/__init__.py
116
nyaa/__init__.py
|
@ -4,69 +4,81 @@ import os
|
|||
import flask
|
||||
from flask_assets import Bundle # noqa F401
|
||||
|
||||
from nyaa.api_handler import api_blueprint
|
||||
from nyaa.extensions import assets, db, fix_paginate, toolbar
|
||||
from nyaa.template_utils import bp as template_utils_bp
|
||||
from nyaa.views import register_views
|
||||
|
||||
app = flask.Flask(__name__)
|
||||
app.config.from_object('config')
|
||||
|
||||
# Don't refresh cookie each request
|
||||
app.config['SESSION_REFRESH_EACH_REQUEST'] = False
|
||||
def create_app(config):
|
||||
""" Nyaa app factory """
|
||||
app = flask.Flask(__name__)
|
||||
app.config.from_object(config)
|
||||
|
||||
# Debugging
|
||||
if app.config['DEBUG']:
|
||||
app.config['DEBUG_TB_INTERCEPT_REDIRECTS'] = False
|
||||
toolbar.init_app(app)
|
||||
app.logger.setLevel(logging.DEBUG)
|
||||
# Don't refresh cookie each request
|
||||
app.config['SESSION_REFRESH_EACH_REQUEST'] = False
|
||||
|
||||
# Forbid caching
|
||||
@app.after_request
|
||||
def forbid_cache(request):
|
||||
request.headers['Cache-Control'] = 'no-cache, no-store, must-revalidate, max-age=0'
|
||||
request.headers['Pragma'] = 'no-cache'
|
||||
request.headers['Expires'] = '0'
|
||||
return request
|
||||
# Debugging
|
||||
if app.config['DEBUG']:
|
||||
app.config['DEBUG_TB_INTERCEPT_REDIRECTS'] = False
|
||||
toolbar.init_app(app)
|
||||
app.logger.setLevel(logging.DEBUG)
|
||||
|
||||
else:
|
||||
app.logger.setLevel(logging.WARNING)
|
||||
# Forbid caching
|
||||
@app.after_request
|
||||
def forbid_cache(request):
|
||||
request.headers['Cache-Control'] = 'no-cache, no-store, must-revalidate, max-age=0'
|
||||
request.headers['Pragma'] = 'no-cache'
|
||||
request.headers['Expires'] = '0'
|
||||
return request
|
||||
|
||||
# Logging
|
||||
if 'LOG_FILE' in app.config:
|
||||
from logging.handlers import RotatingFileHandler
|
||||
app.log_handler = RotatingFileHandler(
|
||||
app.config['LOG_FILE'], maxBytes=10000, backupCount=1)
|
||||
app.logger.addHandler(app.log_handler)
|
||||
else:
|
||||
app.logger.setLevel(logging.WARNING)
|
||||
|
||||
# Log errors and display a message to the user in production mdode
|
||||
if not app.config['DEBUG']:
|
||||
@app.errorhandler(500)
|
||||
def internal_error(exception):
|
||||
app.logger.error(exception)
|
||||
flask.flash(flask.Markup(
|
||||
'<strong>An error occurred!</strong> Debug information has been logged.'), 'danger')
|
||||
return flask.redirect('/')
|
||||
# Logging
|
||||
if 'LOG_FILE' in app.config:
|
||||
from logging.handlers import RotatingFileHandler
|
||||
app.log_handler = RotatingFileHandler(
|
||||
app.config['LOG_FILE'], maxBytes=10000, backupCount=1)
|
||||
app.logger.addHandler(app.log_handler)
|
||||
|
||||
# Get git commit hash
|
||||
app.config['COMMIT_HASH'] = None
|
||||
master_head = os.path.abspath(os.path.join(os.path.dirname(__file__), '../.git/refs/heads/master'))
|
||||
if os.path.isfile(master_head):
|
||||
with open(master_head, 'r') as head:
|
||||
app.config['COMMIT_HASH'] = head.readline().strip()
|
||||
# Log errors and display a message to the user in production mdode
|
||||
if not app.config['DEBUG']:
|
||||
@app.errorhandler(500)
|
||||
def internal_error(exception):
|
||||
app.logger.error(exception)
|
||||
flask.flash(flask.Markup(
|
||||
'<strong>An error occurred!</strong> Debug information has been logged.'), 'danger')
|
||||
return flask.redirect('/')
|
||||
|
||||
# Enable the jinja2 do extension.
|
||||
app.jinja_env.add_extension('jinja2.ext.do')
|
||||
app.jinja_env.lstrip_blocks = True
|
||||
app.jinja_env.trim_blocks = True
|
||||
# Get git commit hash
|
||||
app.config['COMMIT_HASH'] = None
|
||||
master_head = os.path.abspath(os.path.join(os.path.dirname(__file__),
|
||||
'../.git/refs/heads/master'))
|
||||
if os.path.isfile(master_head):
|
||||
with open(master_head, 'r') as head:
|
||||
app.config['COMMIT_HASH'] = head.readline().strip()
|
||||
|
||||
# Database
|
||||
fix_paginate() # This has to be before the database is initialized
|
||||
app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False
|
||||
app.config['MYSQL_DATABASE_CHARSET'] = 'utf8mb4'
|
||||
db.init_app(app)
|
||||
# Enable the jinja2 do extension.
|
||||
app.jinja_env.add_extension('jinja2.ext.do')
|
||||
app.jinja_env.lstrip_blocks = True
|
||||
app.jinja_env.trim_blocks = True
|
||||
|
||||
assets.init_app(app)
|
||||
# Database
|
||||
fix_paginate() # This has to be before the database is initialized
|
||||
app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False
|
||||
app.config['MYSQL_DATABASE_CHARSET'] = 'utf8mb4'
|
||||
db.init_app(app)
|
||||
|
||||
# css = Bundle('style.scss', filters='libsass',
|
||||
# output='style.css', depends='**/*.scss')
|
||||
# assets.register('style_all', css)
|
||||
# Assets
|
||||
assets.init_app(app)
|
||||
# css = Bundle('style.scss', filters='libsass',
|
||||
# output='style.css', depends='**/*.scss')
|
||||
# assets.register('style_all', css)
|
||||
|
||||
from nyaa import routes # noqa E402 isort:skip
|
||||
# Blueprints
|
||||
app.register_blueprint(template_utils_bp)
|
||||
app.register_blueprint(api_blueprint, url_prefix='/api')
|
||||
register_views(app)
|
||||
|
||||
return app
|
||||
|
|
|
@ -7,9 +7,11 @@ from werkzeug import secure_filename
|
|||
|
||||
from orderedset import OrderedSet
|
||||
|
||||
from nyaa import app, models, utils
|
||||
from nyaa import models, utils
|
||||
from nyaa.extensions import db
|
||||
|
||||
app = flask.current_app
|
||||
|
||||
|
||||
@utils.cached_function
|
||||
def get_category_id_map():
|
||||
|
|
|
@ -1,4 +1,7 @@
|
|||
import os.path
|
||||
|
||||
from flask import abort
|
||||
from flask.config import Config
|
||||
from flask_assets import Environment
|
||||
from flask_debugtoolbar import DebugToolbarExtension
|
||||
from flask_sqlalchemy import BaseQuery, Pagination, SQLAlchemy
|
||||
|
@ -32,3 +35,16 @@ def fix_paginate():
|
|||
return Pagination(self, page, per_page, total_query_count, items)
|
||||
|
||||
BaseQuery.paginate_faste = paginate_faste
|
||||
|
||||
|
||||
def _get_config():
|
||||
# Workaround to get an available config object before the app is initiallized
|
||||
# Only needed/used in top-level and class statements
|
||||
# https://stackoverflow.com/a/18138250/7597273
|
||||
root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
|
||||
config = Config(root_path)
|
||||
config.from_object('config')
|
||||
return config
|
||||
|
||||
|
||||
config = _get_config()
|
||||
|
|
|
@ -14,9 +14,12 @@ from wtforms.validators import (DataRequired, Email, EqualTo, Length, Optional,
|
|||
from wtforms.widgets import Select as SelectWidget # For DisabledSelectField
|
||||
from wtforms.widgets import HTMLString, html_params # For DisabledSelectField
|
||||
|
||||
from nyaa import app, bencode, models, utils
|
||||
from nyaa import bencode, models, utils
|
||||
from nyaa.extensions import config
|
||||
from nyaa.models import User
|
||||
|
||||
app = flask.current_app
|
||||
|
||||
|
||||
class Unique(object):
|
||||
|
||||
|
@ -81,7 +84,7 @@ class RegisterForm(FlaskForm):
|
|||
|
||||
password_confirm = PasswordField('Password (confirm)')
|
||||
|
||||
if app.config['USE_RECAPTCHA']:
|
||||
if config['USE_RECAPTCHA']:
|
||||
recaptcha = RecaptchaField()
|
||||
|
||||
|
||||
|
@ -195,7 +198,7 @@ class UploadForm(FlaskForm):
|
|||
'%(max)d at most.')
|
||||
])
|
||||
|
||||
if app.config['USE_RECAPTCHA']:
|
||||
if config['USE_RECAPTCHA']:
|
||||
# Captcha only for not logged in users
|
||||
_recaptcha_validator = RecaptchaValidator()
|
||||
|
||||
|
|
|
@ -15,11 +15,12 @@ from sqlalchemy.ext import declarative
|
|||
from sqlalchemy_fulltext import FullText
|
||||
from sqlalchemy_utils import ChoiceType, EmailType, PasswordType
|
||||
|
||||
from nyaa import app
|
||||
from nyaa.extensions import db
|
||||
from nyaa.extensions import config, db
|
||||
from nyaa.torrents import create_magnet
|
||||
|
||||
if app.config['USE_MYSQL']:
|
||||
app = flask.current_app
|
||||
|
||||
if config['USE_MYSQL']:
|
||||
from sqlalchemy.dialects import mysql
|
||||
BinaryType = mysql.BINARY
|
||||
DescriptionTextType = mysql.TEXT
|
||||
|
@ -686,7 +687,7 @@ class SukebeiTorrent(TorrentBase, db.Model):
|
|||
|
||||
|
||||
# Fulltext models for MySQL
|
||||
if app.config['USE_MYSQL']:
|
||||
if config['USE_MYSQL']:
|
||||
class NyaaTorrentNameSearch(FullText, NyaaTorrent):
|
||||
__fulltext_columns__ = ('display_name',)
|
||||
__table_args__ = {'extend_existing': True}
|
||||
|
@ -785,7 +786,7 @@ class SukebeiReport(ReportBase, db.Model):
|
|||
|
||||
|
||||
# Choose our defaults for models.Torrent etc
|
||||
if app.config['SITE_FLAVOR'] == 'nyaa':
|
||||
if config['SITE_FLAVOR'] == 'nyaa':
|
||||
Torrent = NyaaTorrent
|
||||
TorrentFilelist = NyaaTorrentFilelist
|
||||
TorrentInfo = NyaaTorrentInfo
|
||||
|
@ -798,7 +799,7 @@ if app.config['SITE_FLAVOR'] == 'nyaa':
|
|||
Report = NyaaReport
|
||||
TorrentNameSearch = NyaaTorrentNameSearch
|
||||
|
||||
elif app.config['SITE_FLAVOR'] == 'sukebei':
|
||||
elif config['SITE_FLAVOR'] == 'sukebei':
|
||||
Torrent = SukebeiTorrent
|
||||
TorrentFilelist = SukebeiTorrentFilelist
|
||||
TorrentInfo = SukebeiTorrentInfo
|
||||
|
|
|
@ -1,9 +0,0 @@
|
|||
from nyaa import app, template_utils, views
|
||||
from nyaa.api_handler import api_blueprint
|
||||
|
||||
# Register all template filters and template globals
|
||||
app.register_blueprint(template_utils.bp)
|
||||
# Register the API routes
|
||||
app.register_blueprint(api_blueprint, url_prefix='/api')
|
||||
# Register the site's routes
|
||||
views.register(app)
|
|
@ -10,9 +10,11 @@ from elasticsearch import Elasticsearch
|
|||
from elasticsearch_dsl import Q, Search
|
||||
from sqlalchemy_fulltext import FullTextSearch
|
||||
|
||||
from nyaa import app, models
|
||||
from nyaa import models
|
||||
from nyaa.extensions import db
|
||||
|
||||
app = flask.current_app
|
||||
|
||||
DEFAULT_MAX_SEARCH_RESULT = 1000
|
||||
DEFAULT_PER_PAGE = 75
|
||||
SERACH_PAGINATE_DISPLAY_MSG = ('Displaying results {start}-{end} out of {total} results.<br>\n'
|
||||
|
|
|
@ -7,10 +7,10 @@ from urllib.parse import urlencode
|
|||
import flask
|
||||
from werkzeug.urls import url_encode
|
||||
|
||||
from nyaa import app
|
||||
from nyaa.backend import get_category_id_map
|
||||
from nyaa.torrents import get_default_trackers
|
||||
|
||||
app = flask.current_app
|
||||
bp = flask.Blueprint('template-utils', __name__)
|
||||
_static_cache = {} # For static_cachebuster
|
||||
|
||||
|
|
|
@ -3,9 +3,11 @@ import os
|
|||
import time
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from flask import current_app as app
|
||||
|
||||
from orderedset import OrderedSet
|
||||
|
||||
from nyaa import app, bencode
|
||||
from nyaa import bencode
|
||||
|
||||
USED_TRACKERS = OrderedSet()
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ from nyaa.views import ( # isort:skip
|
|||
)
|
||||
|
||||
|
||||
def register(flask_app):
|
||||
def register_views(flask_app):
|
||||
""" Register the blueprints using the flask_app object """
|
||||
flask_app.register_blueprint(account.bp)
|
||||
flask_app.register_blueprint(admin.bp)
|
||||
|
|
|
@ -6,10 +6,11 @@ from ipaddress import ip_address
|
|||
|
||||
import flask
|
||||
|
||||
from nyaa import app, forms, models
|
||||
from nyaa import forms, models
|
||||
from nyaa.extensions import db
|
||||
from nyaa.views.users import get_activation_link
|
||||
|
||||
app = flask.current_app
|
||||
bp = flask.Blueprint('account', __name__)
|
||||
|
||||
|
||||
|
|
|
@ -5,12 +5,13 @@ from datetime import datetime, timedelta
|
|||
import flask
|
||||
from flask_paginate import Pagination
|
||||
|
||||
from nyaa import app, models
|
||||
from nyaa import models
|
||||
from nyaa.search import (DEFAULT_MAX_SEARCH_RESULT, DEFAULT_PER_PAGE, SERACH_PAGINATE_DISPLAY_MSG,
|
||||
_generate_query_string, search_db, search_elastic)
|
||||
from nyaa.utils import chain_get
|
||||
from nyaa.views.account import logout
|
||||
|
||||
app = flask.current_app
|
||||
bp = flask.Blueprint('main', __name__)
|
||||
|
||||
|
||||
|
|
|
@ -7,10 +7,11 @@ from werkzeug.datastructures import CombinedMultiDict
|
|||
|
||||
from sqlalchemy.orm import joinedload
|
||||
|
||||
from nyaa import app, backend, forms, models, torrents
|
||||
from nyaa import backend, forms, models, torrents
|
||||
from nyaa.extensions import db
|
||||
from nyaa.utils import cached_function
|
||||
|
||||
app = flask.current_app
|
||||
bp = flask.Blueprint('torrents', __name__)
|
||||
|
||||
|
||||
|
|
|
@ -5,12 +5,13 @@ from flask_paginate import Pagination
|
|||
|
||||
from itsdangerous import BadSignature, URLSafeSerializer
|
||||
|
||||
from nyaa import app, forms, models
|
||||
from nyaa import forms, models
|
||||
from nyaa.extensions import db
|
||||
from nyaa.search import (DEFAULT_MAX_SEARCH_RESULT, DEFAULT_PER_PAGE, SERACH_PAGINATE_DISPLAY_MSG,
|
||||
_generate_query_string, search_db, search_elastic)
|
||||
from nyaa.utils import chain_get
|
||||
|
||||
app = flask.current_app
|
||||
bp = flask.Blueprint('users', __name__)
|
||||
|
||||
|
||||
|
|
4
run.py
4
run.py
|
@ -1,3 +1,5 @@
|
|||
#!/usr/bin/env python3
|
||||
from nyaa import app
|
||||
from nyaa import create_app
|
||||
|
||||
app = create_app('config')
|
||||
app.run(host='0.0.0.0', port=5500, debug=True)
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
import os
|
||||
import unittest
|
||||
|
||||
from nyaa import app
|
||||
from nyaa import create_app
|
||||
|
||||
USE_MYSQL = True
|
||||
|
||||
|
@ -12,6 +12,7 @@ class NyaaTestCase(unittest.TestCase):
|
|||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
app = create_app('config')
|
||||
app.config['TESTING'] = True
|
||||
cls.app_context = app.app_context()
|
||||
|
||||
|
|
|
@ -6,16 +6,18 @@ import nyaa
|
|||
|
||||
class NyaaTestCase(unittest.TestCase):
|
||||
|
||||
nyaa_app = nyaa.create_app('config')
|
||||
|
||||
def setUp(self):
|
||||
self.db, nyaa.app.config['DATABASE'] = tempfile.mkstemp()
|
||||
nyaa.app.config['TESTING'] = True
|
||||
self.app = nyaa.app.test_client()
|
||||
with nyaa.app.app_context():
|
||||
self.db, self.nyaa_app.config['DATABASE'] = tempfile.mkstemp()
|
||||
self.nyaa_app.config['TESTING'] = True
|
||||
self.app = self.nyaa_app.test_client()
|
||||
with self.nyaa_app.app_context():
|
||||
nyaa.db.create_all()
|
||||
|
||||
def tearDown(self):
|
||||
os.close(self.db)
|
||||
os.unlink(nyaa.app.config['DATABASE'])
|
||||
os.unlink(self.nyaa_app.config['DATABASE'])
|
||||
|
||||
def test_index_url(self):
|
||||
rv = self.app.get('/')
|
||||
|
|
Loading…
Reference in a new issue