From cbb6d9ca53b8c7ca416afeee3ae0074c7e39a24c Mon Sep 17 00:00:00 2001 From: Joachim Bauch Date: Wed, 3 Aug 2022 17:15:02 +0200 Subject: [PATCH] Update capabilities if no hello v2 token key is found in cache. This is necessary to detect updated Talk setups where the signaling server might have cached capabilities without the v2 token key but the clients are trying to connect with a hello v2 token. Fetch updated capabilities in such cases (but throttle to about one invalidation per minute). --- capabilities.go | 121 +++++++++++++++++++++++++---------------- capabilities_test.go | 125 +++++++++++++++++++++++++++++++++++++++---- hub.go | 12 ++++- hub_test.go | 59 +++++++++++++++++++- room_ping.go | 4 +- 5 files changed, 262 insertions(+), 59 deletions(-) diff --git a/capabilities.go b/capabilities.go index 491bd7a..12914bc 100644 --- a/capabilities.go +++ b/capabilities.go @@ -43,8 +43,14 @@ const ( // Cache received capabilities for one hour. CapabilitiesCacheDuration = time.Hour + + // Don't invalidate more than once per minute. + maxInvalidateInterval = time.Minute ) +// Can be overwritten by tests. +var getCapabilitiesNow = time.Now + type capabilitiesEntry struct { nextUpdate time.Time capabilities map[string]interface{} @@ -53,16 +59,18 @@ type capabilitiesEntry struct { type Capabilities struct { mu sync.RWMutex - version string - pool *HttpClientPool - entries map[string]*capabilitiesEntry + version string + pool *HttpClientPool + entries map[string]*capabilitiesEntry + nextInvalidate map[string]time.Time } func NewCapabilities(version string, pool *HttpClientPool) (*Capabilities, error) { result := &Capabilities{ - version: version, - pool: pool, - entries: make(map[string]*capabilitiesEntry), + version: version, + pool: pool, + entries: make(map[string]*capabilitiesEntry), + nextInvalidate: make(map[string]time.Time), } return result, nil @@ -86,7 +94,7 @@ func (c *Capabilities) getCapabilities(key string) (map[string]interface{}, bool c.mu.RLock() defer c.mu.RUnlock() - now := time.Now() + now := getCapabilitiesNow() if entry, found := c.entries[key]; found && entry.nextUpdate.After(now) { return entry.capabilities, true } @@ -95,7 +103,7 @@ func (c *Capabilities) getCapabilities(key string) (map[string]interface{}, bool } func (c *Capabilities) setCapabilities(key string, capabilities map[string]interface{}) { - now := time.Now() + now := getCapabilitiesNow() entry := &capabilitiesEntry{ nextUpdate: now.Add(CapabilitiesCacheDuration), capabilities: capabilities, @@ -106,11 +114,28 @@ func (c *Capabilities) setCapabilities(key string, capabilities map[string]inter c.entries[key] = entry } -func (c *Capabilities) loadCapabilities(ctx context.Context, u *url.URL) (map[string]interface{}, error) { - key := u.String() +func (c *Capabilities) invalidateCapabilities(key string) { + c.mu.Lock() + defer c.mu.Unlock() + now := getCapabilitiesNow() + if entry, found := c.nextInvalidate[key]; found && entry.After(now) { + return + } + + delete(c.entries, key) + c.nextInvalidate[key] = now.Add(maxInvalidateInterval) +} + +func (c *Capabilities) getKeyForUrl(u *url.URL) string { + key := u.String() + return key +} + +func (c *Capabilities) loadCapabilities(ctx context.Context, u *url.URL) (map[string]interface{}, bool, error) { + key := c.getKeyForUrl(u) if caps, found := c.getCapabilities(key); found { - return caps, nil + return caps, true, nil } capUrl := *u @@ -128,14 +153,14 @@ func (c *Capabilities) loadCapabilities(ctx context.Context, u *url.URL) (map[st client, pool, err := c.pool.Get(ctx, &capUrl) if err != nil { log.Printf("Could not get client for host %s: %s", capUrl.Host, err) - return nil, err + return nil, false, err } defer pool.Put(client) req, err := http.NewRequestWithContext(ctx, "GET", capUrl.String(), nil) if err != nil { log.Printf("Could not create request to %s: %s", &capUrl, err) - return nil, err + return nil, false, err } req.Header.Set("Accept", "application/json") req.Header.Set("OCS-APIRequest", "true") @@ -143,56 +168,56 @@ func (c *Capabilities) loadCapabilities(ctx context.Context, u *url.URL) (map[st resp, err := client.Do(req) if err != nil { - return nil, err + return nil, false, err } defer resp.Body.Close() ct := resp.Header.Get("Content-Type") if !strings.HasPrefix(ct, "application/json") { log.Printf("Received unsupported content-type from %s: %s (%s)", capUrl.String(), ct, resp.Status) - return nil, ErrUnsupportedContentType + return nil, false, ErrUnsupportedContentType } body, err := io.ReadAll(resp.Body) if err != nil { log.Printf("Could not read response body from %s: %s", capUrl.String(), err) - return nil, err + return nil, false, err } var ocs OcsResponse if err := json.Unmarshal(body, &ocs); err != nil { log.Printf("Could not decode OCS response %s from %s: %s", string(body), capUrl.String(), err) - return nil, err + return nil, false, err } else if ocs.Ocs == nil || ocs.Ocs.Data == nil { log.Printf("Incomplete OCS response %s from %s", string(body), u) - return nil, fmt.Errorf("incomplete OCS response") + return nil, false, fmt.Errorf("incomplete OCS response") } var response CapabilitiesResponse if err := json.Unmarshal(*ocs.Ocs.Data, &response); err != nil { log.Printf("Could not decode OCS response body %s from %s: %s", string(*ocs.Ocs.Data), capUrl.String(), err) - return nil, err + return nil, false, err } capaObj, found := response.Capabilities[AppNameSpreed] if !found || capaObj == nil { log.Printf("No capabilities received for app spreed from %s: %+v", capUrl.String(), response) - return nil, nil + return nil, false, nil } var capa map[string]interface{} if err := json.Unmarshal(*capaObj, &capa); err != nil { log.Printf("Unsupported capabilities received for app spreed from %s: %+v", capUrl.String(), response) - return nil, nil + return nil, false, nil } log.Printf("Received capabilities %+v from %s", capa, capUrl.String()) c.setCapabilities(key, capa) - return capa, nil + return capa, false, nil } func (c *Capabilities) HasCapabilityFeature(ctx context.Context, u *url.URL, feature string) bool { - caps, err := c.loadCapabilities(ctx, u) + caps, _, err := c.loadCapabilities(ctx, u) if err != nil { log.Printf("Could not get capabilities for %s: %s", u, err) return false @@ -217,80 +242,86 @@ func (c *Capabilities) HasCapabilityFeature(ctx context.Context, u *url.URL, fea return false } -func (c *Capabilities) getConfigGroup(ctx context.Context, u *url.URL, group string) (map[string]interface{}, bool) { - caps, err := c.loadCapabilities(ctx, u) +func (c *Capabilities) getConfigGroup(ctx context.Context, u *url.URL, group string) (map[string]interface{}, bool, bool) { + caps, cached, err := c.loadCapabilities(ctx, u) if err != nil { log.Printf("Could not get capabilities for %s: %s", u, err) - return nil, false + return nil, cached, false } configInterface := caps["config"] if configInterface == nil { - return nil, false + return nil, cached, false } config, ok := configInterface.(map[string]interface{}) if !ok { log.Printf("Invalid config mapping received from %s: %+v", u, configInterface) - return nil, false + return nil, cached, false } groupInterface := config[group] if groupInterface == nil { - return nil, false + return nil, cached, false } groupConfig, ok := groupInterface.(map[string]interface{}) if !ok { log.Printf("Invalid group mapping \"%s\" received from %s: %+v", group, u, groupInterface) - return nil, false + return nil, cached, false } - return groupConfig, true + return groupConfig, cached, true } -func (c *Capabilities) GetIntegerConfig(ctx context.Context, u *url.URL, group, key string) (int, bool) { - groupConfig, found := c.getConfigGroup(ctx, u, group) +func (c *Capabilities) GetIntegerConfig(ctx context.Context, u *url.URL, group, key string) (int, bool, bool) { + groupConfig, cached, found := c.getConfigGroup(ctx, u, group) if !found { - return 0, false + return 0, cached, false } value, found := groupConfig[key] if !found { - return 0, false + return 0, cached, false } switch value := value.(type) { case int: - return value, true + return value, cached, true case float32: - return int(value), true + return int(value), cached, true case float64: - return int(value), true + return int(value), cached, true default: log.Printf("Invalid config value for \"%s\" received from %s: %+v", key, u, value) } - return 0, false + return 0, cached, false } -func (c *Capabilities) GetStringConfig(ctx context.Context, u *url.URL, group, key string) (string, bool) { - groupConfig, found := c.getConfigGroup(ctx, u, group) +func (c *Capabilities) GetStringConfig(ctx context.Context, u *url.URL, group, key string) (string, bool, bool) { + groupConfig, cached, found := c.getConfigGroup(ctx, u, group) if !found { - return "", false + return "", cached, false } value, found := groupConfig[key] if !found { - return "", false + return "", cached, false } switch value := value.(type) { case string: - return value, true + return value, cached, true default: log.Printf("Invalid config value for \"%s\" received from %s: %+v", key, u, value) } - return "", false + return "", cached, false +} + +func (c *Capabilities) InvalidateCapabilities(u *url.URL) { + key := c.getKeyForUrl(u) + + c.invalidateCapabilities(key) } diff --git a/capabilities_test.go b/capabilities_test.go index 19eb087..22f653b 100644 --- a/capabilities_test.go +++ b/capabilities_test.go @@ -28,12 +28,14 @@ import ( "net/http/httptest" "net/url" "strings" + "sync/atomic" "testing" + "time" "github.com/gorilla/mux" ) -func NewCapabilitiesForTest(t *testing.T) (*url.URL, *Capabilities) { +func NewCapabilitiesForTestWithCallback(t *testing.T, callback func(*CapabilitiesResponse)) (*url.URL, *Capabilities) { pool, err := NewHttpClientPool(1, false) if err != nil { t.Fatal(err) @@ -84,6 +86,10 @@ func NewCapabilitiesForTest(t *testing.T) (*url.URL, *Capabilities) { }, } + if callback != nil { + callback(response) + } + data, err := json.Marshal(response) if err != nil { t.Errorf("Could not marshal %+v: %s", response, err) @@ -110,6 +116,19 @@ func NewCapabilitiesForTest(t *testing.T) (*url.URL, *Capabilities) { return u, capabilities } +func NewCapabilitiesForTest(t *testing.T) (*url.URL, *Capabilities) { + return NewCapabilitiesForTestWithCallback(t, nil) +} + +func SetCapabilitiesGetNow(t *testing.T, f func() time.Time) { + old := getCapabilitiesNow + t.Cleanup(func() { + getCapabilitiesNow = old + }) + + getCapabilitiesNow = f +} + func TestCapabilities(t *testing.T) { url, capabilities := NewCapabilitiesForTest(t) @@ -124,34 +143,122 @@ func TestCapabilities(t *testing.T) { } expectedString := "bar" - if value, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found { + if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found { t.Error("could not find value for \"foo\"") } else if value != expectedString { t.Errorf("expected value %s, got %s", expectedString, value) + } else if !cached { + t.Errorf("expected cached response") } - if value, found := capabilities.GetStringConfig(ctx, url, "signaling", "baz"); found { + if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "baz"); found { t.Errorf("should not have found value for \"baz\", got %s", value) + } else if !cached { + t.Errorf("expected cached response") } - if value, found := capabilities.GetStringConfig(ctx, url, "signaling", "invalid"); found { + if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "invalid"); found { t.Errorf("should not have found value for \"invalid\", got %s", value) + } else if !cached { + t.Errorf("expected cached response") } - if value, found := capabilities.GetStringConfig(ctx, url, "invalid", "foo"); found { + if value, cached, found := capabilities.GetStringConfig(ctx, url, "invalid", "foo"); found { t.Errorf("should not have found value for \"baz\", got %s", value) + } else if !cached { + t.Errorf("expected cached response") } expectedInt := 42 - if value, found := capabilities.GetIntegerConfig(ctx, url, "signaling", "baz"); !found { + if value, cached, found := capabilities.GetIntegerConfig(ctx, url, "signaling", "baz"); !found { t.Error("could not find value for \"baz\"") } else if value != expectedInt { t.Errorf("expected value %d, got %d", expectedInt, value) + } else if !cached { + t.Errorf("expected cached response") } - if value, found := capabilities.GetIntegerConfig(ctx, url, "signaling", "foo"); found { + if value, cached, found := capabilities.GetIntegerConfig(ctx, url, "signaling", "foo"); found { t.Errorf("should not have found value for \"foo\", got %d", value) + } else if !cached { + t.Errorf("expected cached response") } - if value, found := capabilities.GetIntegerConfig(ctx, url, "signaling", "invalid"); found { + if value, cached, found := capabilities.GetIntegerConfig(ctx, url, "signaling", "invalid"); found { t.Errorf("should not have found value for \"invalid\", got %d", value) + } else if !cached { + t.Errorf("expected cached response") } - if value, found := capabilities.GetIntegerConfig(ctx, url, "invalid", "baz"); found { + if value, cached, found := capabilities.GetIntegerConfig(ctx, url, "invalid", "baz"); found { t.Errorf("should not have found value for \"baz\", got %d", value) + } else if !cached { + t.Errorf("expected cached response") + } +} + +func TestInvalidateCapabilities(t *testing.T) { + var called uint32 + url, capabilities := NewCapabilitiesForTestWithCallback(t, func(cr *CapabilitiesResponse) { + atomic.AddUint32(&called, 1) + }) + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + expectedString := "bar" + if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found { + t.Error("could not find value for \"foo\"") + } else if value != expectedString { + t.Errorf("expected value %s, got %s", expectedString, value) + } else if cached { + t.Errorf("expected direct response") + } + + if value := atomic.LoadUint32(&called); value != 1 { + t.Errorf("expected called %d, got %d", 1, value) + } + + // Invalidating will cause the capabilities to be reloaded. + capabilities.InvalidateCapabilities(url) + + if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found { + t.Error("could not find value for \"foo\"") + } else if value != expectedString { + t.Errorf("expected value %s, got %s", expectedString, value) + } else if cached { + t.Errorf("expected direct response") + } + + if value := atomic.LoadUint32(&called); value != 2 { + t.Errorf("expected called %d, got %d", 2, value) + } + + // Invalidating is throttled to about once per minute. + capabilities.InvalidateCapabilities(url) + + if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found { + t.Error("could not find value for \"foo\"") + } else if value != expectedString { + t.Errorf("expected value %s, got %s", expectedString, value) + } else if !cached { + t.Errorf("expected cached response") + } + + if value := atomic.LoadUint32(&called); value != 2 { + t.Errorf("expected called %d, got %d", 2, value) + } + + // At a later time, invalidating can be done again. + SetCapabilitiesGetNow(t, func() time.Time { + return time.Now().Add(2 * time.Minute) + }) + + capabilities.InvalidateCapabilities(url) + + if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found { + t.Error("could not find value for \"foo\"") + } else if value != expectedString { + t.Errorf("expected value %s, got %s", expectedString, value) + } else if cached { + t.Errorf("expected direct response") + } + + if value := atomic.LoadUint32(&called); value != 3 { + t.Errorf("expected called %d, got %d", 3, value) } } diff --git a/hub.go b/hub.go index 2b982ca..c7dc1af 100644 --- a/hub.go +++ b/hub.go @@ -1063,9 +1063,17 @@ func (h *Hub) processHelloV2(client *Client, message *ClientMessage) (*Backend, ctx, cancel := context.WithTimeout(context.Background(), h.backendTimeout) defer cancel() - keyData, found := h.backend.capabilities.GetStringConfig(ctx, url, ConfigGroupSignaling, ConfigKeyHelloV2TokenKey) + keyData, cached, found := h.backend.capabilities.GetStringConfig(ctx, url, ConfigGroupSignaling, ConfigKeyHelloV2TokenKey) if !found { - return nil, fmt.Errorf("No key found for issuer") + if cached { + // The Nextcloud instance might just have enabled JWT but we probably use + // the cached capabilities without the public key. Make sure to re-fetch. + h.backend.capabilities.InvalidateCapabilities(url) + keyData, _, found = h.backend.capabilities.GetStringConfig(ctx, url, ConfigGroupSignaling, ConfigKeyHelloV2TokenKey) + } + if !found { + return nil, fmt.Errorf("No key found for issuer") + } } key, err := loadKeyFunc([]byte(keyData)) diff --git a/hub_test.go b/hub_test.go index 9e56a7e..abe6903 100644 --- a/hub_test.go +++ b/hub_test.go @@ -697,7 +697,11 @@ func registerBackendHandlerUrl(t *testing.T, router *mux.Router, url string) { if strings.Contains(t.Name(), "MultiRoom") { signaling[ConfigKeySessionPingLimit] = 2 } - if strings.Contains(t.Name(), "V2") { + useV2 := true + if os.Getenv("SKIP_V2_CAPABILITIES") != "" { + useV2 = false + } + if strings.Contains(t.Name(), "V2") && useV2 { key := getPublicAuthToken(t) public, err := x509.MarshalPKIXPublicKey(key) if err != nil { @@ -1060,6 +1064,59 @@ func TestClientHelloV2_ExpiresAtMissing(t *testing.T) { } } +func TestClientHelloV2_CachedCapabilities(t *testing.T) { + for _, algo := range testHelloV2Algorithms { + t.Run(algo, func(t *testing.T) { + hub, _, _, server := CreateHubForTest(t) + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + // Simulate old-style Nextcloud without capabilities for Hello V2. + t.Setenv("SKIP_V2_CAPABILITIES", "1") + + client1 := NewTestClient(t, server, hub) + defer client1.CloseWithBye() + + if err := client1.SendHelloV1(testDefaultUserId + "1"); err != nil { + t.Fatal(err) + } + + hello1, err := client1.RunUntilHello(ctx) + if err != nil { + t.Fatal(err) + } + if hello1.Hello.UserId != testDefaultUserId+"1" { + t.Errorf("Expected \"%s\", got %+v", testDefaultUserId+"1", hello1.Hello) + } + if hello1.Hello.SessionId == "" { + t.Errorf("Expected session id, got %+v", hello1.Hello) + } + + // Simulate updated Nextcloud with capabilities for Hello V2. + t.Setenv("SKIP_V2_CAPABILITIES", "") + + client2 := NewTestClient(t, server, hub) + defer client2.CloseWithBye() + + if err := client2.SendHelloV2(testDefaultUserId + "2"); err != nil { + t.Fatal(err) + } + + hello2, err := client2.RunUntilHello(ctx) + if err != nil { + t.Fatal(err) + } + if hello2.Hello.UserId != testDefaultUserId+"2" { + t.Errorf("Expected \"%s\", got %+v", testDefaultUserId+"2", hello2.Hello) + } + if hello2.Hello.SessionId == "" { + t.Errorf("Expected session id, got %+v", hello2.Hello) + } + }) + } +} + func TestClientHelloWithSpaces(t *testing.T) { hub, _, _, server := CreateHubForTest(t) diff --git a/room_ping.go b/room_ping.go index 2e83fe9..48c301a 100644 --- a/room_ping.go +++ b/room_ping.go @@ -119,7 +119,7 @@ func (p *RoomPing) publishEntries(entries *pingEntries, timeout time.Duration) { ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() - limit, found := p.capabilities.GetIntegerConfig(ctx, entries.url, ConfigGroupSignaling, ConfigKeySessionPingLimit) + limit, _, found := p.capabilities.GetIntegerConfig(ctx, entries.url, ConfigGroupSignaling, ConfigKeySessionPingLimit) if !found || limit <= 0 { // Limit disabled while waiting for the next iteration, fallback to sending // one request per room. @@ -188,7 +188,7 @@ func (p *RoomPing) sendPingsCombined(url *url.URL, entries []BackendPingEntry, l } func (p *RoomPing) SendPings(ctx context.Context, room *Room, url *url.URL, entries []BackendPingEntry) error { - limit, found := p.capabilities.GetIntegerConfig(ctx, url, ConfigGroupSignaling, ConfigKeySessionPingLimit) + limit, _, found := p.capabilities.GetIntegerConfig(ctx, url, ConfigGroupSignaling, ConfigKeySessionPingLimit) if !found || limit <= 0 { // Old-style Nextcloud or session limit not configured. Perform one request // per room. Don't queue to avoid sending all ping requests to old-style