前言

假设,我们现在正在下载一个十几个G的游戏或者电影,然后”啪“,停电了,等到再来电的时候,打开电脑,如果显示的是从0开始下载,你会是什么样的感觉。

电脑或者桌子表示害怕。

所以现在,断点续传,几乎是每个下载功能中都会有的存在,就是为了拯救你的电脑不被即将暴走的你砸烂。

原理

断点续传的原理其实很简单,它通过在header里两个参数实现的,客户端发请求时对应的是Range,服务器响应时对应的是Content-Range

Range

用于请求头中,指定第一个字节的位置和最后一个字节的位置,一般格式:

Range:(unit=first byte pos)-[last byte pos]

Range 头部的格式有以下几种情况:

Range: bytes=0-499 表示第 0-499 字节范围的内容
Range: bytes=500-999 表示第 500-999 字节范围的内容
Range: bytes=-500 表示最后 500 字节的内容
Range: bytes=500- 表示从第 500 字节开始到文件结束部分的内容
Range: bytes=0-0,-1 表示第一个和最后一个字节
Range: bytes=500-600,601-999 同时指定几个范围

Content-Range

Content-Range: bytes (unit first byte pos) - [last byte pos]/[entity legth]

例如:

Content-Range: bytes 0-499/22400

-499 是指当前发送的数据的范围,而 22400 则是文件的总大小。

而在响应完成后,返回的响应头内容也不同:

HTTP/1.1 200 Ok(不使用断点续传方式) 
HTTP/1.1 206 Partial Content(使用断点续传方式)

断点续传流程

HTTP1.1协议(RFC2616)中定义了断点续传相关的HTTP头 Range和Content-Range字段,一个最简单的断点续传实现大概如下:

客户端下载一个1024K的文件,已经下载了其中512K
网络中断,客户端请求续传,因此需要在HTTP头中申明本次需要续传的片段:Range:bytes=512000-,这个头通知服务端从文件的512K位置开始传输文件
服务端收到断点续传请求,从文件的512K位置开始传输,并且在HTTP头中增加:Content-Range:bytes 512000-1023999/1024000,并且此时服务端返回的HTTP状态码应该是206,而不是200。

细节-文件一致性

在RFC2616中也有相应的定义,比如实现Last-Modified来标识文件的最后修改时间,这样即可判断出续传文件时是否已经发生过改动。同时RFC2616中还定义有一个ETag的头,可以使用ETag头来放置文件的唯一标识,比如文件的MD5值。

终端在发起续传请求时应该在HTTP头中申明If-Match 或者If-Modified-Since 字段,帮助服务端判别文件变化。

另外RFC2616中同时定义有一个If-Range头,终端如果在续传是使用If-Range。If-Range中的内容可以为最初收到的ETag头或者是Last-Modfied中的最后修改时候。服务端在收到续传请求时,通过If-Range中的内容进行校验,校验一致时返回206的续传回应,不一致时服务端则返回200回应,回应的内容为新的文件的全部数据。

golang标准库实现断点续传

在golang中,net/http标准库是一个很强大的库,内部是给实现了断点续传的功能的。

http.ServeFile(w ResponseWriter, r *Request, name string)

值得注意的是,如果直接使用该方法,在文件一致的情况下,也就是Last-Modfied或者ETag一致的情况下,他会返回一个304状态码,并不是200。

如果直接使用的话会很简单,将所需要传输的文件路径传进去即可。

源码解析

首先呢,在该方法中是分割了文件的路径,然后自定义的Dir类型实现接口来进行操作。

这里比较重要的是fs.Openf.Stat,第一个,是通过Dir类型实现的FileSystem接口抽象出来的方法。第二个,他并不会真的打开文件,而是获取文件的一些信息,接下来就是针对于这些信息进行校验,确保正确与健壮性。

serveContent

这个func首先验证了last-modified,并且处理了请求头range。随后根据其type来进行一些校验。

这里是比较重要的地方,根据range的存在与否和range的长度等,来进行不同的操作。总之原理还是遵照规范来的,感兴趣的小伙伴可以去细品一下源码。

最后是针对head方法进行校验

ps: 这似乎并不太合理,因为要处理的事情都已经处理完了。最后再处理head方法有点不太好。

大体源码流程就是如此。end

GoMatrix

集成到框架中来,因为我不太想使用304状态码,所以重写了fs.go相关的代码,根目录新增fs.go

package GoMatrix

import (
    "errors"
    "fmt"
    "io"
    "io/fs"
    "log"
    "mime"
    "mime/multipart"
    "net/http"
    "net/textproto"
    "net/url"
    "os"
    "path"
    "path/filepath"
    "sort"
    "strconv"
    "strings"
    "time"
)

// 解决304状态码,改造了http/fs源码

var htmlReplacer = strings.NewReplacer(
    "&", "&",
    "<", "<",
    ">", ">",
    `"`, "&#34;",
    "'", "'",
)
var errNoOverlap = errors.New("invalid range: failed to overlap")


type (
    Dir string
    condResult int
    dirEntryDirs []fs.DirEntry
    fileInfoDirs []fs.FileInfo
    countingWriter int64
)

const (
    condNone condResult = iota
    condTrue
    condFalse
)

const sniffLen = 512

type anyDirs interface {
    len() int
    name(i int) string
    isDir(i int) bool
}

type httpRange struct {
    start, length int64
}

func (d dirEntryDirs) len() int          { return len(d) }
func (d dirEntryDirs) isDir(i int) bool  { return d[i].IsDir() }
func (d dirEntryDirs) name(i int) string { return d[i].Name() }

func (d fileInfoDirs) len() int          { return len(d) }
func (d fileInfoDirs) isDir(i int) bool  { return d[i].IsDir() }
func (d fileInfoDirs) name(i int) string { return d[i].Name() }


func (d Dir) Open(name string) (http.File, error) {
    if filepath.Separator != '/' && strings.ContainsRune(name, filepath.Separator) {
        return nil, errors.New("http: invalid character in file path")
    }
    dir := string(d)
    if dir == "" {
        dir = "."
    }
    fullName := filepath.Join(dir, filepath.FromSlash(path.Clean("/"+name)))
    f, err := os.Open(fullName)
    if err != nil {
        return nil, err
    }
    return f, nil
}

func (c *Context)ServeFile(name string)  {
    if containsDotDot(c.Req.URL.Path) {
        http.Error(c.Writer, "invalid URL path", http.StatusBadRequest)
        return
    }
    dir, file := filepath.Split(name)
    c.serveFile(Dir(dir), file, false)
}

func (c *Context)serveFile(fs http.FileSystem, name string, redirect bool) {
    const indexPage = "/index.html"
    if strings.HasSuffix(c.Req.URL.Path, indexPage) {
        c.localRedirect("./")
        return
    }
    f, err := fs.Open(name)
    if err != nil {
        msg, code := toHTTPError(err)
        http.Error(c.Writer,msg,code)
        return
    }
    defer f.Close()
    d, err := f.Stat()
    if err != nil {
        msg, code := toHTTPError(err)
        http.Error(c.Writer, msg, code)
        return
    }

    if redirect {
        url := c.Req.URL.Path
        if d.IsDir() {
            if url[len(url)-1] != '/' {
                c.localRedirect(path.Base(url)+"/")
                return
            }
        } else {
            if url[len(url)-1] == '/' {
                c.localRedirect("../"+path.Base(url))
                return
            }
        }
    }
    if d.IsDir() {
        url := c.Req.URL.Path
        if url == "" || url[len(url)-1] != '/' {
            c.localRedirect(path.Base(url)+"/")
            return
        }

        index := strings.TrimSuffix(name, "/") + indexPage
        ff, err := fs.Open(index)
        if err == nil {
            defer ff.Close()
            dd, err := ff.Stat()
            if err == nil {
                name = index
                d = dd
                f = ff
            }
        }
    }

    if d.IsDir() {
        if c.checkIfModifiedSince(d.ModTime()) == condFalse {
            c.writeNotModified()
            return
        }
        c.setLastModified(d.ModTime())
        c.dirList(f)
        return
    }
    sizeFunc := func() (int64, error) { return d.Size(), nil }
    c.serveContent(sizeFunc, f, d.Name(), d.ModTime())
}

func (c *Context)localRedirect(newPath string) {
    if q := c.Req.URL.RawQuery; q != "" {
        newPath += "?" + q
    }
    c.SetHeader("Location", newPath)
    c.Status(http.StatusMovedPermanently)
}

func (c *Context)serveContent(sizeFunc func() (int64, error), content io.ReadSeeker, name string, modtime time.Time,) {
    c.setLastModified(modtime)
    done, rangeReq := c.checkPreconditions(modtime)
    if done {
        return
    }
    code := http.StatusOK
    ctypes, haveType := c.Writer.Header()["Content-Type"]
    var ctype string
    if !haveType {
        ctype = mime.TypeByExtension(filepath.Ext(name))
        if ctype == "" {
            // read a chunk to decide between utf-8 text and binary
            var buf [sniffLen]byte
            n, _ := io.ReadFull(content, buf[:])
            ctype = http.DetectContentType(buf[:n])
            _, err := content.Seek(0, io.SeekStart) // rewind to output whole file
            if err != nil {
                http.Error(c.Writer, "seeker can't seek", http.StatusInternalServerError)
                return
            }
        }
        c.SetHeader("Content-Type", ctype)
    } else if len(ctypes) > 0 {
        ctype = ctypes[0]
    }
    size, err := sizeFunc()
    if err != nil {
        http.Error(c.Writer, err.Error(), http.StatusInternalServerError)
        return
    }
    sendSize := size
    var sendContent io.Reader = content
    if size >= 0 {
        ranges, err := parseRange(rangeReq, size)
        if err != nil {
            if err == errNoOverlap {
                c.SetHeader("Content-Range", fmt.Sprintf("bytes */%d", size))
            }
            http.Error(c.Writer, err.Error(), http.StatusRequestedRangeNotSatisfiable)
            return
        }
        if sumRangesSize(ranges) > size {
            ranges = nil
        }
        switch {
        case len(ranges) == 1:
            ra := ranges[0]
            if _, err := content.Seek(ra.start, io.SeekStart); err != nil {
                http.Error(c.Writer, err.Error(), http.StatusRequestedRangeNotSatisfiable)
                return
            }
            sendSize = ra.length
            code = http.StatusPartialContent
            c.SetHeader("Content-Range", ra.contentRange(size))
        case len(ranges) > 1:
            sendSize = rangesMIMESize(ranges, ctype, size)
            code = http.StatusPartialContent
            pr, pw := io.Pipe()
            mw := multipart.NewWriter(pw)
            c.SetHeader("Content-Type", "multipart/byteranges; boundary="+mw.Boundary())
            sendContent = pr
            defer pr.Close() // cause writing goroutine to fail and exit if CopyN doesn't finish.
            go func() {
                for _, ra := range ranges {
                    part, err := mw.CreatePart(ra.mimeHeader(ctype, size))
                    if err != nil {
                        pw.CloseWithError(err)
                        return
                    }
                    if _, err := content.Seek(ra.start, io.SeekStart); err != nil {
                        pw.CloseWithError(err)
                        return
                    }
                    if _, err := io.CopyN(part, content, ra.length); err != nil {
                        pw.CloseWithError(err)
                        return
                    }
                }
                mw.Close()
                pw.Close()
            }()
        }
        c.SetHeader("Accept-Ranges", "bytes")
        if c.GetHeader("Content-Encoding") == "" {
            c.SetHeader("Content-Length", strconv.FormatInt(sendSize, 10))
        }
    }
    c.Status(code)
    if c.Req.Method != "HEAD" {
        io.CopyN(c.Writer, sendContent, sendSize)
    }
}

func parseRange(s string, size int64) ([]httpRange, error) {
    if s == "" {
        return nil, nil
    }
    const b = "bytes="
    if !strings.HasPrefix(s, b) {
        return nil, fmt.Errorf("invalid range")
    }
    var ranges []httpRange
    noOverlap := false
    for _, ra := range strings.Split(s[len(b):], ",") {
        ra = textproto.TrimString(ra)
        if ra == "" {
            continue
        }
        i := strings.Index(ra, "-")
        if i < 0 {
            return nil, fmt.Errorf("invalid range")
        }
        start, end := textproto.TrimString(ra[:i]), textproto.TrimString(ra[i+1:])
        var r httpRange
        if start == "" {
            if end == "" || end[0] == '-' {
                return nil, fmt.Errorf("invalid range")
            }
            i, err := strconv.ParseInt(end, 10, 64)
            if i < 0 || err != nil {
                return nil, fmt.Errorf("invalid range")
            }
            if i > size {
                i = size
            }
            r.start = size - i
            r.length = size - r.start
        } else {
            i, err := strconv.ParseInt(start, 10, 64)
            if err != nil || i < 0 {
                return nil, fmt.Errorf("invalid range")
            }
            if i >= size {
                noOverlap = true
                continue
            }
            r.start = i
            if end == "" {
                r.length = size - r.start
            } else {
                i, err := strconv.ParseInt(end, 10, 64)
                if err != nil || r.start > i {
                    return nil, fmt.Errorf("invalid range")
                }
                if i >= size {
                    i = size - 1
                }
                r.length = i - r.start + 1
            }
        }
        ranges = append(ranges, r)
    }
    if noOverlap && len(ranges) == 0 {
        return nil, fmt.Errorf("invalid range: failed to overlap")
    }
    return ranges, nil
}

func (c *Context)setLastModified(modtime time.Time) {
    if !isZeroTime(modtime) {
        c.SetHeader("Last-Modified", modtime.UTC().Format(TimeFormat))
    }
}

func isZeroTime(t time.Time) bool {
    var unixEpochTime = time.Unix(0, 0)
    return t.IsZero() || t.Equal(unixEpochTime)
}

func isSlashRune(r rune) bool { return r == '/' || r == '\\' }

func containsDotDot(v string) bool {
    if !strings.Contains(v, "..") {
        return false
    }
    for _, ent := range strings.FieldsFunc(v, isSlashRune) {
        if ent == ".." {
            return true
        }
    }
    return false
}

func (c *Context)checkIfModifiedSince(modtime time.Time) condResult {
    if c.Req.Method != "GET" && c.Req.Method != "HEAD" {
        return condNone
    }
    ims := c.GetHeader("If-Modified-Since")
    if ims == "" || isZeroTime(modtime) {
        return condNone
    }
    t, err := http.ParseTime(ims)
    if err != nil {
        return condNone
    }
    // The Last-Modified header truncates sub-second precision so
    // the modtime needs to be truncated too.
    modtime = modtime.Truncate(time.Second)
    if modtime.Before(t) || modtime.Equal(t) {
        return condFalse
    }
    return condTrue
}

func (c *Context)writeNotModified() {
    h := c.Writer.Header()
    delete(h, "Content-Type")
    delete(h, "Content-Length")
    if h.Get("Etag") != "" {
        delete(h, "Last-Modified")
    }
    c.Status(http.StatusOK)
}

func logf(r *http.Request, format string, args ...interface{}) {
    s, _ := r.Context().Value(http.ServerContextKey).(*http.Server)
    if s != nil && s.ErrorLog != nil {
        s.ErrorLog.Printf(format, args...)
    } else {
        log.Printf(format, args...)
    }
}

func (c *Context)dirList(f http.File) {
    var dirs anyDirs
    var err error
    if d, ok := f.(fs.ReadDirFile); ok {
        var list dirEntryDirs
        list, err = d.ReadDir(-1)
        dirs = list
    } else {
        var list fileInfoDirs
        list, err = f.Readdir(-1)
        dirs = list
    }

    if err != nil {
        logf(c.Req, "http: error reading directory: %v", err)
        http.Error(c.Writer, "Error reading directory", http.StatusInternalServerError)
        return
    }
    sort.Slice(dirs, func(i, j int) bool { return dirs.name(i) < dirs.name(j) })

    c.SetHeader("Content-Type", "text/html; charset=utf-8")
    fmt.Fprintf(c.Writer, "<pre>\n")
    for i, n := 0, dirs.len(); i < n; i++ {
        name := dirs.name(i)
        if dirs.isDir(i) {
            name += "/"
        }
        url := url.URL{Path: name}
        fmt.Fprintf(c.Writer, "<a href=\"%s\">%s</a>\n", url.String(), htmlReplacer.Replace(name))
    }
    fmt.Fprintf(c.Writer, "</pre>\n")
}

func (c *Context)checkPreconditions(modtime time.Time) (done bool, rangeHeader string) {
    rangeHeader = c.GetHeader("Range")
    if rangeHeader != "" && c.checkIfRange(modtime) == condFalse {
        rangeHeader = ""
    }
    return false, rangeHeader
}

func scanETag(s string) (etag string, remain string) {
    s = textproto.TrimString(s)
    start := 0
    if strings.HasPrefix(s, "W/") {
        start = 2
    }
    if len(s[start:]) < 2 || s[start] != '"' {
        return "", ""
    }
    for i := start + 1; i < len(s); i++ {
        c := s[i]
        switch {
        case c == 0x21 || c >= 0x23 && c <= 0x7E || c >= 0x80:
        case c == '"':
            return s[:i+1], s[i+1:]
        default:
            return "", ""
        }
    }
    return "", ""
}

func etagStrongMatch(a, b string) bool {
    return a == b && a != "" && a[0] == '"'
}

func (c *Context)checkIfRange(modtime time.Time) condResult {
    if c.Req.Method != "GET" && c.Req.Method != "HEAD" {
        return condNone
    }
    ir := c.GetHeader("If-Range")
    if ir == "" {
        return condNone
    }
    etag, _ := scanETag(ir)
    if etag != "" {
        if etagStrongMatch(etag, c.GetHeader("Etag")) {
            return condTrue
        } else {
            return condFalse
        }
    }
    if modtime.IsZero() {
        return condFalse
    }
    t, err := http.ParseTime(ir)
    if err != nil {
        return condFalse
    }
    if t.Unix() == modtime.Unix() {
        return condTrue
    }
    return condFalse
}

func sumRangesSize(ranges []httpRange) (size int64) {
    for _, ra := range ranges {
        size += ra.length
    }
    return
}

func (r httpRange) contentRange(size int64) string {
    return fmt.Sprintf("bytes %d-%d/%d", r.start, r.start+r.length-1, size)
}

func rangesMIMESize(ranges []httpRange, contentType string, contentSize int64) (encSize int64) {
    var w countingWriter
    mw := multipart.NewWriter(&w)
    for _, ra := range ranges {
        mw.CreatePart(ra.mimeHeader(contentType, contentSize))
        encSize += ra.length
    }
    mw.Close()
    encSize += int64(w)
    return
}

func (w *countingWriter) Write(p []byte) (n int, err error) {
    *w += countingWriter(len(p))
    return len(p), nil
}

func (r httpRange) mimeHeader(contentType string, size int64) textproto.MIMEHeader {
    return textproto.MIMEHeader{
        "Content-Range": {r.contentRange(size)},
        "Content-Type":  {contentType},
    }
}

func toHTTPError(err error) (msg string, httpStatus int) {
    if errors.Is(err, fs.ErrNotExist) {
        return "404 page not found", http.StatusNotFound
    }
    if errors.Is(err, fs.ErrPermission) {
        return "403 Forbidden", http.StatusForbidden
    }
    // Default:
    return "500 Internal Server Error", http.StatusInternalServerError
}

基本和标准库源码无差

然后在context.go新增个格式化的常量,用作Last-Modified时的验证

const TimeFormat = "Mon, 02 Jan 2006 15:04:05 GMT"

源地址