Commit 52088fba authored by Zach Perkins's avatar Zach Perkins

Added scoped session so we don't have to decorate every view function

parent c21f4dda
from flask import Flask, redirect, jsonify, abort, request, url_for, make_response
from flask import Flask, redirect, jsonify, abort, request, url_for, make_response, g
from webargs.flaskparser import use_args
from webargs import fields
from where.model import with_session, Point, Category, Field
from where.model import Session, Point, Category, Field
from where.model.field_types import FieldType
from where.validation import PointSchema, CategorySchema, FieldSchema
app = Flask(__name__)
# Endpoints:
@app.before_request
def create_local_db_session():
g.db_session = Session()
@app.after_request
def destroy_local_db_session(resp):
try:
g.db_session.commit()
except BaseException:
g.db_session.rollback()
raise
finally:
Session.remove()
return resp
@app.route('/')
......@@ -26,31 +41,30 @@ def index():
@app.route('/test_data')
@with_session
def test_data(session):
session.query(Point).delete()
session.query(Field).delete()
session.query(Category).delete()
def test_data():
g.db_session.query(Point).delete()
g.db_session.query(Field).delete()
g.db_session.query(Category).delete()
# Water Fountain, the class.
wf = Category()
wf.name = "Water Fountain"
wf.icon = "https://karel.pw/water.png"
session.add(wf)
session.commit()
g.db_session.add(wf)
g.db_session.commit()
# Building
bd = Category()
bd.name = "Building"
bd.icon = "https://hips.hearstapps.com/hmg-prod.s3.amazonaws.com/images/basket-building-news-photo-1572015168.jpg?resize=980:*"
session.add(bd)
session.commit()
g.db_session.add(bd)
g.db_session.commit()
# Radius (Really the simplest metric we can have for building size)
rd = Field()
rd.name = "Radius"
rd.slug = "radius"
rd.type = FieldType.FLOAT
rd.category_id = bd.id
session.add(rd)
session.commit()
g.db_session.add(rd)
g.db_session.commit()
# coldness
cd = Field()
......@@ -64,9 +78,9 @@ def test_data(session):
fl.name = "Has Bottle Filler"
fl.type = FieldType.BOOLEAN
fl.category_id = wf.id
session.add(cd)
session.add(fl)
session.commit()
g.db_session.add(cd)
g.db_session.add(fl)
g.db_session.commit()
# The johnson center
jc = Point()
......@@ -80,7 +94,7 @@ def test_data(session):
"value": 2.0
}
}
session.add(jc)
g.db_session.add(jc)
# A water fountain inside the JC
fn = Point()
......@@ -100,153 +114,139 @@ def test_data(session):
}
}
session.add(fn)
session.commit()
g.db_session.add(fn)
g.db_session.commit()
return redirect('/')
@app.route('/category/<int:id>')
@with_session
def get_category(session, id):
return get_resource(session, Category, id)
def get_category(id):
return get_resource(Category, id)
@app.route('/category/<int:id>/children')
@with_session
def get_category_children(data, session, id):
def get_category_children(data, id):
data = dict(request.args)
data['parent_id'] = id
return search_resource(session, Point, data)
return search_resource(Point, data)
@app.route('/point', methods=['GET'])
@use_args({'parent_id': fields.Int(), 'category_id': fields.Int(required=True)})
@with_session
def search_point(session, args):
return search_resource(session, Point, args)
def search_point(args):
return search_resource(Point, args)
@app.route('/point', methods=['POST'])
@use_args(PointSchema)
@with_session
def create_point(session, args):
args['category'] = session.query(Category).get(args.pop('category_id'))
return create_resource(session, Point, args, 'get_point')
def create_point(args):
args['category'] = g.db_session.query(Category).get(args.pop('category_id'))
return create_resource(Point, args, 'get_point')
@app.route('/point/<int:id>', methods=['GET'])
@with_session
def get_point(session, id):
return get_resource(session, Point, id)
def get_point(id):
return get_resource(Point, id)
@app.route('/point/<int:id>', methods=['DELETE'])
@with_session
def del_point(session, id):
return delete_resource(session, Point, id)
def del_point(id):
return delete_resource(Point, id)
@app.route('/point/<int:id>', methods=['PUT'])
@with_session
def edit_point(session, id):
return edit_resource(session, Point, id, request.get_json())
def edit_point(id):
return edit_resource(Point, id, request.get_json())
@app.route('/point/<int:id>/children', methods=['GET'])
@with_session
def get_point_children(session, id):
def get_point_children(id):
data = dict(request.args)
data['parent_id'] = id
return search_resource(session, Point, data)
return search_resource(Point, data)
# Helper functions:
# TODO: Add helper functions for data validation
def create_resource(session, model_cls, data, get_function):
def create_resource(model_cls, data, get_function):
'''
Create the resource specified by the given model class and initialized with the data
dict, returning an appropriate JSON response.
:param session: The sqlalchemy session
:param model_cls: The class of the model for this resource
:param data: The initial data for this resource stored as a dictionary
:param get_function: The name of the view function (as a string) that gets a single instance of this resource. This is used for the response Location header.
:return: a Flask Response object
'''
resource = model_cls(**data)
session.add(resource)
session.commit()
g.db_session.add(resource)
g.db_session.commit()
response = make_response(jsonify(resource.as_json()), 201)
response.headers['Location'] = url_for(get_function, id=resource.id)
return response
def get_resource(session, model_cls, id):
def get_resource(model_cls, id):
'''
Get a single resource of the specified model class by its ID.
:param session: The sqlalchemy session
:param model_cls: The class of the model for this resource
:param id: The id of this resource
:return: a Flask Response object
'''
resource = session.query(model_cls).get(id)
resource = g.db_session.query(model_cls).get(id)
resp = (None, 404) if resource is None else \
(resource.as_json(), 200)
return make_response(jsonify(resp[0]), resp[1])
def edit_resource(session, model_cls, id, data):
def edit_resource(model_cls, id, data):
'''
Modify the resource of the specified model class and id with the data from
data. Does not perform data validation.
:param session: The sqlalchemy session
:param model_cls: The class of the model for this resource
:param id: The id of this resource
:param data: The new data for this resource stored as a dictionary
:return: a Flask Response object
'''
resource = session.query(model_cls).get(id)
resource = g.db_session.query(model_cls).get(id)
for attr in data:
setattr(resource, attr, data[attr])
session.commit()
g.db_session.commit()
return make_response(jsonify(resource.as_json()), 200)
def delete_resource(session, model_cls, id):
def delete_resource(model_cls, id):
'''
Delete the resource of the specified model class and id and return the
appropriate response.
:param session: The sqlalchemy session
:param model_cls: The class of the model for this resource
:param id: The id of this resource
:return: a Flask Response object
'''
resource = session.query(model_cls).get(id)
session.delete(resource)
session.commit()
resource = g.db_session.query(model_cls).get(id)
g.db_session.delete(resource)
g.db_session.commit()
return make_response('', 204)
def search_resource(session, model_cls, data):
def search_resource(model_cls, data):
'''
Search the database for a list of instances of the specified model class
that have the attributes given in data and return the appropriate JSON
response. Does not perform validation on search parameters.
:param session: The sqlalchemy session
:param model_cls: The class of the model for this resource
:param data: A dictionary containing search parameters
:return: a Flask Response object
'''
query = session.query(model_cls).filter_by(**data)
query = g.db_session.query(model_cls).filter_by(**data)
resp = (None, 404) if query.first() is None else \
(list(map(lambda m: m.as_json(), query.limit(100).all())), 200)
......
......@@ -8,36 +8,6 @@ from .field_types import FieldType
from .meta import Session, engine
@contextmanager
def session_context():
session = Session()
try:
yield session
session.commit()
except BaseException:
session.rollback()
raise
finally:
session.close()
def with_session(func):
"""
Decorator for convenience when building endpoints. The first argument to the
decorated function will be a safe-to-use, autocommitting Session instance.
:param func: the view function to wrap
:return: the wrapped function
"""
def wrapper(*args, **kwargs):
with session_context() as session:
return func(session, *args, **kwargs)
# Flask identifies endpoint handlers based on their name
wrapper.__name__ = func.__name__
return wrapper
@as_declarative()
class Base(object):
pass
......
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import sessionmaker, scoped_session
engine = create_engine('sqlite:///db.sqlite3') # TODO configurable
Session = sessionmaker(bind=engine)
# Session is scoped to the current thread (which, in Flask, is the current request)
Session = scoped_session(sessionmaker(bind=engine))
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment