package cdn import ( "crypto/sha256" "encoding/hex" "encoding/json" "net/http" "net/http/httptest" "os" "testing" ) // ── Helpers ────────────────────────────────────────────────────────────────── func testHash(data []byte) string { h := sha256.Sum256(data) return hex.EncodeToString(h[:]) } // ── TestNewClientFromEnv ───────────────────────────────────────────────────── func TestNewClientFromEnv(t *testing.T) { // Set env vars. os.Setenv("VOLT_CDN_BLOBS_URL", "https://blobs.example.com") os.Setenv("VOLT_CDN_MANIFESTS_URL", "https://manifests.example.com") os.Setenv("BUNNY_API_KEY", "test-api-key-123") os.Setenv("BUNNY_STORAGE_ZONE", "test-zone") os.Setenv("BUNNY_REGION", "la") defer func() { os.Unsetenv("VOLT_CDN_BLOBS_URL") os.Unsetenv("VOLT_CDN_MANIFESTS_URL") os.Unsetenv("BUNNY_API_KEY") os.Unsetenv("BUNNY_STORAGE_ZONE") os.Unsetenv("BUNNY_REGION") }() // Use a non-existent config file so we rely purely on env. c, err := NewClientFromConfigFile("/nonexistent/config.yaml") if err != nil { t.Fatalf("NewClientFromConfigFile: %v", err) } if c.BlobsBaseURL != "https://blobs.example.com" { t.Errorf("BlobsBaseURL = %q, want %q", c.BlobsBaseURL, "https://blobs.example.com") } if c.ManifestsBaseURL != "https://manifests.example.com" { t.Errorf("ManifestsBaseURL = %q, want %q", c.ManifestsBaseURL, "https://manifests.example.com") } if c.StorageAPIKey != "test-api-key-123" { t.Errorf("StorageAPIKey = %q, want %q", c.StorageAPIKey, "test-api-key-123") } if c.StorageZoneName != "test-zone" { t.Errorf("StorageZoneName = %q, want %q", c.StorageZoneName, "test-zone") } if c.Region != "la" { t.Errorf("Region = %q, want %q", c.Region, "la") } } func TestNewClientDefaults(t *testing.T) { // Clear all relevant env vars. for _, key := range []string{ "VOLT_CDN_BLOBS_URL", "VOLT_CDN_MANIFESTS_URL", "BUNNY_API_KEY", "BUNNY_STORAGE_ZONE", "BUNNY_REGION", } { os.Unsetenv(key) } c, err := NewClientFromConfigFile("/nonexistent/config.yaml") if err != nil { t.Fatalf("NewClientFromConfigFile: %v", err) } if c.BlobsBaseURL != DefaultBlobsURL { t.Errorf("BlobsBaseURL = %q, want default %q", c.BlobsBaseURL, DefaultBlobsURL) } if c.ManifestsBaseURL != DefaultManifestsURL { t.Errorf("ManifestsBaseURL = %q, want default %q", c.ManifestsBaseURL, DefaultManifestsURL) } if c.Region != DefaultRegion { t.Errorf("Region = %q, want default %q", c.Region, DefaultRegion) } } func TestNewClientFromConfig(t *testing.T) { c := NewClientFromConfig("https://b.example.com", "https://m.example.com", "key", "zone") if c.BlobsBaseURL != "https://b.example.com" { t.Errorf("BlobsBaseURL = %q", c.BlobsBaseURL) } if c.StorageAPIKey != "key" { t.Errorf("StorageAPIKey = %q", c.StorageAPIKey) } } // ── TestPullBlob (integrity) ───────────────────────────────────────────────── func TestPullBlobIntegrity(t *testing.T) { content := []byte("hello stellarium blob") hash := testHash(content) srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { expectedPath := "/sha256:" + hash if r.URL.Path != expectedPath { http.NotFound(w, r) return } w.WriteHeader(http.StatusOK) w.Write(content) })) defer srv.Close() c := NewClientFromConfig(srv.URL, "", "", "") c.HTTPClient = srv.Client() data, err := c.PullBlob(hash) if err != nil { t.Fatalf("PullBlob: %v", err) } if string(data) != string(content) { t.Errorf("PullBlob data = %q, want %q", data, content) } } func TestPullBlobHashVerification(t *testing.T) { content := []byte("original content") hash := testHash(content) // Serve tampered content that doesn't match the hash. tampered := []byte("tampered content!!!") srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) w.Write(tampered) })) defer srv.Close() c := NewClientFromConfig(srv.URL, "", "", "") c.HTTPClient = srv.Client() _, err := c.PullBlob(hash) if err == nil { t.Fatal("PullBlob should fail on tampered content, got nil error") } if !contains(err.Error(), "integrity check failed") { t.Errorf("expected integrity error, got: %v", err) } } func TestPullBlobNotFound(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { http.NotFound(w, r) })) defer srv.Close() c := NewClientFromConfig(srv.URL, "", "", "") c.HTTPClient = srv.Client() _, err := c.PullBlob("abcdef123456abcdef123456abcdef123456abcdef123456abcdef123456abcd") if err == nil { t.Fatal("PullBlob should fail on 404") } if !contains(err.Error(), "HTTP 404") { t.Errorf("expected HTTP 404 error, got: %v", err) } } // ── TestPullManifest ───────────────────────────────────────────────────────── func TestPullManifest(t *testing.T) { manifest := Manifest{ Name: "test-image", CreatedAt: "2024-01-01T00:00:00Z", Objects: map[string]string{ "usr/bin/hello": "aabbccdd", "etc/config": "eeff0011", }, } manifestJSON, _ := json.Marshal(manifest) srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/v2/public/test-image/latest.json" { http.NotFound(w, r) return } w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) w.Write(manifestJSON) })) defer srv.Close() c := NewClientFromConfig("", srv.URL, "", "") c.HTTPClient = srv.Client() m, err := c.PullManifest("test-image") if err != nil { t.Fatalf("PullManifest: %v", err) } if m.Name != "test-image" { t.Errorf("Name = %q, want %q", m.Name, "test-image") } if len(m.Objects) != 2 { t.Errorf("Objects count = %d, want 2", len(m.Objects)) } } func TestPullManifestNotFound(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { http.NotFound(w, r) })) defer srv.Close() c := NewClientFromConfig("", srv.URL, "", "") c.HTTPClient = srv.Client() _, err := c.PullManifest("nonexistent") if err == nil { t.Fatal("PullManifest should fail on 404") } if !contains(err.Error(), "not found") { t.Errorf("expected 'not found' error, got: %v", err) } } // ── TestBlobExists ─────────────────────────────────────────────────────────── func TestBlobExists(t *testing.T) { existingHash := "aabbccddee112233aabbccddee112233aabbccddee112233aabbccddee112233" srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodHead { t.Errorf("expected HEAD, got %s", r.Method) } if r.URL.Path == "/sha256:"+existingHash { w.WriteHeader(http.StatusOK) } else { w.WriteHeader(http.StatusNotFound) } })) defer srv.Close() c := NewClientFromConfig(srv.URL, "", "", "") c.HTTPClient = srv.Client() exists, err := c.BlobExists(existingHash) if err != nil { t.Fatalf("BlobExists: %v", err) } if !exists { t.Error("BlobExists = false, want true") } exists, err = c.BlobExists("0000000000000000000000000000000000000000000000000000000000000000") if err != nil { t.Fatalf("BlobExists: %v", err) } if exists { t.Error("BlobExists = true, want false") } } // ── TestPushBlob ───────────────────────────────────────────────────────────── func TestPushBlob(t *testing.T) { content := []byte("push me to CDN") hash := testHash(content) var receivedKey string var receivedBody []byte srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPut { t.Errorf("expected PUT, got %s", r.Method) } receivedKey = r.Header.Get("AccessKey") var err error receivedBody, err = readAll(r.Body) if err != nil { t.Errorf("read body: %v", err) } w.WriteHeader(http.StatusCreated) })) defer srv.Close() // Override the storage URL by setting region to a dummy value and using // the test server URL directly. We'll need to construct the client manually. c := &Client{ BlobsBaseURL: srv.URL, StorageAPIKey: "test-key-456", StorageZoneName: "test-zone", Region: "ny", HTTPClient: srv.Client(), } // Override the storage endpoint to use our test server. // We need to monkeypatch the push URL. Since the real URL uses bunnycdn.com, // we'll create a custom roundtripper. c.HTTPClient.Transport = &rewriteTransport{ inner: srv.Client().Transport, targetURL: srv.URL, } err := c.PushBlob(hash, content) if err != nil { t.Fatalf("PushBlob: %v", err) } if receivedKey != "test-key-456" { t.Errorf("AccessKey header = %q, want %q", receivedKey, "test-key-456") } if string(receivedBody) != string(content) { t.Errorf("body = %q, want %q", receivedBody, content) } } func TestPushBlobHashMismatch(t *testing.T) { content := []byte("some content") wrongHash := "0000000000000000000000000000000000000000000000000000000000000000" c := &Client{ StorageAPIKey: "key", StorageZoneName: "zone", HTTPClient: &http.Client{}, } err := c.PushBlob(wrongHash, content) if err == nil { t.Fatal("PushBlob should fail on hash mismatch") } if !contains(err.Error(), "hash mismatch") { t.Errorf("expected hash mismatch error, got: %v", err) } } func TestPushBlobNoAPIKey(t *testing.T) { c := &Client{ StorageAPIKey: "", StorageZoneName: "zone", HTTPClient: &http.Client{}, } err := c.PushBlob("abc", []byte("data")) if err == nil { t.Fatal("PushBlob should fail without API key") } if !contains(err.Error(), "StorageAPIKey not configured") { t.Errorf("expected 'not configured' error, got: %v", err) } } // ── TestExpandEnv ──────────────────────────────────────────────────────────── func TestExpandEnv(t *testing.T) { os.Setenv("TEST_CDN_VAR", "expanded-value") defer os.Unsetenv("TEST_CDN_VAR") result := expandEnv("${TEST_CDN_VAR}") if result != "expanded-value" { t.Errorf("expandEnv = %q, want %q", result, "expanded-value") } // No expansion when no pattern. result = expandEnv("plain-string") if result != "plain-string" { t.Errorf("expandEnv = %q, want %q", result, "plain-string") } } // ── TestConfigFile ─────────────────────────────────────────────────────────── func TestConfigFileLoading(t *testing.T) { // Clear env vars so config file values are used. for _, key := range []string{ "VOLT_CDN_BLOBS_URL", "VOLT_CDN_MANIFESTS_URL", "BUNNY_API_KEY", "BUNNY_STORAGE_ZONE", "BUNNY_REGION", } { os.Unsetenv(key) } os.Setenv("MY_API_KEY", "from-env-ref") defer os.Unsetenv("MY_API_KEY") // Write a temp config file. configContent := `cdn: blobs_url: "https://custom-blobs.example.com" manifests_url: "https://custom-manifests.example.com" storage_api_key: "${MY_API_KEY}" storage_zone: "my-zone" region: "sg" ` tmpFile, err := os.CreateTemp("", "volt-config-*.yaml") if err != nil { t.Fatalf("create temp: %v", err) } defer os.Remove(tmpFile.Name()) if _, err := tmpFile.WriteString(configContent); err != nil { t.Fatalf("write temp: %v", err) } tmpFile.Close() c, err := NewClientFromConfigFile(tmpFile.Name()) if err != nil { t.Fatalf("NewClientFromConfigFile: %v", err) } if c.BlobsBaseURL != "https://custom-blobs.example.com" { t.Errorf("BlobsBaseURL = %q", c.BlobsBaseURL) } if c.ManifestsBaseURL != "https://custom-manifests.example.com" { t.Errorf("ManifestsBaseURL = %q", c.ManifestsBaseURL) } if c.StorageAPIKey != "from-env-ref" { t.Errorf("StorageAPIKey = %q, want %q", c.StorageAPIKey, "from-env-ref") } if c.StorageZoneName != "my-zone" { t.Errorf("StorageZoneName = %q", c.StorageZoneName) } if c.Region != "sg" { t.Errorf("Region = %q", c.Region) } } // ── Test Helpers ───────────────────────────────────────────────────────────── func contains(s, substr string) bool { return len(s) >= len(substr) && searchString(s, substr) } func searchString(s, substr string) bool { for i := 0; i <= len(s)-len(substr); i++ { if s[i:i+len(substr)] == substr { return true } } return false } func readAll(r interface{ Read([]byte) (int, error) }) ([]byte, error) { var buf []byte tmp := make([]byte, 4096) for { n, err := r.Read(tmp) if n > 0 { buf = append(buf, tmp[:n]...) } if err != nil { if err.Error() == "EOF" { break } return buf, err } } return buf, nil } // rewriteTransport rewrites all requests to point at a test server. type rewriteTransport struct { inner http.RoundTripper targetURL string } func (t *rewriteTransport) RoundTrip(req *http.Request) (*http.Response, error) { // Replace the host with our test server. req.URL.Scheme = "http" req.URL.Host = stripScheme(t.targetURL) transport := t.inner if transport == nil { transport = http.DefaultTransport } return transport.RoundTrip(req) } func stripScheme(url string) string { if idx := findIndex(url, "://"); idx >= 0 { return url[idx+3:] } return url } func findIndex(s, substr string) int { for i := 0; i <= len(s)-len(substr); i++ { if s[i:i+len(substr)] == substr { return i } } return -1 }