import pygame
import sys
import json
import random
import math
import os
from pygame.locals import *

pygame.init()
WIDTH, HEIGHT = 1000, 640
screen = pygame.display.set_mode((WIDTH, HEIGHT))
pygame.display.set_caption("成绩走势图 Demo")
clock = pygame.time.Clock()
FONT = pygame.font.SysFont("Arial", 16)
FONT_B = pygame.font.SysFont("Arial", 18, bold=True)

BG = (24, 30, 40)
AXIS_COLOR = (200, 200, 200)
GRID_COLOR = (70, 80, 95)
TEXT_COLOR = (230, 230, 230)
TOOLTIP_BG = (20, 24, 30)

MARGIN_LEFT = 80
MARGIN_BOTTOM = 60
MARGIN_TOP = 40
MARGIN_RIGHT = 40

DATA_FILE = "scores_data.json"

# 示例数据：每科为 (x-value, score) 列表。x 可以表示考试编号或日期（这里用序号）
data_series = {
    "语文": [(1, 82), (2, 85), (3, 78), (4, 88), (5, 90)],
    "数学": [(1, 75), (2, 80), (3, 85), (4, 87), (5, 92)],
    "英语": [(1, 88), (2, 82), (3, 84), (4, 86), (5, 85)],
}

# 每条线的颜色
colors = {"语文": (255, 160, 100), "数学": (120, 200, 255), "英语": (160, 255, 140)}
series_order = list(data_series.keys())

# 视图变换参数（数据坐标 -> 屏幕坐标）
offset_x = 0.0  # 平移（像素）
offset_y = 0.0
scale = 1.0     # 缩放（像素/单位x）

# 交互状态
dragging = False
drag_start = (0, 0)
offset_start = (0, 0)
mid_dragging = False


def save_data(filename=DATA_FILE):
    with open(filename, "w", encoding="utf-8") as f:
        json.dump(data_series, f, ensure_ascii=False, indent=2)
    print("Saved to", filename)


def load_data(filename=DATA_FILE):
    global data_series, series_order
    if os.path.exists(filename):
        with open(filename, "r", encoding="utf-8") as f:
            data_series = json.load(f)
        series_order = list(data_series.keys())
        print("Loaded", filename)
    else:
        print("No data file", filename)


def data_bounds():
    # 计算所有数据的 x,y 范围
    xs = []
    ys = []
    for series in data_series.values():
        for x, y in series:
            xs.append(x)
            ys.append(y)
    if not xs:
        return (0, 1, 0, 100)
    xmin, xmax = min(xs), max(xs)
    ymin, ymax = min(ys), max(ys)
    # 让范围有余量
    dx = max(1, xmax-xmin)
    dy = max(1, ymax-ymin)
    return (xmin - 0.1*dx, xmax + 0.1*dx, max(0, ymin - 0.1*dy), min(100, ymax + 0.1*dy))


def world_to_screen(x, y, bounds, plot_rect):
    xmin, xmax, ymin, ymax = bounds
    pw, ph = plot_rect.width, plot_rect.height
    # apply world scale (x -> pixel) with zoom/scale and offset
    # base mapping (without zoom/offset):
    sx = plot_rect.x + ((x - xmin) / (xmax - xmin)) * pw
    sy = plot_rect.y + ph - ((y - ymin) / (ymax - ymin)) * ph
    # now apply scale and offset around center
    cx = plot_rect.x + pw/2
    cy = plot_rect.y + ph/2
    sx = cx + (sx - cx) * scale + offset_x
    sy = cy + (sy - cy) * scale + offset_y
    return int(sx), int(sy)


def draw_axes(bounds, plot_rect):
    # grid lines and ticks
    xmin, xmax, ymin, ymax = bounds
    # draw background
    pygame.draw.rect(screen, (20, 24, 30), plot_rect)
    # vertical grid by x (use int ticks)
    nx = max(4, int((xmax - xmin)))
    for i in range(nx+1):
        t = xmin + (xmax-xmin)*i/nx
        xpix, _ = world_to_screen(t, ymin, bounds, plot_rect)
        pygame.draw.line(screen, GRID_COLOR, (xpix, plot_rect.y),
                         (xpix, plot_rect.y+plot_rect.height))
        # tick label
        label = FONT.render(str(round(t, 2)), True, TEXT_COLOR)
        screen.blit(label, (xpix-10, plot_rect.y+plot_rect.height+6))
    # horizontal grid by score (0..100 every 10)
    for val in range(0, 101, 10):
        _, ypix = world_to_screen(xmin, val, bounds, plot_rect)
        pygame.draw.line(screen, GRID_COLOR, (plot_rect.x, ypix),
                         (plot_rect.x+plot_rect.width, ypix))
        label = FONT.render(str(val), True, TEXT_COLOR)
        screen.blit(label, (plot_rect.x-48, ypix-8))
    # border
    pygame.draw.rect(screen, AXIS_COLOR, plot_rect, 1)


