package main import ( "fmt" "io" "io/ioutil" "net" "net/http" "sync" "time" ) const TIMEOUT = 10 * time.Second func handleReq(w http.ResponseWriter, r *http.Request) { if r.Method != "CONNECT" { proxyReq(w, r) } else { tunnelReq(w, r) } } func tunnelReq(w http.ResponseWriter, r *http.Request) { fmt.Printf("recv tunnel req: %+v\n", r.URL.String()) hj, ok := w.(http.Hijacker) if !ok { http.Error(w, "webserver doesn't support hijacking", http.StatusInternalServerError) return } conn, bufrw, err := hj.Hijack() if err != nil { fmt.Println("hijack err:", err.Error()) http.Error(w, err.Error(), http.StatusInternalServerError) return } defer conn.Close() connDest, err := net.DialTimeout("tcp", r.URL.Host, TIMEOUT) if err != nil { fmt.Println("dial err:", err.Error()) bufrw.WriteString("HTTP/1.1 500 Internal Server Error\r\n\r\n") return } defer connDest.Close() bufrw.WriteString("HTTP/1.1 200 Connection Established\r\n\r\n") bufrw.Flush() var wg sync.WaitGroup wg.Add(2) go func() { _, err := io.Copy(connDest, conn) if err != nil { fmt.Println("copy err:", err) } conn.Close() connDest.Close() fmt.Println("src -> dest close") wg.Done() }() go func() { _, err := io.Copy(conn, connDest) if err != nil { fmt.Println("copy err:", err) } conn.Close() connDest.Close() fmt.Println("dest -> src close") wg.Done() }() wg.Wait() fmt.Println("disconnect tunnel") } func proxyReq(w http.ResponseWriter, r *http.Request) { fmt.Printf("recv proxy req: %+v\n", r.URL.String()) client := &http.Client{Timeout: TIMEOUT} req, e := http.NewRequest(r.Method, r.URL.String(), r.Body) if e != nil { fmt.Println("create request err: ", e) return } req.Header = r.Header resp, e := client.Do(req) if e != nil { fmt.Println("do client err: ", e) return } for k, v := range resp.Header { w.Header()[k] = v } w.WriteHeader(resp.StatusCode) defer resp.Body.Close() body, e := ioutil.ReadAll(resp.Body) if e != nil { fmt.Println("read body err:", e) return } _, e = w.Write(body) if e != nil { fmt.Println("write body err:", e) return } fmt.Println("end proxy req") } func main() { s := &http.Server{Addr: ":12345", Handler: http.HandlerFunc(handleReq), ReadTimeout: TIMEOUT, WriteTimeout: TIMEOUT} err := s.ListenAndServe() if err != nil { fmt.Println("ListenAndServe: ", err) } }