package client import ( "comlink" "context" "encoding/json" "fmt" "github.com/recws-org/recws" "github.com/spf13/viper" "io" "net/http" "net/url" "os" "path/filepath" "sync" "time" ) const downloadInProgressExtension = ".tmp" type ( Progress struct { ws *recws.RecConn msg comlink.Message Filename string `json:"filename"` Total uint64 `json:"total"` Sent uint64 `json:"sent"` } Downloader struct { ws *recws.RecConn dls map[string]context.CancelFunc dlsLock *sync.Mutex } ctxReader struct { ctx context.Context r io.Reader } ) func NewDownloader(ws *recws.RecConn) *Downloader { return &Downloader{ ws: ws, dls: make(map[string]context.CancelFunc), dlsLock: &sync.Mutex{}, } } func (d *Downloader) Cancel(msg comlink.Message) { if cancel, ok := d.dls[string(msg.Payload)]; ok { cancel() } } // Download will download a url to a local file. It's efficient because it will // write as it downloads and not load the whole file into memory. func (d *Downloader) Download(msg comlink.Message) error { base := string(msg.Payload) filename := filepath.Join(viper.GetString("pp.media"), base) tmpFilename := filename + downloadInProgressExtension // Create the file, but give it a tmp file extension, this means we won't overwrite a // file until it's downloaded, but we'll remove the tmp extension once downloaded. out, err := os.Create(tmpFilename) if err != nil { return err } // Get the data from the server. u := url.URL{ Scheme: viper.GetString("server.scheme"), Host: viper.GetString("server.host"), Path: "/uploads/" + base, } resp, err := http.Get(u.String()) if err != nil { out.Close() return err } defer resp.Body.Close() // Have our ctx-aware reader able us to cancel the download. ctx, cancel := context.WithCancel(context.Background()) ctxR := &ctxReader{ ctx: ctx, r: resp.Body, } defer d.dlsLock.Unlock() d.dlsLock.Lock() d.dls[base] = cancel d.dlsLock.Unlock() // Create a Progress to count the already downloaded data. pr := &Progress{ ws: d.ws, msg: msg.WithType(comlink.Progress), Filename: base, Total: uint64(resp.ContentLength), } // Start a timer to sent the current download progress to the server once a second. ticker := time.NewTicker(time.Second) done := make(chan struct{}) defer func() { // Break the ticker loop. close(done) // Stop the ticker routine. ticker.Stop() }() go func() { for { select { case <-ticker.C: pr.ReportProgress() case <-done: return } } }() // Make sure all resources get cleaned up when the download ends (if aborted or not). defer func() { d.dlsLock.Lock() delete(d.dls, base) d.dlsLock.Unlock() }() // Download to file. if _, err := io.Copy(out, io.TeeReader(ctxR, pr)); err != nil { out.Close() // Remove the tmp file. os.Remove(tmpFilename) return err } // Close the file without defer so it can happen before Rename() out.Close() if err = os.Rename(tmpFilename, filename); err != nil { return err } // Report the progress once more so the server / managers know the download is complete. pr.ReportProgress() return nil } // Write will count the bytes that went through it. It will report the progress on the given // websocket after it received threshold bytes since the last report. func (p *Progress) Write(b []byte) (int, error) { n := len(b) p.Sent += uint64(n) time.Sleep(time.Second) return n, nil } func (p Progress) ReportProgress() { // TODO: Error handling j, _ := json.Marshal(p) fmt.Printf("Received %d of %d\n", p.Sent, p.Total) _ = p.ws.WriteJSON(p.msg.WithPayload(j)) } func (r ctxReader) Read(p []byte) (int, error) { if err := r.ctx.Err(); err != nil { return 0, err } return r.r.Read(p) }