diff --git a/WSGI.py b/WSGI.py
index 0208c97..6e080f8 100644
--- a/WSGI.py
+++ b/WSGI.py
@@ -3,13 +3,15 @@
import gevent.monkey
gevent.monkey.patch_all()
-from nyaa import app
+from nyaa import create_app
-if app.config["DEBUG"]:
- from werkzeug.debug import DebuggedApplication
- app.wsgi_app = DebuggedApplication(app.wsgi_app, True)
+app = create_app('config')
+
+if app.config['DEBUG']:
+ from werkzeug.debug import DebuggedApplication
+ app.wsgi_app = DebuggedApplication(app.wsgi_app, True)
if __name__ == '__main__':
- import gevent.pywsgi
- gevent_server = gevent.pywsgi.WSGIServer(("localhost", 5000), app.wsgi_app)
- gevent_server.serve_forever()
+ import gevent.pywsgi
+ gevent_server = gevent.pywsgi.WSGIServer(("localhost", 5000), app.wsgi_app)
+ gevent_server.serve_forever()
diff --git a/db_create.py b/db_create.py
index 2b7f7ee..30fe4fe 100755
--- a/db_create.py
+++ b/db_create.py
@@ -1,9 +1,11 @@
#!/usr/bin/env python3
import sqlalchemy
-from nyaa import app, models
+from nyaa import create_app, models
from nyaa.extensions import db
+app = create_app('config')
+
NYAA_CATEGORIES = [
('Anime', ['Anime Music Video', 'English-translated', 'Non-English-translated', 'Raw']),
('Audio', ['Lossless', 'Lossy']),
diff --git a/db_migrate.py b/db_migrate.py
index 127f377..92789f6 100755
--- a/db_migrate.py
+++ b/db_migrate.py
@@ -5,9 +5,10 @@ import sys
from flask_script import Manager
from flask_migrate import Migrate, MigrateCommand
-from nyaa import app
+from nyaa import create_app
from nyaa.extensions import db
+app = create_app('config')
migrate = Migrate(app, db)
manager = Manager(app)
diff --git a/import_to_es.py b/import_to_es.py
index f74420a..18c6b31 100755
--- a/import_to_es.py
+++ b/import_to_es.py
@@ -14,9 +14,10 @@ from elasticsearch import Elasticsearch
from elasticsearch.client import IndicesClient
from elasticsearch import helpers
-from nyaa import app, models
+from nyaa import create_app, models
from nyaa.extensions import db
+app = create_app('config')
es = Elasticsearch(timeout=30)
ic = IndicesClient(es)
diff --git a/nyaa/__init__.py b/nyaa/__init__.py
index f9f0a03..492c29b 100644
--- a/nyaa/__init__.py
+++ b/nyaa/__init__.py
@@ -4,69 +4,81 @@ import os
import flask
from flask_assets import Bundle # noqa F401
+from nyaa.api_handler import api_blueprint
from nyaa.extensions import assets, db, fix_paginate, toolbar
+from nyaa.template_utils import bp as template_utils_bp
+from nyaa.views import register_views
-app = flask.Flask(__name__)
-app.config.from_object('config')
-# Don't refresh cookie each request
-app.config['SESSION_REFRESH_EACH_REQUEST'] = False
+def create_app(config):
+ """ Nyaa app factory """
+ app = flask.Flask(__name__)
+ app.config.from_object(config)
-# Debugging
-if app.config['DEBUG']:
- app.config['DEBUG_TB_INTERCEPT_REDIRECTS'] = False
- toolbar.init_app(app)
- app.logger.setLevel(logging.DEBUG)
+ # Don't refresh cookie each request
+ app.config['SESSION_REFRESH_EACH_REQUEST'] = False
- # Forbid caching
- @app.after_request
- def forbid_cache(request):
- request.headers['Cache-Control'] = 'no-cache, no-store, must-revalidate, max-age=0'
- request.headers['Pragma'] = 'no-cache'
- request.headers['Expires'] = '0'
- return request
+ # Debugging
+ if app.config['DEBUG']:
+ app.config['DEBUG_TB_INTERCEPT_REDIRECTS'] = False
+ toolbar.init_app(app)
+ app.logger.setLevel(logging.DEBUG)
-else:
- app.logger.setLevel(logging.WARNING)
+ # Forbid caching
+ @app.after_request
+ def forbid_cache(request):
+ request.headers['Cache-Control'] = 'no-cache, no-store, must-revalidate, max-age=0'
+ request.headers['Pragma'] = 'no-cache'
+ request.headers['Expires'] = '0'
+ return request
-# Logging
-if 'LOG_FILE' in app.config:
- from logging.handlers import RotatingFileHandler
- app.log_handler = RotatingFileHandler(
- app.config['LOG_FILE'], maxBytes=10000, backupCount=1)
- app.logger.addHandler(app.log_handler)
+ else:
+ app.logger.setLevel(logging.WARNING)
-# Log errors and display a message to the user in production mdode
-if not app.config['DEBUG']:
- @app.errorhandler(500)
- def internal_error(exception):
- app.logger.error(exception)
- flask.flash(flask.Markup(
- 'An error occurred! Debug information has been logged.'), 'danger')
- return flask.redirect('/')
+ # Logging
+ if 'LOG_FILE' in app.config:
+ from logging.handlers import RotatingFileHandler
+ app.log_handler = RotatingFileHandler(
+ app.config['LOG_FILE'], maxBytes=10000, backupCount=1)
+ app.logger.addHandler(app.log_handler)
-# Get git commit hash
-app.config['COMMIT_HASH'] = None
-master_head = os.path.abspath(os.path.join(os.path.dirname(__file__), '../.git/refs/heads/master'))
-if os.path.isfile(master_head):
- with open(master_head, 'r') as head:
- app.config['COMMIT_HASH'] = head.readline().strip()
+ # Log errors and display a message to the user in production mdode
+ if not app.config['DEBUG']:
+ @app.errorhandler(500)
+ def internal_error(exception):
+ app.logger.error(exception)
+ flask.flash(flask.Markup(
+ 'An error occurred! Debug information has been logged.'), 'danger')
+ return flask.redirect('/')
-# Enable the jinja2 do extension.
-app.jinja_env.add_extension('jinja2.ext.do')
-app.jinja_env.lstrip_blocks = True
-app.jinja_env.trim_blocks = True
+ # Get git commit hash
+ app.config['COMMIT_HASH'] = None
+ master_head = os.path.abspath(os.path.join(os.path.dirname(__file__),
+ '../.git/refs/heads/master'))
+ if os.path.isfile(master_head):
+ with open(master_head, 'r') as head:
+ app.config['COMMIT_HASH'] = head.readline().strip()
-# Database
-fix_paginate() # This has to be before the database is initialized
-app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False
-app.config['MYSQL_DATABASE_CHARSET'] = 'utf8mb4'
-db.init_app(app)
+ # Enable the jinja2 do extension.
+ app.jinja_env.add_extension('jinja2.ext.do')
+ app.jinja_env.lstrip_blocks = True
+ app.jinja_env.trim_blocks = True
-assets.init_app(app)
+ # Database
+ fix_paginate() # This has to be before the database is initialized
+ app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False
+ app.config['MYSQL_DATABASE_CHARSET'] = 'utf8mb4'
+ db.init_app(app)
-# css = Bundle('style.scss', filters='libsass',
-# output='style.css', depends='**/*.scss')
-# assets.register('style_all', css)
+ # Assets
+ assets.init_app(app)
+ # css = Bundle('style.scss', filters='libsass',
+ # output='style.css', depends='**/*.scss')
+ # assets.register('style_all', css)
-from nyaa import routes # noqa E402 isort:skip
+ # Blueprints
+ app.register_blueprint(template_utils_bp)
+ app.register_blueprint(api_blueprint, url_prefix='/api')
+ register_views(app)
+
+ return app
diff --git a/nyaa/backend.py b/nyaa/backend.py
index fc8f993..0c20ecf 100644
--- a/nyaa/backend.py
+++ b/nyaa/backend.py
@@ -7,9 +7,11 @@ from werkzeug import secure_filename
from orderedset import OrderedSet
-from nyaa import app, models, utils
+from nyaa import models, utils
from nyaa.extensions import db
+app = flask.current_app
+
@utils.cached_function
def get_category_id_map():
diff --git a/nyaa/extensions.py b/nyaa/extensions.py
index f1c509d..7abc26e 100644
--- a/nyaa/extensions.py
+++ b/nyaa/extensions.py
@@ -1,4 +1,7 @@
+import os.path
+
from flask import abort
+from flask.config import Config
from flask_assets import Environment
from flask_debugtoolbar import DebugToolbarExtension
from flask_sqlalchemy import BaseQuery, Pagination, SQLAlchemy
@@ -32,3 +35,16 @@ def fix_paginate():
return Pagination(self, page, per_page, total_query_count, items)
BaseQuery.paginate_faste = paginate_faste
+
+
+def _get_config():
+ # Workaround to get an available config object before the app is initiallized
+ # Only needed/used in top-level and class statements
+ # https://stackoverflow.com/a/18138250/7597273
+ root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
+ config = Config(root_path)
+ config.from_object('config')
+ return config
+
+
+config = _get_config()
diff --git a/nyaa/forms.py b/nyaa/forms.py
index 37e4909..528f303 100644
--- a/nyaa/forms.py
+++ b/nyaa/forms.py
@@ -14,9 +14,12 @@ from wtforms.validators import (DataRequired, Email, EqualTo, Length, Optional,
from wtforms.widgets import Select as SelectWidget # For DisabledSelectField
from wtforms.widgets import HTMLString, html_params # For DisabledSelectField
-from nyaa import app, bencode, models, utils
+from nyaa import bencode, models, utils
+from nyaa.extensions import config
from nyaa.models import User
+app = flask.current_app
+
class Unique(object):
@@ -81,7 +84,7 @@ class RegisterForm(FlaskForm):
password_confirm = PasswordField('Password (confirm)')
- if app.config['USE_RECAPTCHA']:
+ if config['USE_RECAPTCHA']:
recaptcha = RecaptchaField()
@@ -195,7 +198,7 @@ class UploadForm(FlaskForm):
'%(max)d at most.')
])
- if app.config['USE_RECAPTCHA']:
+ if config['USE_RECAPTCHA']:
# Captcha only for not logged in users
_recaptcha_validator = RecaptchaValidator()
diff --git a/nyaa/models.py b/nyaa/models.py
index 4b236f6..1f26cbc 100644
--- a/nyaa/models.py
+++ b/nyaa/models.py
@@ -15,11 +15,12 @@ from sqlalchemy.ext import declarative
from sqlalchemy_fulltext import FullText
from sqlalchemy_utils import ChoiceType, EmailType, PasswordType
-from nyaa import app
-from nyaa.extensions import db
+from nyaa.extensions import config, db
from nyaa.torrents import create_magnet
-if app.config['USE_MYSQL']:
+app = flask.current_app
+
+if config['USE_MYSQL']:
from sqlalchemy.dialects import mysql
BinaryType = mysql.BINARY
DescriptionTextType = mysql.TEXT
@@ -686,7 +687,7 @@ class SukebeiTorrent(TorrentBase, db.Model):
# Fulltext models for MySQL
-if app.config['USE_MYSQL']:
+if config['USE_MYSQL']:
class NyaaTorrentNameSearch(FullText, NyaaTorrent):
__fulltext_columns__ = ('display_name',)
__table_args__ = {'extend_existing': True}
@@ -785,7 +786,7 @@ class SukebeiReport(ReportBase, db.Model):
# Choose our defaults for models.Torrent etc
-if app.config['SITE_FLAVOR'] == 'nyaa':
+if config['SITE_FLAVOR'] == 'nyaa':
Torrent = NyaaTorrent
TorrentFilelist = NyaaTorrentFilelist
TorrentInfo = NyaaTorrentInfo
@@ -798,7 +799,7 @@ if app.config['SITE_FLAVOR'] == 'nyaa':
Report = NyaaReport
TorrentNameSearch = NyaaTorrentNameSearch
-elif app.config['SITE_FLAVOR'] == 'sukebei':
+elif config['SITE_FLAVOR'] == 'sukebei':
Torrent = SukebeiTorrent
TorrentFilelist = SukebeiTorrentFilelist
TorrentInfo = SukebeiTorrentInfo
diff --git a/nyaa/routes.py b/nyaa/routes.py
deleted file mode 100644
index a67ac28..0000000
--- a/nyaa/routes.py
+++ /dev/null
@@ -1,9 +0,0 @@
-from nyaa import app, template_utils, views
-from nyaa.api_handler import api_blueprint
-
-# Register all template filters and template globals
-app.register_blueprint(template_utils.bp)
-# Register the API routes
-app.register_blueprint(api_blueprint, url_prefix='/api')
-# Register the site's routes
-views.register(app)
diff --git a/nyaa/search.py b/nyaa/search.py
index 1f20a36..5661973 100644
--- a/nyaa/search.py
+++ b/nyaa/search.py
@@ -10,9 +10,11 @@ from elasticsearch import Elasticsearch
from elasticsearch_dsl import Q, Search
from sqlalchemy_fulltext import FullTextSearch
-from nyaa import app, models
+from nyaa import models
from nyaa.extensions import db
+app = flask.current_app
+
DEFAULT_MAX_SEARCH_RESULT = 1000
DEFAULT_PER_PAGE = 75
SERACH_PAGINATE_DISPLAY_MSG = ('Displaying results {start}-{end} out of {total} results.
\n'
diff --git a/nyaa/template_utils.py b/nyaa/template_utils.py
index 91b7b17..54c4de8 100644
--- a/nyaa/template_utils.py
+++ b/nyaa/template_utils.py
@@ -7,10 +7,10 @@ from urllib.parse import urlencode
import flask
from werkzeug.urls import url_encode
-from nyaa import app
from nyaa.backend import get_category_id_map
from nyaa.torrents import get_default_trackers
+app = flask.current_app
bp = flask.Blueprint('template-utils', __name__)
_static_cache = {} # For static_cachebuster
diff --git a/nyaa/torrents.py b/nyaa/torrents.py
index 17862bc..a088c6c 100644
--- a/nyaa/torrents.py
+++ b/nyaa/torrents.py
@@ -3,9 +3,11 @@ import os
import time
from urllib.parse import urlencode
+from flask import current_app as app
+
from orderedset import OrderedSet
-from nyaa import app, bencode
+from nyaa import bencode
USED_TRACKERS = OrderedSet()
diff --git a/nyaa/views/__init__.py b/nyaa/views/__init__.py
index 03f4945..ae58c99 100644
--- a/nyaa/views/__init__.py
+++ b/nyaa/views/__init__.py
@@ -8,7 +8,7 @@ from nyaa.views import ( # isort:skip
)
-def register(flask_app):
+def register_views(flask_app):
""" Register the blueprints using the flask_app object """
flask_app.register_blueprint(account.bp)
flask_app.register_blueprint(admin.bp)
diff --git a/nyaa/views/account.py b/nyaa/views/account.py
index a1f1835..ce7922c 100644
--- a/nyaa/views/account.py
+++ b/nyaa/views/account.py
@@ -6,10 +6,11 @@ from ipaddress import ip_address
import flask
-from nyaa import app, forms, models
+from nyaa import forms, models
from nyaa.extensions import db
from nyaa.views.users import get_activation_link
+app = flask.current_app
bp = flask.Blueprint('account', __name__)
diff --git a/nyaa/views/main.py b/nyaa/views/main.py
index f687ef5..f9c54e0 100644
--- a/nyaa/views/main.py
+++ b/nyaa/views/main.py
@@ -5,12 +5,13 @@ from datetime import datetime, timedelta
import flask
from flask_paginate import Pagination
-from nyaa import app, models
+from nyaa import models
from nyaa.search import (DEFAULT_MAX_SEARCH_RESULT, DEFAULT_PER_PAGE, SERACH_PAGINATE_DISPLAY_MSG,
_generate_query_string, search_db, search_elastic)
from nyaa.utils import chain_get
from nyaa.views.account import logout
+app = flask.current_app
bp = flask.Blueprint('main', __name__)
diff --git a/nyaa/views/torrents.py b/nyaa/views/torrents.py
index 3e93909..227838a 100644
--- a/nyaa/views/torrents.py
+++ b/nyaa/views/torrents.py
@@ -7,10 +7,11 @@ from werkzeug.datastructures import CombinedMultiDict
from sqlalchemy.orm import joinedload
-from nyaa import app, backend, forms, models, torrents
+from nyaa import backend, forms, models, torrents
from nyaa.extensions import db
from nyaa.utils import cached_function
+app = flask.current_app
bp = flask.Blueprint('torrents', __name__)
diff --git a/nyaa/views/users.py b/nyaa/views/users.py
index abeb30c..22c1d9e 100644
--- a/nyaa/views/users.py
+++ b/nyaa/views/users.py
@@ -5,12 +5,13 @@ from flask_paginate import Pagination
from itsdangerous import BadSignature, URLSafeSerializer
-from nyaa import app, forms, models
+from nyaa import forms, models
from nyaa.extensions import db
from nyaa.search import (DEFAULT_MAX_SEARCH_RESULT, DEFAULT_PER_PAGE, SERACH_PAGINATE_DISPLAY_MSG,
_generate_query_string, search_db, search_elastic)
from nyaa.utils import chain_get
+app = flask.current_app
bp = flask.Blueprint('users', __name__)
diff --git a/run.py b/run.py
index e0d9fa5..1b6472d 100755
--- a/run.py
+++ b/run.py
@@ -1,3 +1,5 @@
#!/usr/bin/env python3
-from nyaa import app
+from nyaa import create_app
+
+app = create_app('config')
app.run(host='0.0.0.0', port=5500, debug=True)
diff --git a/tests/__init__.py b/tests/__init__.py
index 27f5185..4663455 100644
--- a/tests/__init__.py
+++ b/tests/__init__.py
@@ -3,7 +3,7 @@
import os
import unittest
-from nyaa import app
+from nyaa import create_app
USE_MYSQL = True
@@ -12,6 +12,7 @@ class NyaaTestCase(unittest.TestCase):
@classmethod
def setUpClass(cls):
+ app = create_app('config')
app.config['TESTING'] = True
cls.app_context = app.app_context()
diff --git a/tests/test_nyaa.py b/tests/test_nyaa.py
index 9a81330..e0f87ad 100644
--- a/tests/test_nyaa.py
+++ b/tests/test_nyaa.py
@@ -6,16 +6,18 @@ import nyaa
class NyaaTestCase(unittest.TestCase):
+ nyaa_app = nyaa.create_app('config')
+
def setUp(self):
- self.db, nyaa.app.config['DATABASE'] = tempfile.mkstemp()
- nyaa.app.config['TESTING'] = True
- self.app = nyaa.app.test_client()
- with nyaa.app.app_context():
+ self.db, self.nyaa_app.config['DATABASE'] = tempfile.mkstemp()
+ self.nyaa_app.config['TESTING'] = True
+ self.app = self.nyaa_app.test_client()
+ with self.nyaa_app.app_context():
nyaa.db.create_all()
def tearDown(self):
os.close(self.db)
- os.unlink(nyaa.app.config['DATABASE'])
+ os.unlink(self.nyaa_app.config['DATABASE'])
def test_index_url(self):
rv = self.app.get('/')