commit 2751a839b97fab3abcde40a7cbf2bd548b866669 Author: William Perron Date: Thu Jul 4 10:03:33 2024 -0400 initial commit diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..b822dba --- /dev/null +++ b/go.mod @@ -0,0 +1,3 @@ +module go.wperron.io/servertiming + +go 1.22.1 diff --git a/servertiming.go b/servertiming.go new file mode 100644 index 0000000..496cf18 --- /dev/null +++ b/servertiming.go @@ -0,0 +1,127 @@ +package servertiming + +import ( + "fmt" + "net/http" + "strconv" + "strings" + "time" +) + +type ServerTiming struct { + Name string + Dur time.Duration + Desc string + + // Used when formatting. The default behavior is to output time as an int + // representing milliseconds. Increasing the precision will add decimals to + // the output, down to nanoseconds. + DecimalPrecision int +} + +func (t ServerTiming) String() string { + sb := strings.Builder{} + sb.WriteString(t.Name) + + if t.Dur != 0 { + sb.WriteString(";dur=") + // precision is clamped between 0 and 6 inclusively. 0 represents + // milliseconds as an integer, 6 decimal positions represent nanoseconds + precision := min(max(0, t.DecimalPrecision), 6) + sb.WriteString(fmt.Sprintf("%.*f", precision, float64(t.Dur.Nanoseconds())/1_000_000)) + } + + if t.Desc != "" { + sb.WriteString(";desc=") + sb.WriteString(t.Desc) + } + + return sb.String() +} + +func max(a, b int) int { + if a > b { + return a + } + return b +} + +func min(a, b int) int { + if a < b { + return a + } + return b +} + +func FromString(s string) ServerTiming { + st := ServerTiming{} + + part, rest, more := strings.Cut(s, ";") + st.Name = strings.TrimSpace(part) + + for more { + part, rest, more = strings.Cut(strings.TrimSpace(rest), ";") + key, val, _ := strings.Cut(part, "=") + val = strings.TrimSpace(val) + switch key { + case "desc": + st.Desc = val + case "dur": + // From the spec: + // Since duration is a DOMHighResTimeStamp, it usually represents a + // duration in milliseconds. Since this is not enforcable in + // practice, duration can represent any unit of time, and having it + // represent a duration in milliseconds is a recommendation. + + // The happy path is an int, in which case milliseconds are assumed + if i, err := strconv.Atoi(val); err == nil { + st.Dur = time.Duration(i * int(time.Millisecond)) + break + } + + // Otherwise we try to parse as a float, and multiply by 1,000,000 + // to get nanoseconds and truncate the rest + if f, err := strconv.ParseFloat(val, 64); err == nil { + st.Dur = time.Duration(int(f*1_000_000) * int(time.Nanosecond)) + } + default: + // ignore any unknown token + } + } + + return st +} + +func Append(r *http.Response, t ServerTiming) { + r.Header.Add("Server-Timing", t.String()) +} + +func Trailer(r *http.Response, t ServerTiming) { + r.Trailer.Add("Server-Timing", t.String()) +} + +func Parse(r *http.Response) []ServerTiming { + res := []ServerTiming{} + + for k, v := range r.Header { + if strings.ToLower(k) == "server-timing" { + for _, s := range v { + for _, single := range strings.Split(s, ",") { + res = append(res, FromString(single)) + } + } + } + } + + for k, v := range r.Trailer { + if strings.ToLower(k) == "server-timing" { + for _, s := range v { + for _, single := range strings.Split(s, ",") { + res = append(res, FromString(single)) + } + } + } + } + + return res +} diff --git a/servertiming_test.go b/servertiming_test.go new file mode 100644 index 0000000..963e9a3 --- /dev/null +++ b/servertiming_test.go @@ -0,0 +1,293 @@ +package servertiming + +import ( + "net/http" + "reflect" + "testing" + "time" +) + +// Example taken from the W3C spec +// see: https://w3c.github.io/server-timing/#examples +// +// ``` +// > GET /resource HTTP/1.1 +// > Host: example.com +// +// +// < HTTP/1.1 200 OK +// < Server-Timing: miss, db;dur=53, app;dur=47.2 +// < Server-Timing: customView, dc;desc=atl +// < Server-Timing: cache;desc="Cache Read";dur=23.2 +// < Trailer: Server-Timing +// < (... snip response body ...) +// < Server-Timing: total;dur=123.4 +// ``` +// +// | Name | Duration | Description | +// | ---------- | -------- | ----------- | +// | miss | | | +// | db | 53 | | +// | app | 47.2 | | +// | customView | | | +// | dc | | atl | +// | cache | 23.2 | Cache Read | +// | total | 123.4 | | + +func TestServerTiming_String(t *testing.T) { + tests := []struct { + name string + st ServerTiming + want string + }{ + { + name: "just name", + st: ServerTiming{Name: "miss"}, + want: "miss", + }, + { + name: "name and dur", + st: ServerTiming{ + Name: "db", + Dur: 53 * time.Millisecond, + }, + want: "db;dur=53", + }, + { + name: "name and decimal dur", + st: ServerTiming{ + Name: "app", + Dur: 47_200 * time.Microsecond, + DecimalPrecision: 1, + }, + want: "app;dur=47.2", + }, + { + name: "name and nanosecond dur", + st: ServerTiming{ + Name: "app", + Dur: 47_200 * time.Microsecond, + DecimalPrecision: 6, + }, + want: "app;dur=47.200000", + }, + { + name: "name and dur, negative precision", + st: ServerTiming{ + Name: "app", + Dur: 47_200 * time.Microsecond, + DecimalPrecision: -1, + }, + want: "app;dur=47", + }, + { + name: "name and dur, out-of-bound precision", + st: ServerTiming{ + Name: "app", + Dur: 47_200 * time.Microsecond, + DecimalPrecision: 7, + }, + want: "app;dur=47.200000", + }, + { + name: "name and desc", + st: ServerTiming{ + Name: "dc", + Desc: "atl", + }, + want: "dc;desc=atl", + }, + { + name: "name, desc, and dur", + st: ServerTiming{ + Name: "cache", + Dur: 23 * time.Millisecond, + Desc: "Cache Read", + }, + want: "cache;dur=23;desc=Cache Read", + }, + { + name: "name, desc, and decimal dur", + st: ServerTiming{ + Name: "cache", + Dur: 23_200 * time.Microsecond, + Desc: "Cache Read", + DecimalPrecision: 1, + }, + want: "cache;dur=23.2;desc=Cache Read", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.st.String(); got != tt.want { + t.Errorf("ServerTiming.String() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestFromString(t *testing.T) { + tests := []struct { + name string + s string + want ServerTiming + }{ + { + name: "empty", + s: "", + want: ServerTiming{}, + }, + { + name: "name only", + s: "miss", + want: ServerTiming{Name: "miss"}, + }, + { + name: "name and dur", + s: "db;dur=53", + want: ServerTiming{ + Name: "db", + Dur: 53 * time.Millisecond, + }, + }, + { + name: "name, dur and desc", + s: "cache;dur=23;desc=Cache Read", + want: ServerTiming{ + Name: "cache", + Dur: 23 * time.Millisecond, + Desc: "Cache Read", + }, + }, + { + name: "name, desc", + s: "cache;desc=Cache Read;dur=23", + want: ServerTiming{ + Name: "cache", + Dur: 23 * time.Millisecond, + Desc: "Cache Read", + }, + }, + { + name: "name, dur and desc with padding", + s: "cache ; dur=23 ; desc=Cache Read ", + want: ServerTiming{ + Name: "cache", + Dur: 23 * time.Millisecond, + Desc: "Cache Read", + }, + }, + { + name: "name and decimal dur", + s: "cache;dur=23.2", + want: ServerTiming{ + Name: "cache", + Dur: 23_200 * time.Microsecond, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := FromString(tt.s); !reflect.DeepEqual(got, tt.want) { + t.Errorf("FromString() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestAppend(t *testing.T) { + resp := &http.Response{ + Header: http.Header{}, + } + + tests := []struct { + name string + st ServerTiming + }{ + { + name: "name only", + st: ServerTiming{Name: "miss"}, + }, + { + name: "name and dur", + st: ServerTiming{ + Name: "db", + Dur: 53 * time.Millisecond, + }, + }, + { + name: "name, dur and desc", + st: ServerTiming{ + Name: "cache", + Dur: 23 * time.Millisecond, + Desc: "Cache Read", + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + Append(resp, tt.st) + }) + } + + vals := resp.Header.Values("Server-Timing") + if len(vals) != len(tests) { + t.Errorf("Expected %d values in the headers, got %d", len(tests), len(vals)) + } + + for i, v := range vals { + if v != tests[i].st.String() { + t.Errorf("Expected '%s', got %s", tests[i].st.String(), v) + } + } +} + +func TestTrailer(t *testing.T) { + type args struct { + r *http.Response + t ServerTiming + } + tests := []struct { + name string + args args + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + Trailer(tt.args.r, tt.args.t) + }) + } +} + +func TestParse(t *testing.T) { + resp := &http.Response{ + Header: http.Header{}, + Trailer: http.Header{}, + } + resp.Header.Add("Server-Timing", "miss, db;dur=53, app;dur=47") + resp.Header.Add("Server-Timing", "customView, dc;desc=atl") + resp.Header.Add("Server-Timing", `cache;desc="Cache Read";dur=23`) + resp.Trailer.Add("Server-Timing", "total;dur=123") + + timings := Parse(resp) + if len(timings) != 7 { + t.Errorf("Expected 7 timings, got %d", len(timings)) + } +} + +func TestParseWithDecimal(t *testing.T) { + resp := &http.Response{ + Header: http.Header{}, + Trailer: http.Header{}, + } + resp.Header.Add("Server-Timing", "miss, db;dur=53, app;dur=47.2") + resp.Header.Add("Server-Timing", "customView, dc;desc=atl") + resp.Header.Add("Server-Timing", `cache;desc="Cache Read";dur=23.2`) + resp.Trailer.Add("Server-Timing", "total;dur=123.4") + + timings := Parse(resp) + if len(timings) != 7 { + t.Errorf("Expected 7 timings, got %d", len(timings)) + } +}