Files
91/backend/internal/proxy/proxy.go
T

206 lines
4.7 KiB
Go

package proxy
import (
"context"
"io"
"net/http"
"net/url"
"sync"
"time"
"github.com/video-site/backend/internal/drives"
)
type streamURLWithHeader interface {
StreamURLWithHeader(ctx context.Context, fileID string, header http.Header) (*drives.StreamLink, error)
}
// Registry 管理多个 Drive 实例
type Registry struct {
mu sync.RWMutex
drives map[string]drives.Drive
}
func NewRegistry() *Registry {
return &Registry{drives: make(map[string]drives.Drive)}
}
func (r *Registry) Set(id string, d drives.Drive) {
r.mu.Lock()
defer r.mu.Unlock()
r.drives[id] = d
}
func (r *Registry) Get(id string) (drives.Drive, bool) {
r.mu.RLock()
defer r.mu.RUnlock()
d, ok := r.drives[id]
return d, ok
}
func (r *Registry) All() []drives.Drive {
r.mu.RLock()
defer r.mu.RUnlock()
out := make([]drives.Drive, 0, len(r.drives))
for _, d := range r.drives {
out = append(out, d)
}
return out
}
func (r *Registry) Remove(id string) {
r.mu.Lock()
defer r.mu.Unlock()
delete(r.drives, id)
}
// Proxy 根据 driveID + fileID 反向代理到真实网盘直链
type Proxy struct {
Registry *Registry
// linkCache key: driveID + "/" + fileID (+ User-Agent for UA-bound links)
cacheMu sync.Mutex
cache map[string]cachedLink
http *http.Client
}
type cachedLink struct {
link *drives.StreamLink
fetched time.Time
}
func New(r *Registry) *Proxy {
return &Proxy{
Registry: r,
cache: make(map[string]cachedLink),
http: &http.Client{
Timeout: 0, // 流式不设超时
},
}
}
func (p *Proxy) getLink(ctx context.Context, d drives.Drive, driveID, fileID string, header http.Header) (*drives.StreamLink, error) {
key := linkCacheKey(d, driveID, fileID, header)
p.cacheMu.Lock()
if c, ok := p.cache[key]; ok {
// 缓存 30 秒,且不超过 link.Expires
if time.Since(c.fetched) < 30*time.Second && time.Now().Before(c.link.Expires) {
p.cacheMu.Unlock()
return c.link, nil
}
}
p.cacheMu.Unlock()
var (
link *drives.StreamLink
err error
)
if h, ok := d.(streamURLWithHeader); ok {
link, err = h.StreamURLWithHeader(ctx, fileID, header)
} else {
link, err = d.StreamURL(ctx, fileID)
}
if err != nil {
return nil, err
}
p.cacheMu.Lock()
p.cache[key] = cachedLink{link: link, fetched: time.Now()}
p.cacheMu.Unlock()
return link, nil
}
func linkCacheKey(d drives.Drive, driveID, fileID string, header http.Header) string {
key := driveID + "/" + fileID
if _, ok := d.(streamURLWithHeader); ok {
key += "|ua=" + header.Get("User-Agent")
}
return key
}
func (p *Proxy) ServeStream(w http.ResponseWriter, r *http.Request, driveID, fileID string) {
d, ok := p.Registry.Get(driveID)
if !ok {
http.Error(w, errDriveNotFound.Error(), errDriveNotFound.Code)
return
}
link, err := p.getLink(r.Context(), d, driveID, fileID, r.Header)
if err != nil {
http.Error(w, err.Error(), http.StatusBadGateway)
return
}
if shouldRedirect(d) {
redirect(w, r, link)
return
}
p.serve(w, r, link)
}
func shouldRedirect(d drives.Drive) bool {
return d.Kind() == "p115"
}
func redirect(w http.ResponseWriter, r *http.Request, link *drives.StreamLink) {
w.Header().Set("Referrer-Policy", "no-referrer")
w.Header().Set("Cache-Control", "max-age=0, no-cache, no-store, must-revalidate")
http.Redirect(w, r, link.URL, http.StatusFound)
}
func (p *Proxy) serve(w http.ResponseWriter, r *http.Request, link *drives.StreamLink) {
// 构造上游请求
u, err := url.Parse(link.URL)
if err != nil {
http.Error(w, "bad upstream url", http.StatusBadGateway)
return
}
req, err := http.NewRequestWithContext(r.Context(), r.Method, u.String(), nil)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
// 复制上游请求头
for k, vs := range link.Headers {
for _, v := range vs {
req.Header.Add(k, v)
}
}
// 透传 Range
if rng := r.Header.Get("Range"); rng != "" {
req.Header.Set("Range", rng)
}
resp, err := p.http.Do(req)
if err != nil {
http.Error(w, err.Error(), http.StatusBadGateway)
return
}
defer resp.Body.Close()
// 透传响应头
for _, k := range []string{
"Content-Type", "Content-Length", "Content-Range",
"Accept-Ranges", "Last-Modified", "Etag",
} {
if v := resp.Header.Get(k); v != "" {
w.Header().Set(k, v)
}
}
w.Header().Set("Cache-Control", "private, max-age=300")
w.WriteHeader(resp.StatusCode)
_, _ = io.Copy(w, resp.Body)
}
// ServeLocal 服务本地 teaser 文件
func (p *Proxy) ServeLocal(w http.ResponseWriter, r *http.Request, path string) {
http.ServeFile(w, r, path)
}
var errDriveNotFound = &httpError{Code: http.StatusNotFound, Msg: "drive not found"}
type httpError struct {
Code int
Msg string
}
func (e *httpError) Error() string { return e.Msg }