248 lines
6.6 KiB
Go
248 lines
6.6 KiB
Go
package admin
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/rand"
|
|
"encoding/hex"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"mime"
|
|
"net/http"
|
|
"net/url"
|
|
"path"
|
|
"path/filepath"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/aws/aws-sdk-go-v2/aws"
|
|
"github.com/aws/aws-sdk-go-v2/config"
|
|
"github.com/aws/aws-sdk-go-v2/credentials"
|
|
"github.com/aws/aws-sdk-go-v2/service/s3"
|
|
)
|
|
|
|
type ImageUploader struct {
|
|
client *s3.Client
|
|
bucket string
|
|
prefix string
|
|
publicBaseURL string
|
|
maxUploadBytes int64
|
|
}
|
|
|
|
type UploadedImage struct {
|
|
Key string `json:"key"`
|
|
URL string `json:"url"`
|
|
Markdown string `json:"markdown"`
|
|
Filename string `json:"filename"`
|
|
Size int64 `json:"size"`
|
|
ContentType string `json:"contentType"`
|
|
}
|
|
|
|
func NewImageUploader(ctx context.Context, cfg R2Config) (*ImageUploader, error) {
|
|
normalized, err := normalizeR2Config(cfg)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
awsConfig, err := config.LoadDefaultConfig(ctx,
|
|
config.WithRegion(normalized.Region),
|
|
config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider(
|
|
normalized.AccessKeyID,
|
|
normalized.SecretAccessKey,
|
|
"",
|
|
)),
|
|
)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
client := s3.NewFromConfig(awsConfig, func(options *s3.Options) {
|
|
options.BaseEndpoint = aws.String(normalized.Endpoint)
|
|
options.UsePathStyle = true
|
|
})
|
|
|
|
return &ImageUploader{
|
|
client: client,
|
|
bucket: normalized.Bucket,
|
|
prefix: normalized.Prefix,
|
|
publicBaseURL: normalized.PublicBaseURL,
|
|
maxUploadBytes: normalized.MaxUploadBytes,
|
|
}, nil
|
|
}
|
|
|
|
func (u *ImageUploader) Upload(ctx context.Context, filename string, body io.Reader, size int64) (UploadedImage, error) {
|
|
if u == nil {
|
|
return UploadedImage{}, errors.New("R2 image uploader is not configured")
|
|
}
|
|
if size <= 0 {
|
|
return UploadedImage{}, errors.New("file is empty")
|
|
}
|
|
if u.maxUploadBytes > 0 && size > u.maxUploadBytes {
|
|
return UploadedImage{}, fmt.Errorf("file is too large: max %d bytes", u.maxUploadBytes)
|
|
}
|
|
|
|
limit := u.maxUploadBytes
|
|
if limit <= 0 {
|
|
limit = 20 * 1024 * 1024
|
|
}
|
|
data, err := io.ReadAll(io.LimitReader(body, limit+1))
|
|
if err != nil {
|
|
return UploadedImage{}, err
|
|
}
|
|
if int64(len(data)) > limit {
|
|
return UploadedImage{}, fmt.Errorf("file is too large: max %d bytes", limit)
|
|
}
|
|
if len(data) == 0 {
|
|
return UploadedImage{}, errors.New("file is empty")
|
|
}
|
|
|
|
head := data
|
|
if len(head) > 512 {
|
|
head = head[:512]
|
|
}
|
|
contentType := http.DetectContentType(head)
|
|
if !strings.HasPrefix(contentType, "image/") {
|
|
return UploadedImage{}, errors.New("only image files are supported")
|
|
}
|
|
|
|
key, cleanFilename, err := u.objectKey(filename, contentType)
|
|
if err != nil {
|
|
return UploadedImage{}, err
|
|
}
|
|
|
|
if _, err := u.client.PutObject(ctx, &s3.PutObjectInput{
|
|
Bucket: aws.String(u.bucket),
|
|
Key: aws.String(key),
|
|
Body: bytes.NewReader(data),
|
|
ContentType: aws.String(contentType),
|
|
}); err != nil {
|
|
return UploadedImage{}, err
|
|
}
|
|
|
|
publicURL := strings.TrimRight(u.publicBaseURL, "/") + "/" + escapePath(key)
|
|
return UploadedImage{
|
|
Key: key,
|
|
URL: publicURL,
|
|
Markdown: fmt.Sprintf("", cleanFilename, publicURL),
|
|
Filename: cleanFilename,
|
|
Size: size,
|
|
ContentType: contentType,
|
|
}, nil
|
|
}
|
|
|
|
func (u *ImageUploader) objectKey(filename string, contentType string) (string, string, error) {
|
|
cleanFilename := sanitizeFilename(filename)
|
|
ext := strings.ToLower(filepath.Ext(cleanFilename))
|
|
if ext == "" {
|
|
extensions, _ := mime.ExtensionsByType(contentType)
|
|
if len(extensions) > 0 {
|
|
ext = extensions[0]
|
|
cleanFilename += ext
|
|
}
|
|
}
|
|
token, err := randomHex(4)
|
|
if err != nil {
|
|
return "", "", err
|
|
}
|
|
stem := strings.TrimSuffix(cleanFilename, filepath.Ext(cleanFilename))
|
|
keyName := fmt.Sprintf("%s-%s-%s%s", time.Now().Format("20060102-150405"), token, stem, ext)
|
|
return joinObjectPath(u.prefix, keyName), cleanFilename, nil
|
|
}
|
|
|
|
func normalizeR2Config(cfg R2Config) (R2Config, error) {
|
|
cfg.Endpoint = strings.TrimSpace(cfg.Endpoint)
|
|
cfg.Bucket = strings.Trim(strings.TrimSpace(cfg.Bucket), "/")
|
|
cfg.Prefix = strings.Trim(strings.TrimSpace(cfg.Prefix), "/")
|
|
cfg.AccessKeyID = strings.TrimSpace(cfg.AccessKeyID)
|
|
cfg.SecretAccessKey = strings.TrimSpace(cfg.SecretAccessKey)
|
|
cfg.Region = firstNonEmpty(cfg.Region, "auto")
|
|
cfg.PublicBaseURL = strings.TrimSpace(cfg.PublicBaseURL)
|
|
if cfg.MaxUploadBytes <= 0 {
|
|
cfg.MaxUploadBytes = 20 * 1024 * 1024
|
|
}
|
|
|
|
endpoint, err := url.Parse(cfg.Endpoint)
|
|
if err != nil || endpoint.Scheme == "" || endpoint.Host == "" {
|
|
return R2Config{}, errors.New("r2.endpoint must be a valid URL")
|
|
}
|
|
endpointPath := strings.Trim(endpoint.Path, "/")
|
|
endpoint.Path = ""
|
|
endpoint.RawPath = ""
|
|
endpoint.RawQuery = ""
|
|
endpoint.Fragment = ""
|
|
cfg.Endpoint = endpoint.String()
|
|
|
|
if cfg.Bucket == "" {
|
|
return R2Config{}, errors.New("r2.bucket is required")
|
|
}
|
|
bucketParts := strings.SplitN(cfg.Bucket, "/", 2)
|
|
cfg.Bucket = bucketParts[0]
|
|
if len(bucketParts) == 2 {
|
|
cfg.Prefix = joinObjectPath(bucketParts[1], cfg.Prefix)
|
|
}
|
|
if endpointPath != "" && endpointPath != cfg.Bucket {
|
|
cfg.Prefix = joinObjectPath(endpointPath, cfg.Prefix)
|
|
}
|
|
|
|
if cfg.AccessKeyID == "" || cfg.SecretAccessKey == "" {
|
|
return R2Config{}, errors.New("r2 access keys are required")
|
|
}
|
|
if cfg.PublicBaseURL == "" {
|
|
return R2Config{}, errors.New("r2.publicBaseUrl is required")
|
|
}
|
|
if !strings.HasPrefix(cfg.PublicBaseURL, "http://") && !strings.HasPrefix(cfg.PublicBaseURL, "https://") {
|
|
cfg.PublicBaseURL = "https://" + cfg.PublicBaseURL
|
|
}
|
|
return cfg, nil
|
|
}
|
|
|
|
func sanitizeFilename(filename string) string {
|
|
filename = strings.TrimSpace(filepath.Base(filename))
|
|
if filename == "" || filename == "." {
|
|
return "image"
|
|
}
|
|
var out strings.Builder
|
|
for _, r := range filename {
|
|
switch {
|
|
case r >= 'a' && r <= 'z', r >= 'A' && r <= 'Z', r >= '0' && r <= '9':
|
|
out.WriteRune(r)
|
|
case r == '.', r == '-', r == '_':
|
|
out.WriteRune(r)
|
|
default:
|
|
out.WriteByte('-')
|
|
}
|
|
}
|
|
cleaned := strings.Trim(out.String(), ".-") // keep keys readable and URL-safe
|
|
if cleaned == "" {
|
|
return "image"
|
|
}
|
|
return cleaned
|
|
}
|
|
|
|
func joinObjectPath(parts ...string) string {
|
|
clean := []string{}
|
|
for _, part := range parts {
|
|
part = strings.Trim(part, "/")
|
|
if part != "" {
|
|
clean = append(clean, part)
|
|
}
|
|
}
|
|
return path.Join(clean...)
|
|
}
|
|
|
|
func escapePath(value string) string {
|
|
parts := strings.Split(value, "/")
|
|
for index, part := range parts {
|
|
parts[index] = url.PathEscape(part)
|
|
}
|
|
return strings.Join(parts, "/")
|
|
}
|
|
|
|
func randomHex(bytes int) (string, error) {
|
|
buffer := make([]byte, bytes)
|
|
if _, err := rand.Read(buffer); err != nil {
|
|
return "", err
|
|
}
|
|
return hex.EncodeToString(buffer), nil
|
|
}
|