import Vapor import SwifQL class AuthMiddleware: Middleware { /// Use it on your router like this /// ```swift /// let protectedRoute = router.grouped(AuthMiddleware()) /// ``` func respond(to request: Request, chainingTo next: Responder) throws -> Future { return try request.requireToken().flatMap { try next.respond(to: $0) } } } extension Request { /// Parses token from headers fileprivate func requireToken() throws -> Future { let prefix = "Bearer " guard let _bearerToken = http.headers.firstValue(name: .authorization), _bearerToken.contains(prefix) else { throw Abort(.unauthorized, reason: "Invalid authorization token") } let token = _bearerToken.replacingOccurrences(of: prefix, with: "") return try checkToken(token: token).map { try self.putUserIntoHeaders($0) } } } extension Request { /// Checks if token exists in the database and retrieves a user fileprivate func checkToken(token: String) throws -> Future { return SwifQL .select(User.table.*) .from(UserToken.table) .join(.inner, User.table, on: \User.id == \UserToken.userId) .where(\UserToken.token == token) .execute(on: self, as: .psql) .first(decoding: User.self) .unwrap(or: Abort(.unauthorized, reason: "Invalid auth credentials")) } } extension Request { fileprivate var headerAuthDataKey: String { return "_authData" } /// Puts a user model as a JSON string into headers (to retrieve it in the future) fileprivate func putUserIntoHeaders(_ user: User) throws -> Request { let encodedUserData = try JSONEncoder().encode(user) guard let jsonString = String(data:encodedUserData, encoding: .utf8) else { throw Abort(.internalServerError, reason: "Session encoding error") } http.headers.add(name: headerAuthDataKey, value: jsonString) return self } } extension Request { /// Call it to retrieve authorized user from request headers @discardableResult func authorizedUser() throws -> User { guard let json = http.headers[headerAuthDataKey].first, let data = json.data(using: .utf8) else { throw Abort(.unauthorized, reason: "Invalid auth credentials") } return try JSONDecoder().decode(User.self, from: data) } }