Commit 52088fba authored by Zach Perkins's avatar Zach Perkins
Browse files

Added scoped session so we don't have to decorate every view function

parent c21f4dda
from flask import Flask, redirect, jsonify, abort, request, url_for, make_response from flask import Flask, redirect, jsonify, abort, request, url_for, make_response, g
from webargs.flaskparser import use_args from webargs.flaskparser import use_args
from webargs import fields from webargs import fields
from where.model import with_session, Point, Category, Field from where.model import Session, Point, Category, Field
from where.model.field_types import FieldType from where.model.field_types import FieldType
from where.validation import PointSchema, CategorySchema, FieldSchema from where.validation import PointSchema, CategorySchema, FieldSchema
app = Flask(__name__) app = Flask(__name__)
# Endpoints:
@app.before_request
def create_local_db_session():
g.db_session = Session()
@app.after_request
def destroy_local_db_session(resp):
try:
g.db_session.commit()
except BaseException:
g.db_session.rollback()
raise
finally:
Session.remove()
return resp
@app.route('/') @app.route('/')
...@@ -26,31 +41,30 @@ def index(): ...@@ -26,31 +41,30 @@ def index():
@app.route('/test_data') @app.route('/test_data')
@with_session def test_data():
def test_data(session): g.db_session.query(Point).delete()
session.query(Point).delete() g.db_session.query(Field).delete()
session.query(Field).delete() g.db_session.query(Category).delete()
session.query(Category).delete()
# Water Fountain, the class. # Water Fountain, the class.
wf = Category() wf = Category()
wf.name = "Water Fountain" wf.name = "Water Fountain"
wf.icon = "https://karel.pw/water.png" wf.icon = "https://karel.pw/water.png"
session.add(wf) g.db_session.add(wf)
session.commit() g.db_session.commit()
# Building # Building
bd = Category() bd = Category()
bd.name = "Building" bd.name = "Building"
bd.icon = "https://hips.hearstapps.com/hmg-prod.s3.amazonaws.com/images/basket-building-news-photo-1572015168.jpg?resize=980:*" bd.icon = "https://hips.hearstapps.com/hmg-prod.s3.amazonaws.com/images/basket-building-news-photo-1572015168.jpg?resize=980:*"
session.add(bd) g.db_session.add(bd)
session.commit() g.db_session.commit()
# Radius (Really the simplest metric we can have for building size) # Radius (Really the simplest metric we can have for building size)
rd = Field() rd = Field()
rd.name = "Radius" rd.name = "Radius"
rd.slug = "radius" rd.slug = "radius"
rd.type = FieldType.FLOAT rd.type = FieldType.FLOAT
rd.category_id = bd.id rd.category_id = bd.id
session.add(rd) g.db_session.add(rd)
session.commit() g.db_session.commit()
# coldness # coldness
cd = Field() cd = Field()
...@@ -64,9 +78,9 @@ def test_data(session): ...@@ -64,9 +78,9 @@ def test_data(session):
fl.name = "Has Bottle Filler" fl.name = "Has Bottle Filler"
fl.type = FieldType.BOOLEAN fl.type = FieldType.BOOLEAN
fl.category_id = wf.id fl.category_id = wf.id
session.add(cd) g.db_session.add(cd)
session.add(fl) g.db_session.add(fl)
session.commit() g.db_session.commit()
# The johnson center # The johnson center
jc = Point() jc = Point()
...@@ -80,7 +94,7 @@ def test_data(session): ...@@ -80,7 +94,7 @@ def test_data(session):
"value": 2.0 "value": 2.0
} }
} }
session.add(jc) g.db_session.add(jc)
# A water fountain inside the JC # A water fountain inside the JC
fn = Point() fn = Point()
...@@ -100,153 +114,139 @@ def test_data(session): ...@@ -100,153 +114,139 @@ def test_data(session):
} }
} }
session.add(fn) g.db_session.add(fn)
session.commit() g.db_session.commit()
return redirect('/') return redirect('/')
@app.route('/category/<int:id>') @app.route('/category/<int:id>')
@with_session def get_category(id):
def get_category(session, id): return get_resource(Category, id)
return get_resource(session, Category, id)
@app.route('/category/<int:id>/children') @app.route('/category/<int:id>/children')
@with_session def get_category_children(data, id):
def get_category_children(data, session, id):
data = dict(request.args) data = dict(request.args)
data['parent_id'] = id data['parent_id'] = id
return search_resource(session, Point, data) return search_resource(Point, data)
@app.route('/point', methods=['GET']) @app.route('/point', methods=['GET'])
@use_args({'parent_id': fields.Int(), 'category_id': fields.Int(required=True)}) @use_args({'parent_id': fields.Int(), 'category_id': fields.Int(required=True)})
@with_session def search_point(args):
def search_point(session, args): return search_resource(Point, args)
return search_resource(session, Point, args)
@app.route('/point', methods=['POST']) @app.route('/point', methods=['POST'])
@use_args(PointSchema) @use_args(PointSchema)
@with_session def create_point(args):
def create_point(session, args): args['category'] = g.db_session.query(Category).get(args.pop('category_id'))
args['category'] = session.query(Category).get(args.pop('category_id')) return create_resource(Point, args, 'get_point')
return create_resource(session, Point, args, 'get_point')
@app.route('/point/<int:id>', methods=['GET']) @app.route('/point/<int:id>', methods=['GET'])
@with_session def get_point(id):
def get_point(session, id): return get_resource(Point, id)
return get_resource(session, Point, id)
@app.route('/point/<int:id>', methods=['DELETE']) @app.route('/point/<int:id>', methods=['DELETE'])
@with_session def del_point(id):
def del_point(session, id): return delete_resource(Point, id)
return delete_resource(session, Point, id)
@app.route('/point/<int:id>', methods=['PUT']) @app.route('/point/<int:id>', methods=['PUT'])
@with_session def edit_point(id):
def edit_point(session, id): return edit_resource(Point, id, request.get_json())
return edit_resource(session, Point, id, request.get_json())
@app.route('/point/<int:id>/children', methods=['GET']) @app.route('/point/<int:id>/children', methods=['GET'])
@with_session def get_point_children(id):
def get_point_children(session, id):
data = dict(request.args) data = dict(request.args)
data['parent_id'] = id data['parent_id'] = id
return search_resource(session, Point, data) return search_resource(Point, data)
# Helper functions: # Helper functions:
# TODO: Add helper functions for data validation
def create_resource(session, model_cls, data, get_function): def create_resource(model_cls, data, get_function):
''' '''
Create the resource specified by the given model class and initialized with the data Create the resource specified by the given model class and initialized with the data
dict, returning an appropriate JSON response. dict, returning an appropriate JSON response.
:param session: The sqlalchemy session
:param model_cls: The class of the model for this resource :param model_cls: The class of the model for this resource
:param data: The initial data for this resource stored as a dictionary :param data: The initial data for this resource stored as a dictionary
:param get_function: The name of the view function (as a string) that gets a single instance of this resource. This is used for the response Location header. :param get_function: The name of the view function (as a string) that gets a single instance of this resource. This is used for the response Location header.
:return: a Flask Response object :return: a Flask Response object
''' '''
resource = model_cls(**data) resource = model_cls(**data)
session.add(resource) g.db_session.add(resource)
session.commit() g.db_session.commit()
response = make_response(jsonify(resource.as_json()), 201) response = make_response(jsonify(resource.as_json()), 201)
response.headers['Location'] = url_for(get_function, id=resource.id) response.headers['Location'] = url_for(get_function, id=resource.id)
return response return response
def get_resource(session, model_cls, id): def get_resource(model_cls, id):
''' '''
Get a single resource of the specified model class by its ID. Get a single resource of the specified model class by its ID.
:param session: The sqlalchemy session
:param model_cls: The class of the model for this resource :param model_cls: The class of the model for this resource
:param id: The id of this resource :param id: The id of this resource
:return: a Flask Response object :return: a Flask Response object
''' '''
resource = session.query(model_cls).get(id) resource = g.db_session.query(model_cls).get(id)
resp = (None, 404) if resource is None else \ resp = (None, 404) if resource is None else \
(resource.as_json(), 200) (resource.as_json(), 200)
return make_response(jsonify(resp[0]), resp[1]) return make_response(jsonify(resp[0]), resp[1])
def edit_resource(session, model_cls, id, data): def edit_resource(model_cls, id, data):
''' '''
Modify the resource of the specified model class and id with the data from Modify the resource of the specified model class and id with the data from
data. Does not perform data validation. data. Does not perform data validation.
:param session: The sqlalchemy session
:param model_cls: The class of the model for this resource :param model_cls: The class of the model for this resource
:param id: The id of this resource :param id: The id of this resource
:param data: The new data for this resource stored as a dictionary :param data: The new data for this resource stored as a dictionary
:return: a Flask Response object :return: a Flask Response object
''' '''
resource = session.query(model_cls).get(id) resource = g.db_session.query(model_cls).get(id)
for attr in data: for attr in data:
setattr(resource, attr, data[attr]) setattr(resource, attr, data[attr])
session.commit() g.db_session.commit()
return make_response(jsonify(resource.as_json()), 200) return make_response(jsonify(resource.as_json()), 200)
def delete_resource(session, model_cls, id): def delete_resource(model_cls, id):
''' '''
Delete the resource of the specified model class and id and return the Delete the resource of the specified model class and id and return the
appropriate response. appropriate response.
:param session: The sqlalchemy session
:param model_cls: The class of the model for this resource :param model_cls: The class of the model for this resource
:param id: The id of this resource :param id: The id of this resource
:return: a Flask Response object :return: a Flask Response object
''' '''
resource = session.query(model_cls).get(id) resource = g.db_session.query(model_cls).get(id)
session.delete(resource) g.db_session.delete(resource)
session.commit() g.db_session.commit()
return make_response('', 204) return make_response('', 204)
def search_resource(session, model_cls, data): def search_resource(model_cls, data):
''' '''
Search the database for a list of instances of the specified model class Search the database for a list of instances of the specified model class
that have the attributes given in data and return the appropriate JSON that have the attributes given in data and return the appropriate JSON
response. Does not perform validation on search parameters. response. Does not perform validation on search parameters.
:param session: The sqlalchemy session
:param model_cls: The class of the model for this resource :param model_cls: The class of the model for this resource
:param data: A dictionary containing search parameters :param data: A dictionary containing search parameters
:return: a Flask Response object :return: a Flask Response object
''' '''
query = session.query(model_cls).filter_by(**data) query = g.db_session.query(model_cls).filter_by(**data)
resp = (None, 404) if query.first() is None else \ resp = (None, 404) if query.first() is None else \
(list(map(lambda m: m.as_json(), query.limit(100).all())), 200) (list(map(lambda m: m.as_json(), query.limit(100).all())), 200)
......
...@@ -8,36 +8,6 @@ from .field_types import FieldType ...@@ -8,36 +8,6 @@ from .field_types import FieldType
from .meta import Session, engine from .meta import Session, engine
@contextmanager
def session_context():
session = Session()
try:
yield session
session.commit()
except BaseException:
session.rollback()
raise
finally:
session.close()
def with_session(func):
"""
Decorator for convenience when building endpoints. The first argument to the
decorated function will be a safe-to-use, autocommitting Session instance.
:param func: the view function to wrap
:return: the wrapped function
"""
def wrapper(*args, **kwargs):
with session_context() as session:
return func(session, *args, **kwargs)
# Flask identifies endpoint handlers based on their name
wrapper.__name__ = func.__name__
return wrapper
@as_declarative() @as_declarative()
class Base(object): class Base(object):
pass pass
......
from sqlalchemy import create_engine from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker, scoped_session
engine = create_engine('sqlite:///db.sqlite3') # TODO configurable engine = create_engine('sqlite:///db.sqlite3') # TODO configurable
Session = sessionmaker(bind=engine) # Session is scoped to the current thread (which, in Flask, is the current request)
Session = scoped_session(sessionmaker(bind=engine))
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment