node build fixed

This commit is contained in:
ra_ma
2025-09-20 14:08:38 +01:00
parent c6ebbe069d
commit 3d298fa434
1516 changed files with 535727 additions and 2 deletions

View File

@@ -0,0 +1,111 @@
package util
import (
"fmt"
"io"
)
// CachedReadSeeker wraps an io.ReadSeekCloser and caches bytes as they are read.
// It implements io.ReadSeeker, allowing seeking within the already-cached
// range without hitting the underlying reader again.
// Additional reads beyond the cache will append to the cache automatically.
type CachedReadSeeker struct {
src io.ReadSeekCloser // underlying source
cache []byte // bytes read so far
pos int64 // current read position
}
func (c *CachedReadSeeker) Close() error {
return c.src.Close()
}
var _ io.ReadSeekCloser = (*CachedReadSeeker)(nil)
// NewCachedReadSeeker constructs a new CachedReadSeeker wrapping a io.ReadSeekCloser.
func NewCachedReadSeeker(r io.ReadSeekCloser) *CachedReadSeeker {
return &CachedReadSeeker{src: r}
}
// Read reads up to len(p) bytes into p. It first serves from cache
// if possible, then reads any remaining bytes from the underlying source,
// appending them to the cache.
func (c *CachedReadSeeker) Read(p []byte) (n int, err error) {
// Check if any part of the request can be served from cache
if c.pos < int64(len(c.cache)) {
// Calculate how much we can read from cache
available := int64(len(c.cache)) - c.pos
toRead := int64(len(p))
if available >= toRead {
// Can serve entirely from cache
n = copy(p, c.cache[c.pos:c.pos+toRead])
c.pos += int64(n)
return n, nil
}
// Read what we can from cache
n = copy(p, c.cache[c.pos:])
c.pos += int64(n)
if n == len(p) {
return n, nil
}
// Read the rest from source
m, err := c.readFromSrc(p[n:])
n += m
return n, err
}
// Nothing in cache, read from source
return c.readFromSrc(p)
}
// readFromSrc reads from the underlying source at the current position,
// appends those bytes to cache, and updates the current position.
func (c *CachedReadSeeker) readFromSrc(p []byte) (n int, err error) {
// Seek to the current position in the source
if _, err = c.src.Seek(c.pos, io.SeekStart); err != nil {
return 0, err
}
// Read the requested data
n, err = c.src.Read(p)
if n > 0 {
// If reading sequentially or within small gap of cache, append to cache
if c.pos <= int64(len(c.cache)) {
c.cache = append(c.cache, p[:n]...)
}
c.pos += int64(n)
}
return n, err
}
// Seek sets the read position for subsequent Read calls. Seeking within the
// cached range simply updates the position. Seeking beyond will position
// Read to fetch new data from the underlying source (and cache it).
func (c *CachedReadSeeker) Seek(offset int64, whence int) (int64, error) {
var target int64
switch whence {
case io.SeekStart:
target = offset
case io.SeekCurrent:
target = c.pos + offset
case io.SeekEnd:
// determine end by seeking underlying
end, err := c.src.Seek(0, io.SeekEnd)
if err != nil {
return 0, err
}
target = end + offset
// Cache the end position for future SeekEnd calls
if int64(len(c.cache)) < end {
c.cache = append(c.cache, make([]byte, end-int64(len(c.cache)))...)
}
default:
return 0, fmt.Errorf("invalid whence: %d", whence)
}
if target < 0 {
return 0, fmt.Errorf("negative position: %d", target)
}
c.pos = target
return c.pos, nil
}

View File

@@ -0,0 +1,306 @@
package util
import (
"bytes"
"errors"
"io"
"testing"
"time"
)
// mockSlowReader simulates a slow reader (like network or disk) by adding artificial delay
type mockSlowReader struct {
data []byte
pos int64
delay time.Duration
readCnt int // count of actual reads from source
}
func newMockSlowReader(data []byte, delay time.Duration) *mockSlowReader {
return &mockSlowReader{
data: data,
delay: delay,
}
}
func (m *mockSlowReader) Read(p []byte) (n int, err error) {
if m.pos >= int64(len(m.data)) {
return 0, io.EOF
}
// Simulate latency
time.Sleep(m.delay)
m.readCnt++ // track actual reads from source
n = copy(p, m.data[m.pos:])
m.pos += int64(n)
return n, nil
}
func (m *mockSlowReader) Seek(offset int64, whence int) (int64, error) {
var abs int64
switch whence {
case io.SeekStart:
abs = offset
case io.SeekCurrent:
abs = m.pos + offset
case io.SeekEnd:
abs = int64(len(m.data)) + offset
default:
return 0, errors.New("invalid whence")
}
if abs < 0 {
return 0, errors.New("negative position")
}
m.pos = abs
return abs, nil
}
func (m *mockSlowReader) Close() error {
return nil
}
func TestCachedReadSeeker_CachingBehavior(t *testing.T) {
data := []byte("Hello, this is test data for streaming!")
delay := 10 * time.Millisecond
mock := newMockSlowReader(data, delay)
cached := NewCachedReadSeeker(mock)
// First read - should hit the source
buf1 := make([]byte, 5)
n, err := cached.Read(buf1)
if err != nil || n != 5 || string(buf1) != "Hello" {
t.Errorf("First read failed: got %q, want %q", buf1, "Hello")
}
// Seek back to start - should not hit source
_, err = cached.Seek(0, io.SeekStart)
if err != nil {
t.Errorf("Seek failed: %v", err)
}
// Second read of same data - should be from cache
readCntBefore := mock.readCnt
buf2 := make([]byte, 5)
n, err = cached.Read(buf2)
if err != nil || n != 5 || string(buf2) != "Hello" {
t.Errorf("Second read failed: got %q, want %q", buf2, "Hello")
}
if mock.readCnt != readCntBefore {
t.Error("Second read hit source when it should have used cache")
}
}
func TestCachedReadSeeker_Performance(t *testing.T) {
data := bytes.Repeat([]byte("abcdefghijklmnopqrstuvwxyz"), 1000) // ~26KB of data
delay := 10 * time.Millisecond
t.Run("Without Cache", func(t *testing.T) {
mock := newMockSlowReader(data, delay)
start := time.Now()
// Read entire data
if _, err := io.ReadAll(mock); err != nil {
t.Fatal(err)
}
// Seek back and read again
mock.Seek(0, io.SeekStart)
if _, err := io.ReadAll(mock); err != nil {
t.Fatal(err)
}
uncachedDuration := time.Since(start)
t.Logf("Without cache duration: %v", uncachedDuration)
})
t.Run("With Cache", func(t *testing.T) {
mock := newMockSlowReader(data, delay)
cached := NewCachedReadSeeker(mock)
start := time.Now()
// Read entire data
if _, err := io.ReadAll(cached); err != nil {
t.Fatal(err)
}
// Seek back and read again
cached.Seek(0, io.SeekStart)
if _, err := io.ReadAll(cached); err != nil {
t.Fatal(err)
}
cachedDuration := time.Since(start)
t.Logf("With cache duration: %v", cachedDuration)
})
}
func TestCachedReadSeeker_SeekBehavior(t *testing.T) {
data := []byte("0123456789")
mock := newMockSlowReader(data, 0)
cached := NewCachedReadSeeker(mock)
tests := []struct {
name string
offset int64
whence int
wantPos int64
wantRead string
readBufSize int
}{
{"SeekStart", 3, io.SeekStart, 3, "3456", 4},
{"SeekCurrent", 2, io.SeekCurrent, 9, "9", 4},
{"SeekEnd", -5, io.SeekEnd, 5, "56789", 5},
{"SeekStartZero", 0, io.SeekStart, 0, "0123", 4},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
pos, err := cached.Seek(tt.offset, tt.whence)
if err != nil {
t.Errorf("Seek failed: %v", err)
return
}
if pos != tt.wantPos {
t.Errorf("Seek position = %d, want %d", pos, tt.wantPos)
}
buf := make([]byte, tt.readBufSize)
n, err := cached.Read(buf)
if err != nil && err != io.EOF {
t.Errorf("Read failed: %v", err)
return
}
got := string(buf[:n])
if got != tt.wantRead {
t.Errorf("Read after seek = %q, want %q", got, tt.wantRead)
}
})
}
}
func TestCachedReadSeeker_LargeReads(t *testing.T) {
// Test with larger data to simulate real streaming scenarios
data := bytes.Repeat([]byte("abcdefghijklmnopqrstuvwxyz"), 1000) // ~26KB
mock := newMockSlowReader(data, 0)
cached := NewCachedReadSeeker(mock)
// Read in chunks
chunkSize := 1024
buf := make([]byte, chunkSize)
var totalRead int
for {
n, err := cached.Read(buf)
totalRead += n
if err == io.EOF {
break
}
if err != nil {
t.Fatalf("Read error: %v", err)
}
}
if totalRead != len(data) {
t.Errorf("Total read = %d, want %d", totalRead, len(data))
}
// Verify cache by seeking back and reading again
cached.Seek(0, io.SeekStart)
readCntBefore := mock.readCnt
totalRead = 0
for {
n, err := cached.Read(buf)
totalRead += n
if err == io.EOF {
break
}
if err != nil {
t.Fatalf("Second read error: %v", err)
}
}
if mock.readCnt != readCntBefore {
t.Error("Second read hit source when it should have used cache")
}
}
func TestCachedReadSeeker_ChunkedReadsAndSeeks(t *testing.T) {
// Create ~1MB of test data
data := bytes.Repeat([]byte("abcdefghijklmnopqrstuvwxyz0123456789"), 30_000)
delay := 300 * time.Millisecond // 10ms delay per read to simulate network/disk latency
// Define read patterns to simulate real-world streaming
type readOp struct {
seekOffset int64
seekWhence int
readSize int
desc string
}
// Simulate typical streaming behavior with repeated reads
ops := []readOp{
{0, io.SeekStart, 10 * 1024 * 1024, "initial header"}, // Read first 10MB (headers)
{500_000, io.SeekStart, 15 * 1024 * 1024, "middle preview"}, // Seek to middle, read 15MB
{0, io.SeekStart, len(data), "full read after random seeks"}, // Read entire file
{0, io.SeekStart, len(data), "re-read entire file"}, // Re-read entire file (should be cached)
}
var uncachedDuration, cachedDuration time.Duration
var uncachedReads, cachedReads int
runTest := func(name string, useCache bool) {
t.Run(name, func(t *testing.T) {
mock := newMockSlowReader(data, delay)
var reader io.ReadSeekCloser = mock
if useCache {
reader = NewCachedReadSeeker(mock)
}
start := time.Now()
var totalRead int64
for i, op := range ops {
pos, err := reader.Seek(op.seekOffset, op.seekWhence)
if err != nil {
t.Fatalf("op %d (%s) - seek failed: %v", i, op.desc, err)
}
buf := make([]byte, op.readSize)
n, err := io.ReadFull(reader, buf)
if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF {
t.Fatalf("op %d (%s) - read failed: %v", i, op.desc, err)
}
totalRead += int64(n)
t.Logf("op %d (%s) - seek to %d, read %d bytes", i, op.desc, pos, n)
}
duration := time.Since(start)
t.Logf("Total bytes read: %d", totalRead)
t.Logf("Total time: %v", duration)
t.Logf("Read count from source: %d", mock.readCnt)
if useCache {
cachedDuration = duration
cachedReads = mock.readCnt
} else {
uncachedDuration = duration
uncachedReads = mock.readCnt
}
})
}
// Run both tests
runTest("Without Cache", false)
runTest("With Cache", true)
// Report performance comparison
t.Logf("\nPerformance comparison:")
t.Logf("Uncached: %v (%d reads from source)", uncachedDuration, uncachedReads)
t.Logf("Cached: %v (%d reads from source)", cachedDuration, cachedReads)
t.Logf("Speed improvement: %.2fx", float64(uncachedDuration)/float64(cachedDuration))
t.Logf("Read reduction: %.2fx", float64(uncachedReads)/float64(cachedReads))
}

View File

@@ -0,0 +1,22 @@
//go:build !windows
package util
import (
"context"
"os/exec"
)
func NewCmd(arg string, args ...string) *exec.Cmd {
if len(args) == 0 {
return exec.Command(arg)
}
return exec.Command(arg, args...)
}
func NewCmdCtx(ctx context.Context, arg string, args ...string) *exec.Cmd {
if len(args) == 0 {
return exec.CommandContext(ctx, arg)
}
return exec.CommandContext(ctx, arg, args...)
}

View File

@@ -0,0 +1,37 @@
//go:build windows
package util
import (
"context"
"os/exec"
"syscall"
)
// NewCmd creates a new exec.Cmd object with the given arguments.
// Since for Windows, the app is built as a GUI application, we need to hide the console windows launched when running commands.
func NewCmd(arg string, args ...string) *exec.Cmd {
//cmdPrompt := "C:\\Windows\\system32\\cmd.exe"
//cmdArgs := append([]string{"/c", arg}, args...)
//cmd := exec.Command(cmdPrompt, cmdArgs...)
//cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true}
cmd := exec.Command(arg, args...)
cmd.SysProcAttr = &syscall.SysProcAttr{
CreationFlags: 0x08000000,
//HideWindow: true,
}
return cmd
}
func NewCmdCtx(ctx context.Context, arg string, args ...string) *exec.Cmd {
//cmdPrompt := "C:\\Windows\\system32\\cmd.exe"
//cmdArgs := append([]string{"/c", arg}, args...)
//cmd := exec.CommandContext(ctx, cmdPrompt, cmdArgs...)
//cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true}
cmd := exec.CommandContext(ctx, arg, args...)
cmd.SysProcAttr = &syscall.SysProcAttr{
CreationFlags: 0x08000000,
//HideWindow: true,
}
return cmd
}

View File

@@ -0,0 +1,157 @@
package comparison
import (
"regexp"
"strconv"
"strings"
)
func ValueContainsSeason(val string) bool {
val = strings.ToLower(val)
if strings.IndexRune(val, '第') != -1 {
return false
}
if ValueContainsSpecial(val) {
return false
}
if strings.Contains(val, "season") {
return true
}
re := regexp.MustCompile(`\d(st|nd|rd|th) [Ss].*`)
if re.MatchString(val) {
return true
}
return false
}
func ExtractSeasonNumber(val string) int {
val = strings.ToLower(val)
// Check for the word "season" followed by a number
re := regexp.MustCompile(`season (\d+)`)
matches := re.FindStringSubmatch(val)
if len(matches) > 1 {
season, err := strconv.Atoi(matches[1])
if err == nil {
return season
}
}
// Check for a number followed by "st", "nd", "rd", or "th", followed by "s" or "S"
re = regexp.MustCompile(`(\d+)(st|nd|rd|th) [sS]`)
matches = re.FindStringSubmatch(val)
if len(matches) > 1 {
season, err := strconv.Atoi(matches[1])
if err == nil {
return season
}
}
// No season number found
return -1
}
// ExtractResolutionInt extracts the resolution from a string and returns it as an integer.
// This is used for comparing resolutions.
// If the resolution is not found, it returns 0.
func ExtractResolutionInt(val string) int {
val = strings.ToLower(val)
if strings.Contains(strings.ToUpper(val), "4K") {
return 2160
}
if strings.Contains(val, "2160") {
return 2160
}
if strings.Contains(val, "1080") {
return 1080
}
if strings.Contains(val, "720") {
return 720
}
if strings.Contains(val, "540") {
return 540
}
if strings.Contains(val, "480") {
return 480
}
re := regexp.MustCompile(`^\d{3,4}([pP])$`)
matches := re.FindStringSubmatch(val)
if len(matches) > 1 {
res, err := strconv.Atoi(matches[1])
if err != nil {
return 0
}
return res
}
return 0
}
func ValueContainsSpecial(val string) bool {
regexes := []*regexp.Regexp{
regexp.MustCompile(`(?i)(^|(?P<show>.*?)[ _.\-(]+)(SP|OAV|OVA|OAD|ONA) ?(?P<ep>\d{1,2})(-(?P<ep2>[0-9]{1,3}))? ?(?P<title>.*)$`),
regexp.MustCompile(`(?i)[-._( ](OVA|ONA)[-._) ]`),
regexp.MustCompile(`(?i)[-._ ](S|SP)(?P<season>(0|00))([Ee]\d)`),
regexp.MustCompile(`[({\[]?(OVA|ONA|OAV|OAD|SP|SPECIAL)[])}]?`),
}
for _, regex := range regexes {
if regex.MatchString(val) {
return true
}
}
return false
}
func ValueContainsIgnoredKeywords(val string) bool {
regexes := []*regexp.Regexp{
regexp.MustCompile(`(?i)^\s?[({\[]?\s?(EXTRAS?|OVAS?|OTHERS?|SPECIALS|MOVIES|SEASONS|NC)\s?[])}]?\s?$`),
}
for _, regex := range regexes {
if regex.MatchString(val) {
return true
}
}
return false
}
func ValueContainsBatchKeywords(val string) bool {
regexes := []*regexp.Regexp{
regexp.MustCompile(`(?i)[({\[]?\s?(EXTRAS|OVAS|OTHERS|SPECIALS|MOVIES|SEASONS|BATCH|COMPLETE|COMPLETE SERIES)\s?[])}]?\s?`),
}
for _, regex := range regexes {
if regex.MatchString(val) {
return true
}
}
return false
}
func ValueContainsNC(val string) bool {
regexes := []*regexp.Regexp{
regexp.MustCompile(`(?i)(^|(?P<show>.*?)[ _.\-(]+)\b(OP|NCOP|OPED)\b ?(?P<ep>\d{1,2}[a-z]?)? ?([ _.\-)]+(?P<title>.*))?`),
regexp.MustCompile(`(?i)(^|(?P<show>.*?)[ _.\-(]+)\b(ED|NCED)\b ?(?P<ep>\d{1,2}[a-z]?)? ?([ _.\-)]+(?P<title>.*))?`),
regexp.MustCompile(`(?i)(^|(?P<show>.*?)[ _.\-(]+)\b(TRAILER|PROMO|PV)\b ?(?P<ep>\d{1,2}) ?([ _.\-)]+(?P<title>.*))?`),
regexp.MustCompile(`(?i)(^|(?P<show>.*?)[ _.\-(]+)\b(OTHERS?)\b(?P<ep>\d{1,2}) ?[ _.\-)]+(?P<title>.*)`),
regexp.MustCompile(`(?i)(^|(?P<show>.*?)[ _.\-(]+)\b(CM|COMMERCIAL|AD)\b ?(?P<ep>\d{1,2}) ?([ _.\-)]+(?P<title>.*))?`),
regexp.MustCompile(`(?i)(^|(?P<show>.*?)[ _.\-(]+)\b(CREDITLESS|NCOP|NCED|OP|ED)\b ?(?P<ep>\d{1,2}[a-z]?)? ?([ _.\-)]+(?P<title>.*))?`),
}
for _, regex := range regexes {
if regex.MatchString(val) {
return true
}
}
return false
}

View File

@@ -0,0 +1,344 @@
package comparison
import (
"testing"
)
func TestValueContainsSeason(t *testing.T) {
tests := []struct {
name string
input string
expected bool
}{
{
name: "Contains 'season' in lowercase",
input: "JJK season 2",
expected: true,
},
{
name: "Contains 'season' in uppercase",
input: "JJK SEASON 2",
expected: true,
},
{
name: "Contains '2nd S' in lowercase",
input: "Spy x Family 2nd Season",
expected: true,
},
{
name: "Contains '2nd S' in uppercase",
input: "Spy x Family 2ND SEASON",
expected: true,
},
{
name: "Does not contain 'season' or '1st S'",
input: "This is a test",
expected: false,
},
{
name: "Contains special characters",
input: "JJK season 2 (OVA)",
expected: false,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
result := ValueContainsSeason(test.input)
if result != test.expected {
t.Errorf("ValueContainsSeason() with args %v, expected %v, but got %v.", test.input, test.expected, result)
}
})
}
}
func TestExtractSeasonNumber(t *testing.T) {
tests := []struct {
name string
input string
expected int
}{
{
name: "Contains 'season' followed by a number",
input: "JJK season 2",
expected: 2,
},
{
name: "Contains a number followed by 'st', 'nd', 'rd', or 'th', followed by 's' or 'S'",
input: "Spy x Family 2nd S",
expected: 2,
},
{
name: "Does not contain 'season' or '1st S'",
input: "This is a test",
expected: -1,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
result := ExtractSeasonNumber(test.input)
if result != test.expected {
t.Errorf("ExtractSeasonNumber() with args %v, expected %v, but got %v.", test.input, test.expected, result)
}
})
}
}
func TestExtractResolutionInt(t *testing.T) {
tests := []struct {
name string
input string
expected int
}{
{
name: "Contains '4K' in uppercase",
input: "4K",
expected: 2160,
},
{
name: "Contains '4k' in lowercase",
input: "4k",
expected: 2160,
},
{
name: "Contains '2160'",
input: "2160",
expected: 2160,
},
{
name: "Contains '1080'",
input: "1080",
expected: 1080,
},
{
name: "Contains '720'",
input: "720",
expected: 720,
},
{
name: "Contains '480'",
input: "480",
expected: 480,
},
{
name: "Does not contain a resolution",
input: "This is a test",
expected: 0,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
result := ExtractResolutionInt(test.input)
if result != test.expected {
t.Errorf("ExtractResolutionInt() with args %v, expected %v, but got %v.", test.input, test.expected, result)
}
})
}
}
func TestValueContainsSpecial(t *testing.T) {
tests := []struct {
name string
input string
expected bool
}{
{
name: "Contains 'OVA' in uppercase",
input: "JJK OVA",
expected: true,
},
{
name: "Contains 'ova' in lowercase",
input: "JJK ova",
expected: false,
},
{
name: "Does not contain special keywords",
input: "This is a test",
expected: false,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
result := ValueContainsSpecial(test.input)
if result != test.expected {
t.Errorf("ValueContainsSpecial() with args %v, expected %v, but got %v.", test.input, test.expected, result)
}
})
}
}
func TestValueContainsIgnoredKeywords(t *testing.T) {
tests := []struct {
name string
input string
expected bool
}{
{
name: "Contains 'EXTRAS' in uppercase",
input: "EXTRAS",
expected: true,
},
{
name: "Contains 'extras' in lowercase",
input: "extras",
expected: true,
},
{
name: "Does not contain ignored keywords",
input: "This is a test",
expected: false,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
result := ValueContainsIgnoredKeywords(test.input)
if result != test.expected {
t.Errorf("ValueContainsIgnoredKeywords() with args %v, expected %v, but got %v.", test.input, test.expected, result)
}
})
}
}
func TestValueContainsBatchKeywords(t *testing.T) {
tests := []struct {
name string
input string
expected bool
}{
{
name: "Contains 'BATCH' in uppercase",
input: "BATCH",
expected: true,
},
{
name: "Contains 'batch' in lowercase",
input: "batch",
expected: true,
},
{
name: "Does not contain batch keywords",
input: "This is a test",
expected: false,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
result := ValueContainsBatchKeywords(test.input)
if result != test.expected {
t.Errorf("ValueContainsBatchKeywords() with args %v, expected %v, but got %v.", test.input, test.expected, result)
}
})
}
}
func TestValueContainsNC(t *testing.T) {
tests := []struct {
input string
expected bool
}{
{
input: "NCOP",
expected: true,
},
{
input: "ncop",
expected: true,
},
{
input: "One Piece - 1000 - NCOP",
expected: true,
},
{
input: "One Piece ED 2",
expected: true,
},
{
input: "This is a test",
expected: false,
}, {
input: "This is a test",
expected: false,
},
{
input: "Himouto.Umaru.chan.S01E02.1080p.BluRay.Opus2.0.x265-smol",
expected: false,
},
{
input: "Himouto.Umaru.chan.S01E02.1080p.BluRay.x265-smol",
expected: false,
},
{
input: "One Piece - 1000 - Operation something something",
expected: false,
},
}
for _, test := range tests {
t.Run(test.input, func(t *testing.T) {
result := ValueContainsNC(test.input)
if result != test.expected {
t.Errorf("ValueContainsNC() with args %v, expected %v, but got %v.", test.input, test.expected, result)
}
})
}
}
//func TestLikelyNC(t *testing.T) {
// tests := []struct {
// name string
// input string
// expected bool
// }{
// {
// name: "Does not contain NC keywords 1",
// input: "Himouto.Umaru.chan.S01E02.1080p.BluRay.Opus2.0.x265-smol",
// expected: false,
// },
// {
// name: "Does not contain NC keywords 2",
// input: "Himouto.Umaru.chan.S01E02.1080p.BluRay.x265-smol",
// expected: false,
// },
// {
// name: "Contains NC keywords 1",
// input: "Himouto.Umaru.chan.S00E02.1080p.BluRay.x265-smol",
// expected: true,
// },
// {
// name: "Contains NC keywords 2",
// input: "Himouto.Umaru.chan.OP02.1080p.BluRay.x265-smol",
// expected: true,
// },
// }
//
// for _, test := range tests {
// t.Run(test.name, func(t *testing.T) {
// metadata := habari.Parse(test.input)
// var episode string
// var season string
//
// if len(metadata.SeasonNumber) > 0 {
// if len(metadata.SeasonNumber) == 1 {
// season = metadata.SeasonNumber[0]
// }
// }
//
// if len(metadata.EpisodeNumber) > 0 {
// if len(metadata.EpisodeNumber) == 1 {
// episode = metadata.EpisodeNumber[0]
// }
// }
//
// result := LikelyNC(test.input, season, episode)
// if result != test.expected {
// t.Errorf("ValueContainsNC() with args %v, expected %v, but got %v.", test.input, test.expected, result)
// }
// })
// }
//}

View File

@@ -0,0 +1,230 @@
// Package comparison contains helpers related to comparison, comparison and filtering of media titles.
package comparison
import (
"github.com/adrg/strutil/metrics"
)
// LevenshteinResult is a struct that holds a string and its Levenshtein distance compared to another string.
type LevenshteinResult struct {
OriginalValue *string
Value *string
Distance int
}
// CompareWithLevenshtein compares a string to a slice of strings and returns a slice of LevenshteinResult containing the Levenshtein distance for each string.
func CompareWithLevenshtein(v *string, vals []*string) []*LevenshteinResult {
return CompareWithLevenshteinCleanFunc(v, vals, func(val string) string {
return val
})
}
func CompareWithLevenshteinCleanFunc(v *string, vals []*string, cleanFunc func(val string) string) []*LevenshteinResult {
lev := metrics.NewLevenshtein()
lev.CaseSensitive = false
//lev.DeleteCost = 1
res := make([]*LevenshteinResult, len(vals))
for _, val := range vals {
res = append(res, &LevenshteinResult{
OriginalValue: v,
Value: val,
Distance: lev.Distance(cleanFunc(*v), cleanFunc(*val)),
})
}
return res
}
// FindBestMatchWithLevenshtein returns the best match from a slice of strings as a reference to a LevenshteinResult.
// It also returns a boolean indicating whether the best match was found.
func FindBestMatchWithLevenshtein(v *string, vals []*string) (*LevenshteinResult, bool) {
res := CompareWithLevenshtein(v, vals)
if len(res) == 0 {
return nil, false
}
var bestResult *LevenshteinResult
for _, result := range res {
if bestResult == nil || result.Distance < bestResult.Distance {
bestResult = result
}
}
return bestResult, true
}
//----------------------------------------------------------------------------------------------------------------------
// JaroWinklerResult is a struct that holds a string and its JaroWinkler distance compared to another string.
type JaroWinklerResult struct {
OriginalValue *string
Value *string
Rating float64
}
// CompareWithJaroWinkler compares a string to a slice of strings and returns a slice of JaroWinklerResult containing the JaroWinkler distance for each string.
func CompareWithJaroWinkler(v *string, vals []*string) []*JaroWinklerResult {
jw := metrics.NewJaroWinkler()
jw.CaseSensitive = false
res := make([]*JaroWinklerResult, len(vals))
for _, val := range vals {
res = append(res, &JaroWinklerResult{
OriginalValue: v,
Value: val,
Rating: jw.Compare(*v, *val),
})
}
return res
}
// FindBestMatchWithJaroWinkler returns the best match from a slice of strings as a reference to a JaroWinklerResult.
// It also returns a boolean indicating whether the best match was found.
func FindBestMatchWithJaroWinkler(v *string, vals []*string) (*JaroWinklerResult, bool) {
res := CompareWithJaroWinkler(v, vals)
if len(res) == 0 {
return nil, false
}
var bestResult *JaroWinklerResult
for _, result := range res {
if bestResult == nil || result.Rating > bestResult.Rating {
bestResult = result
}
}
return bestResult, true
}
//----------------------------------------------------------------------------------------------------------------------
// JaccardResult is a struct that holds a string and its Jaccard distance compared to another string.
type JaccardResult struct {
OriginalValue *string
Value *string
Rating float64
}
// CompareWithJaccard compares a string to a slice of strings and returns a slice of JaccardResult containing the Jaccard distance for each string.
func CompareWithJaccard(v *string, vals []*string) []*JaccardResult {
jw := metrics.NewJaccard()
jw.CaseSensitive = false
jw.NgramSize = 1
res := make([]*JaccardResult, len(vals))
for _, val := range vals {
res = append(res, &JaccardResult{
OriginalValue: v,
Value: val,
Rating: jw.Compare(*v, *val),
})
}
return res
}
// FindBestMatchWithJaccard returns the best match from a slice of strings as a reference to a JaccardResult.
// It also returns a boolean indicating whether the best match was found.
func FindBestMatchWithJaccard(v *string, vals []*string) (*JaccardResult, bool) {
res := CompareWithJaccard(v, vals)
if len(res) == 0 {
return nil, false
}
var bestResult *JaccardResult
for _, result := range res {
if bestResult == nil || result.Rating > bestResult.Rating {
bestResult = result
}
}
return bestResult, true
}
//----------------------------------------------------------------------------------------------------------------------
type SorensenDiceResult struct {
OriginalValue *string
Value *string
Rating float64
}
func CompareWithSorensenDice(v *string, vals []*string) []*SorensenDiceResult {
dice := metrics.NewSorensenDice()
dice.CaseSensitive = false
res := make([]*SorensenDiceResult, len(vals))
for _, val := range vals {
res = append(res, &SorensenDiceResult{
OriginalValue: v,
Value: val,
Rating: dice.Compare(*v, *val),
})
}
return res
}
func FindBestMatchWithSorensenDice(v *string, vals []*string) (*SorensenDiceResult, bool) {
res := CompareWithSorensenDice(v, vals)
if len(res) == 0 {
return nil, false
}
var bestResult *SorensenDiceResult
for _, result := range res {
if bestResult == nil || result.Rating > bestResult.Rating {
bestResult = result
}
}
return bestResult, true
}
func EliminateLeastSimilarValue(arr []string) []string {
if len(arr) < 3 {
return arr
}
sd := metrics.NewSorensenDice()
sd.CaseSensitive = false
leastSimilarIndex := -1
leastSimilarScore := 2.0
for i := 0; i < len(arr); i++ {
totalSimilarity := 0.0
for j := 0; j < len(arr); j++ {
if i != j {
score := sd.Compare(arr[i], arr[j])
totalSimilarity += score
}
}
if totalSimilarity < leastSimilarScore {
leastSimilarScore = totalSimilarity
leastSimilarIndex = i
}
}
if leastSimilarIndex != -1 {
arr = append(arr[:leastSimilarIndex], arr[leastSimilarIndex+1:]...)
}
return arr
}

View File

@@ -0,0 +1,114 @@
package comparison
import (
"github.com/samber/lo"
"github.com/stretchr/testify/assert"
"testing"
)
func TestFindBestMatchWithLevenstein(t *testing.T) {
tests := []struct {
title string
comparisonTitles []string
expectedResult string
expectedDistance int
}{
{
title: "jujutsu kaisen 2",
comparisonTitles: []string{"JJK", "Jujutsu Kaisen", "Jujutsu Kaisen 2"},
expectedResult: "Jujutsu Kaisen 2",
expectedDistance: 0,
},
}
for _, test := range tests {
t.Run(test.title, func(t *testing.T) {
res, ok := FindBestMatchWithLevenshtein(&test.title, lo.ToSlicePtr(test.comparisonTitles))
if assert.True(t, ok) {
assert.Equal(t, test.expectedResult, *res.Value, "expected result does not match")
assert.Equal(t, test.expectedDistance, res.Distance, "expected distance does not match")
t.Logf("value: %s, distance: %d", *res.Value, res.Distance)
}
})
}
}
func TestFindBestMatchWithDice(t *testing.T) {
tests := []struct {
title string
comparisonTitles []string
expectedResult string
expectedRating float64
}{
{
title: "jujutsu kaisen 2",
comparisonTitles: []string{"JJK", "Jujutsu Kaisen", "Jujutsu Kaisen 2"},
expectedResult: "Jujutsu Kaisen 2",
expectedRating: 1,
},
}
for _, test := range tests {
t.Run(test.title, func(t *testing.T) {
res, ok := FindBestMatchWithSorensenDice(&test.title, lo.ToSlicePtr(test.comparisonTitles))
if assert.True(t, ok, "expected result, got nil") {
assert.Equal(t, test.expectedResult, *res.Value, "expected result does not match")
assert.Equal(t, test.expectedRating, res.Rating, "expected rating does not match")
t.Logf("value: %s, rating: %f", *res.Value, res.Rating)
}
})
}
}
func TestEliminateLestSimilarValue(t *testing.T) {
tests := []struct {
title string
comparisonTitles []string
expectedEliminated string
}{
{
title: "jujutsu kaisen 2",
comparisonTitles: []string{"JJK", "Jujutsu Kaisen", "Jujutsu Kaisen 2"},
expectedEliminated: "JJK",
},
{
title: "One Piece - Film Z",
comparisonTitles: []string{"One Piece - Film Z", "One Piece Film Z", "One Piece Gold"},
expectedEliminated: "One Piece Gold",
},
{
title: "One Piece - Film Z",
comparisonTitles: []string{"One Piece - Film Z", "One Piece Film Z", "One Piece Z"},
expectedEliminated: "One Piece Z",
},
{
title: "Mononogatari",
comparisonTitles: []string{"Mononogatari", "Mononogatari Cour 2", "Nekomonogatari"},
expectedEliminated: "Nekomonogatari",
},
}
for _, test := range tests {
t.Run(test.title, func(t *testing.T) {
res := EliminateLeastSimilarValue(test.comparisonTitles)
for _, n := range res {
if n == test.expectedEliminated {
t.Fatalf("expected \"%s\" to be eliminated from %v", n, res)
}
}
})
}
}

View File

@@ -0,0 +1,148 @@
package crashlog
import (
"bytes"
"context"
"fmt"
"github.com/rs/zerolog"
"github.com/samber/mo"
"io"
"os"
"path/filepath"
"seanime/internal/util"
"sync"
"time"
)
// Global variable that continuously records logs from specific programs and writes them to a file when something unexpected happens.
type CrashLogger struct {
//logger *zerolog.Logger
//logBuffer *bytes.Buffer
//mu sync.Mutex
logDir mo.Option[string]
}
type CrashLoggerArea struct {
name string
logger *zerolog.Logger
logBuffer *bytes.Buffer
mu sync.Mutex
ctx context.Context
cancelFunc context.CancelFunc
}
var GlobalCrashLogger = NewCrashLogger()
// NewCrashLogger creates a new CrashLogger instance.
func NewCrashLogger() *CrashLogger {
//var logBuffer bytes.Buffer
//
//fileOutput := zerolog.ConsoleWriter{
// Out: &logBuffer,
// TimeFormat: time.DateTime,
// FormatMessage: util.ZerologFormatMessageSimple,
// FormatLevel: util.ZerologFormatLevelSimple,
// NoColor: true,
//}
//
//multi := zerolog.MultiLevelWriter(fileOutput)
//logger := zerolog.New(multi).With().Timestamp().Logger()
return &CrashLogger{
//logger: &logger,
//logBuffer: &logBuffer,
//mu: sync.Mutex{},
logDir: mo.None[string](),
}
}
func (c *CrashLogger) SetLogDir(dir string) {
c.logDir = mo.Some(dir)
}
// InitArea creates a new CrashLoggerArea instance.
// This instance can be used to log crashes in a specific area.
func (c *CrashLogger) InitArea(area string) *CrashLoggerArea {
var logBuffer bytes.Buffer
fileOutput := zerolog.ConsoleWriter{
Out: &logBuffer,
TimeFormat: time.DateTime,
FormatLevel: util.ZerologFormatLevelSimple,
NoColor: true,
}
multi := zerolog.MultiLevelWriter(fileOutput)
logger := zerolog.New(multi).With().Timestamp().Logger()
//ctx, cancelFunc := context.WithCancel(context.Background())
return &CrashLoggerArea{
logger: &logger,
name: area,
logBuffer: &logBuffer,
mu: sync.Mutex{},
//ctx: ctx,
//cancelFunc: cancelFunc,
}
}
// Stdout returns the CrashLoggerArea's log buffer so that it can be used as a writer.
//
// Example:
// crashLogger := crashlog.GlobalCrashLogger.InitArea("ffmpeg")
// defer crashLogger.Close()
//
// cmd.Stdout = crashLogger.Stdout()
func (a *CrashLoggerArea) Stdout() io.Writer {
return a.logBuffer
}
func (a *CrashLoggerArea) LogError(msg string) {
a.logger.Error().Msg(msg)
}
func (a *CrashLoggerArea) LogErrorf(format string, args ...interface{}) {
a.logger.Error().Msgf(format, args...)
}
func (a *CrashLoggerArea) LogInfof(format string, args ...interface{}) {
a.logger.Info().Msgf(format, args...)
}
// Close should be always called using defer when a new area is created
//
// logArea := crashlog.GlobalCrashLogger.InitArea("ffmpeg")
// defer logArea.Close()
func (a *CrashLoggerArea) Close() {
a.logBuffer.Reset()
//a.cancelFunc()
}
func (c *CrashLogger) WriteAreaLogToFile(area *CrashLoggerArea) {
logDir, found := c.logDir.Get()
if !found {
return
}
// e.g. crash-ffmpeg-2021-09-01_15-04-05.log
logFilePath := filepath.Join(logDir, fmt.Sprintf("crash-%s-%s.log", area.name, time.Now().Format("2006-01-02_15-04-05")))
// Create file
logFile, err := os.OpenFile(logFilePath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0664)
if err != nil {
fmt.Printf("Failed to open log file: %s\n", logFilePath)
return
}
defer logFile.Close()
area.mu.Lock()
defer area.mu.Unlock()
if _, err := area.logBuffer.WriteTo(logFile); err != nil {
fmt.Printf("Failed to write crash log buffer to file for %s\n", area.name)
}
area.logBuffer.Reset()
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,11 @@
package util
import (
"fmt"
"time"
)
func TimestampToDateStr(timestamp int64) string {
tm := time.Unix(timestamp, 0)
return fmt.Sprintf("%v", tm)
}

View File

@@ -0,0 +1,483 @@
package filecache
import (
"fmt"
"os"
"path/filepath"
"strings"
"sync"
"time"
"github.com/goccy/go-json"
"github.com/samber/lo"
)
// CacheStore represents a single-process, file-based, key/value cache store.
type CacheStore struct {
filePath string
mu sync.Mutex
data map[string]*cacheItem
}
// Bucket represents a cache bucket with a name and TTL.
type Bucket struct {
name string
ttl time.Duration
}
type PermanentBucket struct {
name string
}
func NewBucket(name string, ttl time.Duration) Bucket {
return Bucket{name: name, ttl: ttl}
}
func (b *Bucket) Name() string {
return b.name
}
func NewPermanentBucket(name string) PermanentBucket {
return PermanentBucket{name: name}
}
func (b *PermanentBucket) Name() string {
return b.name
}
// Cacher represents a single-process, file-based, key/value cache.
type Cacher struct {
dir string
stores map[string]*CacheStore
mu sync.Mutex
}
type cacheItem struct {
Value interface{} `json:"value"`
Expiration *time.Time `json:"expiration,omitempty"`
}
// NewCacher creates a new instance of Cacher.
func NewCacher(dir string) (*Cacher, error) {
// Check if the directory exists
_, err := os.Stat(dir)
if err != nil {
if os.IsNotExist(err) {
if err := os.MkdirAll(dir, os.ModePerm); err != nil {
return nil, err
}
} else {
return nil, err
}
}
return &Cacher{
stores: make(map[string]*CacheStore),
dir: dir,
}, nil
}
// Close closes all the cache stores.
func (c *Cacher) Close() error {
c.mu.Lock()
defer c.mu.Unlock()
for _, store := range c.stores {
if err := store.saveToFile(); err != nil {
return err
}
}
return nil
}
// getStore returns a cache store for the given bucket name and TTL.
func (c *Cacher) getStore(name string) (*CacheStore, error) {
c.mu.Lock()
defer c.mu.Unlock()
store, ok := c.stores[name]
if !ok {
store = &CacheStore{
filePath: filepath.Join(c.dir, name+".cache"),
data: make(map[string]*cacheItem),
}
if err := store.loadFromFile(); err != nil {
return nil, err
}
c.stores[name] = store
}
return store, nil
}
// Set sets the value for the given key in the given bucket.
func (c *Cacher) Set(bucket Bucket, key string, value interface{}) error {
store, err := c.getStore(bucket.name)
if err != nil {
return err
}
store.mu.Lock()
defer store.mu.Unlock()
store.data[key] = &cacheItem{Value: value, Expiration: lo.ToPtr(time.Now().Add(bucket.ttl))}
return store.saveToFile()
}
func Range[T any](c *Cacher, bucket Bucket, f func(key string, value T) bool) error {
store, err := c.getStore(bucket.name)
if err != nil {
return err
}
store.mu.Lock()
defer store.mu.Unlock()
for key, item := range store.data {
if item.Expiration != nil && time.Now().After(*item.Expiration) {
delete(store.data, key)
} else {
itemVal, err := json.Marshal(item.Value)
if err != nil {
return err
}
var out T
err = json.Unmarshal(itemVal, &out)
if err != nil {
return err
}
if !f(key, out) {
break
}
}
}
return store.saveToFile()
}
// Get retrieves the value for the given key from the given bucket.
func (c *Cacher) Get(bucket Bucket, key string, out interface{}) (bool, error) {
store, err := c.getStore(bucket.name)
if err != nil {
return false, err
}
store.mu.Lock()
defer store.mu.Unlock()
item, ok := store.data[key]
if !ok {
return false, nil
}
if item.Expiration != nil && time.Now().After(*item.Expiration) {
delete(store.data, key)
_ = store.saveToFile() // Ignore errors here
return false, nil
}
data, err := json.Marshal(item.Value)
if err != nil {
return false, err
}
return true, json.Unmarshal(data, out)
}
func GetAll[T any](c *Cacher, bucket Bucket) (map[string]T, error) {
store, err := c.getStore(bucket.name)
if err != nil {
return nil, err
}
data := make(map[string]T)
err = Range(c, bucket, func(key string, value T) bool {
data[key] = value
return true
})
if err != nil {
return nil, err
}
store.mu.Lock()
defer store.mu.Unlock()
return data, store.saveToFile()
}
// Delete deletes the value for the given key from the given bucket.
func (c *Cacher) Delete(bucket Bucket, key string) error {
store, err := c.getStore(bucket.name)
if err != nil {
return err
}
store.mu.Lock()
defer store.mu.Unlock()
delete(store.data, key)
return store.saveToFile()
}
func DeleteIf[T any](c *Cacher, bucket Bucket, cond func(key string, value T) bool) error {
store, err := c.getStore(bucket.name)
if err != nil {
return err
}
store.mu.Lock()
defer store.mu.Unlock()
for key, item := range store.data {
itemVal, err := json.Marshal(item.Value)
if err != nil {
return err
}
var out T
err = json.Unmarshal(itemVal, &out)
if err != nil {
return err
}
if cond(key, out) {
delete(store.data, key)
}
}
return store.saveToFile()
}
// Empty empties the given bucket.
func (c *Cacher) Empty(bucket Bucket) error {
store, err := c.getStore(bucket.name)
if err != nil {
return err
}
store.mu.Lock()
defer store.mu.Unlock()
store.data = make(map[string]*cacheItem)
return store.saveToFile()
}
// Remove removes the given bucket.
func (c *Cacher) Remove(bucketName string) error {
c.mu.Lock()
defer c.mu.Unlock()
if _, ok := c.stores[bucketName]; ok {
delete(c.stores, bucketName)
}
_ = os.Remove(filepath.Join(c.dir, bucketName+".cache"))
return nil
}
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// SetPerm sets the value for the given key in the permanent bucket (no expiration).
func (c *Cacher) SetPerm(bucket PermanentBucket, key string, value interface{}) error {
store, err := c.getStore(bucket.name)
if err != nil {
return err
}
store.mu.Lock()
defer store.mu.Unlock()
store.data[key] = &cacheItem{Value: value, Expiration: nil} // No expiration
return store.saveToFile()
}
// GetPerm retrieves the value for the given key from the permanent bucket (ignores expiration).
func (c *Cacher) GetPerm(bucket PermanentBucket, key string, out interface{}) (bool, error) {
store, err := c.getStore(bucket.name)
if err != nil {
return false, err
}
store.mu.Lock()
defer store.mu.Unlock()
item, ok := store.data[key]
if !ok {
return false, nil
}
data, err := json.Marshal(item.Value)
if err != nil {
return false, err
}
return true, json.Unmarshal(data, out)
}
// DeletePerm deletes the value for the given key from the permanent bucket.
func (c *Cacher) DeletePerm(bucket PermanentBucket, key string) error {
store, err := c.getStore(bucket.name)
if err != nil {
return err
}
store.mu.Lock()
defer store.mu.Unlock()
delete(store.data, key)
return store.saveToFile()
}
// EmptyPerm empties the permanent bucket.
func (c *Cacher) EmptyPerm(bucket PermanentBucket) error {
store, err := c.getStore(bucket.name)
if err != nil {
return err
}
store.mu.Lock()
defer store.mu.Unlock()
store.data = make(map[string]*cacheItem)
return store.saveToFile()
}
// RemovePerm calls Remove.
func (c *Cacher) RemovePerm(bucketName string) error {
return c.Remove(bucketName)
}
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
func (cs *CacheStore) loadFromFile() error {
file, err := os.Open(cs.filePath)
if err != nil {
if os.IsNotExist(err) {
return nil // File does not exist, so nothing to load
}
return fmt.Errorf("filecache: failed to open cache file: %w", err)
}
defer file.Close()
if err := json.NewDecoder(file).Decode(&cs.data); err != nil {
// If decode fails (empty or corrupted file), initialize with empty data
cs.data = make(map[string]*cacheItem)
return nil
}
return nil
}
func (cs *CacheStore) saveToFile() error {
file, err := os.Create(cs.filePath)
if err != nil {
return fmt.Errorf("filecache: failed to create cache file: %w", err)
}
defer file.Close()
if err := json.NewEncoder(file).Encode(cs.data); err != nil {
return fmt.Errorf("filecache: failed to encode cache data: %w", err)
}
return nil
}
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// RemoveAllBy removes all files in the cache directory that match the given filter.
func (c *Cacher) RemoveAllBy(filter func(filename string) bool) error {
c.mu.Lock()
defer c.mu.Unlock()
err := filepath.Walk(c.dir, func(_ string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if !info.IsDir() {
if !strings.HasSuffix(info.Name(), ".cache") {
return nil
}
if filter(info.Name()) {
if err := os.Remove(filepath.Join(c.dir, info.Name())); err != nil {
return fmt.Errorf("filecache: failed to remove file: %w", err)
}
}
}
return nil
})
c.stores = make(map[string]*CacheStore)
return err
}
// ClearMediastreamVideoFiles clears all mediastream video file caches.
func (c *Cacher) ClearMediastreamVideoFiles() error {
c.mu.Lock()
// Remove the contents of the directory
files, err := os.ReadDir(filepath.Join(c.dir, "videofiles"))
if err != nil {
c.mu.Unlock()
return nil
}
for _, file := range files {
_ = os.RemoveAll(filepath.Join(c.dir, "videofiles", file.Name()))
}
c.mu.Unlock()
err = c.RemoveAllBy(func(filename string) bool {
return strings.HasPrefix(filename, "mediastream")
})
c.mu.Lock()
c.stores = make(map[string]*CacheStore)
c.mu.Unlock()
return err
}
// TrimMediastreamVideoFiles clears all mediastream video file caches if the number of files exceeds the given limit.
func (c *Cacher) TrimMediastreamVideoFiles() error {
c.mu.Lock()
defer c.mu.Unlock()
// Remove the contents of the "videofiles" cache directory
files, err := os.ReadDir(filepath.Join(c.dir, "videofiles"))
if err != nil {
return nil
}
// If the number of files exceeds 10, remove all files
if len(files) > 10 {
for _, file := range files {
_ = os.RemoveAll(filepath.Join(c.dir, "videofiles", file.Name()))
}
}
c.stores = make(map[string]*CacheStore)
return err
}
func (c *Cacher) GetMediastreamVideoFilesTotalSize() (int64, error) {
c.mu.Lock()
defer c.mu.Unlock()
_, err := os.Stat(filepath.Join(c.dir, "videofiles"))
if err != nil {
if os.IsNotExist(err) {
return 0, nil
}
return 0, err
}
var totalSize int64
err = filepath.Walk(filepath.Join(c.dir, "videofiles"), func(_ string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if !info.IsDir() {
totalSize += info.Size()
}
return nil
})
if err != nil {
return 0, fmt.Errorf("filecache: failed to walk the cache directory: %w", err)
}
return totalSize, nil
}
// GetTotalSize returns the total size of all files in the cache directory that match the given filter.
// The size is in bytes.
func (c *Cacher) GetTotalSize() (int64, error) {
c.mu.Lock()
defer c.mu.Unlock()
var totalSize int64
err := filepath.Walk(c.dir, func(_ string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if !info.IsDir() {
totalSize += info.Size()
}
return nil
})
if err != nil {
return 0, fmt.Errorf("filecache: failed to walk the cache directory: %w", err)
}
return totalSize, nil
}

View File

@@ -0,0 +1,162 @@
package filecache
import (
"github.com/davecgh/go-spew/spew"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"path/filepath"
"seanime/internal/test_utils"
"sync"
"testing"
"time"
)
func TestCacherFunctions(t *testing.T) {
test_utils.InitTestProvider(t)
tempDir := t.TempDir()
t.Log(tempDir)
cacher, err := NewCacher(filepath.Join(tempDir, "cache"))
require.NoError(t, err)
bucket := Bucket{
name: "test",
ttl: 10 * time.Second,
}
keys := []string{"key1", "key2", "key3"}
type valStruct = struct {
Name string
}
values := []*valStruct{
{
Name: "value1",
},
{
Name: "value2",
},
{
Name: "value3",
},
}
for i, key := range keys {
err = cacher.Set(bucket, key, values[i])
if err != nil {
t.Fatalf("Failed to set the value: %v", err)
}
}
allVals, err := GetAll[*valStruct](cacher, bucket)
if err != nil {
t.Fatalf("Failed to get all values: %v", err)
}
if len(allVals) != len(keys) {
t.Fatalf("Failed to get all values: expected %d, got %d", len(keys), len(allVals))
}
spew.Dump(allVals)
}
func TestCacherSetAndGet(t *testing.T) {
test_utils.InitTestProvider(t)
tempDir := t.TempDir()
t.Log(tempDir)
cacher, err := NewCacher(filepath.Join(test_utils.ConfigData.Path.DataDir, "cache"))
bucket := Bucket{
name: "test",
ttl: 4 * time.Second,
}
key := "key"
value := struct {
Name string
}{
Name: "value",
}
// Add "key" -> value to the bucket, with a TTL of 4 seconds
err = cacher.Set(bucket, key, value)
if err != nil {
t.Fatalf("Failed to set the value: %v", err)
}
var out struct {
Name string
}
// Get the value of "key" from the bucket, it shouldn't be expired
found, err := cacher.Get(bucket, key, &out)
if err != nil {
t.Errorf("Failed to get the value: %v", err)
}
if !found || !assert.Equal(t, value, out) {
t.Errorf("Failed to get the correct value. Expected %v, got %v", value, out)
}
spew.Dump(out)
time.Sleep(3 * time.Second)
// Get the value of "key" from the bucket again, it shouldn't be expired
found, err = cacher.Get(bucket, key, &out)
if !found {
t.Errorf("Failed to get the value")
}
if !found || out != value {
t.Errorf("Failed to get the correct value. Expected %v, got %v", value, out)
}
spew.Dump(out)
// Spin up a goroutine to set "key2" -> value2 to the bucket, with a TTL of 1 second
// cacher should be thread-safe
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
defer wg.Done()
key2 := "key2"
value2 := struct {
Name string
}{
Name: "value2",
}
var out2 struct {
Name string
}
err = cacher.Set(bucket, key2, value2)
if err != nil {
t.Errorf("Failed to set the value: %v", err)
}
found, err = cacher.Get(bucket, key2, &out2)
if err != nil {
t.Errorf("Failed to get the value: %v", err)
}
if !found || !assert.Equal(t, value2, out2) {
t.Errorf("Failed to get the correct value. Expected %v, got %v", value2, out2)
}
_ = cacher.Delete(bucket, key2)
spew.Dump(out2)
}()
time.Sleep(2 * time.Second)
// Get the value of "key" from the bucket, it should be expired
found, _ = cacher.Get(bucket, key, &out)
if found {
t.Errorf("Failed to delete the value")
spew.Dump(out)
}
wg.Wait()
}

View File

@@ -0,0 +1,419 @@
package util
import (
"archive/zip"
"fmt"
"io"
"os"
"path/filepath"
"runtime"
"strings"
"github.com/nwaples/rardecode/v2"
)
func DirSize(path string) (uint64, error) {
var size int64
err := filepath.Walk(path, func(_ string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if !info.IsDir() {
size += info.Size()
}
return err
})
return uint64(size), err
}
func IsValidMediaFile(path string) bool {
return !strings.HasPrefix(path, "._")
}
func IsValidVideoExtension(ext string) bool {
validExtensions := map[string]struct{}{
".mp4": {}, ".avi": {}, ".mkv": {}, ".mov": {}, ".flv": {}, ".wmv": {}, ".webm": {},
".mpeg": {}, ".mpg": {}, ".m4v": {}, ".3gp": {}, ".3g2": {}, ".ogg": {}, ".ogv": {},
".vob": {}, ".mts": {}, ".m2ts": {}, ".ts": {}, ".f4v": {}, ".ogm": {}, ".rm": {},
".rmvb": {}, ".drc": {}, ".yuv": {}, ".asf": {}, ".amv": {}, ".m2v": {}, ".mpe": {},
".mpv": {}, ".mp2": {}, ".svi": {}, ".mxf": {}, ".roq": {}, ".nsv": {}, ".f4p": {},
".f4a": {}, ".f4b": {},
}
ext = strings.ToLower(ext)
_, exists := validExtensions[ext]
return exists
}
func IsSubdirectory(parent, child string) bool {
rel, err := filepath.Rel(parent, child)
if err != nil {
return false
}
return rel != "." && !strings.HasPrefix(rel, ".."+string(os.PathSeparator))
}
func IsSubdirectoryOfAny(dirs []string, child string) bool {
for _, dir := range dirs {
if IsSubdirectory(dir, child) {
return true
}
}
return false
}
func IsSameDir(dir1, dir2 string) bool {
if runtime.GOOS == "windows" {
dir1 = strings.ToLower(dir1)
dir2 = strings.ToLower(dir2)
}
absDir1, err := filepath.Abs(dir1)
if err != nil {
return false
}
absDir2, err := filepath.Abs(dir2)
if err != nil {
return false
}
return absDir1 == absDir2
}
func IsFileUnderDir(filePath, dir string) bool {
// Get the absolute path of the file
absFilePath, err := filepath.Abs(filePath)
if err != nil {
return false
}
// Get the absolute path of the directory
absDir, err := filepath.Abs(dir)
if err != nil {
return false
}
if runtime.GOOS == "windows" {
absFilePath = strings.ToLower(absFilePath)
absDir = strings.ToLower(absDir)
}
// Check if the file path starts with the directory path
return strings.HasPrefix(absFilePath, absDir+string(os.PathSeparator))
}
// UnzipFile unzips a file to the destination.
//
// Example:
// // If "file.zip" contains `folder > file.text`
// UnzipFile("file.zip", "/path/to/dest") // -> "/path/to/dest/folder/file.txt"
// // If "file.zip" contains `file.txt`
// UnzipFile("file.zip", "/path/to/dest") // -> "/path/to/dest/file.txt"
func UnzipFile(src, dest string) error {
r, err := zip.OpenReader(src)
if err != nil {
return fmt.Errorf("failed to open zip file: %w", err)
}
defer r.Close()
// Create a temporary folder to extract the files
extractedDir, err := os.MkdirTemp(filepath.Dir(dest), ".extracted-")
if err != nil {
return fmt.Errorf("failed to create temp folder: %w", err)
}
defer os.RemoveAll(extractedDir)
// Iterate through the files in the archive
for _, f := range r.File {
// Get the full path of the file in the destination
fpath := filepath.Join(extractedDir, f.Name)
// If the file is a directory, create it in the destination
if f.FileInfo().IsDir() {
_ = os.MkdirAll(fpath, os.ModePerm)
continue
}
// Make sure the parent directory exists (will not return an error if it already exists)
if err := os.MkdirAll(filepath.Dir(fpath), os.ModePerm); err != nil {
return fmt.Errorf("failed to create parent directory: %w", err)
}
// Open the file in the destination
outFile, err := os.OpenFile(fpath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, f.Mode())
if err != nil {
return fmt.Errorf("failed to open file: %w", err)
}
// Open the file in the archive
rc, err := f.Open()
if err != nil {
_ = outFile.Close()
return fmt.Errorf("failed to open file in archive: %w", err)
}
// Copy the file from the archive to the destination
_, err = io.Copy(outFile, rc)
_ = outFile.Close()
_ = rc.Close()
if err != nil {
return fmt.Errorf("failed to copy file: %w", err)
}
}
// Ensure the destination directory exists
if err := os.MkdirAll(dest, os.ModePerm); err != nil {
return fmt.Errorf("failed to create destination directory: %w", err)
}
// Move the contents of the extracted directory to the destination
entries, err := os.ReadDir(extractedDir)
if err != nil {
return fmt.Errorf("failed to read extracted directory: %w", err)
}
for _, entry := range entries {
srcPath := filepath.Join(extractedDir, entry.Name())
destPath := filepath.Join(dest, entry.Name())
// Remove existing file/directory at destination if it exists
_ = os.RemoveAll(destPath)
// Move the file/directory to the destination
if err := os.Rename(srcPath, destPath); err != nil {
return fmt.Errorf("failed to move extracted item %s: %w", entry.Name(), err)
}
}
return nil
}
// UnrarFile unzips a rar file to the destination.
func UnrarFile(src, dest string) error {
r, err := rardecode.OpenReader(src)
if err != nil {
return fmt.Errorf("failed to open rar file: %w", err)
}
defer r.Close()
// Create a temporary folder to extract the files
extractedDir, err := os.MkdirTemp(filepath.Dir(dest), ".extracted-")
if err != nil {
return fmt.Errorf("failed to create temp folder: %w", err)
}
defer os.RemoveAll(extractedDir)
// Iterate through the files in the archive
for {
header, err := r.Next()
if err == io.EOF {
break
}
if err != nil {
return fmt.Errorf("failed to get next file in archive: %w", err)
}
// Get the full path of the file in the destination
fpath := filepath.Join(extractedDir, header.Name)
// If the file is a directory, create it in the destination
if header.IsDir {
_ = os.MkdirAll(fpath, os.ModePerm)
continue
}
// Make sure the parent directory exists (will not return an error if it already exists)
if err := os.MkdirAll(filepath.Dir(fpath), os.ModePerm); err != nil {
return fmt.Errorf("failed to create parent directory: %w", err)
}
// Open the file in the destination
outFile, err := os.Create(fpath)
if err != nil {
return fmt.Errorf("failed to create file: %w", err)
}
// Copy the file from the archive to the destination
_, err = io.Copy(outFile, r)
outFile.Close()
if err != nil {
return fmt.Errorf("failed to copy file: %w", err)
}
}
// Ensure the destination directory exists
if err := os.MkdirAll(dest, os.ModePerm); err != nil {
return fmt.Errorf("failed to create destination directory: %w", err)
}
// Move the contents of the extracted directory to the destination
entries, err := os.ReadDir(extractedDir)
if err != nil {
return fmt.Errorf("failed to read extracted directory: %w", err)
}
for _, entry := range entries {
srcPath := filepath.Join(extractedDir, entry.Name())
destPath := filepath.Join(dest, entry.Name())
// Remove existing file/directory at destination if it exists
_ = os.RemoveAll(destPath)
// Move the file/directory to the destination
if err := os.Rename(srcPath, destPath); err != nil {
return fmt.Errorf("failed to move extracted item %s: %w", entry.Name(), err)
}
}
return nil
}
// MoveToDestination moves a folder or file to the destination
//
// Example:
// MoveToDestination("/path/to/src/folder", "/path/to/dest") // -> "/path/to/dest/folder"
func MoveToDestination(src, dest string) error {
// Ensure the destination folder exists
if _, err := os.Stat(dest); os.IsNotExist(err) {
err := os.MkdirAll(dest, os.ModePerm)
if err != nil {
return fmt.Errorf("failed to create destination folder: %v", err)
}
}
destFolder := filepath.Join(dest, filepath.Base(src))
// Move the folder by renaming it
err := os.Rename(src, destFolder)
if err != nil {
return fmt.Errorf("failed to move folder: %v", err)
}
return nil
}
// UnwrapAndMove moves the last subfolder containing the files to the destination.
// If there is a single file, it will move that file only.
//
// Example:
//
// Case 1:
// src/
// - Anime/
// - Ep1.mkv
// - Ep2.mkv
// UnwrapAndMove("/path/to/src", "/path/to/dest") // -> "/path/to/dest/Anime"
//
// Case 2:
// src/
// - {HASH}/
// - Anime/
// - Ep1.mkv
// - Ep2.mkv
// UnwrapAndMove("/path/to/src", "/path/to/dest") // -> "/path/to/dest/Anime"
//
// Case 3:
// src/
// - {HASH}/
// - Anime/
// - Ep1.mkv
// UnwrapAndMove("/path/to/src", "/path/to/dest") // -> "/path/to/dest/Ep1.mkv"
//
// Case 4:
// src/
// - {HASH}/
// - Anime/
// - Anime 1/
// - Ep1.mkv
// - Ep2.mkv
// - Anime 2/
// - Ep1.mkv
// - Ep2.mkv
// UnwrapAndMove("/path/to/src", "/path/to/dest") // -> "/path/to/dest/Anime"
func UnwrapAndMove(src, dest string) error {
// Ensure the source and destination directories exist
if _, err := os.Stat(src); os.IsNotExist(err) {
return fmt.Errorf("source directory does not exist: %s", src)
}
_ = os.MkdirAll(dest, os.ModePerm)
srcEntries, err := os.ReadDir(src)
if err != nil {
return err
}
// If the source folder contains multiple files or folders, move its contents to the destination
if len(srcEntries) > 1 {
for _, srcEntry := range srcEntries {
err := MoveToDestination(filepath.Join(src, srcEntry.Name()), dest)
if err != nil {
return err
}
}
return nil
}
folderMap := make(map[string]int)
err = FindFolderChildCount(src, folderMap)
if err != nil {
return err
}
var folderToMove string
for folder, count := range folderMap {
if count > 1 {
if folderToMove == "" || len(folder) < len(folderToMove) {
folderToMove = folder
}
continue
}
}
// It's a single file, move that file only
if folderToMove == "" {
fp := GetDeepestFile(src)
if fp == "" {
return fmt.Errorf("no files found in the source directory")
}
return MoveToDestination(fp, dest)
}
// Move the folder containing multiple files or folders
err = MoveToDestination(folderToMove, dest)
if err != nil {
return err
}
return nil
}
// Finds the folder to move to the destination
func FindFolderChildCount(src string, folderMap map[string]int) error {
srcEntries, err := os.ReadDir(src)
if err != nil {
return err
}
for _, srcEntry := range srcEntries {
folderMap[src]++
if srcEntry.IsDir() {
err = FindFolderChildCount(filepath.Join(src, srcEntry.Name()), folderMap)
if err != nil {
return err
}
}
}
return nil
}
func GetDeepestFile(src string) (fp string) {
srcEntries, err := os.ReadDir(src)
if err != nil {
return ""
}
for _, srcEntry := range srcEntries {
if srcEntry.IsDir() {
return GetDeepestFile(filepath.Join(src, srcEntry.Name()))
}
return filepath.Join(src, srcEntry.Name())
}
return ""
}

View File

@@ -0,0 +1,91 @@
package util
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestValidVideoExtension(t *testing.T) {
tests := []struct {
ext string
expected bool
}{
{ext: ".mp4", expected: true},
{ext: ".avi", expected: true},
{ext: ".mkv", expected: true},
{ext: ".mov", expected: true},
{ext: ".unknown", expected: false},
{ext: ".MP4", expected: true},
{ext: ".AVI", expected: true},
{ext: "", expected: false},
}
for _, test := range tests {
t.Run(test.ext, func(t *testing.T) {
result := IsValidVideoExtension(test.ext)
require.Equal(t, test.expected, result)
})
}
}
func TestSubdirectory(t *testing.T) {
tests := []struct {
parent string
child string
expected bool
}{
{parent: "C:\\parent", child: "C:\\parent\\child", expected: true},
{parent: "C:\\parent", child: "C:\\parent\\child.txt", expected: true},
{parent: "C:\\parent", child: "C:/PARENT/child.txt", expected: true},
{parent: "C:\\parent", child: "C:\\parent\\..\\child", expected: false},
{parent: "C:\\parent", child: "C:\\parent", expected: false},
}
for _, test := range tests {
t.Run(test.child, func(t *testing.T) {
result := IsSubdirectory(test.parent, test.child)
require.Equal(t, test.expected, result)
})
}
}
func TestIsFileUnderDir(t *testing.T) {
tests := []struct {
parent string
child string
expected bool
}{
{parent: "C:\\parent", child: "C:\\parent\\child", expected: true},
{parent: "C:\\parent", child: "C:\\parent\\child.txt", expected: true},
{parent: "C:\\parent", child: "C:/PARENT/child.txt", expected: true},
{parent: "C:\\parent", child: "C:\\parent\\..\\child", expected: false},
{parent: "C:\\parent", child: "C:\\parent", expected: false},
}
for _, test := range tests {
t.Run(test.child, func(t *testing.T) {
result := IsFileUnderDir(test.parent, test.child)
require.Equal(t, test.expected, result)
})
}
}
func TestSameDir(t *testing.T) {
tests := []struct {
dir1 string
dir2 string
expected bool
}{
{dir1: "C:\\dir", dir2: "C:\\dir", expected: true},
{dir1: "C:\\dir", dir2: "C:\\DIR", expected: true},
{dir1: "C:\\dir1", dir2: "C:\\dir2", expected: false},
}
for _, test := range tests {
t.Run(test.dir2, func(t *testing.T) {
result := IsSameDir(test.dir1, test.dir2)
require.Equal(t, test.expected, result)
})
}
}

View File

@@ -0,0 +1,41 @@
package goja_util
import (
"fmt"
"time"
"github.com/dop251/goja"
)
// BindAwait binds the $await function to the Goja runtime.
// Hooks don't wait for promises to resolve, so $await is used to wrap a promise and wait for it to resolve.
func BindAwait(vm *goja.Runtime) {
vm.Set("$await", func(promise goja.Value) (goja.Value, error) {
if promise, ok := promise.Export().(*goja.Promise); ok {
doneCh := make(chan struct{})
// Wait for the promise to resolve
go func() {
for promise.State() == goja.PromiseStatePending {
time.Sleep(10 * time.Millisecond)
}
close(doneCh)
}()
<-doneCh
// If the promise is rejected, return the error
if promise.State() == goja.PromiseStateRejected {
err := promise.Result()
return nil, fmt.Errorf("promise rejected: %v", err)
}
// If the promise is fulfilled, return the result
res := promise.Result()
return res, nil
}
// If the promise is not a Goja promise, return the value as is
return promise, nil
})
}

View File

@@ -0,0 +1,228 @@
package goja_util
import (
"encoding/json"
"fmt"
"reflect"
"github.com/dop251/goja"
)
func BindMutable(vm *goja.Runtime) {
vm.Set("$mutable", vm.ToValue(func(call goja.FunctionCall) goja.Value {
if len(call.Arguments) == 0 || goja.IsUndefined(call.Arguments[0]) || goja.IsNull(call.Arguments[0]) {
return vm.NewObject()
}
// Convert the input to a map first
jsonBytes, err := json.Marshal(call.Arguments[0].Export())
if err != nil {
panic(vm.NewTypeError("Failed to marshal input: %v", err))
}
var objMap map[string]interface{}
if err := json.Unmarshal(jsonBytes, &objMap); err != nil {
panic(vm.NewTypeError("Failed to unmarshal input: %v", err))
}
// Create a new object with getters and setters
obj := vm.NewObject()
for key, val := range objMap {
// Capture current key and value
k, v := key, val
if mapVal, ok := v.(map[string]interface{}); ok {
// For nested objects, create a new mutable object
nestedObj := vm.NewObject()
// Add get method
nestedObj.Set("get", vm.ToValue(func() interface{} {
return mapVal
}))
// Add set method
nestedObj.Set("set", vm.ToValue(func(call goja.FunctionCall) goja.Value {
if len(call.Arguments) > 0 {
newVal := call.Arguments[0].Export()
if newMap, ok := newVal.(map[string]interface{}); ok {
mapVal = newMap
objMap[k] = newMap
}
}
return goja.Undefined()
}))
// Add direct property access
for mk, mv := range mapVal {
// Capture map key and value
mapKey := mk
mapValue := mv
nestedObj.DefineAccessorProperty(mapKey, vm.ToValue(func() interface{} {
return mapValue
}), vm.ToValue(func(call goja.FunctionCall) goja.Value {
if len(call.Arguments) > 0 {
mapVal[mapKey] = call.Arguments[0].Export()
}
return goja.Undefined()
}), goja.FLAG_FALSE, goja.FLAG_TRUE)
}
obj.Set(k, nestedObj)
} else if arrVal, ok := v.([]interface{}); ok {
// For arrays, create a proxy object that allows index access
arrObj := vm.NewObject()
for i, av := range arrVal {
idx := i
val := av
arrObj.DefineAccessorProperty(fmt.Sprintf("%d", idx), vm.ToValue(func() interface{} {
return val
}), vm.ToValue(func(call goja.FunctionCall) goja.Value {
if len(call.Arguments) > 0 {
arrVal[idx] = call.Arguments[0].Export()
objMap[k] = arrVal
}
return goja.Undefined()
}), goja.FLAG_FALSE, goja.FLAG_TRUE)
}
arrObj.Set("length", len(arrVal))
// Add explicit get/set methods for arrays
arrObj.Set("get", vm.ToValue(func() interface{} {
return arrVal
}))
arrObj.Set("set", vm.ToValue(func(call goja.FunctionCall) goja.Value {
if len(call.Arguments) > 0 {
newVal := call.Arguments[0].Export()
if newArr, ok := newVal.([]interface{}); ok {
arrVal = newArr
objMap[k] = newArr
arrObj.Set("length", len(newArr))
}
}
return goja.Undefined()
}))
obj.Set(k, arrObj)
} else {
// For primitive values, create simple getter/setter
obj.DefineAccessorProperty(k, vm.ToValue(func() interface{} {
return objMap[k]
}), vm.ToValue(func(call goja.FunctionCall) goja.Value {
if len(call.Arguments) > 0 {
objMap[k] = call.Arguments[0].Export()
}
return goja.Undefined()
}), goja.FLAG_FALSE, goja.FLAG_TRUE)
}
}
// Add a toJSON method that creates a fresh copy
obj.Set("toJSON", vm.ToValue(func() interface{} {
// Convert to JSON and back to create a fresh copy with no shared references
jsonBytes, err := json.Marshal(objMap)
if err != nil {
panic(vm.NewTypeError("Failed to marshal to JSON: %v", err))
}
var freshCopy interface{}
if err := json.Unmarshal(jsonBytes, &freshCopy); err != nil {
panic(vm.NewTypeError("Failed to unmarshal from JSON: %v", err))
}
return freshCopy
}))
// Add a replace method to completely replace a Go struct's contents.
// Usage in JS: mutableAnime.replace(e.anime)
obj.Set("replace", vm.ToValue(func(call goja.FunctionCall) goja.Value {
if len(call.Arguments) < 1 {
panic(vm.NewTypeError("replace requires one argument: target"))
}
// Use the current internal state.
jsonBytes, err := json.Marshal(objMap)
if err != nil {
panic(vm.NewTypeError("Failed to marshal state: %v", err))
}
// Get the reflect.Value of the target pointer
target := call.Arguments[0].Export()
targetVal := reflect.ValueOf(target)
if targetVal.Kind() != reflect.Ptr {
// panic(vm.NewTypeError("Target must be a pointer"))
return goja.Undefined()
}
// Create a new instance of the target type and unmarshal into it
newVal := reflect.New(targetVal.Elem().Type())
if err := json.Unmarshal(jsonBytes, newVal.Interface()); err != nil {
panic(vm.NewTypeError("Failed to unmarshal into target: %v", err))
}
// Replace the contents of the target with the new value
targetVal.Elem().Set(newVal.Elem())
return goja.Undefined()
}))
return obj
}))
// Add replace function to completely replace a Go struct's contents
vm.Set("$replace", vm.ToValue(func(call goja.FunctionCall) goja.Value {
if len(call.Arguments) < 2 {
panic(vm.NewTypeError("replace requires two arguments: target and source"))
}
target := call.Arguments[0].Export()
source := call.Arguments[1].Export()
// Marshal source to JSON
sourceJSON, err := json.Marshal(source)
if err != nil {
panic(vm.NewTypeError("Failed to marshal source: %v", err))
}
// Get the reflect.Value of the target pointer
targetVal := reflect.ValueOf(target)
if targetVal.Kind() != reflect.Ptr {
// panic(vm.NewTypeError("Target must be a pointer"))
// TODO: Handle non-pointer targets
return goja.Undefined()
}
// Create a new instance of the target type
newVal := reflect.New(targetVal.Elem().Type())
// Unmarshal JSON into the new instance
if err := json.Unmarshal(sourceJSON, newVal.Interface()); err != nil {
panic(vm.NewTypeError("Failed to unmarshal into target: %v", err))
}
// Replace the contents of the target with the new value
targetVal.Elem().Set(newVal.Elem())
return goja.Undefined()
}))
vm.Set("$clone", vm.ToValue(func(call goja.FunctionCall) goja.Value {
if len(call.Arguments) == 0 {
return goja.Undefined()
}
// First convert to JSON to strip all pointers and references
jsonBytes, err := json.Marshal(call.Arguments[0].Export())
if err != nil {
panic(vm.NewTypeError("Failed to marshal input: %v", err))
}
// Then unmarshal into a fresh interface{} to get a completely new object
var newObj interface{}
if err := json.Unmarshal(jsonBytes, &newObj); err != nil {
panic(vm.NewTypeError("Failed to unmarshal input: %v", err))
}
// Convert back to a goja value
return vm.ToValue(newObj)
}))
}

View File

@@ -0,0 +1,234 @@
package goja_util
import (
"context"
"fmt"
"runtime/debug"
"sync"
"time"
"github.com/samber/mo"
)
// Job represents a task to be executed in the VM
type Job struct {
fn func() error
resultCh chan error
async bool // Flag to indicate if the job is async (doesn't need to wait for result)
}
// Scheduler handles all VM operations added concurrently in a single goroutine
// Any goroutine that needs to execute a VM operation must schedule it because the UI VM isn't thread safe
type Scheduler struct {
jobQueue chan *Job
ctx context.Context
cancel context.CancelFunc
wg sync.WaitGroup
// Track the currently executing job to detect nested scheduling
currentJob *Job
currentJobLock sync.Mutex
onException mo.Option[func(err error)]
}
func NewScheduler() *Scheduler {
ctx, cancel := context.WithCancel(context.Background())
s := &Scheduler{
jobQueue: make(chan *Job, 9999),
ctx: ctx,
onException: mo.None[func(err error)](),
cancel: cancel,
}
s.start()
return s
}
func (s *Scheduler) SetOnException(onException func(err error)) {
s.onException = mo.Some(onException)
}
func (s *Scheduler) start() {
s.wg.Add(1)
go func() {
defer s.wg.Done()
for {
select {
case <-s.ctx.Done():
return
case job := <-s.jobQueue:
// Set the current job before execution
s.currentJobLock.Lock()
s.currentJob = job
s.currentJobLock.Unlock()
err := job.fn()
// Clear the current job after execution
s.currentJobLock.Lock()
s.currentJob = nil
s.currentJobLock.Unlock()
// Only send result if the job is not async
if !job.async {
job.resultCh <- err
}
if err != nil {
if onException, ok := s.onException.Get(); ok {
onException(err)
}
}
}
}
}()
}
func (s *Scheduler) Stop() {
if s.cancel != nil {
s.cancel()
}
//s.wg.Wait()
}
// Schedule adds a job to the queue and waits for its completion
func (s *Scheduler) Schedule(fn func() error) error {
resultCh := make(chan error, 1)
job := &Job{
fn: func() error {
defer func() {
if r := recover(); r != nil {
resultCh <- fmt.Errorf("panic: %v", r)
}
}()
return fn()
},
resultCh: resultCh,
async: false,
}
// Check if we're already in a job execution context
s.currentJobLock.Lock()
isNestedCall := s.currentJob != nil && !s.currentJob.async
s.currentJobLock.Unlock()
// If this is a nested call from a synchronous job, we need to be careful
// We can't execute directly because the VM isn't thread-safe
// Instead, we'll queue it and use a separate goroutine to wait for the result
if isNestedCall {
// Queue the job
select {
case <-s.ctx.Done():
return fmt.Errorf("scheduler stopped")
case s.jobQueue <- job:
// Create a separate goroutine to wait for the result
// This prevents deadlock while still ensuring the job runs in the scheduler
resultCh2 := make(chan error, 1)
go func() {
resultCh2 <- <-resultCh
}()
return <-resultCh2
}
}
// Otherwise, queue the job normally
select {
case <-s.ctx.Done():
return fmt.Errorf("scheduler stopped")
case s.jobQueue <- job:
return <-resultCh
}
}
// ScheduleAsync adds a job to the queue without waiting for completion
// This is useful for fire-and-forget operations or when a job needs to schedule another job
func (s *Scheduler) ScheduleAsync(fn func() error) {
job := &Job{
fn: func() error {
defer func() {
if r := recover(); r != nil {
// Get stack trace for better identification
stack := debug.Stack()
jobInfo := fmt.Sprintf("async job panic: %v\nStack: %s", r, stack)
if onException, ok := s.onException.Get(); ok {
onException(fmt.Errorf("panic in async job: %v\n%s", r, jobInfo))
}
}
}()
return fn()
},
resultCh: nil, // No result channel needed
async: true,
}
// Queue the job without blocking
select {
case <-s.ctx.Done():
// Scheduler is stopped, just ignore
return
case s.jobQueue <- job:
// Job queued successfully
// fmt.Printf("job queued successfully, length: %d\n", len(s.jobQueue))
return
default:
// Queue is full, log an error
if onException, ok := s.onException.Get(); ok {
onException(fmt.Errorf("async job queue is full"))
}
}
}
// ScheduleWithTimeout schedules a job with a timeout
func (s *Scheduler) ScheduleWithTimeout(fn func() error, timeout time.Duration) error {
resultCh := make(chan error, 1)
job := &Job{
fn: func() error {
defer func() {
if r := recover(); r != nil {
resultCh <- fmt.Errorf("panic: %v", r)
}
}()
return fn()
},
resultCh: resultCh,
async: false,
}
// Check if we're already in a job execution context
s.currentJobLock.Lock()
isNestedCall := s.currentJob != nil && !s.currentJob.async
s.currentJobLock.Unlock()
// If this is a nested call from a synchronous job, handle it specially
if isNestedCall {
// Queue the job
select {
case <-s.ctx.Done():
return fmt.Errorf("scheduler stopped")
case s.jobQueue <- job:
// Create a separate goroutine to wait for the result with timeout
resultCh2 := make(chan error, 1)
go func() {
select {
case err := <-resultCh:
resultCh2 <- err
case <-time.After(timeout):
resultCh2 <- fmt.Errorf("operation timed out")
}
}()
return <-resultCh2
}
}
select {
case <-s.ctx.Done():
return fmt.Errorf("scheduler stopped")
case s.jobQueue <- job:
select {
case err := <-resultCh:
return err
case <-time.After(timeout):
return fmt.Errorf("operation timed out")
}
}
}

View File

@@ -0,0 +1,24 @@
//go:build !windows
package util
import (
"os"
"path/filepath"
"strings"
)
func HideFile(path string) (string, error) {
filename := filepath.Base(path)
if strings.HasPrefix(filename, ".") {
return path, nil
}
newPath := filepath.Join(filepath.Dir(path), "."+filename)
err := os.Rename(path, newPath)
if err != nil {
return "", err
}
return newPath, nil
}

View File

@@ -0,0 +1,21 @@
//go:build windows
package util
import (
"syscall"
)
func HideFile(path string) (string, error) {
defer HandlePanicInModuleThen("HideFile", func() {})
p, err := syscall.UTF16PtrFromString(path)
if err != nil {
return "", err
}
err = syscall.SetFileAttributes(p, syscall.FILE_ATTRIBUTE_HIDDEN)
if err != nil {
return "", err
}
return path, nil
}

View File

@@ -0,0 +1,176 @@
package util
import (
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"strings"
"time"
)
type TokenClaims struct {
Endpoint string `json:"endpoint"` // The endpoint this token is valid for
IssuedAt int64 `json:"iat"`
ExpiresAt int64 `json:"exp"`
}
type HMACAuth struct {
secret []byte
ttl time.Duration
}
// base64URLEncode encodes to base64url without padding (to match frontend)
func base64URLEncode(data []byte) string {
return strings.TrimRight(base64.URLEncoding.EncodeToString(data), "=")
}
// base64URLDecode decodes from base64url with or without padding
func base64URLDecode(data string) ([]byte, error) {
// Add padding if needed
if m := len(data) % 4; m != 0 {
data += strings.Repeat("=", 4-m)
}
return base64.URLEncoding.DecodeString(data)
}
// NewHMACAuth creates a new HMAC authentication instance
func NewHMACAuth(secret string, ttl time.Duration) *HMACAuth {
return &HMACAuth{
secret: []byte(secret),
ttl: ttl,
}
}
// GenerateToken generates an HMAC-signed token for the given endpoint
func (h *HMACAuth) GenerateToken(endpoint string) (string, error) {
now := time.Now().Unix()
claims := TokenClaims{
Endpoint: endpoint,
IssuedAt: now,
ExpiresAt: now + int64(h.ttl.Seconds()),
}
// Serialize claims to JSON
claimsJSON, err := json.Marshal(claims)
if err != nil {
return "", fmt.Errorf("failed to marshal claims: %w", err)
}
// Encode claims as base64url without padding
claimsB64 := base64URLEncode(claimsJSON)
// Generate HMAC signature
mac := hmac.New(sha256.New, h.secret)
mac.Write([]byte(claimsB64))
signature := base64URLEncode(mac.Sum(nil))
// Return token in format: claims.signature
return claimsB64 + "." + signature, nil
}
// ValidateToken validates an HMAC token and returns the claims if valid
func (h *HMACAuth) ValidateToken(token string, endpoint string) (*TokenClaims, error) {
// Split token into claims and signature
parts := splitToken(token)
if len(parts) != 2 {
return nil, fmt.Errorf("invalid token format - expected 2 parts, got %d", len(parts))
}
claimsB64, signature := parts[0], parts[1]
// Verify signature
mac := hmac.New(sha256.New, h.secret)
mac.Write([]byte(claimsB64))
expectedSignature := base64URLEncode(mac.Sum(nil))
if !hmac.Equal([]byte(signature), []byte(expectedSignature)) {
return nil, fmt.Errorf("invalid token signature, the password hashes may not match")
}
// Decode claims
claimsJSON, err := base64URLDecode(claimsB64)
if err != nil {
return nil, fmt.Errorf("failed to decode claims: %w", err)
}
var claims TokenClaims
if err := json.Unmarshal(claimsJSON, &claims); err != nil {
return nil, fmt.Errorf("failed to unmarshal claims: %w", err)
}
// Validate expiration
now := time.Now().Unix()
if claims.ExpiresAt < now {
return nil, fmt.Errorf("token expired - expires at %d, current time %d", claims.ExpiresAt, now)
}
// Validate endpoint (optional, can be wildcard *)
if endpoint != "" && claims.Endpoint != "*" && claims.Endpoint != endpoint {
return nil, fmt.Errorf("token not valid for endpoint %s - claim endpoint: %s", endpoint, claims.Endpoint)
}
return &claims, nil
}
// GenerateQueryParam generates a query parameter string with the HMAC token
func (h *HMACAuth) GenerateQueryParam(endpoint string, symbol string) (string, error) {
token, err := h.GenerateToken(endpoint)
if err != nil {
return "", err
}
if symbol == "" {
symbol = "?"
}
return fmt.Sprintf("%stoken=%s", symbol, token), nil
}
// ValidateQueryParam extracts and validates token from query parameter
func (h *HMACAuth) ValidateQueryParam(tokenParam string, endpoint string) (*TokenClaims, error) {
if tokenParam == "" {
return nil, fmt.Errorf("no token provided")
}
return h.ValidateToken(tokenParam, endpoint)
}
// splitToken splits a token string by the last dot separator
func splitToken(token string) []string {
// Find the last dot to split claims from signature
for i := len(token) - 1; i >= 0; i-- {
if token[i] == '.' {
return []string{token[:i], token[i+1:]}
}
}
return []string{token}
}
func (h *HMACAuth) GetTokenExpiry(token string) (time.Time, error) {
parts := splitToken(token)
if len(parts) != 2 {
return time.Time{}, fmt.Errorf("invalid token format")
}
claimsJSON, err := base64URLDecode(parts[0])
if err != nil {
return time.Time{}, fmt.Errorf("failed to decode claims: %w", err)
}
var claims TokenClaims
if err := json.Unmarshal(claimsJSON, &claims); err != nil {
return time.Time{}, fmt.Errorf("failed to unmarshal claims: %w", err)
}
return time.Unix(claims.ExpiresAt, 0), nil
}
func (h *HMACAuth) IsTokenExpired(token string) bool {
expiry, err := h.GetTokenExpiry(token)
if err != nil {
return true
}
return time.Now().After(expiry)
}

View File

@@ -0,0 +1,388 @@
package httputil
import (
"context"
"errors"
"io"
"os"
"sync"
"time"
"github.com/rs/zerolog"
)
type piece struct {
start int64
end int64
}
// FileStream saves a HTTP file being streamed to disk.
// It allows multiple readers to read the file concurrently.
// It works by being fed the stream from the HTTP response body. It will simultaneously write to disk and to the HTTP writer.
type FileStream struct {
length int64
file *os.File
closed bool
mu sync.Mutex
pieces map[int64]*piece
readers []FileStreamReader
readersMu sync.Mutex
ctx context.Context
cancel context.CancelFunc
logger *zerolog.Logger
}
type FileStreamReader interface {
io.ReadSeekCloser
}
// NewFileStream creates a new FileStream instance with a temporary file
func NewFileStream(ctx context.Context, logger *zerolog.Logger, contentLength int64) (*FileStream, error) {
file, err := os.CreateTemp("", "filestream_*.tmp")
if err != nil {
return nil, err
}
// Pre-allocate the file to the expected content length
if contentLength > 0 {
if err := file.Truncate(contentLength); err != nil {
_ = file.Close()
_ = os.Remove(file.Name())
return nil, err
}
}
ctx, cancel := context.WithCancel(ctx)
return &FileStream{
file: file,
ctx: ctx,
cancel: cancel,
logger: logger,
pieces: make(map[int64]*piece),
length: contentLength,
}, nil
}
// WriteAndFlush writes the stream to the file at the given offset and flushes it to the HTTP writer
func (fs *FileStream) WriteAndFlush(src io.Reader, dst io.Writer, offset int64) error {
fs.mu.Lock()
if fs.closed {
fs.mu.Unlock()
return io.ErrClosedPipe
}
fs.mu.Unlock()
buffer := make([]byte, 32*1024) // 32KB buffer
currentOffset := offset
for {
select {
case <-fs.ctx.Done():
return fs.ctx.Err()
default:
}
n, readErr := src.Read(buffer)
if n > 0 {
// Write to file
fs.mu.Lock()
if !fs.closed {
if _, err := fs.file.WriteAt(buffer[:n], currentOffset); err != nil {
fs.mu.Unlock()
return err
}
// Update pieces map
pieceEnd := currentOffset + int64(n) - 1
fs.updatePieces(currentOffset, pieceEnd)
}
fs.mu.Unlock()
// Write to HTTP response
if _, err := dst.Write(buffer[:n]); err != nil {
return err
}
// Flush if possible
if flusher, ok := dst.(interface{ Flush() }); ok {
flusher.Flush()
}
currentOffset += int64(n)
}
if readErr != nil {
if readErr == io.EOF {
break
}
return readErr
}
}
// Sync file to ensure data is written
fs.mu.Lock()
if !fs.closed {
_ = fs.file.Sync()
}
fs.mu.Unlock()
return nil
}
// updatePieces merges the new piece with existing pieces
func (fs *FileStream) updatePieces(start, end int64) {
newPiece := &piece{start: start, end: end}
// Find overlapping or adjacent pieces
var toMerge []*piece
var toDelete []int64
for key, p := range fs.pieces {
if p.start <= end+1 && p.end >= start-1 {
toMerge = append(toMerge, p)
toDelete = append(toDelete, key)
}
}
// Merge all overlapping pieces
for _, p := range toMerge {
if p.start < newPiece.start {
newPiece.start = p.start
}
if p.end > newPiece.end {
newPiece.end = p.end
}
}
// Delete old pieces
for _, key := range toDelete {
delete(fs.pieces, key)
}
// Add the merged piece
fs.pieces[newPiece.start] = newPiece
}
// isRangeAvailable checks if a given range is completely downloaded
func (fs *FileStream) isRangeAvailable(start, end int64) bool {
for _, p := range fs.pieces {
if p.start <= start && p.end >= end {
return true
}
}
return false
}
// NewReader creates a new FileStreamReader for concurrent reading
func (fs *FileStream) NewReader() (FileStreamReader, error) {
fs.mu.Lock()
defer fs.mu.Unlock()
if fs.closed {
return nil, io.ErrClosedPipe
}
reader := &fileStreamReader{
fs: fs,
file: fs.file,
offset: 0,
}
fs.readersMu.Lock()
fs.readers = append(fs.readers, reader)
fs.readersMu.Unlock()
return reader, nil
}
// Close closes the FileStream and cleans up resources
func (fs *FileStream) Close() error {
fs.mu.Lock()
defer fs.mu.Unlock()
if fs.closed {
return nil
}
fs.closed = true
fs.cancel()
// Close all readers
fs.readersMu.Lock()
for _, reader := range fs.readers {
go reader.Close()
}
fs.readers = nil
fs.readersMu.Unlock()
// Remove the temp file and close
fileName := fs.file.Name()
_ = fs.file.Close()
_ = os.Remove(fileName)
return nil
}
// Length returns the current length of the stream
func (fs *FileStream) Length() int64 {
return fs.length
}
// fileStreamReader implements FileStreamReader interface
type fileStreamReader struct {
fs *FileStream
file *os.File
offset int64
closed bool
mu sync.Mutex
}
// Read reads data from the file stream, blocking if data is not yet available
func (r *fileStreamReader) Read(p []byte) (int, error) {
r.mu.Lock()
defer r.mu.Unlock()
if r.closed {
return 0, io.ErrClosedPipe
}
readSize := int64(len(p))
readEnd := r.offset + readSize - 1
if readEnd >= r.fs.length {
readEnd = r.fs.length - 1
readSize = r.fs.length - r.offset
if readSize <= 0 {
return 0, io.EOF
}
}
for {
select {
case <-r.fs.ctx.Done():
return 0, r.fs.ctx.Err()
default:
}
r.fs.mu.Lock()
streamClosed := r.fs.closed
// Check if the range we want to read is available
available := r.fs.isRangeAvailable(r.offset, readEnd)
// If not fully available, check what we can read
var actualReadSize int64 = readSize
if !available {
// Find the largest available chunk starting from our offset
var maxRead int64 = 0
for _, piece := range r.fs.pieces {
if piece.start <= r.offset && piece.end >= r.offset {
chunkEnd := piece.end
if chunkEnd >= readEnd {
maxRead = readSize
} else {
maxRead = chunkEnd - r.offset + 1
}
break
}
}
actualReadSize = maxRead
}
r.fs.mu.Unlock()
// If we have some data to read, or if stream is closed, attempt the read
if available || actualReadSize > 0 || streamClosed {
var n int
var err error
if actualReadSize > 0 {
n, err = r.file.ReadAt(p[:actualReadSize], r.offset)
} else if streamClosed {
// If stream is closed and no data available, try reading anyway to get proper EOF
n, err = r.file.ReadAt(p[:readSize], r.offset)
}
if n > 0 {
r.offset += int64(n)
}
// If we read less than requested and stream is closed, return EOF
if n < len(p) && streamClosed && r.offset >= r.fs.length {
if err == nil {
err = io.EOF
}
}
// If no data was read and stream is closed, return EOF
if n == 0 && streamClosed {
return 0, io.EOF
}
// Return what we got, even if it's 0 bytes (this prevents hanging)
return n, err
}
// Wait a bit before checking again
r.mu.Unlock()
select {
case <-r.fs.ctx.Done():
r.mu.Lock()
return 0, r.fs.ctx.Err()
case <-time.After(10 * time.Millisecond):
r.mu.Lock()
}
}
}
// Seek sets the offset for the next Read
func (r *fileStreamReader) Seek(offset int64, whence int) (int64, error) {
r.mu.Lock()
defer r.mu.Unlock()
if r.closed {
return 0, io.ErrClosedPipe
}
switch whence {
case io.SeekStart:
r.offset = offset
case io.SeekCurrent:
r.offset += offset
case io.SeekEnd:
r.fs.mu.Lock()
r.offset = r.fs.length + offset
r.fs.mu.Unlock()
default:
return 0, errors.New("invalid whence")
}
if r.offset < 0 {
r.offset = 0
}
return r.offset, nil
}
// Close closes the reader
func (r *fileStreamReader) Close() error {
r.mu.Lock()
defer r.mu.Unlock()
if r.closed {
return nil
}
r.closed = true
r.fs.readersMu.Lock()
for i, reader := range r.fs.readers {
if reader == r {
r.fs.readers = append(r.fs.readers[:i], r.fs.readers[i+1:]...)
break
}
}
r.fs.readersMu.Unlock()
return nil
}

View File

@@ -0,0 +1,106 @@
package httputil
import (
"errors"
"fmt"
"net/textproto"
"strconv"
"strings"
)
// Range specifies the byte range to be sent to the client.
type Range struct {
Start int64
Length int64
}
// ContentRange returns Content-Range header value.
func (r Range) ContentRange(size int64) string {
return fmt.Sprintf("bytes %d-%d/%d", r.Start, r.Start+r.Length-1, size)
}
var (
// ErrNoOverlap is returned by ParseRange if first-byte-pos of
// all of the byte-range-spec values is greater than the content size.
ErrNoOverlap = errors.New("invalid range: failed to overlap")
// ErrInvalid is returned by ParseRange on invalid input.
ErrInvalid = errors.New("invalid range")
)
// ParseRange parses a Range header string as per RFC 7233.
// ErrNoOverlap is returned if none of the ranges overlap.
// ErrInvalid is returned if s is invalid range.
func ParseRange(s string, size int64) ([]Range, error) {
if s == "" {
return nil, nil // header not present
}
const b = "bytes="
if !strings.HasPrefix(s, b) {
return nil, ErrInvalid
}
var ranges []Range
noOverlap := false
for _, ra := range strings.Split(s[len(b):], ",") {
ra = textproto.TrimString(ra)
if ra == "" {
continue
}
i := strings.Index(ra, "-")
if i < 0 {
return nil, ErrInvalid
}
start, end := textproto.TrimString(ra[:i]), textproto.TrimString(ra[i+1:])
var r Range
if start == "" {
// If no start is specified, end specifies the
// range start relative to the end of the file,
// and we are dealing with <suffix-length>
// which has to be a non-negative integer as per
// RFC 7233 Section 2.1 "Byte-Ranges".
if end == "" || end[0] == '-' {
return nil, ErrInvalid
}
i, err := strconv.ParseInt(end, 10, 64)
if i < 0 || err != nil {
return nil, ErrInvalid
}
if i > size {
i = size
}
r.Start = size - i
r.Length = size - r.Start
} else {
i, err := strconv.ParseInt(start, 10, 64)
if err != nil || i < 0 {
return nil, ErrInvalid
}
if i >= size {
// If the range begins after the size of the content,
// then it does not overlap.
noOverlap = true
continue
}
r.Start = i
if end == "" {
// If no end is specified, range extends to end of the file.
r.Length = size - r.Start
} else {
i, err := strconv.ParseInt(end, 10, 64)
if err != nil || r.Start > i {
return nil, ErrInvalid
}
if i >= size {
i = size - 1
}
r.Length = i - r.Start + 1
}
}
ranges = append(ranges, r)
}
if noOverlap && len(ranges) == 0 {
// The specified ranges did not overlap with the content.
return nil, ErrNoOverlap
}
return ranges, nil
}

View File

@@ -0,0 +1,289 @@
package httputil
// Original source: https://github.com/jfbus/httprs/tree/master
/*
Package httprs provides a ReadSeeker for http.Response.Body.
Usage :
resp, err := http.Get(url)
rs := httprs.NewHttpReadSeeker(resp)
defer rs.Close()
io.ReadFull(rs, buf) // reads the first bytes from the response body
rs.Seek(1024, 0) // moves the position, but does no range request
io.ReadFull(rs, buf) // does a range request and reads from the response body
If you want to use a specific http.Client for additional range requests :
rs := httprs.NewHttpReadSeeker(resp, client)
*/
import (
"fmt"
"io"
"net/http"
"seanime/internal/util/limiter"
"strconv"
"strings"
"sync"
)
// HttpReadSeeker implements io.ReadSeeker for HTTP responses
// It allows seeking within an HTTP response by using HTTP Range requests
type HttpReadSeeker struct {
url string // The URL of the resource
client *http.Client // HTTP client to use for requests
resp *http.Response // Current response
offset int64 // Current offset in the resource
size int64 // Size of the resource, -1 if unknown
readBuf []byte // Buffer for reading
readOffset int // Current offset in readBuf
mu sync.Mutex // Mutex for thread safety
rateLimiter *limiter.Limiter
}
// NewHttpReadSeeker creates a new HttpReadSeeker from an http.Response
func NewHttpReadSeeker(resp *http.Response) *HttpReadSeeker {
url := ""
if resp.Request != nil {
url = resp.Request.URL.String()
}
size := int64(-1)
if resp.ContentLength > 0 {
size = resp.ContentLength
}
return &HttpReadSeeker{
url: url,
client: http.DefaultClient,
resp: resp,
offset: 0,
size: size,
readBuf: nil,
readOffset: 0,
}
}
func NewHttpReadSeekerFromURL(url string) (*HttpReadSeeker, error) {
resp, err := http.Get(url)
if err != nil {
return nil, fmt.Errorf("httprs: failed to get URL %s: %w", url, err)
}
return NewHttpReadSeeker(resp), nil
}
// Read implements io.Reader
func (hrs *HttpReadSeeker) Read(p []byte) (n int, err error) {
hrs.mu.Lock()
defer hrs.mu.Unlock()
// If we have buffered data, read from it first
if hrs.readBuf != nil && hrs.readOffset < len(hrs.readBuf) {
n = copy(p, hrs.readBuf[hrs.readOffset:])
hrs.readOffset += n
hrs.offset += int64(n)
// Clear buffer if we've read it all
if hrs.readOffset >= len(hrs.readBuf) {
hrs.readBuf = nil
hrs.readOffset = 0
}
return n, nil
}
// If we don't have a response or it's been closed, get a new one
if hrs.resp == nil {
if err := hrs.makeRangeRequest(); err != nil {
return 0, err
}
}
// Read from the response body
n, err = hrs.resp.Body.Read(p)
hrs.offset += int64(n)
return n, err
}
// Seek implements io.Seeker
func (hrs *HttpReadSeeker) Seek(offset int64, whence int) (int64, error) {
hrs.mu.Lock()
defer hrs.mu.Unlock()
var newOffset int64
switch whence {
case io.SeekStart:
newOffset = offset
case io.SeekCurrent:
newOffset = hrs.offset + offset
case io.SeekEnd:
if hrs.size < 0 {
// If we don't know the size, we need to determine it
if err := hrs.determineSize(); err != nil {
return hrs.offset, err
}
}
newOffset = hrs.size + offset
default:
return hrs.offset, fmt.Errorf("httprs: invalid whence %d", whence)
}
if newOffset < 0 {
return hrs.offset, fmt.Errorf("httprs: negative position")
}
// If we're just moving the offset without reading, we can skip the request
// We'll make a new request when Read is called
if hrs.resp != nil {
hrs.resp.Body.Close()
hrs.resp = nil
}
hrs.offset = newOffset
hrs.readBuf = nil
hrs.readOffset = 0
return hrs.offset, nil
}
// Close closes the underlying response body
func (hrs *HttpReadSeeker) Close() error {
hrs.mu.Lock()
defer hrs.mu.Unlock()
if hrs.resp != nil {
err := hrs.resp.Body.Close()
hrs.resp = nil
return err
}
return nil
}
// makeRangeRequest makes a new HTTP request with the Range header
func (hrs *HttpReadSeeker) makeRangeRequest() error {
req, err := http.NewRequest("GET", hrs.url, nil)
if err != nil {
return err
}
// Set Range header from current offset
req.Header.Set("Range", fmt.Sprintf("bytes=%d-", hrs.offset))
// Make the request
resp, err := hrs.client.Do(req)
if err != nil {
return err
}
// Check if the server supports range requests
if resp.StatusCode != http.StatusPartialContent && hrs.offset > 0 {
resp.Body.Close()
return fmt.Errorf("httprs: server does not support range requests")
}
// Update our response and offset
if hrs.resp != nil {
hrs.resp.Body.Close()
}
hrs.resp = resp
// Update the size if we get it from Content-Range
if contentRange := resp.Header.Get("Content-Range"); contentRange != "" {
// Format: bytes <start>-<end>/<size>
parts := strings.Split(contentRange, "/")
if len(parts) > 1 && parts[1] != "*" {
if size, err := strconv.ParseInt(parts[1], 10, 64); err == nil {
hrs.size = size
}
}
} else if resp.ContentLength > 0 {
// If we don't have a Content-Range header but we do have Content-Length,
// then the size is the current offset plus the content length
hrs.size = hrs.offset + resp.ContentLength
}
return nil
}
// determineSize makes a HEAD request to determine the size of the resource
func (hrs *HttpReadSeeker) determineSize() error {
req, err := http.NewRequest("HEAD", hrs.url, nil)
if err != nil {
return err
}
resp, err := hrs.client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.ContentLength > 0 {
hrs.size = resp.ContentLength
} else {
// If we still don't know the size, return an error
return fmt.Errorf("httprs: unable to determine resource size")
}
return nil
}
// ReadAt implements io.ReaderAt
func (hrs *HttpReadSeeker) ReadAt(p []byte, off int64) (n int, err error) {
// Save current offset
currentOffset := hrs.offset
// Seek to the requested offset
if _, err := hrs.Seek(off, io.SeekStart); err != nil {
return 0, err
}
// Read the data
n, err = hrs.Read(p)
// Restore the original offset
if _, seekErr := hrs.Seek(currentOffset, io.SeekStart); seekErr != nil {
// If we can't restore the offset, return that error instead
if err == nil {
err = seekErr
}
}
return n, err
}
// Size returns the size of the resource, or -1 if unknown
func (hrs *HttpReadSeeker) Size() int64 {
hrs.mu.Lock()
defer hrs.mu.Unlock()
if hrs.size < 0 {
// Try to determine the size
_ = hrs.determineSize()
}
return hrs.size
}
// WithClient returns a new HttpReadSeeker with the specified client
func (hrs *HttpReadSeeker) WithClient(client *http.Client) *HttpReadSeeker {
hrs.mu.Lock()
defer hrs.mu.Unlock()
hrs.client = client
return hrs
}
func (hrs *HttpReadSeeker) WithRateLimiter(rl *limiter.Limiter) *HttpReadSeeker {
hrs.mu.Lock()
defer hrs.mu.Unlock()
hrs.rateLimiter = rl
return hrs
}

View File

@@ -0,0 +1,389 @@
package image_downloader
import (
"bytes"
"fmt"
"image"
_ "image/gif"
_ "image/jpeg"
_ "image/png"
"io"
"net/http"
"os"
"path/filepath"
"seanime/internal/util"
"seanime/internal/util/limiter"
"slices"
"sync"
"time"
"github.com/goccy/go-json"
"github.com/google/uuid"
"github.com/rs/zerolog"
_ "golang.org/x/image/bmp"
_ "golang.org/x/image/tiff"
_ "golang.org/x/image/webp"
)
const (
RegistryFilename = "registry.json"
)
type (
ImageDownloader struct {
downloadDir string
registry Registry
cancelChannel chan struct{}
logger *zerolog.Logger
actionMu sync.Mutex
registryMu sync.Mutex
}
Registry struct {
content *RegistryContent
logger *zerolog.Logger
downloadDir string
registryPath string
mu sync.Mutex
}
RegistryContent struct {
UrlToId map[string]string `json:"url_to_id"`
IdToUrl map[string]string `json:"id_to_url"`
IdToExt map[string]string `json:"id_to_ext"`
}
)
func NewImageDownloader(downloadDir string, logger *zerolog.Logger) *ImageDownloader {
_ = os.MkdirAll(downloadDir, os.ModePerm)
return &ImageDownloader{
downloadDir: downloadDir,
logger: logger,
registry: Registry{
logger: logger,
registryPath: filepath.Join(downloadDir, RegistryFilename),
downloadDir: downloadDir,
content: &RegistryContent{},
},
cancelChannel: make(chan struct{}),
}
}
// DownloadImages downloads multiple images concurrently.
func (id *ImageDownloader) DownloadImages(urls []string) (err error) {
id.cancelChannel = make(chan struct{})
if err = id.registry.setup(); err != nil {
return
}
rateLimiter := limiter.NewLimiter(1*time.Second, 10)
var wg sync.WaitGroup
for _, url := range urls {
wg.Add(1)
go func(url string) {
defer wg.Done()
select {
case <-id.cancelChannel:
id.logger.Warn().Msg("image downloader: Download process canceled")
return
default:
rateLimiter.Wait()
id.downloadImage(url)
}
}(url)
}
wg.Wait()
if err = id.registry.save(urls); err != nil {
return
}
return
}
func (id *ImageDownloader) DeleteDownloads() {
id.actionMu.Lock()
defer id.actionMu.Unlock()
id.registryMu.Lock()
defer id.registryMu.Unlock()
_ = os.RemoveAll(id.downloadDir)
id.registry.content = &RegistryContent{}
}
// CancelDownload cancels the download process.
func (id *ImageDownloader) CancelDownload() {
close(id.cancelChannel)
}
func (id *ImageDownloader) GetImageFilenameByUrl(url string) (filename string, ok bool) {
id.actionMu.Lock()
defer id.actionMu.Unlock()
id.registryMu.Lock()
defer id.registryMu.Unlock()
if err := id.registry.setup(); err != nil {
return
}
var imgID string
imgID, ok = id.registry.content.UrlToId[url]
if !ok {
return
}
filename = imgID + "." + id.registry.content.IdToExt[imgID]
return
}
// GetImageFilenamesByUrls returns a map of URLs to image filenames.
//
// e.g., {"url1": "filename1.png", "url2": "filename2.jpg"}
func (id *ImageDownloader) GetImageFilenamesByUrls(urls []string) (ret map[string]string, err error) {
id.actionMu.Lock()
defer id.actionMu.Unlock()
id.registryMu.Lock()
defer id.registryMu.Unlock()
ret = make(map[string]string)
if err = id.registry.setup(); err != nil {
return nil, err
}
for _, url := range urls {
imgID, ok := id.registry.content.UrlToId[url]
if !ok {
continue
}
ret[url] = imgID + "." + id.registry.content.IdToExt[imgID]
}
return
}
func (id *ImageDownloader) DeleteImagesByUrls(urls []string) (err error) {
id.actionMu.Lock()
defer id.actionMu.Unlock()
id.registryMu.Lock()
defer id.registryMu.Unlock()
if err = id.registry.setup(); err != nil {
return
}
for _, url := range urls {
imgID, ok := id.registry.content.UrlToId[url]
if !ok {
continue
}
err = os.Remove(filepath.Join(id.downloadDir, imgID+"."+id.registry.content.IdToExt[imgID]))
if err != nil {
continue
}
delete(id.registry.content.UrlToId, url)
delete(id.registry.content.IdToUrl, imgID)
delete(id.registry.content.IdToExt, imgID)
}
return
}
// downloadImage downloads an image from a URL.
func (id *ImageDownloader) downloadImage(url string) {
defer util.HandlePanicInModuleThen("util/image_downloader/downloadImage", func() {
})
if url == "" {
id.logger.Warn().Msg("image downloader: Empty URL provided, skipping download")
return
}
// Check if the image has already been downloaded
id.registryMu.Lock()
if _, ok := id.registry.content.UrlToId[url]; ok {
id.registryMu.Unlock()
id.logger.Debug().Msgf("image downloader: Image from URL %s has already been downloaded", url)
return
}
id.registryMu.Unlock()
// Download image from URL
id.logger.Info().Msgf("image downloader: Downloading image from URL: %s", url)
imgID := uuid.NewString()
// Download the image
resp, err := http.Get(url)
if err != nil {
id.logger.Error().Err(err).Msgf("image downloader: Failed to download image from URL %s", url)
return
}
defer resp.Body.Close()
buf, err := io.ReadAll(resp.Body)
if err != nil {
id.logger.Error().Err(err).Msgf("image downloader: Failed to read image data from URL %s", url)
return
}
// Get the image format
_, format, err := image.DecodeConfig(bytes.NewReader(buf))
if err != nil {
id.logger.Error().Err(err).Msgf("image downloader: Failed to decode image format from URL %s", url)
return
}
// Create the file
filePath := filepath.Join(id.downloadDir, imgID+"."+format)
file, err := os.Create(filePath)
if err != nil {
id.logger.Error().Err(err).Msgf("image downloader: Failed to create file for image %s", imgID)
return
}
defer file.Close()
// Copy the image data to the file
_, err = io.Copy(file, bytes.NewReader(buf))
if err != nil {
id.logger.Error().Err(err).Msgf("image downloader: Failed to write image data to file for image from %s", url)
return
}
// Update registry
id.registryMu.Lock()
id.registry.addUrl(imgID, url, format)
id.registryMu.Unlock()
return
}
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
func (r *Registry) setup() (err error) {
r.mu.Lock()
defer r.mu.Unlock()
defer util.HandlePanicInModuleThen("util/image_downloader/setup", func() {
err = fmt.Errorf("image downloader: Failed to setup registry")
})
if r.content.IdToUrl != nil && r.content.UrlToId != nil {
return nil
}
r.content.UrlToId = make(map[string]string)
r.content.IdToUrl = make(map[string]string)
r.content.IdToExt = make(map[string]string)
// Check if the registry exists
_ = os.MkdirAll(filepath.Dir(r.registryPath), os.ModePerm)
_, err = os.Stat(r.registryPath)
if os.IsNotExist(err) {
// Create the registry file
err = os.WriteFile(r.registryPath, []byte("{}"), os.ModePerm)
if err != nil {
return err
}
}
// Read the registry file
file, err := os.Open(r.registryPath)
if err != nil {
return err
}
defer file.Close()
// Decode the registry file if there is content
if file != nil {
r.logger.Debug().Msg("image downloader: Reading registry content")
err = json.NewDecoder(file).Decode(&r.content)
if err != nil {
return err
}
}
if r.content == nil {
r.content = &RegistryContent{
UrlToId: make(map[string]string),
IdToUrl: make(map[string]string),
IdToExt: make(map[string]string),
}
}
return nil
}
// save verifies and saves the registry content.
func (r *Registry) save(urls []string) (err error) {
r.mu.Lock()
defer r.mu.Unlock()
defer util.HandlePanicInModuleThen("util/image_downloader/save", func() {
err = fmt.Errorf("image downloader: Failed to save registry content")
})
// Verify all images have been downloaded
allDownloaded := true
for _, url := range urls {
if url == "" {
continue
}
if _, ok := r.content.UrlToId[url]; !ok {
allDownloaded = false
break
}
}
if !allDownloaded {
// Clean up downloaded images
go func() {
r.logger.Error().Msg("image downloader: Not all images have been downloaded, aborting")
// Read the directory
files, err := os.ReadDir(r.downloadDir)
if err != nil {
r.logger.Error().Err(err).Msg("image downloader: Failed to abort")
return
}
// Delete all files that have been downloaded (are in the registry)
for _, file := range files {
fileNameWithoutExt := file.Name()[:len(file.Name())-len(filepath.Ext(file.Name()))]
if url, ok := r.content.IdToUrl[fileNameWithoutExt]; ok && slices.Contains(urls, url) {
err = os.Remove(filepath.Join(r.downloadDir, file.Name()))
if err != nil {
r.logger.Error().Err(err).Msgf("image downloader: Failed to delete file %s", file.Name())
}
}
}
}()
return fmt.Errorf("image downloader: Not all images have been downloaded, operation aborted")
}
data, err := json.Marshal(r.content)
if err != nil {
r.logger.Error().Err(err).Msg("image downloader: Failed to marshal registry content")
}
// Overwrite the registry file
err = os.WriteFile(r.registryPath, data, 0644)
if err != nil {
r.logger.Error().Err(err).Msg("image downloader: Failed to write registry content")
return err
}
return nil
}
func (r *Registry) addUrl(imgID, url, format string) {
r.mu.Lock()
defer r.mu.Unlock()
r.content.UrlToId[url] = imgID
r.content.IdToUrl[imgID] = url
r.content.IdToExt[imgID] = format
}

View File

@@ -0,0 +1,75 @@
package image_downloader
import (
"fmt"
"github.com/stretchr/testify/require"
"seanime/internal/util"
"testing"
"time"
)
func TestImageDownloader_DownloadImages(t *testing.T) {
tests := []struct {
name string
urls []string
downloadDir string
expectedNum int
cancelAfter int
}{
{
name: "test1",
urls: []string{
"https://s4.anilist.co/file/anilistcdn/media/anime/banner/153518-7uRvV7SLqmHV.jpg",
"https://s4.anilist.co/file/anilistcdn/media/anime/banner/153518-7uRvV7SLqmHV.jpg",
"https://s4.anilist.co/file/anilistcdn/media/anime/cover/medium/bx153518-LEK6pAXtI03D.jpg",
},
downloadDir: t.TempDir(),
expectedNum: 2,
cancelAfter: 0,
},
//{
// name: "test1",
// urls: []string{"https://s4.anilist.co/file/anilistcdn/media/anime/banner/153518-7uRvV7SLqmHVn.jpg"},
// downloadDir: t.TempDir(),
// cancelAfter: 0,
//},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
id := NewImageDownloader(tt.downloadDir, util.NewLogger())
if tt.cancelAfter > 0 {
go func() {
time.Sleep(time.Duration(tt.cancelAfter) * time.Second)
close(id.cancelChannel)
}()
}
fmt.Print(tt.downloadDir)
if err := id.DownloadImages(tt.urls); err != nil {
t.Errorf("ImageDownloader.DownloadImages() error = %v", err)
}
downloadedImages := make(map[string]string, 0)
for _, url := range tt.urls {
imgPath, ok := id.GetImageFilenameByUrl(url)
downloadedImages[imgPath] = imgPath
if !ok {
t.Errorf("ImageDownloader.GetImagePathByUrl() error")
} else {
t.Logf("ImageDownloader.GetImagePathByUrl() = %v", imgPath)
}
}
require.Len(t, downloadedImages, tt.expectedNum)
})
}
time.Sleep(1 * time.Second)
}

View File

@@ -0,0 +1,110 @@
package util
import (
"errors"
"io"
)
// Common errors that might occur during operations
var (
ErrInvalidOffset = errors.New("invalid offset: negative or beyond limit")
ErrInvalidWhence = errors.New("invalid whence value")
ErrReadLimit = errors.New("read would exceed limit")
)
// LimitedReadSeeker wraps an io.ReadSeeker and limits the number of bytes
// that can be read from it.
type LimitedReadSeeker struct {
rs io.ReadSeeker // The underlying ReadSeeker
offset int64 // Current read position relative to start
limit int64 // Maximum number of bytes that can be read
basePos int64 // Original position in the underlying ReadSeeker
}
// NewLimitedReadSeeker creates a new LimitedReadSeeker from the provided
// io.ReadSeeker, starting at the current position and with the given limit.
// The limit parameter specifies the maximum number of bytes that can be
// read from the underlying ReadSeeker.
func NewLimitedReadSeeker(rs io.ReadSeeker, limit int64) (*LimitedReadSeeker, error) {
if limit < 0 {
return nil, errors.New("negative limit")
}
// Get the current position
pos, err := rs.Seek(0, io.SeekCurrent)
if err != nil {
return nil, err
}
return &LimitedReadSeeker{
rs: rs,
offset: 0,
limit: limit,
basePos: pos,
}, nil
}
// Read implements the io.Reader interface.
func (lrs *LimitedReadSeeker) Read(p []byte) (n int, err error) {
if lrs.offset >= lrs.limit {
return 0, io.EOF
}
// Calculate how many bytes we can read
maxToRead := lrs.limit - lrs.offset
if int64(len(p)) > maxToRead {
p = p[:maxToRead]
}
n, err = lrs.rs.Read(p)
lrs.offset += int64(n)
return
}
// Seek implements the io.Seeker interface.
func (lrs *LimitedReadSeeker) Seek(offset int64, whence int) (int64, error) {
var absoluteOffset int64
// Calculate the absolute offset based on whence
switch whence {
case io.SeekStart:
absoluteOffset = offset
case io.SeekCurrent:
absoluteOffset = lrs.offset + offset
case io.SeekEnd:
absoluteOffset = lrs.limit + offset
default:
return 0, ErrInvalidWhence
}
// Check if the offset is valid
if absoluteOffset < 0 || absoluteOffset > lrs.limit {
return 0, ErrInvalidOffset
}
// Seek in the underlying ReadSeeker
_, err := lrs.rs.Seek(lrs.basePos+absoluteOffset, io.SeekStart)
if err != nil {
return 0, err
}
// Update our offset
lrs.offset = absoluteOffset
return absoluteOffset, nil
}
// Size returns the limit of this LimitedReadSeeker.
func (lrs *LimitedReadSeeker) Size() int64 {
return lrs.limit
}
// Remaining returns the number of bytes that can still be read.
func (lrs *LimitedReadSeeker) Remaining() int64 {
return lrs.limit - lrs.offset
}
// Reset resets the read position to the beginning of the limited section.
func (lrs *LimitedReadSeeker) Reset() error {
_, err := lrs.Seek(0, io.SeekStart)
return err
}

View File

@@ -0,0 +1,53 @@
package limiter
import (
"sync"
"time"
)
// https://stackoverflow.com/a/72452542
func NewAnilistLimiter() *Limiter {
//return NewLimiter(15*time.Second, 18)
return NewLimiter(6*time.Second, 8)
}
//----------------------------------------------------------------------------------------------------------------------
type Limiter struct {
tick time.Duration
count uint
entries []time.Time
index uint
mu sync.Mutex
}
func NewLimiter(tick time.Duration, count uint) *Limiter {
l := Limiter{
tick: tick,
count: count,
index: 0,
}
l.entries = make([]time.Time, count)
before := time.Now().Add(-2 * tick)
for i := range l.entries {
l.entries[i] = before
}
return &l
}
func (l *Limiter) Wait() {
l.mu.Lock()
defer l.mu.Unlock()
last := &l.entries[l.index]
next := last.Add(l.tick)
now := time.Now()
if now.Before(next) {
time.Sleep(next.Sub(now))
}
*last = time.Now()
l.index = l.index + 1
if l.index == l.count {
l.index = 0
}
}

View File

@@ -0,0 +1,178 @@
package util
import (
"bytes"
"fmt"
"os"
"os/signal"
"strings"
"sync"
"syscall"
"time"
"github.com/rs/zerolog/log"
"github.com/rs/zerolog"
)
const (
colorBlack = iota + 30
colorRed
colorGreen
colorYellow
colorBlue
colorMagenta
colorCyan
colorWhite
colorBold = 1
colorDarkGray = 90
unknownLevel = "???"
)
// Stores logs from all loggers. Used to write logs to a file when WriteGlobalLogBufferToFile is called.
// It is reset after writing to a file.
var logBuffer bytes.Buffer
var logBufferMutex = &sync.Mutex{}
func NewLogger() *zerolog.Logger {
timeFormat := fmt.Sprintf("%s", time.DateTime)
fieldsOrder := []string{"method", "status", "error", "uri", "latency_human"}
fieldsExclude := []string{"host", "latency", "referer", "remote_ip", "user_agent", "bytes_in", "bytes_out", "file"}
// Set up logger
consoleOutput := zerolog.ConsoleWriter{
Out: os.Stdout,
TimeFormat: timeFormat,
FormatLevel: ZerologFormatLevelPretty,
FormatMessage: ZerologFormatMessagePretty,
FieldsExclude: fieldsExclude,
FieldsOrder: fieldsOrder,
}
fileOutput := zerolog.ConsoleWriter{
Out: &logBuffer,
TimeFormat: timeFormat,
FormatMessage: ZerologFormatMessageSimple,
FormatLevel: ZerologFormatLevelSimple,
NoColor: true, // Needed to prevent color codes from being written to the file
FieldsExclude: fieldsExclude,
FieldsOrder: fieldsOrder,
}
multi := zerolog.MultiLevelWriter(consoleOutput, fileOutput)
logger := zerolog.New(multi).With().Timestamp().Logger()
return &logger
}
func WriteGlobalLogBufferToFile(file *os.File) {
defer HandlePanicInModuleThen("util/WriteGlobalLogBufferToFile", func() {})
if file == nil {
return
}
logBufferMutex.Lock()
defer logBufferMutex.Unlock()
if _, err := logBuffer.WriteTo(file); err != nil {
fmt.Print("Failed to write log buffer to file")
}
logBuffer.Reset()
}
func SetupLoggerSignalHandling(file *os.File) {
if file == nil {
return
}
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
go func() {
sig := <-sigChan
log.Trace().Msgf("Received signal: %s", sig)
// Flush log buffer to the log file when the app exits
WriteGlobalLogBufferToFile(file)
_ = file.Close()
os.Exit(0)
}()
}
func ZerologFormatMessagePretty(i interface{}) string {
if msg, ok := i.(string); ok {
if bytes.ContainsRune([]byte(msg), ':') {
parts := strings.SplitN(msg, ":", 2)
if len(parts) > 1 {
return colorizeb(parts[0], colorCyan) + colorizeb(" >", colorDarkGray) + parts[1]
}
}
return msg
}
return ""
}
func ZerologFormatMessageSimple(i interface{}) string {
if msg, ok := i.(string); ok {
if bytes.ContainsRune([]byte(msg), ':') {
parts := strings.SplitN(msg, ":", 2)
if len(parts) > 1 {
return parts[0] + " >" + parts[1]
}
}
return msg
}
return ""
}
func ZerologFormatLevelPretty(i interface{}) string {
if ll, ok := i.(string); ok {
s := strings.ToLower(ll)
switch s {
case "debug":
s = "DBG" + colorizeb(" -", colorDarkGray)
case "info":
s = fmt.Sprint(colorizeb("INF", colorBold)) + colorizeb(" -", colorDarkGray)
case "warn":
s = colorizeb("WRN", colorYellow) + colorizeb(" -", colorDarkGray)
case "trace":
s = colorizeb("TRC", colorDarkGray) + colorizeb(" -", colorDarkGray)
case "error":
s = colorizeb("ERR", colorRed) + colorizeb(" -", colorDarkGray)
case "fatal":
s = colorizeb("FTL", colorRed) + colorizeb(" -", colorDarkGray)
case "panic":
s = colorizeb("PNC", colorRed) + colorizeb(" -", colorDarkGray)
}
return fmt.Sprint(s)
}
return ""
}
func ZerologFormatLevelSimple(i interface{}) string {
if ll, ok := i.(string); ok {
s := strings.ToLower(ll)
switch s {
case "debug":
s = "|DBG|"
case "info":
s = "|INF|"
case "warn":
s = "|WRN|"
case "trace":
s = "|TRC|"
case "error":
s = "|ERR|"
case "fatal":
s = "|FTL|"
case "panic":
s = "|PNC|"
}
return fmt.Sprint(s)
}
return ""
}
func colorizeb(s interface{}, c int) string {
return fmt.Sprintf("\x1b[%dm%v\x1b[0m", c, s)
}

View File

@@ -0,0 +1,94 @@
package util
import (
"sync"
)
// https://sreramk.medium.com/go-inside-sync-map-how-does-sync-map-work-internally-97e87b8e6bf
// RWMutexMap is an implementation of mapInterface using a sync.RWMutex.
type RWMutexMap struct {
mu sync.RWMutex
dirty map[interface{}]interface{}
}
func (m *RWMutexMap) Load(key interface{}) (value interface{}, ok bool) {
m.mu.RLock()
value, ok = m.dirty[key]
m.mu.RUnlock()
return
}
func (m *RWMutexMap) Store(key, value interface{}) {
m.mu.Lock()
if m.dirty == nil {
m.dirty = make(map[interface{}]interface{})
}
m.dirty[key] = value
m.mu.Unlock()
}
func (m *RWMutexMap) LoadOrStore(key, value interface{}) (actual interface{}, loaded bool) {
m.mu.Lock()
actual, loaded = m.dirty[key]
if !loaded {
actual = value
if m.dirty == nil {
m.dirty = make(map[interface{}]interface{})
}
m.dirty[key] = value
}
m.mu.Unlock()
return actual, loaded
}
func (m *RWMutexMap) LoadAndDelete(key interface{}) (value interface{}, loaded bool) {
m.mu.Lock()
value, loaded = m.dirty[key]
if !loaded {
m.mu.Unlock()
return nil, false
}
delete(m.dirty, key)
m.mu.Unlock()
return value, loaded
}
func (m *RWMutexMap) Delete(key interface{}) {
m.mu.Lock()
delete(m.dirty, key)
m.mu.Unlock()
}
func (m *RWMutexMap) Range(f func(key, value interface{}) (shouldContinue bool)) {
m.mu.RLock()
keys := make([]interface{}, 0, len(m.dirty))
for k := range m.dirty {
keys = append(keys, k)
}
m.mu.RUnlock()
for _, k := range keys {
v, ok := m.Load(k)
if !ok {
continue
}
if !f(k, v) {
break
}
}
}
// MapInterface is the interface Map implements.
type MapInterface interface {
Load(interface{}) (interface{}, bool)
Store(key, value interface{})
LoadOrStore(key, value interface{}) (actual interface{}, loaded bool)
LoadAndDelete(key interface{}) (value interface{}, loaded bool)
Delete(interface{})
Range(func(key, value interface{}) (shouldContinue bool))
}
func NewRWMutexMap() MapInterface {
return &RWMutexMap{}
}

View File

@@ -0,0 +1,30 @@
package util
import (
"fmt"
"os/exec"
"runtime"
"strings"
)
func GetMemAddrStr(v interface{}) string {
return fmt.Sprintf("%p", v)
}
func ProgramIsRunning(name string) bool {
var cmd *exec.Cmd
switch runtime.GOOS {
case "windows":
cmd = NewCmd("tasklist")
case "linux":
cmd = NewCmd("pgrep", name)
case "darwin":
cmd = NewCmd("pgrep", name)
default:
return false
}
output, _ := cmd.Output()
return strings.Contains(string(output), name)
}

View File

@@ -0,0 +1,93 @@
package util
import (
"math"
"strconv"
"strings"
)
func StringToInt(str string) (int, bool) {
dotIndex := strings.IndexByte(str, '.')
if dotIndex != -1 {
str = str[:dotIndex]
}
i, err := strconv.Atoi(str)
if err != nil {
return 0, false
}
return i, true
}
func StringToIntMust(str string) int {
dotIndex := strings.IndexByte(str, '.')
if dotIndex != -1 {
str = str[:dotIndex]
}
i, err := strconv.Atoi(str)
if err != nil {
return 0
}
return i
}
func IntegerToRoman(number int) string {
maxRomanNumber := 3999
if number > maxRomanNumber {
return strconv.Itoa(number)
}
conversions := []struct {
value int
digit string
}{
{1000, "M"},
{900, "CM"},
{500, "D"},
{400, "CD"},
{100, "C"},
{90, "XC"},
{50, "L"},
{40, "XL"},
{10, "X"},
{9, "IX"},
{5, "V"},
{4, "IV"},
{1, "I"},
}
var roman strings.Builder
for _, conversion := range conversions {
for number >= conversion.value {
roman.WriteString(conversion.digit)
number -= conversion.value
}
}
return roman.String()
}
// Ordinal returns the ordinal string for a specific integer.
func toOrdinal(number int) string {
absNumber := int(math.Abs(float64(number)))
i := absNumber % 100
if i == 11 || i == 12 || i == 13 {
return "th"
}
switch absNumber % 10 {
case 1:
return "st"
case 2:
return "nd"
case 3:
return "rd"
default:
return "th"
}
}
// IntegerToOrdinal the number by adding the Ordinal to the number.
func IntegerToOrdinal(number int) string {
return strconv.Itoa(number) + toOrdinal(number)
}

View File

@@ -0,0 +1,59 @@
package util
import (
"fmt"
"os"
"path/filepath"
"runtime"
)
func DownloadDir() (string, error) {
return userDir("Downloads")
}
func DesktopDir() (string, error) {
return userDir("Desktop")
}
func DocumentsDir() (string, error) {
return userDir("Documents")
}
// userDir returns the path to the specified user directory (Desktop or Documents).
func userDir(dirType string) (string, error) {
var dir string
home, err := os.UserHomeDir()
if err != nil {
return "", err
}
switch runtime.GOOS {
case "windows":
dir = filepath.Join(home, dirType)
case "darwin":
dir = filepath.Join(home, dirType)
case "linux":
// Linux: Use $XDG_DESKTOP_DIR / $XDG_DOCUMENTS_DIR / $XDG_DOWNLOAD_DIR if set, otherwise default
envVar := ""
if dirType == "Desktop" {
envVar = os.Getenv("XDG_DESKTOP_DIR")
} else if dirType == "Documents" {
envVar = os.Getenv("XDG_DOCUMENTS_DIR")
} else if dirType == "Downloads" {
envVar = os.Getenv("XDG_DOWNLOAD_DIR")
}
if envVar != "" {
dir = envVar
} else {
dir = filepath.Join(home, dirType)
}
default:
return "", fmt.Errorf("unsupported platform: %s", runtime.GOOS)
}
return dir, nil
}

View File

@@ -0,0 +1,73 @@
package util
import (
"errors"
"github.com/rs/zerolog/log"
"runtime/debug"
"sync"
)
var printLock = sync.Mutex{}
func printRuntimeError(r any, module string) string {
printLock.Lock()
debugStr := string(debug.Stack())
logger := NewLogger()
log.Error().Msgf("go: PANIC RECOVERY")
if module != "" {
log.Error().Msgf("go: Runtime error in \"%s\"", module)
}
log.Error().Msgf("go: A runtime error occurred, please send the logs to the developer\n")
log.Printf("go: ========================================= Stack Trace =========================================\n")
logger.Error().Msgf("%+v\n\n%+v", r, debugStr)
log.Printf("go: ===================================== End of Stack Trace ======================================\n")
printLock.Unlock()
return debugStr
}
func HandlePanicWithError(err *error) {
if r := recover(); r != nil {
*err = errors.New("fatal error occurred, please report this issue")
printRuntimeError(r, "")
}
}
func HandlePanicInModuleWithError(module string, err *error) {
if r := recover(); r != nil {
*err = errors.New("fatal error occurred, please report this issue")
printRuntimeError(r, module)
}
}
func HandlePanicThen(f func()) {
if r := recover(); r != nil {
f()
printRuntimeError(r, "")
}
}
func HandlePanicInModuleThen(module string, f func()) {
if r := recover(); r != nil {
f()
printRuntimeError(r, module)
}
}
func HandlePanicInModuleThenS(module string, f func(stackTrace string)) {
if r := recover(); r != nil {
str := printRuntimeError(r, module)
f(str)
}
}
func Recover() {
if r := recover(); r != nil {
printRuntimeError(r, "")
}
}
func RecoverInModule(module string) {
if r := recover(); r != nil {
printRuntimeError(r, module)
}
}

View File

@@ -0,0 +1,55 @@
package util
import "testing"
func TestHandlePanicInModuleThen(t *testing.T) {
type testStruct struct {
mediaId int
}
testDangerousWork := func(obj *testStruct, work func()) {
defer HandlePanicInModuleThen("util/panic_test", func() {
obj.mediaId = 0
})
work()
}
var testCases = []struct {
name string
obj testStruct
work func()
expectedMediaId int
}{
{
name: "Test 1",
obj: testStruct{mediaId: 1},
work: func() {
panic("Test 1")
},
expectedMediaId: 0,
},
{
name: "Test 2",
obj: testStruct{mediaId: 2},
work: func() {
// Do nothing
},
expectedMediaId: 2,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
testDangerousWork(&tc.obj, tc.work)
if tc.obj.mediaId != tc.expectedMediaId {
t.Errorf("Expected mediaId to be %d, got %d", tc.expectedMediaId, tc.obj.mediaId)
}
})
}
}

View File

@@ -0,0 +1,96 @@
package parallel
import (
"github.com/samber/lo"
"seanime/internal/util/limiter"
"sync"
)
// EachTask iterates over elements of collection and invokes the task function for each element.
// `task` is called in parallel.
func EachTask[T any](collection []T, task func(item T, index int)) {
var wg sync.WaitGroup
for i, item := range collection {
wg.Add(1)
go func(_item T, _i int) {
defer wg.Done()
task(_item, _i)
}(item, i)
}
wg.Wait()
}
// EachTaskL is the same as EachTask, but takes a pointer to limiter.Limiter.
func EachTaskL[T any](collection []T, rl *limiter.Limiter, task func(item T, index int)) {
var wg sync.WaitGroup
for i, item := range collection {
wg.Add(1)
go func(_item T, _i int) {
defer wg.Done()
rl.Wait()
task(_item, _i)
}(item, i)
}
wg.Wait()
}
type SettledResults[T comparable, R any] struct {
Collection []T
Fulfilled map[T]R
Results []R
Rejected map[T]error
}
// NewSettledResults returns a pointer to a new SettledResults struct.
func NewSettledResults[T comparable, R any](c []T) *SettledResults[T, R] {
return &SettledResults[T, R]{
Collection: c,
Fulfilled: map[T]R{},
Rejected: map[T]error{},
}
}
// GetFulfilledResults returns a pointer to the slice of fulfilled results and a boolean indicating whether the slice is not nil.
func (sr *SettledResults[T, R]) GetFulfilledResults() (*[]R, bool) {
if sr.Results != nil {
return &sr.Results, true
}
return nil, false
}
// AllSettled executes the provided task function once, in parallel for each element in the slice passed to NewSettledResults.
// It returns a map of fulfilled results and a map of errors whose keys are the elements of the slice.
func (sr *SettledResults[T, R]) AllSettled(task func(item T, index int) (R, error)) (map[T]R, map[T]error) {
var wg sync.WaitGroup
var mu sync.Mutex
for i, item := range sr.Collection {
wg.Add(1)
go func(_item T, _i int) {
res, err := task(_item, _i)
mu.Lock()
if err != nil {
sr.Rejected[_item] = err
} else {
sr.Fulfilled[_item] = res
}
mu.Unlock()
wg.Done()
}(item, i)
}
wg.Wait()
sr.Results = lo.MapToSlice(sr.Fulfilled, func(key T, value R) R {
return value
})
return sr.Fulfilled, sr.Rejected
}

View File

@@ -0,0 +1,78 @@
package parallel
import (
"fmt"
"github.com/sourcegraph/conc/pool"
"github.com/sourcegraph/conc/stream"
"testing"
"time"
)
func fakeAPICall(id int) (int, error) {
//time.Sleep(time.Millisecond * time.Duration(100+rand.Intn(500)))
time.Sleep(time.Millisecond * 200)
return id, nil
}
func TestAllSettled(t *testing.T) {
ids := []int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30}
sr := NewSettledResults[int, int](ids)
sr.AllSettled(func(item int, index int) (int, error) {
return fakeAPICall(item)
})
fulfilled, ok := sr.GetFulfilledResults()
if !ok {
t.Error("expected results, got error")
}
for _, v := range *fulfilled {
t.Log(v)
}
}
func TestConc(t *testing.T) {
ids := []int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30}
fetch := func(ids []int) ([]int, error) {
p := pool.NewWithResults[int]().WithErrors()
for _, id := range ids {
id := id
p.Go(func() (int, error) {
return fakeAPICall(id)
})
}
return p.Wait()
}
res, _ := fetch(ids)
for _, v := range res {
t.Log(v)
}
}
func TestConcStream(t *testing.T) {
ids := []int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30}
strm := stream.New()
for _, id := range ids {
id := id
strm.Go(func() stream.Callback {
res, err := fakeAPICall(id)
// This will print in the order the tasks were submitted
return func() {
fmt.Println(res, err)
}
})
}
strm.Wait()
}

View File

@@ -0,0 +1,59 @@
package util
import (
"regexp"
"strings"
)
func ExtractSeasonNumber(title string) (int, string) {
title = strings.ToLower(title)
rgx := regexp.MustCompile(`((?P<a>\d+)(st|nd|rd|th)?\s*(season))|((season)\s*(?P<b>\d+))`)
matches := rgx.FindStringSubmatch(title)
if len(matches) < 1 {
return 0, title
}
m := matches[rgx.SubexpIndex("a")]
if m == "" {
m = matches[rgx.SubexpIndex("b")]
}
if m == "" {
return 0, title
}
ret, ok := StringToInt(m)
if !ok {
return 0, title
}
cTitle := strings.TrimSpace(rgx.ReplaceAllString(title, ""))
return ret, cTitle
}
func ExtractPartNumber(title string) (int, string) {
title = strings.ToLower(title)
rgx := regexp.MustCompile(`((?P<a>\d+)(st|nd|rd|th)?\s*(cour|part))|((cour|part)\s*(?P<b>\d+))`)
matches := rgx.FindStringSubmatch(title)
if len(matches) < 1 {
return 0, title
}
m := matches[rgx.SubexpIndex("a")]
if m == "" {
m = matches[rgx.SubexpIndex("b")]
}
if m == "" {
return 0, title
}
ret, ok := StringToInt(m)
if !ok {
return 0, title
}
cTitle := strings.TrimSpace(rgx.ReplaceAllString(title, ""))
return ret, cTitle
}

View File

@@ -0,0 +1,25 @@
package util
import "sync"
type Pool[T any] struct {
sync.Pool
}
func (p *Pool[T]) Get() T {
return p.Pool.Get().(T)
}
func (p *Pool[T]) Put(x T) {
p.Pool.Put(x)
}
func NewPool[T any](newF func() T) *Pool[T] {
return &Pool[T]{
Pool: sync.Pool{
New: func() interface{} {
return newF()
},
},
}
}

View File

@@ -0,0 +1,70 @@
package util
import (
"encoding/json"
"io"
"net/http"
"seanime/internal/util"
"github.com/labstack/echo/v4"
)
type ImageProxy struct{}
func (ip *ImageProxy) GetImage(url string, headers map[string]string) ([]byte, error) {
client := &http.Client{}
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return nil, err
}
for key, value := range headers {
req.Header.Add(key, value)
}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
return body, nil
}
func (ip *ImageProxy) setHeaders(c echo.Context) {
c.Set("Content-Type", "image/jpeg")
c.Set("Cache-Control", "public, max-age=31536000")
c.Set("Access-Control-Allow-Origin", "*")
c.Set("Access-Control-Allow-Methods", "GET")
c.Set("Access-Control-Allow-Headers", "Origin, X-Requested-With, Content-Type, Accept")
c.Set("Access-Control-Allow-Credentials", "true")
}
func (ip *ImageProxy) ProxyImage(c echo.Context) (err error) {
defer util.HandlePanicInModuleWithError("util/ImageProxy", &err)
url := c.QueryParam("url")
headersJSON := c.QueryParam("headers")
if url == "" || headersJSON == "" {
return c.String(echo.ErrBadRequest.Code, "No URL provided")
}
headers := make(map[string]string)
if err := json.Unmarshal([]byte(headersJSON), &headers); err != nil {
return c.String(echo.ErrBadRequest.Code, "Error parsing headers JSON")
}
ip.setHeaders(c)
imageBuffer, err := ip.GetImage(url, headers)
if err != nil {
return c.String(echo.ErrInternalServerError.Code, "Error fetching image")
}
return c.Blob(http.StatusOK, c.Response().Header().Get("Content-Type"), imageBuffer)
}

View File

@@ -0,0 +1,284 @@
package util
import (
"bytes"
"io"
"net/http"
url2 "net/url"
"seanime/internal/util"
"strconv"
"strings"
"time"
"github.com/Eyevinn/hls-m3u8/m3u8"
"github.com/goccy/go-json"
"github.com/labstack/echo/v4"
"github.com/rs/zerolog/log"
)
var proxyUA = util.GetRandomUserAgent()
var videoProxyClient = &http.Client{
Transport: &http.Transport{
MaxIdleConns: 100,
MaxIdleConnsPerHost: 10,
IdleConnTimeout: 90 * time.Second,
ForceAttemptHTTP2: false, // Fixes issues on Linux
},
Timeout: 60 * time.Second,
}
func VideoProxy(c echo.Context) (err error) {
defer util.HandlePanicInModuleWithError("util/VideoProxy", &err)
url := c.QueryParam("url")
headers := c.QueryParam("headers")
// Always use GET request internally, even for HEAD requests
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
log.Error().Err(err).Msg("proxy: Error creating request")
return echo.NewHTTPError(http.StatusInternalServerError)
}
var headerMap map[string]string
if headers != "" {
if err := json.Unmarshal([]byte(headers), &headerMap); err != nil {
log.Error().Err(err).Msg("proxy: Error unmarshalling headers")
return echo.NewHTTPError(http.StatusInternalServerError)
}
for key, value := range headerMap {
req.Header.Set(key, value)
}
}
req.Header.Set("User-Agent", proxyUA)
req.Header.Set("Accept", "*/*")
if rangeHeader := c.Request().Header.Get("Range"); rangeHeader != "" {
req.Header.Set("Range", rangeHeader)
}
resp, err := videoProxyClient.Do(req)
if err != nil {
log.Error().Err(err).Msg("proxy: Error sending request")
return echo.NewHTTPError(http.StatusInternalServerError)
}
defer resp.Body.Close()
// Copy response headers
for k, vs := range resp.Header {
for _, v := range vs {
if !strings.EqualFold(k, "Content-Length") { // Skip Content-Length header, fixes net::ERR_CONTENT_LENGTH_MISMATCH
c.Response().Header().Set(k, v)
}
}
}
// Set CORS headers
c.Response().Header().Set("Access-Control-Allow-Origin", "*")
c.Response().Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
c.Response().Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type")
// For HEAD requests, return only headers
if c.Request().Method == http.MethodHead {
return c.NoContent(http.StatusOK)
}
isHlsPlaylist := strings.HasSuffix(url, ".m3u8") || strings.Contains(resp.Header.Get("Content-Type"), "mpegurl")
if !isHlsPlaylist {
return c.Stream(resp.StatusCode, c.Response().Header().Get("Content-Type"), resp.Body)
}
// HLS Playlist
//log.Debug().Str("url", url).Msg("proxy: Processing HLS playlist")
bodyBytes, readErr := io.ReadAll(resp.Body)
if readErr != nil {
log.Error().Err(readErr).Str("url", url).Msg("proxy: Error reading HLS response body")
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to read HLS playlist")
}
buffer := bytes.NewBuffer(bodyBytes)
playlist, listType, decodeErr := m3u8.Decode(*buffer, true)
if decodeErr != nil {
// Playlist might be valid but not decodable by the library, or simply corrupted.
// Option 1: Proxy as-is (might be preferred if decoding fails unexpectedly)
log.Warn().Err(decodeErr).Str("url", url).Msg("proxy: Failed to decode M3U8 playlist, proxying raw content")
c.Response().Header().Set(echo.HeaderContentType, resp.Header.Get("Content-Type")) // Use original Content-Type
c.Response().Header().Set(echo.HeaderContentLength, strconv.Itoa(len(bodyBytes)))
c.Response().WriteHeader(resp.StatusCode)
_, writeErr := c.Response().Writer.Write(bodyBytes)
return writeErr
}
var modifiedPlaylistBytes []byte
needsRewrite := false // Flag to check if we actually need to rewrite
if listType == m3u8.MEDIA {
mediaPl := playlist.(*m3u8.MediaPlaylist)
baseURL, _ := url2.Parse(url) // Base URL for resolving relative paths
for _, segment := range mediaPl.Segments {
if segment != nil {
// Rewrite Segment URI
if !isAlreadyProxied(segment.URI) {
if segment.URI != "" {
if !strings.HasPrefix(segment.URI, "http") {
segment.URI = resolveURL(baseURL, segment.URI)
}
segment.URI = rewriteProxyURL(segment.URI, headerMap)
needsRewrite = true
}
}
// Rewrite encryption key URIs
for i, key := range segment.Keys {
if key.URI != "" {
if !isAlreadyProxied(key.URI) {
keyURI := key.URI
if !strings.HasPrefix(key.URI, "http") {
keyURI = resolveURL(baseURL, key.URI)
}
segment.Keys[i].URI = rewriteProxyURL(keyURI, headerMap)
needsRewrite = true
}
}
}
}
}
// Rewrite playlist-level encryption key URIs
for i, key := range mediaPl.Keys {
if key.URI != "" {
if !isAlreadyProxied(key.URI) {
keyURI := key.URI
if !strings.HasPrefix(key.URI, "http") {
keyURI = resolveURL(baseURL, key.URI)
}
mediaPl.Keys[i].URI = rewriteProxyURL(keyURI, headerMap)
needsRewrite = true
}
}
}
// Encode the modified media playlist
buffer := mediaPl.Encode()
modifiedPlaylistBytes = buffer.Bytes()
} else if listType == m3u8.MASTER {
// Rewrite URIs in Master playlists
masterPl := playlist.(*m3u8.MasterPlaylist)
baseURL, _ := url2.Parse(url) // Base URL for resolving relative paths
for _, variant := range masterPl.Variants {
if variant != nil && variant.URI != "" {
if !isAlreadyProxied(variant.URI) {
variantURI := variant.URI
if !strings.HasPrefix(variant.URI, "http") {
variantURI = resolveURL(baseURL, variant.URI)
}
variant.URI = rewriteProxyURL(variantURI, headerMap)
needsRewrite = true
}
}
// Handle alternative media groups (audio, subtitles, etc.) for each variant
if variant != nil {
for _, alternative := range variant.Alternatives {
if alternative != nil && alternative.URI != "" {
if !isAlreadyProxied(alternative.URI) {
alternativeURI := alternative.URI
if !strings.HasPrefix(alternative.URI, "http") {
alternativeURI = resolveURL(baseURL, alternative.URI)
}
alternative.URI = rewriteProxyURL(alternativeURI, headerMap)
needsRewrite = true
}
}
}
}
}
allAlternatives := masterPl.GetAllAlternatives()
for _, alternative := range allAlternatives {
if alternative != nil && alternative.URI != "" {
if !isAlreadyProxied(alternative.URI) {
alternativeURI := alternative.URI
if !strings.HasPrefix(alternative.URI, "http") {
alternativeURI = resolveURL(baseURL, alternative.URI)
}
alternative.URI = rewriteProxyURL(alternativeURI, headerMap)
needsRewrite = true
}
}
}
// Rewrite session key URIs
for i, sessionKey := range masterPl.SessionKeys {
if sessionKey.URI != "" {
if !isAlreadyProxied(sessionKey.URI) {
sessionKeyURI := sessionKey.URI
if !strings.HasPrefix(sessionKey.URI, "http") {
sessionKeyURI = resolveURL(baseURL, sessionKey.URI)
}
masterPl.SessionKeys[i].URI = rewriteProxyURL(sessionKeyURI, headerMap)
needsRewrite = true
}
}
}
// Encode the modified master playlist
buffer := masterPl.Encode()
modifiedPlaylistBytes = buffer.Bytes()
} else {
// Unknown type, pass through
modifiedPlaylistBytes = bodyBytes
}
// Set headers *after* potential modification
contentType := "application/vnd.apple.mpegurl"
c.Response().Header().Set(echo.HeaderContentType, contentType)
// Set Content-Length based on the *modified* playlist
c.Response().Header().Set(echo.HeaderContentLength, strconv.Itoa(len(modifiedPlaylistBytes)))
// Set Cache-Control headers appropriate for playlists (often no-cache for live)
if resp.Header.Get("Cache-Control") == "" {
c.Response().Header().Set("Cache-Control", "no-cache")
}
log.Debug().Bool("rewritten", needsRewrite).Str("url", url).Msg("proxy: Sending modified HLS playlist")
c.Response().WriteHeader(resp.StatusCode)
return c.Blob(http.StatusOK, c.Response().Header().Get("Content-Type"), modifiedPlaylistBytes)
}
func resolveURL(base *url2.URL, relativeURI string) string {
if base == nil {
return relativeURI // Cannot resolve without a base
}
relativeURL, err := url2.Parse(relativeURI)
if err != nil {
return relativeURI // Invalid relative URI
}
return base.ResolveReference(relativeURL).String()
}
func rewriteProxyURL(targetMediaURL string, headerMap map[string]string) string {
proxyURL := "/api/v1/proxy?url=" + url2.QueryEscape(targetMediaURL)
if len(headerMap) > 0 {
headersStrB, err := json.Marshal(headerMap)
// Ignore marshalling errors here? Or log them? For simplicity, ignoring now.
if err == nil && len(headersStrB) > 2 { // Check > 2 for "{}" empty map
proxyURL += "&headers=" + url2.QueryEscape(string(headersStrB))
}
}
return proxyURL
}
func isAlreadyProxied(url string) bool {
// Check if the URL contains the proxy pattern
return strings.Contains(url, "/api/v1/proxy?url=") || strings.Contains(url, url2.QueryEscape("/api/v1/proxy?url="))
}

