Skip to content

Instantly share code, notes, and snippets.

@stevenferrer
Forked from Hunsin/router.go
Created November 19, 2021 14:58
Show Gist options
  • Select an option

  • Save stevenferrer/2a22f555bcf54d923d26190deb42f259 to your computer and use it in GitHub Desktop.

Select an option

Save stevenferrer/2a22f555bcf54d923d26190deb42f259 to your computer and use it in GitHub Desktop.
The package wraps julienschmidt's httprouter, making it support functions such as middlewares, sub/group routing with same prefix. Written in Go (Golang).
package router
import (
"net/http"
"github.com/julienschmidt/httprouter"
)
type middleware func(httprouter.Handle) httprouter.Handle
// Router is a http.Handler that wraps httprouter.Router with additional features.
type Router struct {
middlewares []middleware
path string
router *httprouter.Router
}
// NewRouter returns *Router with a new initialized *httprouter.Router embedded.
func NewRouter() *Router {
return &Router{router: httprouter.New()}
}
func (r *Router) joinPath(path string) string {
if (r.path + path)[0] != '/' {
panic("path should start with '/' in path '" + path + "'.")
}
return r.path + path
}
// Group returns new *Router with given path and middlewares.
// It should be used for handles which have same path prefix or common middlewares.
func (r *Router) Group(path string, m ...middleware) *Router {
if path[len(path)-1] == '/' {
path = path[:len(path)-1]
}
return &Router{
middlewares: append(m, r.middlewares...),
path: r.joinPath(path),
router: r.router,
}
}
// Use appends new middleware to current Router.
func (r *Router) Use(m ...middleware) *Router {
r.middlewares = append(m, r.middlewares...)
return r
}
// Handle registers a new request handle combined with middlewares.
func (r *Router) Handle(method, path string, handle httprouter.Handle) {
for _, v := range r.middlewares {
handle = v(handle)
}
r.router.Handle(method, r.joinPath(path), handle)
}
// GET is a shortcut for Router.Handle("GET", path, handle)
func (r *Router) GET(path string, handle httprouter.Handle) {
r.Handle("GET", path, handle)
}
// HEAD is a shortcut for Router.Handle("HEAD", path, handle)
func (r *Router) HEAD(path string, handle httprouter.Handle) {
r.Handle("HEAD", path, handle)
}
// OPTIONS is a shortcut for Router.Handle("OPTIONS", path, handle)
func (r *Router) OPTIONS(path string, handle httprouter.Handle) {
r.Handle("OPTIONS", path, handle)
}
// POST is a shortcut for Router.Handle("POST", path, handle)
func (r *Router) POST(path string, handle httprouter.Handle) {
r.Handle("POST", path, handle)
}
// PUT is a shortcut for Router.Handle("PUT", path, handle)
func (r *Router) PUT(path string, handle httprouter.Handle) {
r.Handle("PUT", path, handle)
}
// PATCH is a shortcut for Router.Handle("PATCH", path, handle)
func (r *Router) PATCH(path string, handle httprouter.Handle) {
r.Handle("PATCH", path, handle)
}
// DELETE is a shortcut for Router.Handle("DELETE", path, handle)
func (r *Router) DELETE(path string, handle httprouter.Handle) {
r.Handle("DELETE", path, handle)
}
// Handler is an adapter for http.Handler.
func (r *Router) Handler(method, path string, handler http.Handler) {
handle := func(w http.ResponseWriter, req *http.Request, _ httprouter.Params) {
handler.ServeHTTP(w, req)
}
r.Handle(method, path, handle)
}
// HandlerFunc is an adapter for http.HandlerFunc.
func (r *Router) HandlerFunc(method, path string, handler http.HandlerFunc) {
r.Handler(method, path, handler)
}
// Static serves files from given root directory.
func (r *Router) Static(path, root string) {
if len(path) < 10 || path[len(path)-10:] != "/*filepath" {
panic("path should end with '/*filepath' in path '" + path + "'.")
}
base := r.joinPath(path[:len(path)-9])
fileServer := http.StripPrefix(base, http.FileServer(http.Dir(root)))
r.Handler("GET", path, fileServer)
}
// File serves the named file.
func (r *Router) File(path, name string) {
r.HandlerFunc("GET", path, func(w http.ResponseWriter, req *http.Request) {
http.ServeFile(w, req, name)
})
}
func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
r.router.ServeHTTP(w, req)
}
package router
import (
"io/ioutil"
"net/http"
"net/http/httptest"
"os"
"testing"
"github.com/julienschmidt/httprouter"
)
func TestHandle(t *testing.T) {
router := NewRouter()
h := func(w http.ResponseWriter, _ *http.Request, _ httprouter.Params) {
w.WriteHeader(http.StatusTeapot)
}
router.Handle("GET", "/", h)
r := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, r)
if w.Code != http.StatusTeapot {
t.Error("Test Handle failed")
}
}
func TestHandler(t *testing.T) {
router := NewRouter()
h := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusTeapot)
})
router.Handler("GET", "/", h)
r := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, r)
if w.Code != http.StatusTeapot {
t.Error("Test Handler failed")
}
}
func TestHandlerFunc(t *testing.T) {
router := NewRouter()
h := func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusTeapot)
}
router.HandlerFunc("GET", "/", h)
r := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, r)
if w.Code != http.StatusTeapot {
t.Error("Test HandlerFunc failed")
}
}
func TestMethod(t *testing.T) {
router := NewRouter()
router.DELETE("/delete", func(w http.ResponseWriter, _ *http.Request, _ httprouter.Params) {
w.WriteHeader(http.StatusTeapot)
})
router.GET("/get", func(w http.ResponseWriter, _ *http.Request, _ httprouter.Params) {
w.WriteHeader(http.StatusTeapot)
})
router.HEAD("/head", func(w http.ResponseWriter, _ *http.Request, _ httprouter.Params) {
w.WriteHeader(http.StatusTeapot)
})
router.OPTIONS("/options", func(w http.ResponseWriter, _ *http.Request, _ httprouter.Params) {
w.WriteHeader(http.StatusTeapot)
})
router.PATCH("/patch", func(w http.ResponseWriter, _ *http.Request, _ httprouter.Params) {
w.WriteHeader(http.StatusTeapot)
})
router.POST("/post", func(w http.ResponseWriter, _ *http.Request, _ httprouter.Params) {
w.WriteHeader(http.StatusTeapot)
})
router.PUT("/put", func(w http.ResponseWriter, _ *http.Request, _ httprouter.Params) {
w.WriteHeader(http.StatusTeapot)
})
samples := map[string]string{
"DELETE": "/delete",
"GET": "/get",
"HEAD": "/head",
"OPTIONS": "/options",
"PATCH": "/patch",
"POST": "/post",
"PUT": "/put",
}
for method, path := range samples {
r := httptest.NewRequest(method, path, nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, r)
if w.Code != http.StatusTeapot {
t.Errorf("Path %s not registered", path)
}
}
}
func TestGroup(t *testing.T) {
router := NewRouter()
foo := router.Group("/foo")
bar := router.Group("/bar")
baz := foo.Group("/baz")
foo.HandlerFunc("GET", "", func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusTeapot)
})
foo.HandlerFunc("GET", "/group", func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusTeapot)
})
bar.HandlerFunc("GET", "/group", func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusTeapot)
})
baz.HandlerFunc("GET", "/group", func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusTeapot)
})
samples := []string{"/foo", "/foo/group", "/foo/baz/group", "/bar/group"}
for _, path := range samples {
r := httptest.NewRequest("GET", path, nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, r)
if w.Code != http.StatusTeapot {
t.Errorf("Grouped path %s not registered", path)
}
}
}
func TestMiddleware(t *testing.T) {
var use, group bool
router := NewRouter().Use(func(next httprouter.Handle) httprouter.Handle {
return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
use = true
next(w, r, ps)
}
})
foo := router.Group("/foo", func(next httprouter.Handle) httprouter.Handle {
return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
group = true
next(w, r, ps)
}
})
foo.HandlerFunc("GET", "/bar", func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusTeapot)
})
r := httptest.NewRequest("GET", "/foo/bar", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, r)
if !use {
t.Error("Middleware registered by Use() under \"/\" not touched")
}
if !group {
t.Error("Middleware registered by Group() under \"/foo\" not touched")
}
}
func TestStatic(t *testing.T) {
files := []string{"temp_1", "temp_2"}
strs := []string{"test content", "static contents"}
for i := range files {
f, _ := os.Create(files[i])
defer os.Remove(files[i])
f.WriteString(strs[i])
f.Sync()
f.Close()
}
pwd, _ := os.Getwd()
router := NewRouter()
router.Static("/*filepath", pwd)
for i := range files {
r := httptest.NewRequest("GET", "/"+files[i], nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, r)
body := w.Result().Body
defer body.Close()
file, _ := ioutil.ReadAll(body)
if string(file) != strs[i] {
t.Error("Test Static failed")
}
}
}
func TestFile(t *testing.T) {
str := "test_content"
f, _ := os.Create("temp_file")
defer os.Remove("temp_file")
f.WriteString(str)
f.Sync()
f.Close()
router := NewRouter()
router.File("/file", "temp_file")
r := httptest.NewRequest("GET", "/file", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, r)
body := w.Result().Body
defer body.Close()
file, _ := ioutil.ReadAll(body)
if string(file) != str {
t.Error("Test File failed")
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment