summaryrefslogtreecommitdiffstats
path: root/client
diff options
context:
space:
mode:
authorMartin Polden <mpolden@mpolden.no>2024-03-13 12:59:03 +0100
committerMartin Polden <mpolden@mpolden.no>2024-03-14 10:45:59 +0100
commit0b29d67b18754c16eb95ccf1d0a5ec6573f1c90d (patch)
tree699ba963d7beed878b0df05629587fd4ffdb181d /client
parent25fa0f5fbdf67aca2ce0f7d460ce402474a28fc1 (diff)
Implement SSE decoder
Diffstat (limited to 'client')
-rw-r--r--client/go/internal/sse/sse.go89
-rw-r--r--client/go/internal/sse/sse_test.go94
2 files changed, 183 insertions, 0 deletions
diff --git a/client/go/internal/sse/sse.go b/client/go/internal/sse/sse.go
new file mode 100644
index 00000000000..a056e4a598a
--- /dev/null
+++ b/client/go/internal/sse/sse.go
@@ -0,0 +1,89 @@
+package sse
+
+import (
+ "bufio"
+ "fmt"
+ "io"
+ "strings"
+)
+
+// Event represents a server-sent event. Name and ID are optional fields.
+type Event struct {
+ Name string
+ ID string
+ Data string
+}
+
+// Decoder reads and decodes a server-sent event from an input stream.
+type Decoder struct {
+ scanner *bufio.Scanner
+}
+
+// Decode reads and decodes the next event from the underlying reader.
+func (d *Decoder) Decode() (*Event, error) {
+ // https://www.rfc-editor.org/rfc/rfc8895.html#name-server-push-server-sent-eve
+ var (
+ event Event
+ data strings.Builder
+ lastRead string
+ gotName bool
+ gotID bool
+ gotData bool
+ decoding bool
+ )
+ for d.scanner.Scan() {
+ line := strings.TrimSpace(d.scanner.Text())
+ if line == "" {
+ if decoding {
+ break // Done with event
+ } else {
+ continue // Waiting for first non-empty line
+ }
+ }
+ lastRead = line
+ decoding = true
+ parts := strings.SplitN(line, ": ", 2)
+ if len(parts) < 2 || parts[0] == "" {
+ continue
+ }
+ switch parts[0] {
+ case "event":
+ if gotName {
+ return nil, fmt.Errorf("got more than one event line: last read %q", lastRead)
+ }
+ event.Name = parts[1]
+ gotName = true
+ case "id":
+ if gotID {
+ return nil, fmt.Errorf("got more than one id line: last read %q", lastRead)
+ }
+ event.ID = parts[1]
+ gotID = true
+ case "data":
+ if data.Len() > 0 {
+ data.WriteString(" ")
+ }
+ data.WriteString(parts[1])
+ gotData = true
+ default:
+ return nil, fmt.Errorf("invalid field name %q: last read %q", parts[0], lastRead)
+ }
+ }
+ if err := d.scanner.Err(); err != nil {
+ return nil, err
+ }
+ if !decoding {
+ return nil, io.EOF
+ }
+ if !event.IsEnd() && !gotData {
+ return nil, fmt.Errorf("no data line found for event: last read %q", lastRead)
+ }
+ event.Data = data.String()
+ return &event, nil
+}
+
+// NewDecoder creates a new Decoder that reads from r.
+func NewDecoder(r io.Reader) *Decoder { return &Decoder{scanner: bufio.NewScanner(r)} }
+
+// IsEnd returns whether this event indicates that the stream has ended.
+func (e Event) IsEnd() bool { return e.Name == "end" }
diff --git a/client/go/internal/sse/sse_test.go b/client/go/internal/sse/sse_test.go
new file mode 100644
index 00000000000..c81dc4995b6
--- /dev/null
+++ b/client/go/internal/sse/sse_test.go
@@ -0,0 +1,94 @@
+package sse
+
+import (
+ "errors"
+ "io"
+ "strings"
+ "testing"
+)
+
+func TestDecoder(t *testing.T) {
+ r := strings.NewReader(`
+event: foo
+id: 42
+data: data 1
+
+event: bar
+ignored
+: ignored
+data: data 2
+
+
+event: baz
+data: data 3
+data: data 4
+
+event: bax
+data: data 5
+
+data: data 6
+
+event: end
+`)
+ dec := NewDecoder(r)
+
+ assertDecode(&Event{Name: "foo", ID: "42", Data: "data 1"}, dec, t)
+ assertDecode(&Event{Name: "bar", Data: "data 2"}, dec, t)
+ assertDecode(&Event{Name: "baz", Data: "data 3 data 4"}, dec, t)
+ assertDecode(&Event{Name: "bax", Data: "data 5"}, dec, t)
+ assertDecode(&Event{Data: "data 6"}, dec, t)
+ assertDecode(&Event{Name: "end"}, dec, t)
+ assertDecodeErr(io.EOF, dec, t)
+}
+
+func TestDecoderInvalid(t *testing.T) {
+ r := strings.NewReader(`
+event: foo
+event: bar
+
+event: foo
+id: 42
+
+foo
+
+bad: field
+`)
+ dec := NewDecoder(r)
+ assertDecodeErrString(`got more than one event line: last read "event: bar"`, dec, t)
+ assertDecodeErrString(`no data line found for event: last read "id: 42"`, dec, t)
+ assertDecodeErrString(`no data line found for event: last read "foo"`, dec, t)
+ assertDecodeErrString(`invalid field name "bad": last read "bad: field"`, dec, t)
+}
+
+func assertDecode(want *Event, dec *Decoder, t *testing.T) {
+ t.Helper()
+ got, err := dec.Decode()
+ if err != nil {
+ t.Fatalf("got error %v, want %+v", err, want)
+ }
+ if got.Name != want.Name {
+ t.Errorf("got Name=%q, want %q", got.Name, want.Name)
+ }
+ if got.ID != want.ID {
+ t.Errorf("got ID=%q, want %q", got.ID, want.ID)
+ }
+ if got.Data != want.Data {
+ t.Errorf("got Data=%q, want %q", got.Data, want.Data)
+ }
+}
+
+func assertDecodeErrString(errMsg string, dec *Decoder, t *testing.T) {
+ t.Helper()
+ assertDecodeErr(errors.New(errMsg), dec, t)
+}
+
+func assertDecodeErr(wantErr error, dec *Decoder, t *testing.T) {
+ t.Helper()
+ _, err := dec.Decode()
+ if err == nil {
+ t.Fatal("expected error")
+ }
+ if err.Error() != wantErr.Error() {
+ t.Errorf("got error %q, want %q", err, wantErr)
+ }
+}