package main import ( "bufio" "errors" "flag" "fmt" "io" "log" "net" "net/http" "net/url" "os" "os/signal" "runtime" "strings" "syscall" "time" "github.com/hashicorp/yamux" "github.com/polvi/sni" ss "github.com/shadowsocks/shadowsocks-go/shadowsocks" "golang.org/x/net/proxy" ) const ( connectTimeout = 30 * time.Second ) var ( dialer = &net.Dialer{ Timeout: connectTimeout, } ciph *ss.Cipher proxyDial = dialer.Dial localPort uint remotePort uint ) // httpProxy is a HTTP/HTTPS connect proxy. type httpProxy struct { host string haveAuth bool username string password string forward proxy.Dialer } func (s *httpProxy) Dial(network, addr string) (net.Conn, error) { // Dial and create the https client connection. c, err := s.forward.Dial("tcp", s.host) if err != nil { return nil, err } // HACK. http.ReadRequest also does this. reqURL, err := url.Parse("http://" + addr) if err != nil { c.Close() return nil, err } reqURL.Scheme = "" req, err := http.NewRequest("CONNECT", reqURL.String(), nil) if err != nil { c.Close() return nil, err } req.Close = false if s.haveAuth { req.SetBasicAuth(s.username, s.password) } err = req.Write(c) if err != nil { c.Close() return nil, err } resp, err := http.ReadResponse(bufio.NewReader(c), req) if err != nil { if resp != nil { resp.Body.Close() } c.Close() return nil, err } resp.Body.Close() if resp.StatusCode != 200 { c.Close() err = fmt.Errorf("connect server using proxy error %s", resp.Status) return nil, err } return c, nil } func newClientProxyFunc(srvAddrs []string) func(net.Conn) { return func(conn net.Conn) { localConn := conn defer localConn.Close() sniHost, localConn, err := sniServerNameFromConn(localConn) if err != nil { log.Printf("sniServerNameFromConn from %s error: %v", conn.RemoteAddr(), err) return } log.Printf("sniServernameFromConn got %s from %s", sniHost, localConn.RemoteAddr()) var remoteConn net.Conn connCh := make(chan net.Conn) closeCh := make(chan struct{}) connectTimer := time.NewTimer(2 * connectTimeout) delayTimer := time.NewTimer(0) LOOP: for _, srvAddr := range srvAddrs { select { case <-delayTimer.C: delayTimer.Reset(time.Second) log.Printf("try serverConnect %s from %s", srvAddr, localConn.RemoteAddr()) go serverConnect(localConn.RemoteAddr().String(), srvAddr, connCh, closeCh) case remoteConn = <-connCh: delayTimer.Stop() close(closeCh) break LOOP } } if remoteConn == nil { select { case <-connectTimer.C: log.Printf("remote conn not got, local conn: %s", localConn.RemoteAddr()) close(closeCh) return case remoteConn = <-connCh: close(closeCh) } } connectTimer.Stop() defer remoteConn.Close() go relay(localConn, remoteConn) relay(remoteConn, localConn) } } func serverConnect(localAddr, remoteAddr string, connCh chan<- net.Conn, closeCh <-chan struct{}) { remoteConn, err := proxyDial("tcp", remoteAddr) if err != nil { log.Printf("dial from %s to %s error: %v", localAddr, remoteAddr, err) return } remoteConn = ss.NewConn(remoteConn, ciph.Copy()) sess, _ := yamux.Client(remoteConn, yamux.DefaultConfig()) if _, err := sess.Ping(); err != nil { log.Printf("sess.Ping from %s to %s error: %v", localAddr, remoteAddr, err) sess.Close() return } stream, err := sess.Open() if err != nil { log.Printf("sess.Open from %s to %s error: %v", localAddr, remoteAddr, err) sess.Close() return } select { case <-closeCh: sess.Close() case connCh <- stream: } } func sniServerNameFromConn(conn net.Conn) (string, net.Conn, error) { host, conn, err := sni.ServerNameFromConn(conn) if err != nil { return "", nil, fmt.Errorf("invalid sni: %v", err) } if !strings.Contains(host, ":") { host = fmt.Sprintf("%s:%d", host, remotePort) } if !strings.Contains(host, ".google.com:") { return "", nil, fmt.Errorf("not allowed host: %s", host) } return host, conn, nil } func serverProxyFunc(localConn net.Conn) { defer localConn.Close() localConn = ss.NewConn(localConn, ciph.Copy()) sess, _ := yamux.Server(localConn, yamux.DefaultConfig()) defer sess.Close() localStream, err := sess.Accept() if err != nil { log.Printf("sess.Accept from %s error: %v", localConn.RemoteAddr(), err) return } remoteAddr, localStream, err := sniServerNameFromConn(localStream) if err != nil { log.Printf("sniServerNameFromConn from %s error: %v", localConn.RemoteAddr(), err) return } log.Printf("sniConnect got %s from %s", remoteAddr, localConn.RemoteAddr()) remoteConn, err := dialer.Dial("tcp", remoteAddr) if err != nil { log.Printf("dial from %s to %s error: %v", localConn.RemoteAddr(), remoteAddr, err) return } defer remoteConn.Close() if c, ok := remoteConn.(*net.TCPConn); ok { c.SetNoDelay(true) } go relay(localStream, remoteConn) relay(remoteConn, localStream) } func sigQuitHandle() { ch := make(chan os.Signal, 5) signal.Notify(ch, syscall.SIGQUIT) for range ch { buf := make([]byte, 1024*1024) n := runtime.Stack(buf, true) buf = buf[:n] log.Printf("SIGQUIT got, stack:\n%s", buf) } } func relay(src, dst net.Conn) { log.Printf("relay %s <-> %s start", src.RemoteAddr(), dst.RemoteAddr()) n, err := io.Copy(dst, src) log.Printf("relay %s <-> %s %d bytes, error: %v", src.RemoteAddr(), dst.RemoteAddr(), n, err) src.Close() dst.Close() } type ssProxyDialer struct { server string ciph *ss.Cipher upstreamDialer proxy.Dialer } func (s *ssProxyDialer) Dial(network, addr string) (net.Conn, error) { rawaddr, err := ss.RawAddr(addr) if err != nil { return nil, err } conn, err := s.upstreamDialer.Dial("tcp", s.server) if err != nil { return nil, err } if c, ok := conn.(*net.TCPConn); ok { c.SetNoDelay(true) } ssConn := ss.NewConn(conn, s.ciph.Copy()) if _, err := ssConn.Write(rawaddr); err != nil { conn.Close() return nil, err } log.Printf("dial %s via shadowsocks %s", addr, s.server) return ssConn, nil } func ssDial(u *url.URL, d proxy.Dialer) (proxy.Dialer, error) { if u.User == nil { return nil, errors.New("no shadowsocks method") } method := u.User.Username() p, _ := u.User.Password() ciph, err := ss.NewCipher(method, p) if err != nil { return nil, err } log.Printf("using proxy ss://%s:%s@%s", method, p, u.Host) return &ssProxyDialer{ server: u.Host, ciph: ciph, upstreamDialer: d, }, nil } func init() { proxy.RegisterDialerType("http", func(u *url.URL, forward proxy.Dialer) (d proxy.Dialer, err error) { s := new(httpProxy) s.host = u.Host s.forward = forward if u.User != nil { s.haveAuth = true s.username = u.User.Username() s.password, _ = u.User.Password() } return s, nil }) proxy.RegisterDialerType("ss", ssDial) } func main() { log.SetFlags(log.Lmicroseconds) go sigQuitHandle() var ( proxyStr string shadowPass string ) flag.StringVar(&proxyStr, "proxy", "", "Set proxy url") flag.UintVar(&localPort, "port", 5228, "Local port") flag.UintVar(&remotePort, "remote-port", 5228, "Remote port") flag.StringVar(&shadowPass, "key", "fcmproxy", "Encrypt key") flag.Parse() if proxyStr != "" { u, err := url.Parse(proxyStr) if err != nil { log.Fatalf("invalid proxy url: %v", err) } d, err := proxy.FromURL(u, dialer) if err != nil { log.Fatalf("parse proxy failed: %v", err) } proxyDial = d.Dial } ciph, _ = ss.NewCipher("rc4-md5", shadowPass) srvAddr := flag.Args() var proxyFunc func(net.Conn) if len(srvAddr) > 0 { proxyFunc = newClientProxyFunc(srvAddr) log.Printf("working on client, servers: %v", srvAddr) } else { proxyFunc = serverProxyFunc log.Printf("working on server") } l, err := net.Listen("tcp", fmt.Sprintf(":%d", localPort)) if err != nil { log.Fatalf("listen: %v", err) } for { conn, err := l.Accept() if err != nil { log.Fatalf("accept: %v", err) } if c, ok := conn.(*net.TCPConn); ok { c.SetNoDelay(true) } go proxyFunc(conn) } }