View File

@@ -0,0 +1,12 @@
package util
import "regexp"
// MatchesRegex checks if a string matches a regex pattern
func MatchesRegex(str, pattern string) (bool, error) {
re, err := regexp.Compile(pattern)
if err != nil {
return false, err
}
return re.MatchString(str), nil
}

View File

@@ -0,0 +1,199 @@
package result
import (
"container/list"
"sync"
"time"
)
// BoundedCache implements an LRU cache with a maximum capacity
type BoundedCache[K comparable, V any] struct {
mu sync.RWMutex
capacity int
items map[K]*list.Element
order *list.List
}
type boundedCacheItem[K comparable, V any] struct {
key K
value V
expiration time.Time
}
// NewBoundedCache creates a new bounded cache with the specified capacity
func NewBoundedCache[K comparable, V any](capacity int) *BoundedCache[K, V] {
return &BoundedCache[K, V]{
capacity: capacity,
items: make(map[K]*list.Element),
order: list.New(),
}
}
// Set adds or updates an item in the cache with a default TTL
func (c *BoundedCache[K, V]) Set(key K, value V) {
c.SetT(key, value, time.Hour) // Default TTL of 1 hour
}
// SetT adds or updates an item in the cache with a specific TTL
func (c *BoundedCache[K, V]) SetT(key K, value V, ttl time.Duration) {
c.mu.Lock()
defer c.mu.Unlock()
expiration := time.Now().Add(ttl)
item := &boundedCacheItem[K, V]{
key: key,
value: value,
expiration: expiration,
}
// If key already exists, update it and move to front
if elem, exists := c.items[key]; exists {
elem.Value = item
c.order.MoveToFront(elem)
return
}
// If at capacity, remove oldest item
if len(c.items) >= c.capacity {
c.evictOldest()
}
// Add new item to front
elem := c.order.PushFront(item)
c.items[key] = elem
// Set up expiration cleanup
go func() {
<-time.After(ttl)
c.Delete(key)
}()
}
// Get retrieves an item from the cache and marks it as recently used
func (c *BoundedCache[K, V]) Get(key K) (V, bool) {
c.mu.Lock()
defer c.mu.Unlock()
var zero V
elem, exists := c.items[key]
if !exists {
return zero, false
}
item := elem.Value.(*boundedCacheItem[K, V])
// Check if expired
if time.Now().After(item.expiration) {
c.delete(key)
return zero, false
}
// Move to front (mark as recently used)
c.order.MoveToFront(elem)
return item.value, true
}
// Has checks if a key exists in the cache without updating access time
func (c *BoundedCache[K, V]) Has(key K) bool {
c.mu.RLock()
defer c.mu.RUnlock()
elem, exists := c.items[key]
if !exists {
return false
}
item := elem.Value.(*boundedCacheItem[K, V])
if time.Now().After(item.expiration) {
return false
}
return true
}
// Delete removes an item from the cache
func (c *BoundedCache[K, V]) Delete(key K) {
c.mu.Lock()
defer c.mu.Unlock()
c.delete(key)
}
// delete removes an item from the cache (internal, assumes lock is held)
func (c *BoundedCache[K, V]) delete(key K) {
if elem, exists := c.items[key]; exists {
c.order.Remove(elem)
delete(c.items, key)
}
}
// Clear removes all items from the cache
func (c *BoundedCache[K, V]) Clear() {
c.mu.Lock()
defer c.mu.Unlock()
c.items = make(map[K]*list.Element)
c.order.Init()
}
// GetOrSet retrieves an item or creates it if it doesn't exist
func (c *BoundedCache[K, V]) GetOrSet(key K, createFunc func() (V, error)) (V, error) {
// Try to get the value first
value, ok := c.Get(key)
if ok {
return value, nil
}
// Create new value
newValue, err := createFunc()
if err != nil {
return newValue, err
}
// Set the new value
c.Set(key, newValue)
return newValue, nil
}
// Size returns the current number of items in the cache
func (c *BoundedCache[K, V]) Size() int {
c.mu.RLock()
defer c.mu.RUnlock()
return len(c.items)
}
// Capacity returns the maximum capacity of the cache
func (c *BoundedCache[K, V]) Capacity() int {
return c.capacity
}
// evictOldest removes the least recently used item (assumes lock is held)
func (c *BoundedCache[K, V]) evictOldest() {
if c.order.Len() == 0 {
return
}
elem := c.order.Back()
if elem != nil {
item := elem.Value.(*boundedCacheItem[K, V])
c.delete(item.key)
}
}
// Range iterates over all items in the cache
func (c *BoundedCache[K, V]) Range(callback func(key K, value V) bool) {
c.mu.RLock()
defer c.mu.RUnlock()
for elem := c.order.Front(); elem != nil; elem = elem.Next() {
item := elem.Value.(*boundedCacheItem[K, V])
// Skip expired items
if time.Now().After(item.expiration) {
continue
}
if !callback(item.key, item.value) {
break
}
}
}

View File

@@ -0,0 +1,101 @@
package result
import (
"seanime/internal/constants"
"seanime/internal/util"
"time"
)
type Cache[K interface{}, V any] struct {
store util.RWMutexMap
}
type cacheItem[K interface{}, V any] struct {
value V
expiration time.Time
}
func NewCache[K interface{}, V any]() *Cache[K, V] {
return &Cache[K, V]{}
}
func (c *Cache[K, V]) Set(key K, value V) {
ttl := constants.GcTime
c.store.Store(key, &cacheItem[K, V]{value, time.Now().Add(ttl)})
go func() {
<-time.After(ttl)
c.Delete(key)
}()
}
func (c *Cache[K, V]) SetT(key K, value V, ttl time.Duration) {
c.store.Store(key, &cacheItem[K, V]{value, time.Now().Add(ttl)})
go func() {
<-time.After(ttl)
c.Delete(key)
}()
}
func (c *Cache[K, V]) Get(key K) (V, bool) {
item, ok := c.store.Load(key)
if !ok {
return (&cacheItem[K, V]{}).value, false
}
ci := item.(*cacheItem[K, V])
if time.Now().After(ci.expiration) {
c.Delete(key)
return (&cacheItem[K, V]{}).value, false
}
return ci.value, true
}
func (c *Cache[K, V]) Pop() (K, V, bool) {
var key K
var value V
var ok bool
c.store.Range(func(k, v interface{}) bool {
key = k.(K)
value = v.(*cacheItem[K, V]).value
ok = true
c.store.Delete(k)
return false
})
return key, value, ok
}
func (c *Cache[K, V]) Has(key K) bool {
_, ok := c.store.Load(key)
return ok
}
func (c *Cache[K, V]) GetOrSet(key K, createFunc func() (V, error)) (V, error) {
value, ok := c.Get(key)
if ok {
return value, nil
}
newValue, err := createFunc()
if err != nil {
return newValue, err
}
c.Set(key, newValue)
return newValue, nil
}
func (c *Cache[K, V]) Delete(key K) {
c.store.Delete(key)
}
func (c *Cache[K, V]) Clear() {
c.store.Range(func(key interface{}, value interface{}) bool {
c.store.Delete(key)
return true
})
}
func (c *Cache[K, V]) Range(callback func(key K, value V) bool) {
c.store.Range(func(key, value interface{}) bool {
ci := value.(*cacheItem[K, V])
return callback(key.(K), ci.value)
})
}

View File

@@ -0,0 +1,97 @@
package result
import (
"seanime/internal/util"
)
type Map[K interface{}, V any] struct {
store util.RWMutexMap
}
type mapItem[K interface{}, V any] struct {
value V
}
func NewResultMap[K interface{}, V any]() *Map[K, V] {
return &Map[K, V]{}
}
func (c *Map[K, V]) Set(key K, value V) {
c.store.Store(key, &mapItem[K, V]{value})
}
func (c *Map[K, V]) Get(key K) (V, bool) {
item, ok := c.store.Load(key)
if !ok {
return (&mapItem[K, V]{}).value, false
}
ci := item.(*mapItem[K, V])
return ci.value, true
}
func (c *Map[K, V]) Has(key K) bool {
_, ok := c.store.Load(key)
return ok
}
func (c *Map[K, V]) GetOrSet(key K, createFunc func() (V, error)) (V, error) {
value, ok := c.Get(key)
if ok {
return value, nil
}
newValue, err := createFunc()
if err != nil {
return newValue, err
}
c.Set(key, newValue)
return newValue, nil
}
func (c *Map[K, V]) Delete(key K) {
c.store.Delete(key)
}
func (c *Map[K, V]) Clear() {
c.store.Range(func(key interface{}, value interface{}) bool {
c.store.Delete(key)
return true
})
}
// ClearN clears the map and returns the number of items cleared
func (c *Map[K, V]) ClearN() int {
count := 0
c.store.Range(func(key interface{}, value interface{}) bool {
c.store.Delete(key)
count++
return true
})
return count
}
func (c *Map[K, V]) Range(callback func(key K, value V) bool) {
c.store.Range(func(key, value interface{}) bool {
ci := value.(*mapItem[K, V])
return callback(key.(K), ci.value)
})
}
func (c *Map[K, V]) Values() []V {
values := make([]V, 0)
c.store.Range(func(key, value interface{}) bool {
item := value.(*mapItem[K, V]) // Correct type assertion
values = append(values, item.value)
return true
})
return values
}
func (c *Map[K, V]) Keys() []K {
keys := make([]K, 0)
c.store.Range(func(key, value interface{}) bool {
keys = append(keys, key.(K))
return true
})
return keys
}

View File