def draw_series(bounds, plot_rect, mouse_pos):
    # draw polylines and points; return info about nearest point to mouse
    nearest = None
    nearest_dist = 1e9
    for name in series_order:
        series = data_series.get(name, [])
        if not series:
            continue
        pts = []
        for (x, y) in series:
            sx, sy = world_to_screen(x, y, bounds, plot_rect)
            pts.append((sx, sy, x, y))
        # draw polyline
        color = colors.get(name, (200, 200, 200))
        if len(pts) >= 2:
            pygame.draw.lines(screen, color, False, [
                              (p[0], p[1]) for p in pts], 3)
        # draw circles
        for (sx, sy, x, y) in pts:
            pygame.draw.circle(screen, (0, 0, 0), (sx, sy), 6)
            pygame.draw.circle(screen, color, (sx, sy), 4)
            # detect mouse proximity
            mx, my = mouse_pos
            d = math.hypot(mx-sx, my-sy)
            if d < nearest_dist:
                nearest_dist = d
                nearest = (name, x, y, sx, sy)
    return nearest


def draw_legend():
    x = MARGIN_LEFT + 8
    y = 8
    for name in series_order:
        c = colors.get(name, (200, 200, 200))
        pygame.draw.rect(screen, c, (x, y, 18, 12))
        screen.blit(FONT.render(name, True, TEXT_COLOR), (x+26, y-2))
        x += 120


def main():
    global offset_x, offset_y, scale, dragging, drag_start, offset_start, mid_dragging
    running = True
    while running:
        dt = clock.tick(60)/1000.0
        mouse_pos = pygame.mouse.get_pos()
        for ev in pygame.event.get():
            if ev.type == QUIT:
                running = False
            elif ev.type == KEYDOWN:
                if ev.key == K_ESCAPE:
                    running = False
                elif ev.key == K_a:
                    # add random new point to each series for demo
                    for name, series in data_series.items():
                        next_x = max([x for x, _ in series]) + \
                            1 if series else 1
                        new_y = max(40, min(100, int(random.gauss(80, 8))))
                        series.append((next_x, new_y))
                elif ev.key == K_s:
                    save_data()
                elif ev.key == K_l:
                    load_data()
                elif ev.key == K_r:
                    scale = 1.0
                    offset_x = 0
                    offset_y = 0
            elif ev.type == MOUSEBUTTONDOWN:
                if ev.button == 1:
                    dragging = True
                    drag_start = ev.pos
                    offset_start = (offset_x, offset_y)
                elif ev.button == 2:
                    mid_dragging = True
                    drag_start = ev.pos
                    offset_start = (offset_x, offset_y)
                elif ev.button == 4:  # wheel up -> zoom in
                    scale *= 1.1
                elif ev.button == 5:  # wheel down -> zoom out
                    scale /= 1.1
            elif ev.type == MOUSEBUTTONUP:
                if ev.button == 1:
                    dragging = False
                elif ev.button == 2:
                    mid_dragging = False
            elif ev.type == MOUSEMOTION:
                if dragging or mid_dragging:
                    dx = ev.pos[0] - drag_start[0]
                    dy = ev.pos[1] - drag_start[1]
                    offset_x = offset_start[0] + dx
                    offset_y = offset_start[1] + dy

        # draw
        screen.fill(BG)
        plot_rect = pygame.Rect(MARGIN_LEFT, MARGIN_TOP, WIDTH -
                                MARGIN_LEFT - MARGIN_RIGHT, HEIGHT - MARGIN_TOP - MARGIN_BOTTOM)
        bounds = data_bounds()
        draw_axes(bounds, plot_rect)
        nearest = draw_series(bounds, plot_rect, mouse_pos)
        draw_legend()

        # title and hints
        screen.blit(FONT_B.render("成绩走势图", True, TEXT_COLOR),
                    (MARGIN_LEFT, HEIGHT-48))
        hint = "A: 加点  S: 保存  L: 读取  R: 重置  鼠标滚轮缩放  拖动平移"
        screen.blit(FONT.render(hint, True, TEXT_COLOR),
                    (MARGIN_LEFT+160, HEIGHT-46))

        # tooltip: show nearest point info if close
        if nearest:
            name, x, y, sx, sy = nearest
            mx, my = mouse_pos
            if math.hypot(mx-sx, my-sy) < 12:
                txt = f"{name}  第 {int(x)} 次  分数: {y}"
                # draw tooltip rect
                surf = FONT.render(txt, True, TEXT_COLOR)
                w, h = surf.get_size()
                tx = mx + 12
                ty = my - 12 - h
                pygame.draw.rect(screen, TOOLTIP_BG, (tx-6, ty-6, w+12, h+12))
                pygame.draw.rect(screen, (180, 180, 180),
                                 (tx-6, ty-6, w+12, h+12), 1)
                screen.blit(surf, (tx, ty))

        pygame.display.flip()

    pygame.quit()
    sys.exit()


if __name__ == "__main__":
    main()
