diff --git a/cmd/imdroidmix/main.go b/cmd/imdroidmix/main.go new file mode 100644 index 0000000..ba22415 --- /dev/null +++ b/cmd/imdroidmix/main.go @@ -0,0 +1,421 @@ +package main + +import ( + "context" + "database/sql" + "flag" + "fmt" + "log" + "math" + "os" + "sort" + "strings" + "time" + + "weatherstation/internal/config" + "weatherstation/internal/database" +) + +var ( + providerOrder = []string{"open-meteo", "caiyun", "imdroid"} + defaultWeight = []float64{0.4, 0.3, 0.3} + timeLayout = "2006-01-02 15:04:05" +) + +func main() { + var issuedInput string + var issuedRange string + var issuedStep int + var stationFilter string + var apply bool + + flag.StringVar(&issuedInput, "issued", "", "单个发布时间,例如 2025-10-05 20:00:00") + flag.StringVar(&issuedRange, "issued_range", "", "发布时间范围,格式 \"开始,结束\",例如 \"2025-10-05 20:00:00,2025-10-05 23:00:00\"") + flag.IntVar(&issuedStep, "issued_step_hours", 1, "发布时间遍历步长(小时)") + flag.StringVar(&stationFilter, "station", "", "可选,指定 station_id,仅处理该站点") + flag.BoolVar(&apply, "apply", false, "直接写入 forecast_hourly(默认仅计算)") + flag.Parse() + + if issuedInput == "" && issuedRange == "" { + fmt.Fprintln(os.Stderr, "需要提供 --issued 或 --issued_range 参数") + os.Exit(1) + } + if issuedInput != "" && issuedRange != "" { + fmt.Fprintln(os.Stderr, "不能同时提供 --issued 与 --issued_range") + os.Exit(1) + } + + // 预先加载配置(也用于校验配置文件是否存在) + _ = config.GetConfig() + + loc, err := time.LoadLocation("Asia/Shanghai") + if err != nil { + log.Printf("加载 Asia/Shanghai 时区失败,退回使用本地时区: %v", err) + loc = time.Local + } + + issuedTimes, err := buildIssuedTimes(loc, issuedInput, issuedRange, issuedStep) + if err != nil { + fmt.Fprintf(os.Stderr, "处理发布时间失败: %v\n", err) + os.Exit(1) + } + if len(issuedTimes) == 0 { + fmt.Fprintln(os.Stderr, "发布时间列表为空") + return + } + + ctx := context.Background() + db := database.GetDB() + defer database.Close() + + stations, err := loadStations(ctx, db, stationFilter) + if err != nil { + fmt.Fprintf(os.Stderr, "加载站点列表失败: %v\n", err) + os.Exit(1) + } + + if len(stations) == 0 { + fmt.Println("没有找到匹配的站点或数据。") + return + } + + type summaryEntry struct { + stationID string + issuedAt time.Time + targets [3]time.Time + weighted [3]float64 + hasSource bool + } + + var sqlBuilder strings.Builder + var summaries []summaryEntry + + sort.Strings(stations) + for _, issuedAt := range issuedTimes { + targetTimes := []time.Time{ + issuedAt.Add(1 * time.Hour), + issuedAt.Add(2 * time.Hour), + issuedAt.Add(3 * time.Hour), + } + for _, stationID := range stations { + matrix, hasSource, err := loadProviderMatrix(ctx, db, stationID, issuedAt, targetTimes) + if err != nil { + fmt.Fprintf(os.Stderr, "站点 %s 读取数据失败: %v\n", stationID, err) + continue + } + weighted := applyWeights(matrix, defaultWeight) + + if hasSource && len(targetTimes) == 3 { + sqlBuilder.WriteString(renderSQL(stationID, issuedAt, targetTimes, weighted)) + } + + if apply && hasSource { + if err := writeForecast(ctx, db, stationID, issuedAt, targetTimes, weighted); err != nil { + fmt.Fprintf(os.Stderr, "写入站点 %s 失败: %v\n", stationID, err) + } + } + + var targetsArr [3]time.Time + if len(targetTimes) == 3 { + targetsArr = [3]time.Time{targetTimes[0], targetTimes[1], targetTimes[2]} + } + summaries = append(summaries, summaryEntry{ + stationID: stationID, + issuedAt: issuedAt, + targets: targetsArr, + weighted: weighted, + hasSource: hasSource, + }) + } + } + + if sqlBuilder.Len() > 0 { + label := buildIssuedLabel(issuedTimes) + path, err := writeSQLFile(label, sqlBuilder.String()) + if err != nil { + fmt.Fprintf(os.Stderr, "写入 SQL 文件失败: %v\n", err) + } else { + fmt.Printf("SQL 已保存至 %s\n", path) + } + } + + for _, item := range summaries { + targetSlice := []time.Time{item.targets[0], item.targets[1], item.targets[2]} + if err := printMixedSummaryFromDB(ctx, db, item.stationID, item.issuedAt, targetSlice, item.weighted, item.hasSource); err != nil { + fmt.Fprintf(os.Stderr, "读取混合结果失败 station=%s issued=%s: %v\n", item.stationID, item.issuedAt.Format(timeLayout), err) + } + } +} + +func buildIssuedTimes(loc *time.Location, single, rangeInput string, stepHours int) ([]time.Time, error) { + if single == "" && rangeInput == "" { + return nil, fmt.Errorf("未提供发布时间") + } + if single != "" { + t, err := time.ParseInLocation(timeLayout, strings.TrimSpace(single), loc) + if err != nil { + return nil, fmt.Errorf("解析发布时间失败: %w", err) + } + return []time.Time{t}, nil + } + + parts := strings.Split(rangeInput, ",") + if len(parts) != 2 { + return nil, fmt.Errorf("issued_range 格式应为 \"开始,结束\"") + } + + start, err := time.ParseInLocation(timeLayout, strings.TrimSpace(parts[0]), loc) + if err != nil { + return nil, fmt.Errorf("解析开始时间失败: %w", err) + } + end, err := time.ParseInLocation(timeLayout, strings.TrimSpace(parts[1]), loc) + if err != nil { + return nil, fmt.Errorf("解析结束时间失败: %w", err) + } + if end.Before(start) { + return nil, fmt.Errorf("结束时间早于开始时间") + } + if stepHours <= 0 { + stepHours = 1 + } + + var list []time.Time + for current := start; !current.After(end); current = current.Add(time.Duration(stepHours) * time.Hour) { + list = append(list, current) + } + return list, nil +} + +func buildIssuedLabel(times []time.Time) string { + if len(times) == 0 { + return "none" + } + if len(times) == 1 { + return times[0].Format("20060102_150405") + } + first := times[0].Format("20060102_150405") + last := times[len(times)-1].Format("20060102_150405") + return fmt.Sprintf("%s-%s", first, last) +} + +func loadStations(ctx context.Context, db *sql.DB, filter string) ([]string, error) { + const baseQuery = ` +SELECT station_id +FROM stations +WHERE latitude IS NOT NULL + AND longitude IS NOT NULL` + + if filter != "" { + row := db.QueryRowContext(ctx, baseQuery+" AND station_id = $1", filter) + if err := row.Scan(&filter); err != nil { + if err == sql.ErrNoRows { + return nil, nil + } + return nil, err + } + return []string{filter}, nil + } + + rows, err := db.QueryContext(ctx, baseQuery) + if err != nil { + return nil, err + } + defer rows.Close() + + var stations []string + for rows.Next() { + var id string + if err := rows.Scan(&id); err != nil { + return nil, err + } + stations = append(stations, id) + } + return stations, rows.Err() +} + +func loadProviderMatrix(ctx context.Context, db *sql.DB, stationID string, issuedAt time.Time, targets []time.Time) (map[string][3]float64, bool, error) { + result := make(map[string][3]float64, len(providerOrder)) + for _, p := range providerOrder { + result[p] = [3]float64{} + } + + issuedUpper := issuedAt.Add(1 * time.Minute) + + query := ` +SELECT provider, forecast_time, COALESCE(rain_mm_x1000, 0) +FROM forecast_hourly +WHERE station_id = $1 + AND issued_at >= $2 + AND issued_at < $3 + AND provider IN ('open-meteo', 'caiyun', 'imdroid') + AND forecast_time IN ($4, $5, $6)` + + rows, err := db.QueryContext(ctx, query, stationID, issuedAt, issuedUpper, targets[0], targets[1], targets[2]) + if err != nil { + return nil, false, err + } + defer rows.Close() + + hasSource := false + for rows.Next() { + var provider string + var ft time.Time + var rain int64 + if err := rows.Scan(&provider, &ft, &rain); err != nil { + return nil, false, err + } + + offset := int(math.Round(ft.Sub(issuedAt).Hours())) + if offset < 1 || offset > len(targets) { + continue + } + + hasSource = true + data := result[provider] + data[offset-1] = float64(rain) / 1000.0 + result[provider] = data + } + return result, hasSource, rows.Err() +} + +func applyWeights(matrix map[string][3]float64, weight []float64) [3]float64 { + var out [3]float64 + for hour := 0; hour < 3; hour++ { + var sum float64 + for idx, provider := range providerOrder { + sum += matrix[provider][hour] * weight[idx] + } + out[hour] = sum + } + return out +} + +func renderSQL(stationID string, issuedAt time.Time, targets []time.Time, weighted [3]float64) string { + const sqlTemplate = ` +INSERT INTO forecast_hourly ( + station_id, provider, issued_at, forecast_time, + rain_mm_x1000, temp_c_x100, humidity_pct, wind_speed_ms_x1000, + wind_gust_ms_x1000, wind_dir_deg, precip_prob_pct, pressure_hpa_x100 +) VALUES ( + '%s', 'imdroid_mix', '%s'::timestamptz, '%s'::timestamptz, + %d, 0, 0, 0, 0, 0, 0, 0 +) +ON CONFLICT (station_id, provider, issued_at, forecast_time) +DO UPDATE SET + rain_mm_x1000 = EXCLUDED.rain_mm_x1000, + temp_c_x100 = EXCLUDED.temp_c_x100, + humidity_pct = EXCLUDED.humidity_pct, + wind_speed_ms_x1000 = EXCLUDED.wind_speed_ms_x1000, + wind_gust_ms_x1000 = EXCLUDED.wind_gust_ms_x1000, + wind_dir_deg = EXCLUDED.wind_dir_deg, + precip_prob_pct = EXCLUDED.precip_prob_pct, + pressure_hpa_x100 = EXCLUDED.pressure_hpa_x100; +` + var b strings.Builder + fmt.Fprintf(&b, "-- SQL 输出 (默认未执行) station=%s\n", stationID) + for idx, ft := range targets { + rainInt := int64(math.Round(weighted[idx] * 1000)) + fmt.Fprintf(&b, sqlTemplate, stationID, issuedAt.Format(time.RFC3339), ft.Format(time.RFC3339), rainInt) + } + b.WriteByte('\n') + return b.String() +} + +func printMixedSummaryFromDB(ctx context.Context, db *sql.DB, stationID string, issuedAt time.Time, targets []time.Time, expected [3]float64, hasSource bool) error { + query := ` +SELECT forecast_time, COALESCE(rain_mm_x1000, 0) +FROM forecast_hourly +WHERE station_id = $1 + AND provider = 'imdroid_mix' + AND issued_at = $2 + AND forecast_time IN ($3, $4, $5) +ORDER BY forecast_time` + + rows, err := db.QueryContext(ctx, query, stationID, issuedAt, targets[0], targets[1], targets[2]) + if err != nil { + return err + } + defer rows.Close() + + type entry struct { + ft time.Time + rain float64 + } + var list []entry + for rows.Next() { + var ft time.Time + var rain int64 + if err := rows.Scan(&ft, &rain); err != nil { + return err + } + list = append(list, entry{ft: ft, rain: float64(rain) / 1000.0}) + } + if err := rows.Err(); err != nil { + return err + } + + if len(list) == 0 { + fmt.Printf("imdroid_mix station=%s issued=%s => [", stationID, issuedAt.Format(timeLayout)) + for idx, ft := range targets { + if idx > 0 { + fmt.Print(", ") + } + fmt.Printf("{hour:%d, dt:\"%s\", rain_mm:%.3f}", idx+1, ft.Format(timeLayout), expected[idx]) + } + label := "computed" + if !hasSource { + label += ", no_source" + } + fmt.Printf("] (%s)\n", label) + return nil + } + + fmt.Printf("imdroid_mix station=%s issued=%s => [", stationID, issuedAt.Format(timeLayout)) + for idx, item := range list { + if idx > 0 { + fmt.Print(", ") + } + fmt.Printf("{hour:%d, dt:\"%s\", rain_mm:%.3f}", idx+1, item.ft.Format(timeLayout), item.rain) + } + fmt.Println("] (db)") + return nil +} + +func writeSQLFile(label string, sqlContent string) (string, error) { + if !strings.HasSuffix(sqlContent, "\n") { + sqlContent += "\n" + } + runKey := time.Now().Format("20060102_150405") + filename := fmt.Sprintf("imdroid_mix_%s_%s.sql", label, runKey) + if err := os.WriteFile(filename, []byte(sqlContent), 0o644); err != nil { + return "", err + } + return filename, nil +} + +func writeForecast(ctx context.Context, db *sql.DB, stationID string, issuedAt time.Time, targets []time.Time, weighted [3]float64) error { + for idx, ft := range targets { + rain := int64(math.Round(weighted[idx] * 1000)) + _, err := db.ExecContext(ctx, ` +INSERT INTO forecast_hourly ( + station_id, provider, issued_at, forecast_time, + rain_mm_x1000, temp_c_x100, humidity_pct, wind_speed_ms_x1000, + wind_gust_ms_x1000, wind_dir_deg, precip_prob_pct, pressure_hpa_x100 +) VALUES ($1, 'imdroid_mix', $2, $3, $4, 0, 0, 0, 0, 0, 0, 0) +ON CONFLICT (station_id, provider, issued_at, forecast_time) +DO UPDATE SET + rain_mm_x1000 = EXCLUDED.rain_mm_x1000, + temp_c_x100 = EXCLUDED.temp_c_x100, + humidity_pct = EXCLUDED.humidity_pct, + wind_speed_ms_x1000 = EXCLUDED.wind_speed_ms_x1000, + wind_gust_ms_x1000 = EXCLUDED.wind_gust_ms_x1000, + wind_dir_deg = EXCLUDED.wind_dir_deg, + precip_prob_pct = EXCLUDED.precip_prob_pct, + pressure_hpa_x100 = EXCLUDED.pressure_hpa_x100 +`, stationID, issuedAt, ft, rain) + if err != nil { + return err + } + } + return nil +}