@@ -0,0 +1,126 @@
package util
import (
"crypto/tls"
"errors"
"net/http"
"time"
)
// Full credit to https://github.com/DaRealFreak/cloudflare-bp-go
// RetryConfig configures the retry behavior
type RetryConfig struct {
MaxRetries int
RetryDelay time.Duration
TimeoutOnly bool // Only retry on timeout errors
}
// cloudFlareRoundTripper is a custom round tripper add the validated request headers.
type cloudFlareRoundTripper struct {
inner http.RoundTripper
options Options
retry *RetryConfig
}
// Options the option to set custom headers
type Options struct {
AddMissingHeaders bool
Headers map[string]string
}
// AddCloudFlareByPass returns a round tripper adding the required headers for the CloudFlare checks
// and updates the TLS configuration of the passed inner transport.
func AddCloudFlareByPass(inner http.RoundTripper, options ...Options) http.RoundTripper {
if trans, ok := inner.(*http.Transport); ok {
trans.TLSClientConfig = getCloudFlareTLSConfiguration()
}
roundTripper := &cloudFlareRoundTripper{
inner: inner,
retry: &RetryConfig{
MaxRetries: 3,
RetryDelay: 2 * time.Second,
TimeoutOnly: true,
},
}
if options != nil && len(options) > 0 {
roundTripper.options = options[0]
} else {
roundTripper.options = GetDefaultOptions()
}
return roundTripper
}
// RoundTrip adds the required request headers to pass CloudFlare checks.
func (ug *cloudFlareRoundTripper) RoundTrip(r *http.Request) (*http.Response, error) {
var lastErr error
attempts := 0
for attempts <= ug.retry.MaxRetries {
// Add headers for this attempt
if ug.options.AddMissingHeaders {
for header, value := range ug.options.Headers {
if _, ok := r.Header[header]; !ok {
if header == "User-Agent" {
// Generate new random user agent for each attempt
r.Header.Set(header, GetRandomUserAgent())
} else {
r.Header.Set(header, value)
}
}
}
}
// Make the request
var resp *http.Response
var err error
// in case we don't have an inner transport layer from the round tripper
if ug.inner == nil {
resp, err = (&http.Transport{
TLSClientConfig: getCloudFlareTLSConfiguration(),
ForceAttemptHTTP2: false,
}).RoundTrip(r)
} else {
resp, err = ug.inner.RoundTrip(r)
}
// If successful or not a timeout error, return immediately
if err == nil || (ug.retry.TimeoutOnly && !errors.Is(err, http.ErrHandlerTimeout)) {
return resp, err
}
lastErr = err
attempts++
// If we have more retries, wait before next attempt
if attempts <= ug.retry.MaxRetries {
time.Sleep(ug.retry.RetryDelay)
}
}
return nil, lastErr
}
// getCloudFlareTLSConfiguration returns an accepted client TLS configuration to not get detected by CloudFlare directly
// in case the configuration needs to be updated later on: https://wiki.mozilla.org/Security/Server_Side_TLS .
func getCloudFlareTLSConfiguration() *tls.Config {
return &tls.Config{
CurvePreferences: []tls.CurveID{tls.CurveP256, tls.CurveP384, tls.CurveP521, tls.X25519},
}
}
// GetDefaultOptions returns the options set by default
func GetDefaultOptions() Options {
return Options{
AddMissingHeaders: true,
Headers: map[string]string{
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,image/apng,*/*;q=0.8",
"Accept-Language": "en-US,en;q=0.5",
"User-Agent": GetRandomUserAgent(),
},
}
}

View File

@@ -0,0 +1,39 @@
package util
func SliceFrom[T any](slice []T, idx int) (ret []T, ok bool) {
if idx < 0 || idx >= len(slice) {
return []T{}, false
}
return slice[idx:], true
}
func SliceTo[T any](slice []T, idx int) (ret []T, ok bool) {
if idx < 0 || idx >= len(slice) {
return []T{}, false
}
return slice[:idx], true
}
func SliceStrFrom(slice string, idx int) (ret string, ok bool) {
if idx < 0 || idx >= len(slice) {
return "", false
}
return slice[idx:], true
}
func SliceStrTo(slice string, idx int) (ret string, ok bool) {
if idx < 0 || idx >= len(slice) {
return "", false
}
return slice[:idx], true
}
// Contains checks if a string slice contains a specific string
func Contains[T comparable](slice []T, item T) bool {
for _, s := range slice {
if s == item {
return true
}
}
return false
}

View File

@@ -0,0 +1,28 @@
package util
import (
"fmt"
"strings"
"github.com/kr/pretty"
)
func Spew(v interface{}) {
fmt.Printf("%# v\n", pretty.Formatter(v))
}
func SpewMany(v ...interface{}) {
fmt.Println("\nSpewing values:")
for _, val := range v {
Spew(val)
}
fmt.Println()
}
func SpewT(v interface{}) string {
return fmt.Sprintf("%# v\n", pretty.Formatter(v))
}
func InlineSpewT(v interface{}) string {
return strings.ReplaceAll(SpewT(v), "\n", "")
}

View File

@@ -0,0 +1,273 @@
package util
import (
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"fmt"
"math/big"
"path/filepath"
"regexp"
"runtime"
"strconv"
"strings"
"time"
"unicode"
"github.com/dustin/go-humanize"
)
func Bytes(size uint64) string {
switch runtime.GOOS {
case "darwin":
return humanize.Bytes(size)
default:
return humanize.IBytes(size)
}
}
func Decode(s string) string {
decoded, err := base64.StdEncoding.DecodeString(s)
if err != nil {
return ""
}
return string(decoded)
}
func GenerateCryptoID() string {
bytes := make([]byte, 16)
if _, err := rand.Read(bytes); err != nil {
panic(err)
}
return hex.EncodeToString(bytes)
}
func IsMostlyLatinString(str string) bool {
if len(str) <= 0 {
return false
}
latinLength := 0
nonLatinLength := 0
for _, r := range str {
if isLatinRune(r) {
latinLength++
} else {
nonLatinLength++
}
}
return latinLength > nonLatinLength
}
func isLatinRune(r rune) bool {
return unicode.In(r, unicode.Latin)
}
// ToHumanReadableSpeed converts an integer representing bytes per second to a human-readable format using binary notation
func ToHumanReadableSpeed(bytesPerSecond int) string {
if bytesPerSecond <= 0 {
return `0 KiB/s`
}
const unit = 1024
if bytesPerSecond < unit {
return fmt.Sprintf("%d B/s", bytesPerSecond)
}
div, exp := int64(unit), 0
for n := int64(bytesPerSecond) / unit; n >= unit; n /= unit {
div *= unit
exp++
}
return fmt.Sprintf("%.1f %ciB/s", float64(bytesPerSecond)/float64(div), "KMGTPE"[exp])
}
func StringSizeToBytes(str string) (int64, error) {
// Regular expression to extract size and unit
re := regexp.MustCompile(`(?i)^(\d+(\.\d+)?)\s*([KMGT]?i?B)$`)
match := re.FindStringSubmatch(strings.TrimSpace(str))
if match == nil {
return 0, fmt.Errorf("invalid size format: %s", str)
}
// Extract the numeric part and convert to float64
size, err := strconv.ParseFloat(match[1], 64)
if err != nil {
return 0, fmt.Errorf("failed to parse size: %s", err)
}
// Extract the unit and convert to lowercase
unit := strings.ToLower(match[3])
// Map units to their respective multipliers
unitMultipliers := map[string]int64{
"b": 1,
"bi": 1,
"kb": 1024,
"kib": 1024,
"mb": 1024 * 1024,
"mib": 1024 * 1024,
"gb": 1024 * 1024 * 1024,
"gib": 1024 * 1024 * 1024,
"tb": 1024 * 1024 * 1024 * 1024,
"tib": 1024 * 1024 * 1024 * 1024,
}
// Apply the multiplier based on the unit
multiplier, ok := unitMultipliers[unit]
if !ok {
return 0, fmt.Errorf("invalid unit: %s", unit)
}
// Calculate the total bytes
bytes := int64(size * float64(multiplier))
return bytes, nil
}
// FormatETA formats an ETA (in seconds) into a human-readable string
func FormatETA(etaInSeconds int) string {
const noETA = 8640000
if etaInSeconds == noETA {
return "No ETA"
}
etaDuration := time.Duration(etaInSeconds) * time.Second
hours := int(etaDuration.Hours())
minutes := int(etaDuration.Minutes()) % 60
seconds := int(etaDuration.Seconds()) % 60
switch {
case hours > 0:
return fmt.Sprintf("%d hours left", hours)
case minutes > 0:
return fmt.Sprintf("%d minutes left", minutes)
case seconds < 0:
return "No ETA"
default:
return fmt.Sprintf("%d seconds left", seconds)
}
}
func Pluralize(count int, singular, plural string) string {
if count == 1 {
return singular
}
return plural
}
// NormalizePath normalizes a path by converting it to lowercase and replacing backslashes with forward slashes
// Warning: Do not use the returned string for anything filesystem related, only for comparison
func NormalizePath(path string) (ret string) {
return strings.ToLower(filepath.ToSlash(path))
}
func Base64EncodeStr(str string) string {
return base64.StdEncoding.EncodeToString([]byte(str))
}
func Base64DecodeStr(str string) (string, error) {
decoded, err := base64.StdEncoding.DecodeString(str)
if err != nil {
return "", err
}
return string(decoded), nil
}
func IsBase64(s string) bool {
// 1. Check if string is empty
if len(s) == 0 {
return false
}
// 2. Check if length is valid (must be multiple of 4)
if len(s)%4 != 0 {
return false
}
// 3. Check for valid padding
padding := strings.Count(s, "=")
if padding > 2 {
return false
}
// 4. Check if padding is at the end only
if padding > 0 && !strings.HasSuffix(s, strings.Repeat("=", padding)) {
return false
}
// 5. Check if string contains only valid base64 characters
validChars := regexp.MustCompile("^[A-Za-z0-9+/]*=*$")
if !validChars.MatchString(s) {
return false
}
// 6. Try to decode - this is the final verification
_, err := base64.StdEncoding.DecodeString(s)
return err == nil
}
var snakecaseSplitRegex = regexp.MustCompile(`[\W_]+`)
func Snakecase(str string) string {
var result strings.Builder
// split at any non word character and underscore
words := snakecaseSplitRegex.Split(str, -1)
for _, word := range words {
if word == "" {
continue
}
if result.Len() > 0 {
result.WriteString("_")
}
for i, c := range word {
if unicode.IsUpper(c) && i > 0 &&
// is not a following uppercase character
!unicode.IsUpper(rune(word[i-1])) {
result.WriteString("_")
}
result.WriteRune(c)
}
}
return strings.ToLower(result.String())
}
// randomStringWithAlphabet generates a cryptographically random string
// with the specified length and characters set.
//
// It panics if for some reason rand.Int returns a non-nil error.
func RandomStringWithAlphabet(length int, alphabet string) string {
b := make([]byte, length)
max := big.NewInt(int64(len(alphabet)))
for i := range b {
n, err := rand.Int(rand.Reader, max)
if err != nil {
panic(err)
}
b[i] = alphabet[n.Int64()]
}
return string(b)
}
func FileExt(str string) string {
lastDotIndex := strings.LastIndex(str, ".")
if lastDotIndex == -1 {
return ""
}
return str[lastDotIndex:]
}
func HashSHA256Hex(s string) string {
h := sha256.New()
h.Write([]byte(s))
return hex.EncodeToString(h.Sum(nil))
}

View File

@@ -0,0 +1,55 @@
package util
import "testing"
func TestSizeInBytes(t *testing.T) {
tests := []struct {
size string
bytes int64
}{
{"1.5 gb", 1610612736},
{"1.5 GB", 1610612736},
{"1.5 GiB", 1610612736},
{"385.5 mib", 404226048},
}
for _, test := range tests {
bytes, err := StringSizeToBytes(test.size)
if err != nil {
t.Errorf("Error converting size to bytes: %s", err)
}
if bytes != test.bytes {
t.Errorf("Expected %d bytes, got %d", test.bytes, bytes)
}
}
}
func TestIsBase64Encoded(t *testing.T) {
tests := []struct {
str string
isBase64 bool
}{
{"SGVsbG8gV29ybGQ=", true}, // "Hello World"
{"", false}, // Empty string
{"SGVsbG8gV29ybGQ", false}, // Invalid padding
{"SGVsbG8gV29ybGQ==", false}, // Invalid padding
{"SGVsbG8=V29ybGQ=", false}, // Padding in middle
{"SGVsbG8gV29ybGQ!!", false}, // Invalid characters
{"=SGVsbG8gV29ybGQ=", false}, // Padding at start
{"SGVsbG8gV29ybGQ===", false}, // Too much padding
{"A", false}, // Single character
{"AA==", true}, // Valid minimal string
{"YWJjZA==", true}, // "abcd"
}
for _, test := range tests {
if IsBase64(test.str) != test.isBase64 {
t.Errorf("Expected %t for %s, got %t", test.isBase64, test.str, IsBase64(test.str))
}
}
}

View File

@@ -0,0 +1,494 @@
package torrentutil
import (
"fmt"
"io"
"sync"
"time"
"github.com/anacrolix/torrent"
"github.com/rs/zerolog"
)
// +-----------------------+
// + anacrolix/torrent +
// +-----------------------+
const (
piecesForNow = int64(5)
piecesForHighBefore = int64(2)
piecesForNext = int64(30)
piecesForReadahead = int64(30)
)
// readerInfo tracks information about an active reader
type readerInfo struct {
id string
position int64
lastAccess time.Time
}
// priorityManager manages piece priorities for multiple readers on the same file
type priorityManager struct {
mu sync.RWMutex
readers map[string]*readerInfo
torrent *torrent.Torrent
file *torrent.File
logger *zerolog.Logger
}
// global map to track priority managers per torrent+file combination
var (
priorityManagers = make(map[string]*priorityManager)
priorityManagersMu sync.RWMutex
)
// getPriorityManager gets or creates a priority manager for a torrent+file combination
func getPriorityManager(t *torrent.Torrent, file *torrent.File, logger *zerolog.Logger) *priorityManager {
key := fmt.Sprintf("%s:%s", t.InfoHash().String(), file.Path())
priorityManagersMu.Lock()
defer priorityManagersMu.Unlock()
if pm, exists := priorityManagers[key]; exists {
return pm
}
pm := &priorityManager{
readers: make(map[string]*readerInfo),
torrent: t,
file: file,
logger: logger,
}
priorityManagers[key] = pm
// Start cleanup goroutine for the first manager
if len(priorityManagers) == 1 {
go pm.cleanupStaleReaders()
}
return pm
}
// registerReader registers a new reader with the priority manager
func (pm *priorityManager) registerReader(readerID string, position int64) {
pm.mu.Lock()
defer pm.mu.Unlock()
pm.readers[readerID] = &readerInfo{
id: readerID,
position: position,
lastAccess: time.Now(),
}
pm.updatePriorities()
}
// updateReaderPosition updates a reader's position and recalculates priorities
func (pm *priorityManager) updateReaderPosition(readerID string, position int64) {
pm.mu.Lock()
defer pm.mu.Unlock()
if reader, exists := pm.readers[readerID]; exists {
reader.position = position
reader.lastAccess = time.Now()
pm.updatePriorities()
}
}
// unregisterReader removes a reader from tracking
func (pm *priorityManager) unregisterReader(readerID string) {
pm.mu.Lock()
defer pm.mu.Unlock()
delete(pm.readers, readerID)
// If no more readers, clean up and recalculate priorities
if len(pm.readers) == 0 {
pm.resetAllPriorities()
} else {
pm.updatePriorities()
}
}
// updatePriorities recalculates piece priorities based on all active readers
func (pm *priorityManager) updatePriorities() {
if pm.torrent == nil || pm.file == nil || pm.torrent.Info() == nil {
return
}
t := pm.torrent
file := pm.file
pieceLength := t.Info().PieceLength
if pieceLength == 0 {
if pm.logger != nil {
pm.logger.Warn().Msg("torrentutil: piece length is zero, cannot prioritize")
}
return
}
numTorrentPieces := int64(t.NumPieces())
if numTorrentPieces == 0 {
if pm.logger != nil {
pm.logger.Warn().Msg("torrentutil: torrent has zero pieces, cannot prioritize")
}
return
}
// Calculate file piece range
fileFirstPieceIdx := file.Offset() / pieceLength
fileLastPieceIdx := (file.Offset() + file.Length() - 1) / pieceLength
// Collect all needed piece ranges from all active readers
neededPieces := make(map[int64]torrent.PiecePriority)
for _, reader := range pm.readers {
position := reader.position
// Remove 1MB from the position (for subtitle cluster)
position -= 1 * 1024 * 1024
if position < 0 {
position = 0
}
if position < 0 {
position = 0
}
currentGlobalSeekPieceIdx := (file.Offset() + position) / pieceLength
// Pieces needed NOW (immediate)
for i := int64(0); i < piecesForNow; i++ {
idx := currentGlobalSeekPieceIdx + i
if idx >= fileFirstPieceIdx && idx <= fileLastPieceIdx && idx < numTorrentPieces {
if current, exists := neededPieces[idx]; !exists || current < torrent.PiecePriorityNow {
neededPieces[idx] = torrent.PiecePriorityNow
}
}
}
// Pieces needed HIGH (before current position for rewinds)
for i := int64(1); i <= piecesForHighBefore; i++ {
idx := currentGlobalSeekPieceIdx - i
if idx >= fileFirstPieceIdx && idx <= fileLastPieceIdx && idx >= 0 {
if current, exists := neededPieces[idx]; !exists || current < torrent.PiecePriorityHigh {
neededPieces[idx] = torrent.PiecePriorityHigh
}
}
}
// Pieces needed NEXT (immediate readahead)
nextStartIdx := currentGlobalSeekPieceIdx + piecesForNow
for i := int64(0); i < piecesForNext; i++ {
idx := nextStartIdx + i
if idx >= fileFirstPieceIdx && idx <= fileLastPieceIdx && idx < numTorrentPieces {
if current, exists := neededPieces[idx]; !exists || current < torrent.PiecePriorityNext {
neededPieces[idx] = torrent.PiecePriorityNext
}
}
}
// Pieces needed for READAHEAD (further readahead)
readaheadStartIdx := nextStartIdx + piecesForNext
for i := int64(0); i < piecesForReadahead; i++ {
idx := readaheadStartIdx + i
if idx >= fileFirstPieceIdx && idx <= fileLastPieceIdx && idx < numTorrentPieces {
if current, exists := neededPieces[idx]; !exists || current < torrent.PiecePriorityReadahead {
neededPieces[idx] = torrent.PiecePriorityReadahead
}
}
}
}
// Reset pieces that are no longer needed by any reader
for idx := fileFirstPieceIdx; idx <= fileLastPieceIdx; idx++ {
if idx < 0 || idx >= numTorrentPieces {
continue
}
piece := t.Piece(int(idx))
currentPriority := piece.State().Priority
if neededPriority, needed := neededPieces[idx]; needed {
// Set to the highest priority needed by any reader
if currentPriority != neededPriority {
piece.SetPriority(neededPriority)
}
} else {
// Only reset to normal if not completely unwanted and not already at highest priority
if currentPriority != torrent.PiecePriorityNone && currentPriority != torrent.PiecePriorityNow {
piece.SetPriority(torrent.PiecePriorityNormal)
}
}
}
if pm.logger != nil {
pm.logger.Debug().Msgf("torrentutil: Updated priorities for %d readers, %d pieces prioritized", len(pm.readers), len(neededPieces))
}
}
// resetAllPriorities resets all file pieces to normal priority
func (pm *priorityManager) resetAllPriorities() {
if pm.torrent == nil || pm.file == nil || pm.torrent.Info() == nil {
return
}
t := pm.torrent
file := pm.file
pieceLength := t.Info().PieceLength
if pieceLength == 0 {
return
}
numTorrentPieces := int64(t.NumPieces())
fileFirstPieceIdx := file.Offset() / pieceLength
fileLastPieceIdx := (file.Offset() + file.Length() - 1) / pieceLength
for idx := fileFirstPieceIdx; idx <= fileLastPieceIdx; idx++ {
if idx >= 0 && idx < numTorrentPieces {
piece := t.Piece(int(idx))
if piece.State().Priority != torrent.PiecePriorityNone {
piece.SetPriority(torrent.PiecePriorityNormal)
}
}
}
}
// cleanupStaleReaders periodically removes readers that haven't been accessed recently
func (pm *priorityManager) cleanupStaleReaders() {
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
for range ticker.C {
pm.mu.Lock()
cutoff := time.Now().Add(-2 * time.Minute)
for id, reader := range pm.readers {
if reader.lastAccess.Before(cutoff) {
delete(pm.readers, id)
if pm.logger != nil {
pm.logger.Debug().Msgf("torrentutil: Cleaned up stale reader %s", id)
}
}
}
// Update priorities after cleanup
if len(pm.readers) > 0 {
pm.updatePriorities()
}
pm.mu.Unlock()
}
}
// ReadSeeker implements io.ReadSeekCloser for a torrent file being streamed.
// It allows dynamic prioritization of pieces when seeking, optimized for streaming
// and supports multiple concurrent readers on the same file.
type ReadSeeker struct {
id string
torrent *torrent.Torrent
file *torrent.File
reader torrent.Reader
priorityManager *priorityManager
logger *zerolog.Logger
}
var _ io.ReadSeekCloser = &ReadSeeker{}
func NewReadSeeker(t *torrent.Torrent, file *torrent.File, logger ...*zerolog.Logger) io.ReadSeekCloser {
tr := file.NewReader()
tr.SetResponsive()
// Read ahead 5MB for better streaming performance
// DEVNOTE: Not sure if dynamic prioritization overwrites this but whatever
tr.SetReadahead(5 * 1024 * 1024)
var loggerPtr *zerolog.Logger
if len(logger) > 0 {
loggerPtr = logger[0]
}
pm := getPriorityManager(t, file, loggerPtr)
rs := &ReadSeeker{
id: fmt.Sprintf("reader_%d_%d", time.Now().UnixNano(), len(pm.readers)),
torrent: t,
file: file,
reader: tr,
priorityManager: pm,
logger: loggerPtr,
}
// Register this reader with the priority manager
pm.registerReader(rs.id, 0)
return rs
}
func (rs *ReadSeeker) Read(p []byte) (n int, err error) {
return rs.reader.Read(p)
}
func (rs *ReadSeeker) Seek(offset int64, whence int) (int64, error) {
newOffset, err := rs.reader.Seek(offset, whence)
if err != nil {
if rs.logger != nil {
rs.logger.Error().Err(err).Int64("offset", offset).Int("whence", whence).Msg("torrentutil: ReadSeeker seek error")
}
return newOffset, err
}
// Update this reader's position in the priority manager
rs.priorityManager.updateReaderPosition(rs.id, newOffset)
return newOffset, nil
}
// Close closes the underlying torrent file reader and unregisters from priority manager.
// This makes ReadSeeker implement io.ReadSeekCloser.
func (rs *ReadSeeker) Close() error {
// Unregister from priority manager
rs.priorityManager.unregisterReader(rs.id)
if rs.reader != nil {
return rs.reader.Close()
}
return nil
}
// PrioritizeDownloadPieces sets high priority for the first 3% of pieces and the last few pieces to ensure faster loading.
func PrioritizeDownloadPieces(t *torrent.Torrent, file *torrent.File, logger *zerolog.Logger) {
// Calculate file's pieces
firstPieceIdx := file.Offset() * int64(t.NumPieces()) / t.Length()
endPieceIdx := (file.Offset() + file.Length()) * int64(t.NumPieces()) / t.Length()
// Prioritize more pieces at the beginning for faster initial loading (3% for beginning)
numPiecesForStart := (endPieceIdx - firstPieceIdx + 1) * 3 / 100
if logger != nil {
logger.Debug().Msgf("torrentuil: Setting high priority for first 3%% - pieces %d to %d (total %d)",
firstPieceIdx, firstPieceIdx+numPiecesForStart, numPiecesForStart)
}
for idx := firstPieceIdx; idx <= firstPieceIdx+numPiecesForStart; idx++ {
t.Piece(int(idx)).SetPriority(torrent.PiecePriorityNow)
}
// Also prioritize the last few pieces
numPiecesForEnd := (endPieceIdx - firstPieceIdx + 1) * 1 / 100
if logger != nil {
logger.Debug().Msgf("torrentuil: Setting priority for last pieces %d to %d (total %d)",
endPieceIdx-numPiecesForEnd, endPieceIdx, numPiecesForEnd)
}
for idx := endPieceIdx - numPiecesForEnd; idx <= endPieceIdx; idx++ {
if idx >= 0 && int(idx) < t.NumPieces() {
t.Piece(int(idx)).SetPriority(torrent.PiecePriorityNow)
}
}
}
// PrioritizeRangeRequestPieces attempts to prioritize pieces needed for the range request.
func PrioritizeRangeRequestPieces(rangeHeader string, t *torrent.Torrent, file *torrent.File, logger *zerolog.Logger) {
// Parse the range header (format: bytes=START-END)
var start int64
_, _ = fmt.Sscanf(rangeHeader, "bytes=%d-", &start)
if start >= 0 {
// Calculate file's pieces range
fileOffset := file.Offset()
fileLength := file.Length()
// Calculate the total range of pieces for this file
firstFilePieceIdx := fileOffset * int64(t.NumPieces()) / t.Length()
endFilePieceIdx := (fileOffset + fileLength) * int64(t.NumPieces()) / t.Length()
// Calculate the piece index for this seek offset with small padding
// Subtract a small amount to ensure we don't miss the beginning of a needed piece
seekPosition := start
if seekPosition >= 1024*1024 { // If we're at least 1MB in, add some padding
seekPosition -= 1024 * 512 // Subtract 512KB to ensure we get the right piece
}
seekPieceIdx := (fileOffset + seekPosition) * int64(t.NumPieces()) / t.Length()
// Prioritize the next several pieces from this point
// This is especially important for seeking
numPiecesToPrioritize := int64(10) // Prioritize next 10 pieces, adjust as needed
if seekPieceIdx+numPiecesToPrioritize > endFilePieceIdx {
numPiecesToPrioritize = endFilePieceIdx - seekPieceIdx
}
if logger != nil {
logger.Debug().Msgf("torrentutil: Prioritizing range request pieces %d to %d",
seekPieceIdx, seekPieceIdx+numPiecesToPrioritize)
}
// Set normal priority for pieces far from our current position
// This allows background downloading while still prioritizing the seek point
for idx := firstFilePieceIdx; idx <= endFilePieceIdx; idx++ {
if idx >= 0 && int(idx) < t.NumPieces() {
// Don't touch the beginning pieces which should maintain their high priority
// for the next potential restart, and don't touch pieces near our seek point
if idx > firstFilePieceIdx+100 && idx < seekPieceIdx-100 ||
idx > seekPieceIdx+numPiecesToPrioritize+100 {
// Set to normal priority - allow background downloading
t.Piece(int(idx)).SetPriority(torrent.PiecePriorityNormal)
}
}
}
// Now set the highest priority for the pieces we need right now
for idx := seekPieceIdx; idx < seekPieceIdx+numPiecesToPrioritize; idx++ {
if idx >= 0 && int(idx) < t.NumPieces() {
t.Piece(int(idx)).SetPriority(torrent.PiecePriorityNow)
}
}
// Also prioritize a small buffer before the seek point to handle small rewinds
// This is useful for MPV's default rewind behavior
bufferBeforeCount := int64(5) // 5 pieces buffer before seek point
if seekPieceIdx > firstFilePieceIdx+bufferBeforeCount {
for idx := seekPieceIdx - bufferBeforeCount; idx < seekPieceIdx; idx++ {
if idx >= 0 && int(idx) < t.NumPieces() {
t.Piece(int(idx)).SetPriority(torrent.PiecePriorityHigh)
}
}
}
// Also prioritize the next readahead segment after our immediate needs
// This helps prepare for continued playback
nextReadStart := seekPieceIdx + numPiecesToPrioritize
nextReadCount := int64(100) // 100 additional pieces for nextRead
if nextReadStart+nextReadCount > endFilePieceIdx {
nextReadCount = endFilePieceIdx - nextReadStart
}
if nextReadCount > 0 {
if logger != nil {
logger.Debug().Msgf("torrentutil: Setting next priority for pieces %d to %d",
nextReadStart, nextReadStart+nextReadCount)
}
for idx := nextReadStart; idx < nextReadStart+nextReadCount; idx++ {
if idx >= 0 && int(idx) < t.NumPieces() {
t.Piece(int(idx)).SetPriority(torrent.PiecePriorityNext)
}
}
}
// Also prioritize the next readahead segment after our immediate needs
// This helps prepare for continued playback
readAheadCount := int64(100)
if nextReadStart+readAheadCount > endFilePieceIdx {
readAheadCount = endFilePieceIdx - nextReadStart
}
if readAheadCount > 0 {
if logger != nil {
logger.Debug().Msgf("torrentutil: Setting read ahead priority for pieces %d to %d",
nextReadStart, nextReadStart+readAheadCount)
}
for idx := nextReadStart; idx < nextReadStart+readAheadCount; idx++ {
if idx >= 0 && int(idx) < t.NumPieces() {
t.Piece(int(idx)).SetPriority(torrent.PiecePriorityReadahead)
}
}
}
}
}

View File

@@ -0,0 +1,77 @@
package util
import (
"bufio"
"encoding/json"
"math/rand"
"net/http"
"sync"
"time"
"github.com/rs/zerolog/log"
)
var (
userAgentList []string
uaMu sync.RWMutex
)
func init() {
go func() {
defer func() {
if r := recover(); r != nil {
log.Warn().Msgf("util: Failed to get online user agents: %v", r)
}
}()
agents, err := getOnlineUserAgents()
if err != nil {
log.Warn().Err(err).Msg("util: Failed to get online user agents")
return
}
uaMu.Lock()
userAgentList = agents
uaMu.Unlock()
}()
}
func getOnlineUserAgents() ([]string, error) {
link := "https://raw.githubusercontent.com/fake-useragent/fake-useragent/refs/heads/main/src/fake_useragent/data/browsers.jsonl"
client := &http.Client{
Timeout: 10 * time.Second,
}
response, err := client.Get(link)
if err != nil {
return nil, err
}
defer response.Body.Close()
var agents []string
type UserAgent struct {
UserAgent string `json:"useragent"`
}
scanner := bufio.NewScanner(response.Body)
for scanner.Scan() {
line := scanner.Text()
var ua UserAgent
if err := json.Unmarshal([]byte(line), &ua); err != nil {
return nil, err
}
agents = append(agents, ua.UserAgent)
}
return agents, nil
}
func GetRandomUserAgent() string {
uaMu.RLock()
defer uaMu.RUnlock()
if len(userAgentList) > 0 {
return userAgentList[rand.Intn(len(userAgentList))]
}
return UserAgentList[rand.Intn(len(UserAgentList))]
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,59 @@
package util
import (
"bufio"
"encoding/json"
"fmt"
"os"
"path/filepath"
"testing"
)
func TestGetOnlineUserAgents(t *testing.T) {
userAgents, err := getOnlineUserAgents()
if err != nil {
t.Fatalf("Failed to get online user agents: %v", err)
}
t.Logf("Online user agents: %v", userAgents)
}
func TestTransformUserAgentJsonlToSliceFile(t *testing.T) {
jsonlFilePath := filepath.Join("data", "user_agents.jsonl")
jsonlFile, err := os.Open(jsonlFilePath)
if err != nil {
t.Fatalf("Failed to open JSONL file: %v", err)
}
defer jsonlFile.Close()
sliceFilePath := filepath.Join("user_agent_list.go")
sliceFile, err := os.Create(sliceFilePath)
if err != nil {
t.Fatalf("Failed to create slice file: %v", err)
}
defer sliceFile.Close()
sliceFile.WriteString("package util\n\nvar UserAgentList = []string{\n")
type UserAgent struct {
UserAgent string `json:"useragent"`
}
scanner := bufio.NewScanner(jsonlFile)
for scanner.Scan() {
line := scanner.Text()
var ua UserAgent
if err := json.Unmarshal([]byte(line), &ua); err != nil {
t.Fatalf("Failed to unmarshal line: %v", err)
}
sliceFile.WriteString(fmt.Sprintf("\t\"%s\",\n", ua.UserAgent))
}
sliceFile.WriteString("}\n")
if err := scanner.Err(); err != nil {
t.Fatalf("Failed to read JSONL file: %v", err)
}
t.Logf("User agent list generated successfully: %s", sliceFilePath)
}

View File

@@ -0,0 +1,48 @@
package util
import "github.com/mileusna/useragent"
const (
PlatformAndroid = "android"
PlatformIOS = "ios"
PlatformLinux = "linux"
PlatformMac = "mac"
PlatformWindows = "windows"
PlatformChromeOS = "chromeos"
)
const (
DeviceDesktop = "desktop"
DeviceMobile = "mobile"
DeviceTablet = "tablet"
)
type ClientInfo struct {
Device string
Platform string
}
func GetClientInfo(userAgent string) ClientInfo {
ua := useragent.Parse(userAgent)
var device string
var platform string
if ua.Mobile {
device = DeviceMobile
} else if ua.Tablet {
device = DeviceTablet
} else {
device = DeviceDesktop
}
platform = ua.OS
if platform == "" {
platform = "-"
}
return ClientInfo{
Device: device,
Platform: platform,
}
}

View File

@@ -0,0 +1,78 @@
package util
import (
"strconv"
"strings"
"github.com/Masterminds/semver/v3"
)
func IsValidVersion(version string) bool {
parts := strings.Split(version, ".")
if len(parts) != 3 {
return false
}
for _, part := range parts {
if _, err := strconv.Atoi(part); err != nil {
return false
}
}
return true
}
// CompareVersion compares two versions and returns the difference between them.
//
// 3: Current version is newer by major version.
// 2: Current version is newer by minor version.
// 1: Current version is newer by patch version.
// -3: Current version is older by major version.
// -2: Current version is older by minor version.
// -1: Current version is older by patch version.
func CompareVersion(current string, b string) (comp int, shouldUpdate bool) {
currV, err := semver.NewVersion(current)
if err != nil {
return 0, false
}
otherV, err := semver.NewVersion(b)
if err != nil {
return 0, false
}
comp = currV.Compare(otherV)
if comp == 0 {
return 0, false
}
if currV.GreaterThan(otherV) {
shouldUpdate = false
if currV.Major() > otherV.Major() {
comp *= 3
} else if currV.Minor() > otherV.Minor() {
comp *= 2
} else if currV.Patch() > otherV.Patch() {
comp *= 1
}
} else if currV.LessThan(otherV) {
shouldUpdate = true
if currV.Major() < otherV.Major() {
comp *= 3
} else if currV.Minor() < otherV.Minor() {
comp *= 2
} else if currV.Patch() < otherV.Patch() {
comp *= 1
}
}
return comp, shouldUpdate
}
func VersionIsOlderThan(version string, compare string) bool {
comp, shouldUpdate := CompareVersion(version, compare)
// shouldUpdate is false means the current version is newer
return comp < 0 && shouldUpdate
}

View File

@@ -0,0 +1,248 @@
package util
import (
"github.com/Masterminds/semver/v3"
"github.com/stretchr/testify/require"
"testing"
)
func TestCompareVersion(t *testing.T) {
testCases := []struct {
name string
otherVersion string
currVersion string
expectedOutput int
shouldUpdate bool
}{
{
name: "Current version is newer by major version",
currVersion: "2.0.0",
otherVersion: "1.0.0",
expectedOutput: 3,
shouldUpdate: false,
},
{
name: "Current version is older by major version",
currVersion: "2.0.0",
otherVersion: "3.0.0",
expectedOutput: -3,
shouldUpdate: true,
},
{
name: "Current version is older by minor version",
currVersion: "0.2.2",
otherVersion: "0.3.0",
expectedOutput: -2,
shouldUpdate: true,
},
{
name: "Current version is older by major version",
currVersion: "0.2.2",
otherVersion: "3.0.0",
expectedOutput: -3,
shouldUpdate: true,
},
{
name: "Current version is older by minor version",
currVersion: "0.2.2",
otherVersion: "0.2.3",
expectedOutput: -1,
shouldUpdate: true,
},
{
name: "Current version is newer by minor version",
currVersion: "1.2.0",
otherVersion: "1.1.0",
expectedOutput: 2,
shouldUpdate: false,
},
{
name: "Current version is older by minor version",
currVersion: "1.2.0",
otherVersion: "1.3.0",
expectedOutput: -2,
shouldUpdate: true,
},
{
name: "Current version is newer by patch version",
currVersion: "1.1.2",
otherVersion: "1.1.1",
expectedOutput: 1,
shouldUpdate: false,
},
{
name: "Current version is older by patch version",
currVersion: "1.1.2",
otherVersion: "1.1.3",
expectedOutput: -1,
shouldUpdate: true,
},
{
name: "Versions are equal",
currVersion: "1.1.1",
otherVersion: "1.1.1",
expectedOutput: 0,
shouldUpdate: false,
},
{
name: "Current version is newer by patch version",
currVersion: "1.1.1",
otherVersion: "1.1",
expectedOutput: 1,
shouldUpdate: false,
},
{
name: "Current version is newer by minor version + prerelease",
currVersion: "2.2.0-prerelease",
otherVersion: "2.1.0",
expectedOutput: 2,
shouldUpdate: false,
},
{
name: "Current version is newer (not prerelease)",
currVersion: "2.2.0",
otherVersion: "2.2.0-prerelease",
expectedOutput: 1,
shouldUpdate: false,
},
{
name: "Current version is older (is prerelease)",
currVersion: "2.2.0-prerelease",
otherVersion: "2.2.0",
expectedOutput: -1,
shouldUpdate: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
output, boolOutput := CompareVersion(tc.currVersion, tc.otherVersion)
if output != tc.expectedOutput || boolOutput != tc.shouldUpdate {
t.Errorf("Expected output to be %d and shouldUpdate to be %v, got output=%d and shouldUpdate=%v", tc.expectedOutput, tc.shouldUpdate, output, boolOutput)
}
})
}
}
func TestVersionIsOlderThan(t *testing.T) {
testCases := []struct {
name string
version string
compare string
isOlder bool
}{
{
name: "Version is older than compare",
version: "1.7.3",
compare: "2.0.0",
isOlder: true,
},
{
name: "Version is newer than compare",
version: "2.0.1",
compare: "2.0.0",
isOlder: false,
},
{
name: "Version is equal to compare",
version: "2.0.0",
compare: "2.0.0",
isOlder: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
output := VersionIsOlderThan(tc.version, tc.compare)
if output != tc.isOlder {
t.Errorf("Expected output to be %v, got %v", tc.isOlder, output)
}
})
}
}
func TestHasUpdated(t *testing.T) {
testCases := []struct {
name string
previousVersion string
currentVersion string
hasUpdated bool
}{
{
name: "previousVersion is older than currentVersion",
previousVersion: "1.7.3",
currentVersion: "2.0.0",
hasUpdated: true,
},
{
name: "previousVersion is newer than currentVersion",
previousVersion: "2.0.1",
currentVersion: "2.0.0",
hasUpdated: false,
},
{
name: "previousVersion is equal to currentVersion",
previousVersion: "2.0.0",
currentVersion: "2.0.0",
hasUpdated: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
hasUpdated := VersionIsOlderThan(tc.previousVersion, tc.currentVersion)
if hasUpdated != tc.hasUpdated {
t.Errorf("Expected output to be %v, got %v", tc.hasUpdated, hasUpdated)
}
})
}
}
func TestSemverConstraints(t *testing.T) {
testCases := []struct {
name string
version string
constraints string
expectedOutput bool
}{
{
name: "Version is within constraint",
version: "1.2.0",
constraints: ">= 1.2.0, <= 1.3.0",
expectedOutput: true,
},
{
name: "Updating from 2.0.0",
version: "2.0.1",
constraints: "< 2.1.0",
expectedOutput: true,
},
{
name: "Version is still 2.1.0",
version: "2.1.0",
constraints: "< 2.1.0",
expectedOutput: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
c, err := semver.NewConstraint(tc.constraints)
require.NoError(t, err)
v, err := semver.NewVersion(tc.version)
require.NoError(t, err)
output := c.Check(v)
if output != tc.expectedOutput {
t.Errorf("Expected output to be %v, got %v for version %s and constraint %s", tc.expectedOutput, output, tc.version, tc.constraints)
}
})
}
}