mirror of
https://github.com/nianzhibai/91.git
synced 2026-06-15 00:44:30 +08:00
206 lines
4.7 KiB
Go
206 lines
4.7 KiB
Go
package proxy
|
|
|
|
import (
|
|
"context"
|
|
"io"
|
|
"net/http"
|
|
"net/url"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/video-site/backend/internal/drives"
|
|
)
|
|
|
|
type streamURLWithHeader interface {
|
|
StreamURLWithHeader(ctx context.Context, fileID string, header http.Header) (*drives.StreamLink, error)
|
|
}
|
|
|
|
// Registry 管理多个 Drive 实例
|
|
type Registry struct {
|
|
mu sync.RWMutex
|
|
drives map[string]drives.Drive
|
|
}
|
|
|
|
func NewRegistry() *Registry {
|
|
return &Registry{drives: make(map[string]drives.Drive)}
|
|
}
|
|
|
|
func (r *Registry) Set(id string, d drives.Drive) {
|
|
r.mu.Lock()
|
|
defer r.mu.Unlock()
|
|
r.drives[id] = d
|
|
}
|
|
|
|
func (r *Registry) Get(id string) (drives.Drive, bool) {
|
|
r.mu.RLock()
|
|
defer r.mu.RUnlock()
|
|
d, ok := r.drives[id]
|
|
return d, ok
|
|
}
|
|
|
|
func (r *Registry) All() []drives.Drive {
|
|
r.mu.RLock()
|
|
defer r.mu.RUnlock()
|
|
out := make([]drives.Drive, 0, len(r.drives))
|
|
for _, d := range r.drives {
|
|
out = append(out, d)
|
|
}
|
|
return out
|
|
}
|
|
|
|
func (r *Registry) Remove(id string) {
|
|
r.mu.Lock()
|
|
defer r.mu.Unlock()
|
|
delete(r.drives, id)
|
|
}
|
|
|
|
// Proxy 根据 driveID + fileID 反向代理到真实网盘直链
|
|
type Proxy struct {
|
|
Registry *Registry
|
|
// linkCache key: driveID + "/" + fileID (+ User-Agent for UA-bound links)
|
|
cacheMu sync.Mutex
|
|
cache map[string]cachedLink
|
|
http *http.Client
|
|
}
|
|
|
|
type cachedLink struct {
|
|
link *drives.StreamLink
|
|
fetched time.Time
|
|
}
|
|
|
|
func New(r *Registry) *Proxy {
|
|
return &Proxy{
|
|
Registry: r,
|
|
cache: make(map[string]cachedLink),
|
|
http: &http.Client{
|
|
Timeout: 0, // 流式不设超时
|
|
},
|
|
}
|
|
}
|
|
|
|
func (p *Proxy) getLink(ctx context.Context, d drives.Drive, driveID, fileID string, header http.Header) (*drives.StreamLink, error) {
|
|
key := linkCacheKey(d, driveID, fileID, header)
|
|
|
|
p.cacheMu.Lock()
|
|
if c, ok := p.cache[key]; ok {
|
|
// 缓存 30 秒,且不超过 link.Expires
|
|
if time.Since(c.fetched) < 30*time.Second && time.Now().Before(c.link.Expires) {
|
|
p.cacheMu.Unlock()
|
|
return c.link, nil
|
|
}
|
|
}
|
|
p.cacheMu.Unlock()
|
|
|
|
var (
|
|
link *drives.StreamLink
|
|
err error
|
|
)
|
|
if h, ok := d.(streamURLWithHeader); ok {
|
|
link, err = h.StreamURLWithHeader(ctx, fileID, header)
|
|
} else {
|
|
link, err = d.StreamURL(ctx, fileID)
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
p.cacheMu.Lock()
|
|
p.cache[key] = cachedLink{link: link, fetched: time.Now()}
|
|
p.cacheMu.Unlock()
|
|
return link, nil
|
|
}
|
|
|
|
func linkCacheKey(d drives.Drive, driveID, fileID string, header http.Header) string {
|
|
key := driveID + "/" + fileID
|
|
if _, ok := d.(streamURLWithHeader); ok {
|
|
key += "|ua=" + header.Get("User-Agent")
|
|
}
|
|
return key
|
|
}
|
|
|
|
func (p *Proxy) ServeStream(w http.ResponseWriter, r *http.Request, driveID, fileID string) {
|
|
d, ok := p.Registry.Get(driveID)
|
|
if !ok {
|
|
http.Error(w, errDriveNotFound.Error(), errDriveNotFound.Code)
|
|
return
|
|
}
|
|
|
|
link, err := p.getLink(r.Context(), d, driveID, fileID, r.Header)
|
|
if err != nil {
|
|
http.Error(w, err.Error(), http.StatusBadGateway)
|
|
return
|
|
}
|
|
if shouldRedirect(d) {
|
|
redirect(w, r, link)
|
|
return
|
|
}
|
|
p.serve(w, r, link)
|
|
}
|
|
|
|
func shouldRedirect(d drives.Drive) bool {
|
|
return d.Kind() == "p115"
|
|
}
|
|
|
|
func redirect(w http.ResponseWriter, r *http.Request, link *drives.StreamLink) {
|
|
w.Header().Set("Referrer-Policy", "no-referrer")
|
|
w.Header().Set("Cache-Control", "max-age=0, no-cache, no-store, must-revalidate")
|
|
http.Redirect(w, r, link.URL, http.StatusFound)
|
|
}
|
|
|
|
func (p *Proxy) serve(w http.ResponseWriter, r *http.Request, link *drives.StreamLink) {
|
|
// 构造上游请求
|
|
u, err := url.Parse(link.URL)
|
|
if err != nil {
|
|
http.Error(w, "bad upstream url", http.StatusBadGateway)
|
|
return
|
|
}
|
|
req, err := http.NewRequestWithContext(r.Context(), r.Method, u.String(), nil)
|
|
if err != nil {
|
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
// 复制上游请求头
|
|
for k, vs := range link.Headers {
|
|
for _, v := range vs {
|
|
req.Header.Add(k, v)
|
|
}
|
|
}
|
|
// 透传 Range
|
|
if rng := r.Header.Get("Range"); rng != "" {
|
|
req.Header.Set("Range", rng)
|
|
}
|
|
|
|
resp, err := p.http.Do(req)
|
|
if err != nil {
|
|
http.Error(w, err.Error(), http.StatusBadGateway)
|
|
return
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
// 透传响应头
|
|
for _, k := range []string{
|
|
"Content-Type", "Content-Length", "Content-Range",
|
|
"Accept-Ranges", "Last-Modified", "Etag",
|
|
} {
|
|
if v := resp.Header.Get(k); v != "" {
|
|
w.Header().Set(k, v)
|
|
}
|
|
}
|
|
w.Header().Set("Cache-Control", "private, max-age=300")
|
|
w.WriteHeader(resp.StatusCode)
|
|
_, _ = io.Copy(w, resp.Body)
|
|
}
|
|
|
|
// ServeLocal 服务本地 teaser 文件
|
|
func (p *Proxy) ServeLocal(w http.ResponseWriter, r *http.Request, path string) {
|
|
http.ServeFile(w, r, path)
|
|
}
|
|
|
|
var errDriveNotFound = &httpError{Code: http.StatusNotFound, Msg: "drive not found"}
|
|
|
|
type httpError struct {
|
|
Code int
|
|
Msg string
|
|
}
|
|
|
|
func (e *httpError) Error() string { return e.Msg }
|