289 lines
9.3 KiB
Go
289 lines
9.3 KiB
Go
package main
|
||
|
||
import (
|
||
"context"
|
||
"flag"
|
||
"fmt"
|
||
"io"
|
||
"log"
|
||
"os"
|
||
"path/filepath"
|
||
"strings"
|
||
"time"
|
||
|
||
"weatherstation/core/internal/data"
|
||
)
|
||
|
||
// V6 导出工具:
|
||
// - 以 imdroid_mix 为基线 b_t(+k)
|
||
// - 残差 e_t(+k) 优先使用“上一轮 V6 的预测误差”,冷启动时退回使用 mix 的历史预测误差
|
||
// - out(+1) = max(0, base1 + 1.0*e1)
|
||
// - out(+2) = max(0, base2 + 0.5*e2)
|
||
// - out(+3) = max(0, base3 + (1/3)*e3)
|
||
// - 仅生成 SQL 与日志,不写库
|
||
|
||
const (
|
||
baseProvider = "imdroid_mix"
|
||
outProvider = "imdroid_V6"
|
||
)
|
||
|
||
type v6Out struct {
|
||
FT time.Time
|
||
Rain float64
|
||
}
|
||
|
||
func main() {
|
||
var stationsCSV, startStr, endStr, sqlOut, logOut, tzName string
|
||
flag.StringVar(&stationsCSV, "stations", "", "逗号分隔的 station_id 列表,例如: RS485-000001,RS485-000002")
|
||
flag.StringVar(&startStr, "start", "", "开始时间,格式: 2006-01-02 15:00 或 2006-01-02(按整点对齐)")
|
||
flag.StringVar(&endStr, "end", "", "结束时间,格式: 2006-01-02 15:00 或 2006-01-02(不包含该时刻)")
|
||
flag.StringVar(&sqlOut, "sql", "v6_output.sql", "输出 SQL 文件路径")
|
||
flag.StringVar(&logOut, "log", "v6_output.log", "输出日志文件路径")
|
||
flag.StringVar(&tzName, "tz", "Asia/Shanghai", "时区,例如 Asia/Shanghai")
|
||
flag.Parse()
|
||
|
||
if stationsCSV == "" || startStr == "" || endStr == "" {
|
||
fmt.Println("用法: v6-export --stations RS485-XXXXXX --start '2024-08-01 00:00' --end '2024-08-02 00:00' --sql out.sql --log out.log")
|
||
os.Exit(2)
|
||
}
|
||
|
||
if err := os.MkdirAll(filepath.Dir(sqlOut), 0755); err != nil && filepath.Dir(sqlOut) != "." {
|
||
log.Fatalf("create sql dir: %v", err)
|
||
}
|
||
if err := os.MkdirAll(filepath.Dir(logOut), 0755); err != nil && filepath.Dir(logOut) != "." {
|
||
log.Fatalf("create log dir: %v", err)
|
||
}
|
||
lf, err := os.Create(logOut)
|
||
if err != nil {
|
||
log.Fatalf("open log file: %v", err)
|
||
}
|
||
defer lf.Close()
|
||
logger := log.New(io.MultiWriter(os.Stdout, lf), "", log.LstdFlags)
|
||
|
||
sf, err := os.Create(sqlOut)
|
||
if err != nil {
|
||
logger.Fatalf("open sql file: %v", err)
|
||
}
|
||
defer sf.Close()
|
||
fmt.Fprintf(sf, "-- V6 Export generated at %s\nBEGIN;\n", time.Now().Format(time.RFC3339))
|
||
|
||
loc, _ := time.LoadLocation(tzName)
|
||
if loc == nil {
|
||
loc = time.FixedZone("CST", 8*3600)
|
||
}
|
||
parse := func(s string) (time.Time, error) {
|
||
for _, ly := range []string{"2006-01-02 15:04", "2006-01-02 15", "2006-01-02"} {
|
||
if t, err := time.ParseInLocation(ly, s, loc); err == nil {
|
||
return t, nil
|
||
}
|
||
}
|
||
return time.Time{}, fmt.Errorf("invalid time: %s", s)
|
||
}
|
||
start, err := parse(startStr)
|
||
if err != nil {
|
||
logger.Fatalf("parse start: %v", err)
|
||
}
|
||
end, err := parse(endStr)
|
||
if err != nil {
|
||
logger.Fatalf("parse end: %v", err)
|
||
}
|
||
start = start.Truncate(time.Hour)
|
||
end = end.Truncate(time.Hour)
|
||
if !end.After(start) {
|
||
logger.Fatalf("end 必须大于 start")
|
||
}
|
||
|
||
stations := splitStations(stationsCSV)
|
||
ctx := context.Background()
|
||
|
||
for _, st := range stations {
|
||
logger.Printf("V6 导出 站点=%s 窗口=%s→%s", st, start.Format("2006-01-02 15:04"), end.Format("2006-01-02 15:04"))
|
||
// 维护一个“按验证时刻”的 V6 预测缓存:key=forecast_time,value=预测雨量
|
||
v6AtTime := make(map[time.Time]float64)
|
||
|
||
for t := start; t.Before(end); t = t.Add(time.Hour) {
|
||
res := computeV6AtHour(ctx, st, t, v6AtTime, logger)
|
||
if res.skipped {
|
||
logger.Printf("skip station=%s issued=%s: %s", st, t.Format("2006-01-02 15:04"), res.reason)
|
||
continue
|
||
}
|
||
// 写 SQL
|
||
for _, row := range res.sqlRows {
|
||
fmt.Fprintln(sf, row)
|
||
}
|
||
// 更新缓存:将本次的 +1/+2/+3 结果写入对应的验证时刻键
|
||
v6AtTime[t.Add(1*time.Hour)] = res.out[0]
|
||
v6AtTime[t.Add(2*time.Hour)] = res.out[1]
|
||
v6AtTime[t.Add(3*time.Hour)] = res.out[2]
|
||
|
||
logger.Printf("V6 %s issued=%s base=[%.3f,%.3f,%.3f] actual=%.3f prev=[%.3f,%.3f,%.3f] out=[%.3f,%.3f,%.3f] src=[%s,%s,%s]",
|
||
st, t.Format("2006-01-02 15:04"), res.base[0], res.base[1], res.base[2], res.actual,
|
||
res.prev[0], res.prev[1], res.prev[2], res.out[0], res.out[1], res.out[2],
|
||
res.src[0], res.src[1], res.src[2])
|
||
}
|
||
}
|
||
fmt.Fprintln(sf, "COMMIT;")
|
||
logger.Printf("完成,SQL: %s 日志: %s", sqlOut, logOut)
|
||
}
|
||
|
||
type v6Result struct {
|
||
base [3]float64
|
||
prev [3]float64
|
||
src [3]string // 使用的前一预测来源:V6 或 mix
|
||
out [3]float64
|
||
actual float64
|
||
sqlRows []string
|
||
skipped bool
|
||
reason string
|
||
}
|
||
|
||
func computeV6AtHour(ctx context.Context, stationID string, issued time.Time, v6AtTime map[time.Time]float64, logger *log.Logger) v6Result {
|
||
var res v6Result
|
||
|
||
// 读取基线:当期小时桶内 mix 最新 issued 的 +1/+2/+3
|
||
baseIssued, ok, err := data.ResolveIssuedAtInBucket(ctx, stationID, baseProvider, issued)
|
||
if err != nil || !ok {
|
||
res.skipped, res.reason = true, fmt.Sprintf("base issued missing: %v ok=%v", err, ok)
|
||
return res
|
||
}
|
||
pts, err := data.ForecastRainAtIssued(ctx, stationID, baseProvider, baseIssued)
|
||
if err != nil || len(pts) < 3 {
|
||
res.skipped, res.reason = true, fmt.Sprintf("base points insufficient: %v len=%d", err, len(pts))
|
||
return res
|
||
}
|
||
ft1, ft2, ft3 := issued.Add(time.Hour), issued.Add(2*time.Hour), issued.Add(3*time.Hour)
|
||
base1, base2, base3 := pickRain(pts, ft1), pickRain(pts, ft2), pickRain(pts, ft3)
|
||
res.base = [3]float64{base1, base2, base3}
|
||
|
||
// 实况:刚结束一小时 [t-1,t)
|
||
actual, okA, err := data.FetchActualHourlyRain(ctx, stationID, issued.Add(-time.Hour), issued)
|
||
if err != nil || !okA {
|
||
res.skipped, res.reason = true, fmt.Sprintf("actual missing: %v ok=%v", err, okA)
|
||
return res
|
||
}
|
||
res.actual = actual
|
||
|
||
// 前一预测(优先 V6 缓存,否则退回 mix 历史)
|
||
// +1:需要 (t-1) 发布、验证时刻 t 的预测值
|
||
vPrev1, src1, ok1 := prevForValidation(ctx, stationID, issued, 1, v6AtTime)
|
||
vPrev2, src2, ok2 := prevForValidation(ctx, stationID, issued, 2, v6AtTime)
|
||
vPrev3, src3, ok3 := prevForValidation(ctx, stationID, issued, 3, v6AtTime)
|
||
if !(ok1 && ok2 && ok3) {
|
||
// 若冷启动,允许个别 lead 不可用时跳过;也可以只输出可用的 lead,这里采取全量可用才输出
|
||
res.skipped, res.reason = true, fmt.Sprintf("prev missing leads: h1=%v h2=%v h3=%v", ok1, ok2, ok3)
|
||
return res
|
||
}
|
||
res.prev = [3]float64{vPrev1, vPrev2, vPrev3}
|
||
res.src = [3]string{src1, src2, src3}
|
||
|
||
// 残差与输出
|
||
e1 := actual - vPrev1
|
||
e2 := actual - vPrev2
|
||
e3 := actual - vPrev3
|
||
cand1 := base1 + 1.0*e1
|
||
cand2 := base2 + 0.5*e2
|
||
cand3 := base3 + (1.0/3.0)*e3
|
||
var out1, out2, out3 float64
|
||
if cand1 < 0 {
|
||
out1 = base1
|
||
} else {
|
||
out1 = cand1
|
||
}
|
||
if cand2 < 0 {
|
||
out2 = base2
|
||
} else {
|
||
out2 = cand2
|
||
}
|
||
if cand3 < 0 {
|
||
out3 = base3
|
||
} else {
|
||
out3 = cand3
|
||
}
|
||
res.out = [3]float64{out1, out2, out3}
|
||
|
||
// 生成 SQL(仅雨量 upsert)
|
||
rows := make([]string, 0, 3)
|
||
rows = append(rows, insertRainSQL(stationID, outProvider, issued, ft1, toX1000(out1)))
|
||
rows = append(rows, insertRainSQL(stationID, outProvider, issued, ft2, toX1000(out2)))
|
||
rows = append(rows, insertRainSQL(stationID, outProvider, issued, ft3, toX1000(out3)))
|
||
res.sqlRows = rows
|
||
return res
|
||
}
|
||
|
||
// prevForValidation 返回用于“验证时刻=issued+0h”的上一预测:优先使用 V6 的缓存;如无则退回 mix 的历史。
|
||
func prevForValidation(ctx context.Context, stationID string, issued time.Time, lead int, v6AtTime map[time.Time]float64) (float64, string, bool) {
|
||
// 需要的验证时刻
|
||
vt := issued // 验证在 t
|
||
// 先看 V6 缓存:我们在前面会把 V6 的结果按 forecast_time 存入 map
|
||
if v, ok := v6AtTime[vt]; ok {
|
||
return v, "V6", true
|
||
}
|
||
// 否则退回 mix 历史:在 (t-lead) 的小时桶内,取最新 issued 的 +lead
|
||
prevBucket := issued.Add(-time.Duration(lead) * time.Hour)
|
||
iss, ok, err := data.ResolveIssuedAtInBucket(ctx, stationID, baseProvider, prevBucket)
|
||
if err != nil || !ok {
|
||
return 0, "", false
|
||
}
|
||
pts, err := data.ForecastRainAtIssued(ctx, stationID, baseProvider, iss)
|
||
if err != nil || len(pts) < lead {
|
||
return 0, "", false
|
||
}
|
||
// 直接用验证时刻 vt 精确匹配
|
||
if v := pickRain(pts, vt); v >= 0 {
|
||
return v, baseProvider, true
|
||
}
|
||
// 或退回位置索引
|
||
switch lead {
|
||
case 1:
|
||
return toMM(pts[0].RainMMx1000), baseProvider, true
|
||
case 2:
|
||
if len(pts) >= 2 {
|
||
return toMM(pts[1].RainMMx1000), baseProvider, true
|
||
}
|
||
case 3:
|
||
if len(pts) >= 3 {
|
||
return toMM(pts[2].RainMMx1000), baseProvider, true
|
||
}
|
||
}
|
||
return 0, "", false
|
||
}
|
||
|
||
func splitStations(s string) []string {
|
||
parts := strings.Split(s, ",")
|
||
out := make([]string, 0, len(parts))
|
||
for _, p := range parts {
|
||
p = strings.TrimSpace(p)
|
||
if p != "" {
|
||
out = append(out, p)
|
||
}
|
||
}
|
||
return out
|
||
}
|
||
|
||
func pickRain(points []data.PredictPoint, ft time.Time) float64 {
|
||
for _, p := range points {
|
||
if p.ForecastTime.Equal(ft) {
|
||
return toMM(p.RainMMx1000)
|
||
}
|
||
}
|
||
return -1
|
||
}
|
||
|
||
func toMM(vx1000 int32) float64 { return float64(vx1000) / 1000.0 }
|
||
func toX1000(mm float64) int32 { return int32(mm*1000 + 0.5) }
|
||
func clamp0(v float64) float64 {
|
||
if v < 0 {
|
||
return 0
|
||
}
|
||
return v
|
||
}
|
||
|
||
func insertRainSQL(stationID, provider string, issued, ft time.Time, rainX1000 int32) string {
|
||
return fmt.Sprintf(
|
||
"INSERT INTO forecast_hourly (station_id, provider, issued_at, forecast_time, rain_mm_x1000) VALUES ('%s','%s','%s','%s',%d) ON CONFLICT (station_id, provider, issued_at, forecast_time) DO UPDATE SET rain_mm_x1000=EXCLUDED.rain_mm_x1000;",
|
||
escapeSQL(stationID), provider, issued.Format(time.RFC3339), ft.Format(time.RFC3339), rainX1000,
|
||
)
|
||
}
|
||
|
||
func escapeSQL(s string) string { return strings.ReplaceAll(s, "'", "''") }
|